Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import os | |
| import queue | |
| import random | |
| import time | |
| from threading import Thread | |
| from typing import Any, Literal, override | |
| import fastrtc | |
| import gradio as gr | |
| import httpx | |
| import librosa | |
| import numpy as np | |
| from api_schema import ( | |
| AbortController, | |
| AssistantStyle, | |
| ChatAudioBytes, | |
| ChatRequestBody, | |
| ChatResponseItem, | |
| ModelNameResponse, | |
| PresetOptions, | |
| SamplerConfig, | |
| TokenizedConversation, | |
| TokenizedMessage, | |
| ) | |
| from webrtc_vad import VADStreamHandler | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| SERVER_LIST = os.getenv("SERVER_LIST") | |
| TURN_KEY_ID = os.getenv("TURN_KEY_ID") | |
| TURN_KEY_API_TOKEN = os.getenv("TURN_KEY_API_TOKEN") | |
| CONCURRENCY_LIMIT = os.getenv("CONCURRENCY_LIMIT") | |
| assert SERVER_LIST is not None, "SERVER_LIST environment variable is required." | |
| assert TURN_KEY_ID is not None and TURN_KEY_API_TOKEN is not None, ( | |
| "TURN_KEY_ID and TURN_KEY_API_TOKEN environment variables are required " | |
| ) | |
| deployment_server = [ | |
| server_url.strip() for server_url in SERVER_LIST.split(",") if server_url.strip() | |
| ] | |
| assert len(deployment_server) > 0, "SERVER_LIST must contain at least one server URL." | |
| default_concurrency_limit = 32 | |
| try: | |
| concurrency_limit = ( | |
| int(CONCURRENCY_LIMIT) | |
| if CONCURRENCY_LIMIT is not None | |
| else default_concurrency_limit | |
| ) | |
| except ValueError: | |
| concurrency_limit = default_concurrency_limit | |
| def chat_server_url(pathname: str = "/") -> httpx.URL: | |
| n = len(deployment_server) | |
| server_idx = random.randint(0, n - 1) | |
| host = deployment_server[server_idx] | |
| return httpx.URL(host).join(pathname) | |
| def auth_headers() -> dict[str, str]: | |
| if HF_TOKEN is None: | |
| return {} | |
| return {"Authorization": f"Bearer {HF_TOKEN}"} | |
| def get_cloudflare_turn_credentials( | |
| ttl: int = 3600, # 1 hour | |
| ) -> dict[str, Any]: | |
| with httpx.Client() as client: | |
| response = client.post( | |
| f"https://rtc.live.cloudflare.com/v1/turn/keys/{TURN_KEY_ID}/credentials/generate-ice-servers", | |
| headers={ | |
| "Authorization": f"Bearer {TURN_KEY_API_TOKEN}", | |
| "Content-Type": "application/json", | |
| }, | |
| json={"ttl": ttl}, | |
| ) | |
| if response.is_success: | |
| return response.json() | |
| else: | |
| raise Exception( | |
| f"Failed to get TURN credentials: {response.status_code} {response.text}" | |
| ) | |
| class ConversationManager: | |
| def __init__(self, assistant_style: AssistantStyle | None = None): | |
| self.conversation = TokenizedConversation(messages=[]) | |
| self.turn = 0 | |
| self.assistant_style = assistant_style | |
| self.last_access_time = time.monotonic() | |
| self.collected_audio_chunks: list[np.ndarray] = [] | |
| def new_turn(self): | |
| self.turn += 1 | |
| self.last_access_time = time.monotonic() | |
| return ConversationAbortController(self) | |
| def is_idle(self, idle_timeout: float) -> bool: | |
| return time.monotonic() - self.last_access_time > idle_timeout | |
| def append_audio_chunk(self, audio_chunk: tuple[int, np.ndarray]): | |
| sr, audio_data = audio_chunk | |
| target_sr = 24000 | |
| if sr != target_sr: | |
| audio_data = librosa.resample( | |
| audio_data.astype(np.float32) / 32768.0, | |
| orig_sr=sr, | |
| target_sr=target_sr, | |
| ) | |
| audio_data = (audio_data * 32767.0).astype(np.int16) | |
| sr = target_sr | |
| if audio_data.ndim > 1: | |
| # [channels, samples] -> [samples,] | |
| # Not Gradio style | |
| audio_data = audio_data.mean(axis=0).astype(np.int16) | |
| self.collected_audio_chunks.append(audio_data) | |
| def all_collected_audio(self) -> tuple[int, np.ndarray]: | |
| sr = 24000 | |
| audio_data = np.concatenate(self.collected_audio_chunks) | |
| return sr, audio_data | |
| def chat( | |
| self, | |
| chat_id: int, | |
| input_audio: tuple[int, np.ndarray], | |
| global_sampler_config: SamplerConfig | None = None, | |
| local_sampler_config: SamplerConfig | None = None, | |
| ): | |
| controller = self.new_turn() | |
| chat_queue = queue.Queue[ChatResponseItem | None]() | |
| def chat_task(): | |
| url = chat_server_url("/audio-chat") | |
| req = ChatRequestBody( | |
| conversation=self.conversation, | |
| input_audio=ChatAudioBytes.from_audio(input_audio), | |
| assistant_style=self.assistant_style, | |
| global_sampler_config=global_sampler_config, | |
| local_sampler_config=local_sampler_config, | |
| ) | |
| first_output = True | |
| with httpx.Client() as client: | |
| with client.stream( | |
| method="POST", | |
| url=url, | |
| content=req.model_dump_json(), | |
| headers={"Content-Type": "application/json", **auth_headers()}, | |
| ) as response: | |
| if response.status_code != 200: | |
| raise RuntimeError(f"Error {response.status_code}") | |
| for line in response.iter_lines(): | |
| if not controller.is_alive(): | |
| print(f"[{chat_id=}] Streaming aborted by user") | |
| break | |
| if time.monotonic() - consumer_alive_time > 1.0: | |
| print(f"[{chat_id=}] Streaming aborted due to inactivity") | |
| break | |
| if not line.startswith("data: "): | |
| continue | |
| line = line.removeprefix("data: ") | |
| if line.strip() == "[DONE]": | |
| print(f"[{chat_id=}] Streaming finished by server") | |
| break | |
| chunk = ChatResponseItem.model_validate_json(line) | |
| if chunk.tokenized_input is not None: | |
| self.conversation.messages.append( | |
| chunk.tokenized_input, | |
| ) | |
| if chunk.token_chunk is not None: | |
| if first_output: | |
| self.conversation.messages.append( | |
| TokenizedMessage( | |
| role="assistant", | |
| content=chunk.token_chunk, | |
| ) | |
| ) | |
| first_output = False | |
| else: | |
| self.conversation.messages[-1].append( | |
| chunk.token_chunk, | |
| ) | |
| chat_queue.put(chunk) | |
| chat_queue.put(None) | |
| Thread(target=chat_task, daemon=True).start() | |
| while True: | |
| consumer_alive_time = time.monotonic() | |
| try: | |
| item = chat_queue.get(timeout=0.1) | |
| if item is None: | |
| break | |
| yield item | |
| self.last_access_time = time.monotonic() | |
| except queue.Empty: | |
| yield None | |
| class ConversationAbortController(AbortController): | |
| manager: ConversationManager | |
| cur_turn: int | None | |
| def __init__(self, manager: ConversationManager): | |
| self.manager = manager | |
| self.cur_turn = manager.turn | |
| def is_alive(self) -> bool: | |
| return self.manager.turn == self.cur_turn | |
| def abort(self) -> None: | |
| self.cur_turn = None | |
| chat_id_counter = 0 | |
| def new_chat_id(): | |
| global chat_id_counter | |
| chat_id = chat_id_counter | |
| chat_id_counter += 1 | |
| return chat_id | |
| def parse_gradio_audio(gradio_audio: tuple[int, np.ndarray]): | |
| sr, audio = gradio_audio | |
| if len(audio.shape) > 1: | |
| # [samples, channels] -> [channels, samples] | |
| audio = audio.T | |
| if audio.dtype == np.int32: | |
| audio = audio.astype(np.float32) / 2**31 | |
| # [samples] or [channels, samples] | |
| return sr, audio | |
| def main(): | |
| print("Starting WebRTC server") | |
| conversations: dict[str, ConversationManager] = {} | |
| def cleanup_idle_conversations(): | |
| idle_timeout = 30 * 60.0 # 30 minutes | |
| while True: | |
| time.sleep(60) | |
| to_delete = [] | |
| for webrtc_id, manager in conversations.items(): | |
| if manager.is_idle(idle_timeout): | |
| to_delete.append(webrtc_id) | |
| for webrtc_id in to_delete: | |
| print(f"Cleaning up idle conversation {webrtc_id}") | |
| del conversations[webrtc_id] | |
| Thread(target=cleanup_idle_conversations, daemon=True).start() | |
| def get_preset_list(category: Literal["character", "voice"]) -> list[str]: | |
| url = chat_server_url(f"/preset/{category}") | |
| with httpx.Client() as client: | |
| response = client.get(url, headers=auth_headers()) | |
| if response.status_code == 200: | |
| return PresetOptions.model_validate_json(response.text).options | |
| return ["[default]"] | |
| def get_model_name() -> str: | |
| url = chat_server_url("/model-name") | |
| with httpx.Client() as client: | |
| response = client.get(url, headers=auth_headers()) | |
| if response.status_code == 200: | |
| return ModelNameResponse.model_validate_json(response.text).model_name | |
| return "unknown" | |
| def load_initial_data(): | |
| model_name = get_model_name() | |
| title = f"Xiaomi MiMo-Audio WebRTC (model: {model_name})" | |
| character_choices = get_preset_list("character") | |
| voice_choices = get_preset_list("voice") | |
| return ( | |
| gr.update(value=f"# {title}"), | |
| gr.update(choices=character_choices), | |
| gr.update(choices=voice_choices), | |
| ) | |
| def response( | |
| input_audio: tuple[int, np.ndarray], | |
| webrtc_id: str, | |
| preset_character: str | None, | |
| preset_voice: str | None, | |
| custom_character_prompt: str | None, | |
| ): | |
| nonlocal conversations | |
| if webrtc_id not in conversations: | |
| custom_character_prompt = custom_character_prompt.strip() | |
| if custom_character_prompt == "": | |
| custom_character_prompt = None | |
| conversations[webrtc_id] = ConversationManager( | |
| assistant_style=AssistantStyle( | |
| preset_character=preset_character, | |
| custom_character_prompt=custom_character_prompt, | |
| preset_voice=preset_voice, | |
| ) | |
| ) | |
| manager = conversations[webrtc_id] | |
| sr, audio_data = input_audio | |
| chat_id = new_chat_id() | |
| print(f"WebRTC {webrtc_id} [{chat_id=}]: Input {audio_data.shape[1] / sr}s") | |
| # Record input audio | |
| manager.append_audio_chunk(input_audio) | |
| output_text = "" | |
| status_text = "βοΈ Preparing..." | |
| text_active = False | |
| audio_active = False | |
| collected_audio: tuple[int, np.ndarray] | None = None | |
| def additional_outputs(): | |
| return fastrtc.AdditionalOutputs( | |
| output_text, | |
| status_text, | |
| collected_audio, | |
| ) | |
| yield additional_outputs() | |
| try: | |
| for chunk in manager.chat( | |
| chat_id, | |
| input_audio, | |
| ): | |
| if chunk is None: | |
| # Test if consumer is still alive | |
| yield None | |
| continue | |
| if chunk.text_chunk is not None: | |
| text_active = True | |
| output_text += chunk.text_chunk | |
| if chunk.end_of_transcription: | |
| text_active = False | |
| if chunk.audio_chunk is not None: | |
| audio_active = True | |
| audio = chunk.audio_chunk.to_audio() | |
| manager.append_audio_chunk(audio) | |
| yield audio | |
| if chunk.end_of_stream: | |
| audio_active = False | |
| if text_active and audio_active: | |
| status_text = "π¬+π Mixed" | |
| elif text_active: | |
| status_text = "π¬ Text" | |
| elif audio_active: | |
| status_text = "π Audio" | |
| if chunk.stop_reason is not None: | |
| status_text = f"β Finished: {chunk.stop_reason}" | |
| yield additional_outputs() | |
| except RuntimeError as e: | |
| status_text = f"β Error: {e}" | |
| yield additional_outputs() | |
| collected_audio = manager.all_collected_audio() | |
| yield additional_outputs() | |
| title = "Xiaomi MiMo-Audio WebRTC" | |
| with gr.Blocks(title=title) as demo: | |
| title_markdown = gr.Markdown(f"# {title}") | |
| with gr.Row(): | |
| with gr.Column(): | |
| chat = fastrtc.WebRTC( | |
| label="WebRTC Chat", | |
| modality="audio", | |
| mode="send-receive", | |
| full_screen=False, | |
| rtc_configuration=get_cloudflare_turn_credentials, | |
| ) | |
| output_text = gr.Textbox(label="Output", lines=3, interactive=False) | |
| status_text = gr.Textbox(label="Status", lines=1, interactive=False) | |
| with gr.Accordion("Advanced", open=True): | |
| collected_audio = gr.Audio( | |
| label="Full Audio", | |
| type="numpy", | |
| format="wav", | |
| interactive=False, | |
| ) | |
| with gr.Column(): | |
| with gr.Accordion("Settings Help"): | |
| gr.Markdown( | |
| "- `Preset Prompt` controls the response style.\n" | |
| "- `Preset Voice` controls the speaking tone.\n" | |
| "- `Custom Prompt` lets you define the response style in natural language (overrides `Preset Prompt`).\n" | |
| "- For best results, choose prompts and voices that **match your language**. The default settings are optimized for **English**.\n" | |
| "- To apply new settings, end the current conversation and start a new one." | |
| ) | |
| preset_character_dropdown = gr.Dropdown( | |
| label="π Preset Prompt", | |
| choices=["[default]"], | |
| ) | |
| preset_voice_dropdown = gr.Dropdown( | |
| label="π€ Preset Voice", | |
| choices=["[default]"], | |
| ) | |
| custom_character_prompt = gr.Textbox( | |
| label="π οΈ Custom Prompt", | |
| placeholder="For example: You are Xiaomi MiMo-Audio, a large language model trained by Xiaomi. You are chatting with a user over voice.", | |
| lines=2, | |
| interactive=True, | |
| ) | |
| chat.stream( | |
| VADStreamHandler(response), | |
| inputs=[ | |
| chat, | |
| preset_character_dropdown, | |
| preset_voice_dropdown, | |
| custom_character_prompt, | |
| ], | |
| concurrency_limit=concurrency_limit, | |
| outputs=[chat], | |
| ) | |
| chat.on_additional_outputs( | |
| lambda *args: args, | |
| outputs=[output_text, status_text, collected_audio], | |
| concurrency_limit=concurrency_limit, | |
| show_progress="hidden", | |
| ) | |
| demo.load( | |
| load_initial_data, | |
| inputs=[], | |
| outputs=[title_markdown, preset_character_dropdown, preset_voice_dropdown], | |
| ) | |
| demo.launch(server_name="0.0.0.0", server_port=8087, show_api=False) | |
| if __name__ == "__main__": | |
| main() | |