Spaces:
Running
on
L4
Running
on
L4
| import base64 | |
| import io | |
| import json | |
| import queue | |
| import random | |
| import traceback | |
| import wave | |
| from argparse import ArgumentParser | |
| from http import HTTPStatus | |
| from pathlib import Path | |
| from typing import Annotated, Literal, Optional | |
| import librosa | |
| import numpy as np | |
| import pyrootutils | |
| import soundfile as sf | |
| import torch | |
| from kui.asgi import ( | |
| Body, | |
| HTTPException, | |
| HttpView, | |
| JSONResponse, | |
| Kui, | |
| OpenAPI, | |
| StreamResponse, | |
| ) | |
| from kui.asgi.routing import MultimethodRoutes | |
| from loguru import logger | |
| from pydantic import BaseModel, Field | |
| pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) | |
| # from fish_speech.models.vqgan.lit_module import VQGAN | |
| from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture | |
| from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model | |
| from tools.llama.generate import ( | |
| GenerateRequest, | |
| GenerateResponse, | |
| WrappedGenerateResponse, | |
| launch_thread_safe_queue, | |
| ) | |
| from tools.vqgan.inference import load_model as load_decoder_model | |
| def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1): | |
| buffer = io.BytesIO() | |
| with wave.open(buffer, "wb") as wav_file: | |
| wav_file.setnchannels(channels) | |
| wav_file.setsampwidth(bit_depth // 8) | |
| wav_file.setframerate(sample_rate) | |
| wav_header_bytes = buffer.getvalue() | |
| buffer.close() | |
| return wav_header_bytes | |
| # Define utils for web server | |
| async def http_execption_handler(exc: HTTPException): | |
| return JSONResponse( | |
| dict( | |
| statusCode=exc.status_code, | |
| message=exc.content, | |
| error=HTTPStatus(exc.status_code).phrase, | |
| ), | |
| exc.status_code, | |
| exc.headers, | |
| ) | |
| async def other_exception_handler(exc: "Exception"): | |
| traceback.print_exc() | |
| status = HTTPStatus.INTERNAL_SERVER_ERROR | |
| return JSONResponse( | |
| dict(statusCode=status, message=str(exc), error=status.phrase), | |
| status, | |
| ) | |
| def load_audio(reference_audio, sr): | |
| if len(reference_audio) > 255 or not Path(reference_audio).exists(): | |
| try: | |
| audio_data = base64.b64decode(reference_audio) | |
| reference_audio = io.BytesIO(audio_data) | |
| except base64.binascii.Error: | |
| raise ValueError("Invalid path or base64 string") | |
| audio, _ = librosa.load(reference_audio, sr=sr, mono=True) | |
| return audio | |
| def encode_reference(*, decoder_model, reference_audio, enable_reference_audio): | |
| if enable_reference_audio and reference_audio is not None: | |
| # Load audios, and prepare basic info here | |
| reference_audio_content = load_audio( | |
| reference_audio, decoder_model.spec_transform.sample_rate | |
| ) | |
| audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[ | |
| None, None, : | |
| ] | |
| audio_lengths = torch.tensor( | |
| [audios.shape[2]], device=decoder_model.device, dtype=torch.long | |
| ) | |
| logger.info( | |
| f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds" | |
| ) | |
| # VQ Encoder | |
| if isinstance(decoder_model, FireflyArchitecture): | |
| prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0] | |
| logger.info(f"Encoded prompt: {prompt_tokens.shape}") | |
| else: | |
| prompt_tokens = None | |
| logger.info("No reference audio provided") | |
| return prompt_tokens | |
| def decode_vq_tokens( | |
| *, | |
| decoder_model, | |
| codes, | |
| ): | |
| feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device) | |
| logger.info(f"VQ features: {codes.shape}") | |
| if isinstance(decoder_model, FireflyArchitecture): | |
| # VQGAN Inference | |
| return decoder_model.decode( | |
| indices=codes[None], | |
| feature_lengths=feature_lengths, | |
| ).squeeze() | |
| raise ValueError(f"Unknown model type: {type(decoder_model)}") | |
| routes = MultimethodRoutes(base_class=HttpView) | |
| def get_random_paths(base_path, data, speaker, emotion): | |
| if base_path and data and speaker and emotion and (Path(base_path).exists()): | |
| if speaker in data and emotion in data[speaker]: | |
| files = data[speaker][emotion] | |
| lab_files = [f for f in files if f.endswith(".lab")] | |
| wav_files = [f for f in files if f.endswith(".wav")] | |
| if lab_files and wav_files: | |
| selected_lab = random.choice(lab_files) | |
| selected_wav = random.choice(wav_files) | |
| lab_path = Path(base_path) / speaker / emotion / selected_lab | |
| wav_path = Path(base_path) / speaker / emotion / selected_wav | |
| if lab_path.exists() and wav_path.exists(): | |
| return lab_path, wav_path | |
| return None, None | |
| def load_json(json_file): | |
| if not json_file: | |
| logger.info("Not using a json file") | |
| return None | |
| try: | |
| with open(json_file, "r", encoding="utf-8") as file: | |
| data = json.load(file) | |
| except FileNotFoundError: | |
| logger.warning(f"ref json not found: {json_file}") | |
| data = None | |
| except Exception as e: | |
| logger.warning(f"Loading json failed: {e}") | |
| data = None | |
| return data | |
| class InvokeRequest(BaseModel): | |
| text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游." | |
| reference_text: Optional[str] = None | |
| reference_audio: Optional[str] = None | |
| max_new_tokens: int = 1024 | |
| chunk_length: Annotated[int, Field(ge=0, le=500, strict=True)] = 100 | |
| top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 | |
| repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2 | |
| temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 | |
| emotion: Optional[str] = None | |
| format: Literal["wav", "mp3", "flac"] = "wav" | |
| streaming: bool = False | |
| ref_json: Optional[str] = "ref_data.json" | |
| ref_base: Optional[str] = "ref_data" | |
| speaker: Optional[str] = None | |
| def get_content_type(audio_format): | |
| if audio_format == "wav": | |
| return "audio/wav" | |
| elif audio_format == "flac": | |
| return "audio/flac" | |
| elif audio_format == "mp3": | |
| return "audio/mpeg" | |
| else: | |
| return "application/octet-stream" | |
| def inference(req: InvokeRequest): | |
| # Parse reference audio aka prompt | |
| prompt_tokens = None | |
| ref_data = load_json(req.ref_json) | |
| ref_base = req.ref_base | |
| lab_path, wav_path = get_random_paths(ref_base, ref_data, req.speaker, req.emotion) | |
| if lab_path and wav_path: | |
| with open(lab_path, "r", encoding="utf-8") as lab_file: | |
| ref_text = lab_file.read() | |
| req.reference_audio = wav_path | |
| req.reference_text = ref_text | |
| logger.info("ref_path: " + str(wav_path)) | |
| logger.info("ref_text: " + ref_text) | |
| # Parse reference audio aka prompt | |
| prompt_tokens = encode_reference( | |
| decoder_model=decoder_model, | |
| reference_audio=req.reference_audio, | |
| enable_reference_audio=req.reference_audio is not None, | |
| ) | |
| logger.info(f"ref_text: {req.reference_text}") | |
| # LLAMA Inference | |
| request = dict( | |
| device=decoder_model.device, | |
| max_new_tokens=req.max_new_tokens, | |
| text=req.text, | |
| top_p=req.top_p, | |
| repetition_penalty=req.repetition_penalty, | |
| temperature=req.temperature, | |
| compile=args.compile, | |
| iterative_prompt=req.chunk_length > 0, | |
| chunk_length=req.chunk_length, | |
| max_length=2048, | |
| prompt_tokens=prompt_tokens, | |
| prompt_text=req.reference_text, | |
| ) | |
| response_queue = queue.Queue() | |
| llama_queue.put( | |
| GenerateRequest( | |
| request=request, | |
| response_queue=response_queue, | |
| ) | |
| ) | |
| if req.streaming: | |
| yield wav_chunk_header() | |
| segments = [] | |
| while True: | |
| result: WrappedGenerateResponse = response_queue.get() | |
| if result.status == "error": | |
| raise result.response | |
| break | |
| result: GenerateResponse = result.response | |
| if result.action == "next": | |
| break | |
| with torch.autocast( | |
| device_type=decoder_model.device.type, dtype=args.precision | |
| ): | |
| fake_audios = decode_vq_tokens( | |
| decoder_model=decoder_model, | |
| codes=result.codes, | |
| ) | |
| fake_audios = fake_audios.float().cpu().numpy() | |
| if req.streaming: | |
| yield (fake_audios * 32768).astype(np.int16).tobytes() | |
| else: | |
| segments.append(fake_audios) | |
| if req.streaming: | |
| return | |
| if len(segments) == 0: | |
| raise HTTPException( | |
| HTTPStatus.INTERNAL_SERVER_ERROR, | |
| content="No audio generated, please check the input text.", | |
| ) | |
| fake_audios = np.concatenate(segments, axis=0) | |
| yield fake_audios | |
| def auto_rerank_inference(req: InvokeRequest, use_auto_rerank: bool = True): | |
| if not use_auto_rerank: | |
| # 如果不使用 auto_rerank,直接调用原始的 inference 函数 | |
| return inference(req) | |
| zh_model, en_model = load_model() | |
| max_attempts = 5 | |
| best_wer = float("inf") | |
| best_audio = None | |
| for attempt in range(max_attempts): | |
| # 调用原始的 inference 函数 | |
| audio_generator = inference(req) | |
| fake_audios = next(audio_generator) | |
| asr_result = batch_asr( | |
| zh_model if is_chinese(req.text) else en_model, [fake_audios], 44100 | |
| )[0] | |
| wer = calculate_wer(req.text, asr_result["text"]) | |
| if wer <= 0.1 and not asr_result["huge_gap"]: | |
| return fake_audios | |
| if wer < best_wer: | |
| best_wer = wer | |
| best_audio = fake_audios | |
| if attempt == max_attempts - 1: | |
| break | |
| return best_audio | |
| async def inference_async(req: InvokeRequest): | |
| for chunk in inference(req): | |
| yield chunk | |
| async def buffer_to_async_generator(buffer): | |
| yield buffer | |
| async def api_invoke_model( | |
| req: Annotated[InvokeRequest, Body(exclusive=True)], | |
| ): | |
| """ | |
| Invoke model and generate audio | |
| """ | |
| if args.max_text_length > 0 and len(req.text) > args.max_text_length: | |
| raise HTTPException( | |
| HTTPStatus.BAD_REQUEST, | |
| content=f"Text is too long, max length is {args.max_text_length}", | |
| ) | |
| if req.streaming and req.format != "wav": | |
| raise HTTPException( | |
| HTTPStatus.BAD_REQUEST, | |
| content="Streaming only supports WAV format", | |
| ) | |
| if req.streaming: | |
| return StreamResponse( | |
| iterable=inference_async(req), | |
| headers={ | |
| "Content-Disposition": f"attachment; filename=audio.{req.format}", | |
| }, | |
| content_type=get_content_type(req.format), | |
| ) | |
| else: | |
| fake_audios = next(inference(req)) | |
| buffer = io.BytesIO() | |
| sf.write( | |
| buffer, | |
| fake_audios, | |
| decoder_model.spec_transform.sample_rate, | |
| format=req.format, | |
| ) | |
| return StreamResponse( | |
| iterable=buffer_to_async_generator(buffer.getvalue()), | |
| headers={ | |
| "Content-Disposition": f"attachment; filename=audio.{req.format}", | |
| }, | |
| content_type=get_content_type(req.format), | |
| ) | |
| async def api_health(): | |
| """ | |
| Health check | |
| """ | |
| return JSONResponse({"status": "ok"}) | |
| def parse_args(): | |
| parser = ArgumentParser() | |
| parser.add_argument( | |
| "--llama-checkpoint-path", | |
| type=str, | |
| default="checkpoints/fish-speech-1.2-sft", | |
| ) | |
| parser.add_argument( | |
| "--decoder-checkpoint-path", | |
| type=str, | |
| default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth", | |
| ) | |
| parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq") | |
| parser.add_argument("--device", type=str, default="cuda") | |
| parser.add_argument("--half", action="store_true") | |
| parser.add_argument("--compile", action="store_true") | |
| parser.add_argument("--max-text-length", type=int, default=0) | |
| parser.add_argument("--listen", type=str, default="127.0.0.1:8000") | |
| parser.add_argument("--workers", type=int, default=1) | |
| parser.add_argument("--use-auto-rerank", type=bool, default=True) | |
| return parser.parse_args() | |
| # Define Kui app | |
| openapi = OpenAPI( | |
| { | |
| "title": "Fish Speech API", | |
| }, | |
| ).routes | |
| app = Kui( | |
| routes=routes + openapi[1:], # Remove the default route | |
| exception_handlers={ | |
| HTTPException: http_execption_handler, | |
| Exception: other_exception_handler, | |
| }, | |
| cors_config={}, | |
| ) | |
| if __name__ == "__main__": | |
| import threading | |
| import uvicorn | |
| args = parse_args() | |
| args.precision = torch.half if args.half else torch.bfloat16 | |
| logger.info("Loading Llama model...") | |
| llama_queue = launch_thread_safe_queue( | |
| checkpoint_path=args.llama_checkpoint_path, | |
| device=args.device, | |
| precision=args.precision, | |
| compile=args.compile, | |
| ) | |
| logger.info("Llama model loaded, loading VQ-GAN model...") | |
| decoder_model = load_decoder_model( | |
| config_name=args.decoder_config_name, | |
| checkpoint_path=args.decoder_checkpoint_path, | |
| device=args.device, | |
| ) | |
| logger.info("VQ-GAN model loaded, warming up...") | |
| # Dry run to check if the model is loaded correctly and avoid the first-time latency | |
| list( | |
| inference( | |
| InvokeRequest( | |
| text="Hello world.", | |
| reference_text=None, | |
| reference_audio=None, | |
| max_new_tokens=0, | |
| top_p=0.7, | |
| repetition_penalty=1.2, | |
| temperature=0.7, | |
| emotion=None, | |
| format="wav", | |
| ref_base=None, | |
| ref_json=None, | |
| ) | |
| ) | |
| ) | |
| logger.info(f"Warming up done, starting server at http://{args.listen}") | |
| host, port = args.listen.split(":") | |
| uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info") | |