pardi-speech / app.py
mehdi999's picture
chore: retrigger build
8d935ec
import os
import gradio as gr
import numpy as np
import torch
import soundfile as sf
import spaces
from huggingface_hub import login
from pardi_speech import PardiSpeech, VelocityHeadSamplingParams # présent dans ce repo
MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "theodorr/pardi-speech-enfr-forbidden")
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
try:
login(token=HF_TOKEN)
print("✅ Logged to Hugging Face Hub.")
except Exception as e:
print("⚠️ HF login failed:", e)
_pardi = None
_sampling_rate = 24000
# --- Patch sécurité: init d'un état FLA par défaut si None au prefill ---
try:
from tts.model.simple_gla import SimpleGLADecoder
_old_prefill = SimpleGLADecoder.prefill
def _prefill_with_default(self, encoder_output, decoder_input, cache=None, crossatt_mask=None):
# Si aucun cache fourni, crée une structure minimale comprise par FLA
if cache is None or (isinstance(cache, dict) and cache.get("last_state") is None):
cache = {"last_state": {"conv_state": (None, None, None)}}
return _old_prefill(self, encoder_output, decoder_input, cache=cache, crossatt_mask=crossatt_mask)
SimpleGLADecoder.prefill = _prefill_with_default
print("🔧 Patched SimpleGLADecoder.prefill (default conv_state)")
except Exception as e:
print("⚠️ FLA prefill patch skipped:", e)
def _normalize_text(s: str, lang_hint: str = "fr") -> str:
s = (s or "").strip().lower()
try:
import re
from num2words import num2words
def repl(m): return num2words(int(m.group()), lang=lang_hint)
s = re.sub(r"\d+", repl, s)
except Exception:
pass
return s
def _load_model(device: str = "cuda"):
global _pardi, _sampling_rate
if _pardi is None:
_pardi = PardiSpeech.from_pretrained(MODEL_REPO_ID, map_location=device)
_sampling_rate = getattr(_pardi, "sampling_rate", 24000)
print(f"✅ PardiSpeech loaded on {device} (sr={_sampling_rate}).")
return _pardi
def _to_mono_float32(arr: np.ndarray) -> np.ndarray:
arr = arr.astype(np.float32)
if arr.ndim == 2:
arr = arr.mean(axis=1)
return arr
@spaces.GPU(duration=120)
def synthesize(
text: str,
ref_audio,
ref_text: str,
steps: int,
cfg: float,
cfg_ref: float,
temperature: float,
max_seq_len: int,
seed: int,
lang_hint: str
):
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(int(seed))
pardi = _load_model(device)
txt = _normalize_text(text, lang_hint=lang_hint)
# --- IMPORTANT : signature de VelocityHeadSamplingParams ---
# Dans ton notebook d’inférence, la classe attend (cfg_ref, cfg, num_steps) SANS 'temperature'.
# On essaie d’abord sans temperature, puis fallback si la classe en accepte une.
try:
vel_params = VelocityHeadSamplingParams(
cfg_ref=float(cfg_ref),
cfg=float(cfg),
num_steps=int(steps)
)
except TypeError:
vel_params = VelocityHeadSamplingParams(
cfg_ref=float(cfg_ref),
cfg=float(cfg),
num_steps=int(steps),
temperature=float(temperature)
)
# Prefix optionnel
prefix = None
if ref_audio is not None:
if isinstance(ref_audio, str):
wav, sr = sf.read(ref_audio)
else:
sr, wav = ref_audio
wav = _to_mono_float32(np.array(wav))
wav_t = torch.from_numpy(wav).to(device)
import torchaudio
if sr != pardi.sampling_rate:
wav_t = torchaudio.functional.resample(wav_t, sr, pardi.sampling_rate)
wav_t = wav_t.unsqueeze(0)
with torch.inference_mode():
prefix_tokens = pardi.patchvae.encode(wav_t)
prefix = (ref_text or "", prefix_tokens[0])
print(f"[debug] has_prefix={prefix is not None}, steps={steps}, cfg={cfg}, cfg_ref={cfg_ref}, T={temperature}, max_seq_len={max_seq_len}, seed={seed}")
try:
with torch.inference_mode():
wavs, _ = pardi.text_to_speech(
[txt],
prefix,
max_seq_len=int(max_seq_len),
velocity_head_sampling_params=vel_params,
)
except Exception as e:
import traceback, sys
print("❌ text_to_speech failed:", e, file=sys.stderr)
traceback.print_exc()
raise gr.Error(f"Synthèse échouée: {type(e).__name__}: {e}")
wav = wavs[0].detach().cpu().numpy()
return (_sampling_rate, wav)
def build_demo():
with gr.Blocks(title="Lina-speech / pardi-speech Demo") as demo:
gr.Markdown(
"## Lina-speech (pardi-speech) – Démo TTS\n"
"Génère de l'audio à partir de texte, avec ou sans *prefix* (audio de référence).\n"
"Paramètres avancés: *num_steps*, *CFG*, *température*, *max_seq_len*, *seed*."
)
with gr.Row():
text = gr.Textbox(label="Texte à synthétiser", lines=4, placeholder="Tape ton texte ici…")
with gr.Accordion("Prefix (optionnel)", open=False):
ref_audio = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Audio de référence")
ref_text = gr.Textbox(label="Texte du prefix (si connu)", placeholder="Transcription du prefix (optionnel)")
with gr.Accordion("Options avancées", open=False):
with gr.Row():
steps = gr.Slider(1, 50, value=10, step=1, label="num_steps")
cfg = gr.Slider(0.5, 3.0, value=1.4, step=0.05, label="CFG (guidance)")
cfg_ref = gr.Slider(0.5, 3.0, value=1.0, step=0.05, label="CFG (réf.)")
with gr.Row():
temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="Température")
max_seq_len = gr.Slider(50, 1200, value=300, step=10, label="max_seq_len (tokens audio)")
seed = gr.Number(value=0, precision=0, label="Seed (reproductibilité)")
lang_hint = gr.Dropdown(choices=["fr", "en"], value="fr", label="Langue (normalisation)")
btn = gr.Button("Synthétiser")
out_audio = gr.Audio(label="Sortie audio", type="numpy")
demo.queue(default_concurrency_limit=1, max_size=32)
btn.click(
fn=synthesize,
inputs=[text, ref_audio, ref_text, steps, cfg, cfg_ref, temperature, max_seq_len, seed, lang_hint],
outputs=[out_audio]
)
return demo
if __name__ == "__main__":
demo = build_demo()
demo.launch()
# retrigger 2025-10-31T16:46:57+01:00
# retrigger 2025-10-31T17:27:54+01:00
# retrigger 2025-10-31T17:29:41+01:00