import os import queue import random import time from threading import Thread from typing import Any, Literal, override import fastrtc import gradio as gr import httpx import librosa import numpy as np from api_schema import ( AbortController, AssistantStyle, ChatAudioBytes, ChatRequestBody, ChatResponseItem, ModelNameResponse, PresetOptions, SamplerConfig, TokenizedConversation, TokenizedMessage, ) from webrtc_vad import VADStreamHandler HF_TOKEN = os.getenv("HF_TOKEN") SERVER_LIST = os.getenv("SERVER_LIST") TURN_KEY_ID = os.getenv("TURN_KEY_ID") TURN_KEY_API_TOKEN = os.getenv("TURN_KEY_API_TOKEN") CONCURRENCY_LIMIT = os.getenv("CONCURRENCY_LIMIT") assert SERVER_LIST is not None, "SERVER_LIST environment variable is required." assert TURN_KEY_ID is not None and TURN_KEY_API_TOKEN is not None, ( "TURN_KEY_ID and TURN_KEY_API_TOKEN environment variables are required " ) deployment_server = [ server_url.strip() for server_url in SERVER_LIST.split(",") if server_url.strip() ] assert len(deployment_server) > 0, "SERVER_LIST must contain at least one server URL." default_concurrency_limit = 32 try: concurrency_limit = ( int(CONCURRENCY_LIMIT) if CONCURRENCY_LIMIT is not None else default_concurrency_limit ) except ValueError: concurrency_limit = default_concurrency_limit def chat_server_url(pathname: str = "/") -> httpx.URL: n = len(deployment_server) server_idx = random.randint(0, n - 1) host = deployment_server[server_idx] return httpx.URL(host).join(pathname) def auth_headers() -> dict[str, str]: if HF_TOKEN is None: return {} return {"Authorization": f"Bearer {HF_TOKEN}"} def get_cloudflare_turn_credentials( ttl: int = 3600, # 1 hour ) -> dict[str, Any]: with httpx.Client() as client: response = client.post( f"https://rtc.live.cloudflare.com/v1/turn/keys/{TURN_KEY_ID}/credentials/generate-ice-servers", headers={ "Authorization": f"Bearer {TURN_KEY_API_TOKEN}", "Content-Type": "application/json", }, json={"ttl": ttl}, ) if response.is_success: return response.json() else: raise Exception( f"Failed to get TURN credentials: {response.status_code} {response.text}" ) class ConversationManager: def __init__(self, assistant_style: AssistantStyle | None = None): self.conversation = TokenizedConversation(messages=[]) self.turn = 0 self.assistant_style = assistant_style self.last_access_time = time.monotonic() self.collected_audio_chunks: list[np.ndarray] = [] def new_turn(self): self.turn += 1 self.last_access_time = time.monotonic() return ConversationAbortController(self) def is_idle(self, idle_timeout: float) -> bool: return time.monotonic() - self.last_access_time > idle_timeout def append_audio_chunk(self, audio_chunk: tuple[int, np.ndarray]): sr, audio_data = audio_chunk target_sr = 24000 if sr != target_sr: audio_data = librosa.resample( audio_data.astype(np.float32) / 32768.0, orig_sr=sr, target_sr=target_sr, ) audio_data = (audio_data * 32767.0).astype(np.int16) sr = target_sr if audio_data.ndim > 1: # [channels, samples] -> [samples,] # Not Gradio style audio_data = audio_data.mean(axis=0).astype(np.int16) self.collected_audio_chunks.append(audio_data) def all_collected_audio(self) -> tuple[int, np.ndarray]: sr = 24000 audio_data = np.concatenate(self.collected_audio_chunks) return sr, audio_data def chat( self, chat_id: int, input_audio: tuple[int, np.ndarray], global_sampler_config: SamplerConfig | None = None, local_sampler_config: SamplerConfig | None = None, ): controller = self.new_turn() chat_queue = queue.Queue[ChatResponseItem | None]() def chat_task(): url = chat_server_url("/audio-chat") req = ChatRequestBody( conversation=self.conversation, input_audio=ChatAudioBytes.from_audio(input_audio), assistant_style=self.assistant_style, global_sampler_config=global_sampler_config, local_sampler_config=local_sampler_config, ) first_output = True with httpx.Client() as client: with client.stream( method="POST", url=url, content=req.model_dump_json(), headers={"Content-Type": "application/json", **auth_headers()}, ) as response: if response.status_code != 200: raise RuntimeError(f"Error {response.status_code}") for line in response.iter_lines(): if not controller.is_alive(): print(f"[{chat_id=}] Streaming aborted by user") break if time.monotonic() - consumer_alive_time > 1.0: print(f"[{chat_id=}] Streaming aborted due to inactivity") break if not line.startswith("data: "): continue line = line.removeprefix("data: ") if line.strip() == "[DONE]": print(f"[{chat_id=}] Streaming finished by server") break chunk = ChatResponseItem.model_validate_json(line) if chunk.tokenized_input is not None: self.conversation.messages.append( chunk.tokenized_input, ) if chunk.token_chunk is not None: if first_output: self.conversation.messages.append( TokenizedMessage( role="assistant", content=chunk.token_chunk, ) ) first_output = False else: self.conversation.messages[-1].append( chunk.token_chunk, ) chat_queue.put(chunk) chat_queue.put(None) Thread(target=chat_task, daemon=True).start() while True: consumer_alive_time = time.monotonic() try: item = chat_queue.get(timeout=0.1) if item is None: break yield item self.last_access_time = time.monotonic() except queue.Empty: yield None class ConversationAbortController(AbortController): manager: ConversationManager cur_turn: int | None def __init__(self, manager: ConversationManager): self.manager = manager self.cur_turn = manager.turn @override def is_alive(self) -> bool: return self.manager.turn == self.cur_turn def abort(self) -> None: self.cur_turn = None chat_id_counter = 0 def new_chat_id(): global chat_id_counter chat_id = chat_id_counter chat_id_counter += 1 return chat_id def parse_gradio_audio(gradio_audio: tuple[int, np.ndarray]): sr, audio = gradio_audio if len(audio.shape) > 1: # [samples, channels] -> [channels, samples] audio = audio.T if audio.dtype == np.int32: audio = audio.astype(np.float32) / 2**31 # [samples] or [channels, samples] return sr, audio def main(): print("Starting WebRTC server") conversations: dict[str, ConversationManager] = {} def cleanup_idle_conversations(): idle_timeout = 30 * 60.0 # 30 minutes while True: time.sleep(60) to_delete = [] for webrtc_id, manager in conversations.items(): if manager.is_idle(idle_timeout): to_delete.append(webrtc_id) for webrtc_id in to_delete: print(f"Cleaning up idle conversation {webrtc_id}") del conversations[webrtc_id] Thread(target=cleanup_idle_conversations, daemon=True).start() def get_preset_list(category: Literal["character", "voice"]) -> list[str]: url = chat_server_url(f"/preset/{category}") with httpx.Client() as client: response = client.get(url, headers=auth_headers()) if response.status_code == 200: return PresetOptions.model_validate_json(response.text).options return ["[default]"] def get_model_name() -> str: url = chat_server_url("/model-name") with httpx.Client() as client: response = client.get(url, headers=auth_headers()) if response.status_code == 200: return ModelNameResponse.model_validate_json(response.text).model_name return "unknown" def load_initial_data(): model_name = get_model_name() title = f"Xiaomi MiMo-Audio WebRTC (model: {model_name})" character_choices = get_preset_list("character") voice_choices = get_preset_list("voice") return ( gr.update(value=f"# {title}"), gr.update(choices=character_choices), gr.update(choices=voice_choices), ) def response( input_audio: tuple[int, np.ndarray], webrtc_id: str, preset_character: str | None, preset_voice: str | None, custom_character_prompt: str | None, ): nonlocal conversations if webrtc_id not in conversations: custom_character_prompt = custom_character_prompt.strip() if custom_character_prompt == "": custom_character_prompt = None conversations[webrtc_id] = ConversationManager( assistant_style=AssistantStyle( preset_character=preset_character, custom_character_prompt=custom_character_prompt, preset_voice=preset_voice, ) ) manager = conversations[webrtc_id] sr, audio_data = input_audio chat_id = new_chat_id() print(f"WebRTC {webrtc_id} [{chat_id=}]: Input {audio_data.shape[1] / sr}s") # Record input audio manager.append_audio_chunk(input_audio) output_text = "" status_text = "⌛️ Preparing..." text_active = False audio_active = False collected_audio: tuple[int, np.ndarray] | None = None def additional_outputs(): return fastrtc.AdditionalOutputs( output_text, status_text, collected_audio, ) yield additional_outputs() try: for chunk in manager.chat( chat_id, input_audio, ): if chunk is None: # Test if consumer is still alive yield None continue if chunk.text_chunk is not None: text_active = True output_text += chunk.text_chunk if chunk.end_of_transcription: text_active = False if chunk.audio_chunk is not None: audio_active = True audio = chunk.audio_chunk.to_audio() manager.append_audio_chunk(audio) yield audio if chunk.end_of_stream: audio_active = False if text_active and audio_active: status_text = "💬+🔊 Mixed" elif text_active: status_text = "💬 Text" elif audio_active: status_text = "🔊 Audio" if chunk.stop_reason is not None: status_text = f"✅ Finished: {chunk.stop_reason}" yield additional_outputs() except RuntimeError as e: status_text = f"❌ Error: {e}" yield additional_outputs() collected_audio = manager.all_collected_audio() yield additional_outputs() title = "Xiaomi MiMo-Audio WebRTC" with gr.Blocks(title=title) as demo: title_markdown = gr.Markdown(f"# {title}") with gr.Row(): with gr.Column(): chat = fastrtc.WebRTC( label="WebRTC Chat", modality="audio", mode="send-receive", full_screen=False, rtc_configuration=get_cloudflare_turn_credentials, ) output_text = gr.Textbox(label="Output", lines=3, interactive=False) status_text = gr.Textbox(label="Status", lines=1, interactive=False) with gr.Accordion("Advanced", open=True): collected_audio = gr.Audio( label="Full Audio", type="numpy", format="wav", interactive=False, ) with gr.Column(): with gr.Accordion("Settings Help"): gr.Markdown( "- `Preset Prompt` controls the response style.\n" "- `Preset Voice` controls the speaking tone.\n" "- `Custom Prompt` lets you define the response style in natural language (overrides `Preset Prompt`).\n" "- For best results, choose prompts and voices that **match your language**. The default settings are optimized for **English**.\n" "- To apply new settings, end the current conversation and start a new one." ) preset_character_dropdown = gr.Dropdown( label="😊 Preset Prompt", choices=["[default]"], ) preset_voice_dropdown = gr.Dropdown( label="🎤 Preset Voice", choices=["[default]"], ) custom_character_prompt = gr.Textbox( label="🛠️ Custom Prompt", placeholder="For example: You are Xiaomi MiMo-Audio, a large language model trained by Xiaomi. You are chatting with a user over voice.", lines=2, interactive=True, ) chat.stream( VADStreamHandler(response), inputs=[ chat, preset_character_dropdown, preset_voice_dropdown, custom_character_prompt, ], concurrency_limit=concurrency_limit, outputs=[chat], ) chat.on_additional_outputs( lambda *args: args, outputs=[output_text, status_text, collected_audio], concurrency_limit=concurrency_limit, show_progress="hidden", ) demo.load( load_initial_data, inputs=[], outputs=[title_markdown, preset_character_dropdown, preset_voice_dropdown], ) demo.launch(server_name="0.0.0.0", server_port=8087, show_api=False) if __name__ == "__main__": main()