Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import soundfile as sf | |
| import spaces | |
| from huggingface_hub import login | |
| from pardi_speech import PardiSpeech, VelocityHeadSamplingParams # présent dans ce repo | |
| MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "theodorr/pardi-speech-enfr-forbidden") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if HF_TOKEN: | |
| try: | |
| login(token=HF_TOKEN) | |
| print("✅ Logged to Hugging Face Hub.") | |
| except Exception as e: | |
| print("⚠️ HF login failed:", e) | |
| _pardi = None | |
| _sampling_rate = 24000 | |
| # --- Patch sécurité: init d'un état FLA par défaut si None au prefill --- | |
| try: | |
| from tts.model.simple_gla import SimpleGLADecoder | |
| _old_prefill = SimpleGLADecoder.prefill | |
| def _prefill_with_default(self, encoder_output, decoder_input, cache=None, crossatt_mask=None): | |
| # Si aucun cache fourni, crée une structure minimale comprise par FLA | |
| if cache is None or (isinstance(cache, dict) and cache.get("last_state") is None): | |
| cache = {"last_state": {"conv_state": (None, None, None)}} | |
| return _old_prefill(self, encoder_output, decoder_input, cache=cache, crossatt_mask=crossatt_mask) | |
| SimpleGLADecoder.prefill = _prefill_with_default | |
| print("🔧 Patched SimpleGLADecoder.prefill (default conv_state)") | |
| except Exception as e: | |
| print("⚠️ FLA prefill patch skipped:", e) | |
| def _normalize_text(s: str, lang_hint: str = "fr") -> str: | |
| s = (s or "").strip().lower() | |
| try: | |
| import re | |
| from num2words import num2words | |
| def repl(m): return num2words(int(m.group()), lang=lang_hint) | |
| s = re.sub(r"\d+", repl, s) | |
| except Exception: | |
| pass | |
| return s | |
| def _load_model(device: str = "cuda"): | |
| global _pardi, _sampling_rate | |
| if _pardi is None: | |
| _pardi = PardiSpeech.from_pretrained(MODEL_REPO_ID, map_location=device) | |
| _sampling_rate = getattr(_pardi, "sampling_rate", 24000) | |
| print(f"✅ PardiSpeech loaded on {device} (sr={_sampling_rate}).") | |
| return _pardi | |
| def _to_mono_float32(arr: np.ndarray) -> np.ndarray: | |
| arr = arr.astype(np.float32) | |
| if arr.ndim == 2: | |
| arr = arr.mean(axis=1) | |
| return arr | |
| def synthesize( | |
| text: str, | |
| ref_audio, | |
| ref_text: str, | |
| steps: int, | |
| cfg: float, | |
| cfg_ref: float, | |
| temperature: float, | |
| max_seq_len: int, | |
| seed: int, | |
| lang_hint: str | |
| ): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch.manual_seed(int(seed)) | |
| pardi = _load_model(device) | |
| txt = _normalize_text(text, lang_hint=lang_hint) | |
| # --- IMPORTANT : signature de VelocityHeadSamplingParams --- | |
| # Dans ton notebook d’inférence, la classe attend (cfg_ref, cfg, num_steps) SANS 'temperature'. | |
| # On essaie d’abord sans temperature, puis fallback si la classe en accepte une. | |
| try: | |
| vel_params = VelocityHeadSamplingParams( | |
| cfg_ref=float(cfg_ref), | |
| cfg=float(cfg), | |
| num_steps=int(steps) | |
| ) | |
| except TypeError: | |
| vel_params = VelocityHeadSamplingParams( | |
| cfg_ref=float(cfg_ref), | |
| cfg=float(cfg), | |
| num_steps=int(steps), | |
| temperature=float(temperature) | |
| ) | |
| # Prefix optionnel | |
| prefix = None | |
| if ref_audio is not None: | |
| if isinstance(ref_audio, str): | |
| wav, sr = sf.read(ref_audio) | |
| else: | |
| sr, wav = ref_audio | |
| wav = _to_mono_float32(np.array(wav)) | |
| wav_t = torch.from_numpy(wav).to(device) | |
| import torchaudio | |
| if sr != pardi.sampling_rate: | |
| wav_t = torchaudio.functional.resample(wav_t, sr, pardi.sampling_rate) | |
| wav_t = wav_t.unsqueeze(0) | |
| with torch.inference_mode(): | |
| prefix_tokens = pardi.patchvae.encode(wav_t) | |
| prefix = (ref_text or "", prefix_tokens[0]) | |
| print(f"[debug] has_prefix={prefix is not None}, steps={steps}, cfg={cfg}, cfg_ref={cfg_ref}, T={temperature}, max_seq_len={max_seq_len}, seed={seed}") | |
| try: | |
| with torch.inference_mode(): | |
| wavs, _ = pardi.text_to_speech( | |
| [txt], | |
| prefix, | |
| max_seq_len=int(max_seq_len), | |
| velocity_head_sampling_params=vel_params, | |
| ) | |
| except Exception as e: | |
| import traceback, sys | |
| print("❌ text_to_speech failed:", e, file=sys.stderr) | |
| traceback.print_exc() | |
| raise gr.Error(f"Synthèse échouée: {type(e).__name__}: {e}") | |
| wav = wavs[0].detach().cpu().numpy() | |
| return (_sampling_rate, wav) | |
| def build_demo(): | |
| with gr.Blocks(title="Lina-speech / pardi-speech Demo") as demo: | |
| gr.Markdown( | |
| "## Lina-speech (pardi-speech) – Démo TTS\n" | |
| "Génère de l'audio à partir de texte, avec ou sans *prefix* (audio de référence).\n" | |
| "Paramètres avancés: *num_steps*, *CFG*, *température*, *max_seq_len*, *seed*." | |
| ) | |
| with gr.Row(): | |
| text = gr.Textbox(label="Texte à synthétiser", lines=4, placeholder="Tape ton texte ici…") | |
| with gr.Accordion("Prefix (optionnel)", open=False): | |
| ref_audio = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Audio de référence") | |
| ref_text = gr.Textbox(label="Texte du prefix (si connu)", placeholder="Transcription du prefix (optionnel)") | |
| with gr.Accordion("Options avancées", open=False): | |
| with gr.Row(): | |
| steps = gr.Slider(1, 50, value=10, step=1, label="num_steps") | |
| cfg = gr.Slider(0.5, 3.0, value=1.4, step=0.05, label="CFG (guidance)") | |
| cfg_ref = gr.Slider(0.5, 3.0, value=1.0, step=0.05, label="CFG (réf.)") | |
| with gr.Row(): | |
| temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="Température") | |
| max_seq_len = gr.Slider(50, 1200, value=300, step=10, label="max_seq_len (tokens audio)") | |
| seed = gr.Number(value=0, precision=0, label="Seed (reproductibilité)") | |
| lang_hint = gr.Dropdown(choices=["fr", "en"], value="fr", label="Langue (normalisation)") | |
| btn = gr.Button("Synthétiser") | |
| out_audio = gr.Audio(label="Sortie audio", type="numpy") | |
| demo.queue(default_concurrency_limit=1, max_size=32) | |
| btn.click( | |
| fn=synthesize, | |
| inputs=[text, ref_audio, ref_text, steps, cfg, cfg_ref, temperature, max_seq_len, seed, lang_hint], | |
| outputs=[out_audio] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_demo() | |
| demo.launch() | |
| # retrigger 2025-10-31T16:46:57+01:00 | |
| # retrigger 2025-10-31T17:27:54+01:00 | |
| # retrigger 2025-10-31T17:29:41+01:00 | |