mimo_audio_chat / app.py
Corle-heyongzhe's picture
Update app.py
6a06b35 verified
raw
history blame
16.7 kB
import argparse
import queue
import time
from threading import Thread
from typing import Literal, override
import os
import fastrtc
from fastrtc import get_cloudflare_turn_credentials_async
import gradio as gr
import httpx
import numpy as np
from pydantic import BaseModel
import random
from api_schema import (
AbortController,
AssistantStyle,
ChatAudioBytes,
ChatRequestBody,
ChatResponseItem,
ModelNameResponse,
PresetOptions,
SamplerConfig,
TokenizedConversation,
TokenizedMessage,
)
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN is None:
print(
"⚠️ [WARNING] HF_TOKEN environment variable not found.\n"
"WebRTC connections may fail on Hugging Face Spaces because TURN service cannot be used.\n"
"πŸ’‘ Solution: Go to your Hugging Face Space β†’ Settings β†’ Secrets, "
"add a variable named HF_TOKEN or HF_ACCESS_TOKEN with your personal access token (with at least 'read' permission)."
)
else:
print("βœ… [INFO] HF_TOKEN detected. WebRTC will use Hugging Face TURN service for connectivity.")
url_prefix = os.getenv("URL_PREFIX")
server_number = int(os.getenv("NUM_SERVER"))
deployment_server = []
for i in range(1, server_number+1):
url = url_prefix + str(i) + ".hf.space"
deployment_server.append(url)
class Args(BaseModel):
host: str
port: int
concurrency_limit: int
share: bool
debug: bool
chat_server: str
tag: str | None = None
@classmethod
def parse_args(cls):
parser = argparse.ArgumentParser(description="Xiaomi MiMo-Audio Chat")
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8087)
parser.add_argument("--concurrency-limit", type=int, default=40)
parser.add_argument("--share", action="store_true")
parser.add_argument("--debug", action="store_true")
parser.add_argument(
"-S",
"--chat-server",
dest="chat_server",
type=str,
default="deployment_docker_1",
)
parser.add_argument("--tag", type=str)
args = parser.parse_args()
return cls.model_validate(vars(args))
def chat_server_url(self):
return deployment_server[random.randint(0,server_number-1)]
# if self.chat_server in global_chat_server_map:
# return global_chat_server_map[self.chat_server]
# return self.chat_server
class ConversationManager:
def __init__(self, assistant_style: AssistantStyle | None = None):
self.conversation = TokenizedConversation(messages=[])
self.turn = 0
self.assistant_style = assistant_style
self.last_access_time = time.monotonic()
self.collected_audio_chunks: list[np.ndarray] = []
def new_turn(self):
self.turn += 1
self.last_access_time = time.monotonic()
return ConversationAbortController(self)
def is_idle(self, idle_timeout: float) -> bool:
return time.monotonic() - self.last_access_time > idle_timeout
def append_audio_chunk(self, audio_chunk: tuple[int, np.ndarray]):
sr, audio_data = audio_chunk
assert sr == 24000, "Only 24kHz audio is supported"
if audio_data.ndim > 1:
# [channels, samples] -> [samples,]
# Not Gradio style
audio_data = audio_data.mean(axis=0).astype(np.int16)
self.collected_audio_chunks.append(audio_data)
def all_collected_audio(self) -> tuple[int, np.ndarray]:
sr = 24000
audio_data = np.concatenate(self.collected_audio_chunks)
return sr, audio_data
def chat(
self,
url: httpx.URL,
chat_id: int,
input_audio: tuple[int, np.ndarray],
global_sampler_config: SamplerConfig | None = None,
local_sampler_config: SamplerConfig | None = None,
):
controller = self.new_turn()
chat_queue = queue.Queue[ChatResponseItem | None]()
def chat_task():
req = ChatRequestBody(
conversation=self.conversation,
input_audio=ChatAudioBytes.from_audio(input_audio),
assistant_style=self.assistant_style,
global_sampler_config=global_sampler_config,
local_sampler_config=local_sampler_config,
)
first_output = True
with httpx.Client() as client:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {HF_TOKEN}", # <-- εŠ θΏ™δΈ€θ‘Œ
}
with client.stream(
method="POST",
url=url,
content=req.model_dump_json(),
headers=headers,
) as response:
if response.status_code != 200:
raise RuntimeError(f"Error {response.status_code}")
for line in response.iter_lines():
if not controller.is_alive():
print(f"[{chat_id=}] Streaming aborted by user")
break
if time.monotonic() - consumer_alive_time > 1.0:
print(f"[{chat_id=}] Streaming aborted due to inactivity")
break
if not line.startswith("data: "):
continue
line = line.removeprefix("data: ")
if line.strip() == "[DONE]":
print(f"[{chat_id=}] Streaming finished by server")
break
chunk = ChatResponseItem.model_validate_json(line)
if chunk.tokenized_input is not None:
self.conversation.messages.append(
chunk.tokenized_input,
)
if chunk.token_chunk is not None:
if first_output:
self.conversation.messages.append(
TokenizedMessage(
role="assistant",
content=chunk.token_chunk,
)
)
first_output = False
else:
self.conversation.messages[-1].append(
chunk.token_chunk,
)
chat_queue.put(chunk)
chat_queue.put(None)
Thread(target=chat_task, daemon=True).start()
while True:
consumer_alive_time = time.monotonic()
try:
item = chat_queue.get(timeout=0.1)
if item is None:
break
yield item
self.last_access_time = time.monotonic()
except queue.Empty:
yield None
class ConversationAbortController(AbortController):
manager: ConversationManager
cur_turn: int | None
def __init__(self, manager: ConversationManager):
self.manager = manager
self.cur_turn = manager.turn
@override
def is_alive(self) -> bool:
return self.manager.turn == self.cur_turn
def abort(self) -> None:
self.cur_turn = None
chat_id_counter = 0
def new_chat_id():
global chat_id_counter
chat_id = chat_id_counter
chat_id_counter += 1
return chat_id
def main():
args = Args.parse_args()
print("Starting WebRTC server")
conversations: dict[str, ConversationManager] = {}
def cleanup_idle_conversations():
idle_timeout = 30 * 60.0 # 30 minutes
while True:
time.sleep(60)
to_delete = []
for webrtc_id, manager in conversations.items():
if manager.is_idle(idle_timeout):
to_delete.append(webrtc_id)
for webrtc_id in to_delete:
print(f"Cleaning up idle conversation {webrtc_id}")
del conversations[webrtc_id]
Thread(target=cleanup_idle_conversations, daemon=True).start()
def get_preset_list(category: Literal["character", "voice"]) -> list[str]:
url = httpx.URL(args.chat_server_url()).join(f"/preset/{category}")
headers = {
"Authorization": f"Bearer {HF_TOKEN}" # <-- 加上 token
}
with httpx.Client() as client:
response = client.get(url, headers=headers)
if response.status_code == 200:
return PresetOptions.model_validate_json(response.text).options
return ["[default]"]
def get_model_name() -> str:
url = httpx.URL(args.chat_server_url()).join("/model-name")
headers = {
"Authorization": f"Bearer {HF_TOKEN}" # <-- 加上 token
}
with httpx.Client() as client:
response = client.get(url, headers=headers)
if response.status_code == 200:
return ModelNameResponse.model_validate_json(response.text).model_name
return "unknown"
def load_initial_data():
model_name = get_model_name()
title = f"Xiaomi MiMo-Audio WebRTC (model: {model_name})"
if args.tag is not None:
title = f"{args.tag} - {title}"
character_choices = get_preset_list("character")
voice_choices = get_preset_list("voice")
return (
gr.update(value=f"# {title}"),
gr.update(choices=character_choices),
gr.update(choices=voice_choices),
)
def response(
input_audio: tuple[int, np.ndarray],
webrtc_id: str,
preset_character: str | None,
preset_voice: str | None,
custom_character_prompt: str | None,
):
headers = {
"Authorization": f"Bearer {HF_TOKEN}" # <-- 加上 token
}
# deprecate gc
# with httpx.Client() as client:
# client.get(httpx.URL(args.chat_server_url()).join("/gc"), headers=headers)
nonlocal conversations
if webrtc_id not in conversations:
custom_character_prompt = custom_character_prompt.strip()
if custom_character_prompt == "":
custom_character_prompt = None
conversations[webrtc_id] = ConversationManager(
assistant_style=AssistantStyle(
preset_character=preset_character,
custom_character_prompt=custom_character_prompt,
preset_voice=preset_voice,
)
)
manager = conversations[webrtc_id]
sr, audio_data = input_audio
chat_id = new_chat_id()
print(f"WebRTC {webrtc_id} [{chat_id=}]: Input {audio_data.shape[1] / sr}s")
# Record input audio
manager.append_audio_chunk(input_audio)
output_text = ""
status_text = "βŒ›οΈ Preparing..."
text_active = False
audio_active = False
collected_audio: tuple[int, np.ndarray] | None = None
def additional_outputs():
return fastrtc.AdditionalOutputs(
output_text,
status_text,
collected_audio,
)
yield additional_outputs()
try:
url = httpx.URL(args.chat_server_url()).join("/audio-chat")
for chunk in manager.chat(
url,
chat_id,
input_audio,
):
if chunk is None:
# Test if consumer is still alive
yield None
continue
if chunk.text_chunk is not None:
text_active = True
output_text += chunk.text_chunk
if chunk.end_of_transcription:
text_active = False
if chunk.audio_chunk is not None:
audio_active = True
audio = chunk.audio_chunk.to_audio()
manager.append_audio_chunk(audio)
yield audio
if chunk.end_of_stream:
audio_active = False
if text_active and audio_active:
status_text = "πŸ’¬+πŸ”Š Mixed"
elif text_active:
status_text = "πŸ’¬ Text"
elif audio_active:
status_text = "πŸ”Š Audio"
if chunk.stop_reason is not None:
status_text = f"βœ… Finished: {chunk.stop_reason}"
yield additional_outputs()
except RuntimeError as e:
status_text = f"❌ Error: {e}"
yield additional_outputs()
collected_audio = manager.all_collected_audio()
yield additional_outputs()
title = "Xiaomi MiMo-Audio WebRTC"
if args.tag is not None:
title = f"{args.tag} - {title}"
with gr.Blocks(title=title) as demo:
title_markdown = gr.Markdown(f"# {title}")
with gr.Row():
with gr.Column():
chat = fastrtc.WebRTC(
label="WebRTC Chat",
modality="audio",
mode="send-receive",
full_screen=False,
rtc_configuration=get_cloudflare_turn_credentials_async
# server_rtc_configuration=get_hf_turn_credentials(ttl=600 * 1000),
# rtc_configuration=get_hf_turn_credentials,
)
output_text = gr.Textbox(label="Output", lines=3, interactive=False)
status_text = gr.Textbox(label="Status", lines=1, interactive=False)
with gr.Accordion("Advanced", open=False):
collected_audio = gr.Audio(
label="Full Audio",
type="numpy",
format="wav",
interactive=False,
)
with gr.Column():
with gr.Accordion("Settings Help"):
gr.Markdown(
"- `Preset Prompt` controls the response style.\n"
"- `Preset Voice` controls the speaking tone.\n"
"- `Custom Prompt` lets you define the response style in natural language (overrides `Preset Prompt`).\n"
"- For best results, choose prompts and voices that match your language.\n"
"- To apply new settings, end the current conversation and start a new one."
)
preset_character_dropdown = gr.Dropdown(
label="😊 Preset Prompt",
choices=["[default]"],
)
preset_voice_dropdown = gr.Dropdown(
label="🎀 Preset Voice",
choices=["[default]"],
)
custom_character_prompt = gr.Textbox(
label="πŸ› οΈ Custom Prompt",
placeholder="For example: You are Xiaomi MiMo-Audio, a large language model trained by Xiaomi. You are chatting with a user over voice.",
lines=2,
interactive=True,
)
chat.stream(
fastrtc.ReplyOnPause(
response,
input_sample_rate=24000,
output_sample_rate=24000,
model_options=fastrtc.SileroVadOptions(
threshold=0.7,
min_silence_duration_ms=1000,
),
),
inputs=[
chat,
preset_character_dropdown,
preset_voice_dropdown,
custom_character_prompt,
],
concurrency_limit=args.concurrency_limit,
outputs=[chat],
)
chat.on_additional_outputs(
lambda *args: args,
outputs=[output_text, status_text, collected_audio],
concurrency_limit=args.concurrency_limit,
show_progress="hidden",
)
demo.load(
load_initial_data,
inputs=[],
outputs=[title_markdown, preset_character_dropdown, preset_voice_dropdown],
)
demo.queue(
default_concurrency_limit=args.concurrency_limit,
)
demo.launch()
if __name__ == "__main__":
main()