Spaces:
Running
Running
| from fastapi import Depends, HTTPException, Query | |
| from fastapi.responses import StreamingResponse | |
| import io | |
| from pydantic import BaseModel | |
| import soundfile as sf | |
| from fastapi.responses import FileResponse | |
| from modules.normalization import text_normalize | |
| from modules.api import utils as api_utils | |
| from modules.api.Api import APIManager | |
| from modules.synthesize_audio import synthesize_audio | |
| class TTSParams(BaseModel): | |
| text: str = Query(..., description="Text to synthesize") | |
| spk: str = Query( | |
| "female2", description="Specific speaker by speaker name or speaker seed" | |
| ) | |
| style: str = Query("chat", description="Specific style by style name") | |
| temperature: float = Query( | |
| 0.3, description="Temperature for sampling (may be overridden by style or spk)" | |
| ) | |
| top_P: float = Query( | |
| 0.5, description="Top P for sampling (may be overridden by style or spk)" | |
| ) | |
| top_K: int = Query( | |
| 20, description="Top K for sampling (may be overridden by style or spk)" | |
| ) | |
| seed: int = Query( | |
| 42, description="Seed for generate (may be overridden by style or spk)" | |
| ) | |
| format: str = Query("mp3", description="Response audio format: [mp3,wav]") | |
| prompt1: str = Query("", description="Text prompt for inference") | |
| prompt2: str = Query("", description="Text prompt for inference") | |
| prefix: str = Query("", description="Text prefix for inference") | |
| bs: str = Query("8", description="Batch size for inference") | |
| thr: str = Query("100", description="Threshold for sentence spliter") | |
| async def synthesize_tts(params: TTSParams = Depends()): | |
| try: | |
| # Validate text | |
| if not params.text.strip(): | |
| raise HTTPException( | |
| status_code=422, detail="Text parameter cannot be empty" | |
| ) | |
| # Validate temperature | |
| if not (0 <= params.temperature <= 1): | |
| raise HTTPException( | |
| status_code=422, detail="Temperature must be between 0 and 1" | |
| ) | |
| # Validate top_P | |
| if not (0 <= params.top_P <= 1): | |
| raise HTTPException(status_code=422, detail="top_P must be between 0 and 1") | |
| # Validate top_K | |
| if params.top_K <= 0: | |
| raise HTTPException( | |
| status_code=422, detail="top_K must be a positive integer" | |
| ) | |
| if params.top_K > 100: | |
| raise HTTPException( | |
| status_code=422, detail="top_K must be less than or equal to 100" | |
| ) | |
| # Validate format | |
| if params.format not in ["mp3", "wav"]: | |
| raise HTTPException( | |
| status_code=422, | |
| detail="Invalid format. Supported formats are mp3 and wav", | |
| ) | |
| text = text_normalize(params.text, is_end=False) | |
| calc_params = api_utils.calc_spk_style(spk=params.spk, style=params.style) | |
| spk = calc_params.get("spk", params.spk) | |
| seed = params.seed or calc_params.get("seed", params.seed) | |
| temperature = params.temperature or calc_params.get( | |
| "temperature", params.temperature | |
| ) | |
| prefix = params.prefix or calc_params.get("prefix", params.prefix) | |
| prompt1 = params.prompt1 or calc_params.get("prompt1", params.prompt1) | |
| prompt2 = params.prompt2 or calc_params.get("prompt2", params.prompt2) | |
| batch_size = int(params.bs) | |
| threshold = int(params.thr) | |
| sample_rate, audio_data = synthesize_audio( | |
| text, | |
| temperature=temperature, | |
| top_P=params.top_P, | |
| top_K=params.top_K, | |
| spk=spk, | |
| infer_seed=seed, | |
| prompt1=prompt1, | |
| prompt2=prompt2, | |
| prefix=prefix, | |
| batch_size=batch_size, | |
| spliter_threshold=threshold, | |
| ) | |
| buffer = io.BytesIO() | |
| sf.write(buffer, audio_data, sample_rate, format="wav") | |
| buffer.seek(0) | |
| if format == "mp3": | |
| buffer = api_utils.wav_to_mp3(buffer) | |
| return StreamingResponse(buffer, media_type="audio/wav") | |
| except Exception as e: | |
| import logging | |
| logging.exception(e) | |
| if isinstance(e, HTTPException): | |
| raise e | |
| else: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def setup(api_manager: APIManager): | |
| api_manager.get("/v1/tts", response_class=FileResponse)(synthesize_tts) | |