import os import gradio as gr import numpy as np import torch import soundfile as sf import spaces from huggingface_hub import login from pardi_speech import PardiSpeech, VelocityHeadSamplingParams # présent dans ce repo MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "theodorr/pardi-speech-enfr-forbidden") HF_TOKEN = os.environ.get("HF_TOKEN") if HF_TOKEN: try: login(token=HF_TOKEN) print("✅ Logged to Hugging Face Hub.") except Exception as e: print("⚠️ HF login failed:", e) _pardi = None _sampling_rate = 24000 # --- Patch sécurité: init d'un état FLA par défaut si None au prefill --- try: from tts.model.simple_gla import SimpleGLADecoder _old_prefill = SimpleGLADecoder.prefill def _prefill_with_default(self, encoder_output, decoder_input, cache=None, crossatt_mask=None): # Si aucun cache fourni, crée une structure minimale comprise par FLA if cache is None or (isinstance(cache, dict) and cache.get("last_state") is None): cache = {"last_state": {"conv_state": (None, None, None)}} return _old_prefill(self, encoder_output, decoder_input, cache=cache, crossatt_mask=crossatt_mask) SimpleGLADecoder.prefill = _prefill_with_default print("🔧 Patched SimpleGLADecoder.prefill (default conv_state)") except Exception as e: print("⚠️ FLA prefill patch skipped:", e) def _normalize_text(s: str, lang_hint: str = "fr") -> str: s = (s or "").strip().lower() try: import re from num2words import num2words def repl(m): return num2words(int(m.group()), lang=lang_hint) s = re.sub(r"\d+", repl, s) except Exception: pass return s def _load_model(device: str = "cuda"): global _pardi, _sampling_rate if _pardi is None: _pardi = PardiSpeech.from_pretrained(MODEL_REPO_ID, map_location=device) _sampling_rate = getattr(_pardi, "sampling_rate", 24000) print(f"✅ PardiSpeech loaded on {device} (sr={_sampling_rate}).") return _pardi def _to_mono_float32(arr: np.ndarray) -> np.ndarray: arr = arr.astype(np.float32) if arr.ndim == 2: arr = arr.mean(axis=1) return arr @spaces.GPU(duration=120) def synthesize( text: str, ref_audio, ref_text: str, steps: int, cfg: float, cfg_ref: float, temperature: float, max_seq_len: int, seed: int, lang_hint: str ): device = "cuda" if torch.cuda.is_available() else "cpu" torch.manual_seed(int(seed)) pardi = _load_model(device) txt = _normalize_text(text, lang_hint=lang_hint) # --- IMPORTANT : signature de VelocityHeadSamplingParams --- # Dans ton notebook d’inférence, la classe attend (cfg_ref, cfg, num_steps) SANS 'temperature'. # On essaie d’abord sans temperature, puis fallback si la classe en accepte une. try: vel_params = VelocityHeadSamplingParams( cfg_ref=float(cfg_ref), cfg=float(cfg), num_steps=int(steps) ) except TypeError: vel_params = VelocityHeadSamplingParams( cfg_ref=float(cfg_ref), cfg=float(cfg), num_steps=int(steps), temperature=float(temperature) ) # Prefix optionnel prefix = None if ref_audio is not None: if isinstance(ref_audio, str): wav, sr = sf.read(ref_audio) else: sr, wav = ref_audio wav = _to_mono_float32(np.array(wav)) wav_t = torch.from_numpy(wav).to(device) import torchaudio if sr != pardi.sampling_rate: wav_t = torchaudio.functional.resample(wav_t, sr, pardi.sampling_rate) wav_t = wav_t.unsqueeze(0) with torch.inference_mode(): prefix_tokens = pardi.patchvae.encode(wav_t) prefix = (ref_text or "", prefix_tokens[0]) print(f"[debug] has_prefix={prefix is not None}, steps={steps}, cfg={cfg}, cfg_ref={cfg_ref}, T={temperature}, max_seq_len={max_seq_len}, seed={seed}") try: with torch.inference_mode(): wavs, _ = pardi.text_to_speech( [txt], prefix, max_seq_len=int(max_seq_len), velocity_head_sampling_params=vel_params, ) except Exception as e: import traceback, sys print("❌ text_to_speech failed:", e, file=sys.stderr) traceback.print_exc() raise gr.Error(f"Synthèse échouée: {type(e).__name__}: {e}") wav = wavs[0].detach().cpu().numpy() return (_sampling_rate, wav) def build_demo(): with gr.Blocks(title="Lina-speech / pardi-speech Demo") as demo: gr.Markdown( "## Lina-speech (pardi-speech) – Démo TTS\n" "Génère de l'audio à partir de texte, avec ou sans *prefix* (audio de référence).\n" "Paramètres avancés: *num_steps*, *CFG*, *température*, *max_seq_len*, *seed*." ) with gr.Row(): text = gr.Textbox(label="Texte à synthétiser", lines=4, placeholder="Tape ton texte ici…") with gr.Accordion("Prefix (optionnel)", open=False): ref_audio = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Audio de référence") ref_text = gr.Textbox(label="Texte du prefix (si connu)", placeholder="Transcription du prefix (optionnel)") with gr.Accordion("Options avancées", open=False): with gr.Row(): steps = gr.Slider(1, 50, value=10, step=1, label="num_steps") cfg = gr.Slider(0.5, 3.0, value=1.4, step=0.05, label="CFG (guidance)") cfg_ref = gr.Slider(0.5, 3.0, value=1.0, step=0.05, label="CFG (réf.)") with gr.Row(): temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="Température") max_seq_len = gr.Slider(50, 1200, value=300, step=10, label="max_seq_len (tokens audio)") seed = gr.Number(value=0, precision=0, label="Seed (reproductibilité)") lang_hint = gr.Dropdown(choices=["fr", "en"], value="fr", label="Langue (normalisation)") btn = gr.Button("Synthétiser") out_audio = gr.Audio(label="Sortie audio", type="numpy") demo.queue(default_concurrency_limit=1, max_size=32) btn.click( fn=synthesize, inputs=[text, ref_audio, ref_text, steps, cfg, cfg_ref, temperature, max_seq_len, seed, lang_hint], outputs=[out_audio] ) return demo if __name__ == "__main__": demo = build_demo() demo.launch() # retrigger 2025-10-31T16:46:57+01:00 # retrigger 2025-10-31T17:27:54+01:00 # retrigger 2025-10-31T17:29:41+01:00