File size: 6,715 Bytes
92ec5fe
4af42e5
 
92ec5fe
fd1f480
4af42e5
3d734f0
fd1f480
6d29905
92ec5fe
9f2e2fc
5997b2e
fd1f480
 
6b8706f
fd1f480
 
831395c
fd1f480
 
 
 
2dc4aff
72dc299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92ec5fe
fd1f480
92ec5fe
fd1f480
92ec5fe
fd1f480
 
92ec5fe
 
 
 
fd1f480
 
 
 
 
 
 
 
92ec5fe
fd1f480
92ec5fe
 
fd1f480
92ec5fe
fd1f480
92ec5fe
 
 
 
 
 
 
 
 
 
fd1f480
92ec5fe
fd1f480
 
92ec5fe
fd1f480
 
92ec5fe
 
fd1f480
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92ec5fe
fd1f480
 
 
 
 
 
 
 
 
 
 
 
 
f6f4ba0
fd1f480
 
3d734f0
fd1f480
58c000c
fd1f480
 
 
 
 
 
 
 
92ec5fe
fd1f480
 
 
 
 
 
 
4af42e5
 
0a0019f
92ec5fe
fd1f480
 
 
92ec5fe
fd1f480
92ec5fe
 
 
 
fd1f480
92ec5fe
 
 
 
 
 
 
 
fd1f480
 
92ec5fe
 
 
 
 
fd1f480
92ec5fe
 
fd1f480
 
92ec5fe
4af42e5
 
 
fd1f480
 
39d8b01
c6af930
8d935ec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import os
import gradio as gr
import numpy as np
import torch
import soundfile as sf
import spaces

from huggingface_hub import login
from pardi_speech import PardiSpeech, VelocityHeadSamplingParams  # présent dans ce repo

MODEL_REPO_ID = os.environ.get("MODEL_REPO_ID", "theodorr/pardi-speech-enfr-forbidden")

HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
    try:
        login(token=HF_TOKEN)
        print("✅ Logged to Hugging Face Hub.")
    except Exception as e:
        print("⚠️ HF login failed:", e)

_pardi = None
_sampling_rate = 24000



# --- Patch sécurité: init d'un état FLA par défaut si None au prefill ---
try:
    from tts.model.simple_gla import SimpleGLADecoder
    _old_prefill = SimpleGLADecoder.prefill

    def _prefill_with_default(self, encoder_output, decoder_input, cache=None, crossatt_mask=None):
        # Si aucun cache fourni, crée une structure minimale comprise par FLA
        if cache is None or (isinstance(cache, dict) and cache.get("last_state") is None):
            cache = {"last_state": {"conv_state": (None, None, None)}}
        return _old_prefill(self, encoder_output, decoder_input, cache=cache, crossatt_mask=crossatt_mask)

    SimpleGLADecoder.prefill = _prefill_with_default
    print("🔧 Patched SimpleGLADecoder.prefill (default conv_state)")
except Exception as e:
    print("⚠️ FLA prefill patch skipped:", e)

def _normalize_text(s: str, lang_hint: str = "fr") -> str:
    s = (s or "").strip().lower()
    try:
        import re
        from num2words import num2words
        def repl(m): return num2words(int(m.group()), lang=lang_hint)
        s = re.sub(r"\d+", repl, s)
    except Exception:
        pass
    return s

def _load_model(device: str = "cuda"):
    global _pardi, _sampling_rate
    if _pardi is None:
        _pardi = PardiSpeech.from_pretrained(MODEL_REPO_ID, map_location=device)
        _sampling_rate = getattr(_pardi, "sampling_rate", 24000)
        print(f"✅ PardiSpeech loaded on {device} (sr={_sampling_rate}).")
    return _pardi

def _to_mono_float32(arr: np.ndarray) -> np.ndarray:
    arr = arr.astype(np.float32)
    if arr.ndim == 2:
        arr = arr.mean(axis=1)
    return arr

@spaces.GPU(duration=120)
def synthesize(
    text: str,
    ref_audio,
    ref_text: str,
    steps: int,
    cfg: float,
    cfg_ref: float,
    temperature: float,
    max_seq_len: int,
    seed: int,
    lang_hint: str
):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.manual_seed(int(seed))

    pardi = _load_model(device)
    txt = _normalize_text(text, lang_hint=lang_hint)


    # --- IMPORTANT : signature de VelocityHeadSamplingParams ---
    # Dans ton notebook d’inférence, la classe attend (cfg_ref, cfg, num_steps) SANS 'temperature'.
    # On essaie d’abord sans temperature, puis fallback si la classe en accepte une.
    try:
        vel_params = VelocityHeadSamplingParams(
            cfg_ref=float(cfg_ref),
            cfg=float(cfg),
            num_steps=int(steps)
        )
    except TypeError:
        vel_params = VelocityHeadSamplingParams(
            cfg_ref=float(cfg_ref),
            cfg=float(cfg),
            num_steps=int(steps),
            temperature=float(temperature)
        )

    # Prefix optionnel
    prefix = None
    if ref_audio is not None:
        if isinstance(ref_audio, str):
            wav, sr = sf.read(ref_audio)
        else:
            sr, wav = ref_audio
        wav = _to_mono_float32(np.array(wav))
        wav_t = torch.from_numpy(wav).to(device)
        import torchaudio
        if sr != pardi.sampling_rate:
            wav_t = torchaudio.functional.resample(wav_t, sr, pardi.sampling_rate)
        wav_t = wav_t.unsqueeze(0)
        with torch.inference_mode():
            prefix_tokens = pardi.patchvae.encode(wav_t)
        prefix = (ref_text or "", prefix_tokens[0])

    print(f"[debug] has_prefix={prefix is not None}, steps={steps}, cfg={cfg}, cfg_ref={cfg_ref}, T={temperature}, max_seq_len={max_seq_len}, seed={seed}")

    try:
        with torch.inference_mode():
            wavs, _ = pardi.text_to_speech(
                [txt],
                prefix,
                max_seq_len=int(max_seq_len),
                velocity_head_sampling_params=vel_params,
            )
    except Exception as e:
        import traceback, sys
        print("❌ text_to_speech failed:", e, file=sys.stderr)
        traceback.print_exc()
        raise gr.Error(f"Synthèse échouée: {type(e).__name__}: {e}")

    wav = wavs[0].detach().cpu().numpy()
    return (_sampling_rate, wav)

def build_demo():
    with gr.Blocks(title="Lina-speech / pardi-speech Demo") as demo:
        gr.Markdown(
            "## Lina-speech (pardi-speech) – Démo TTS\n"
            "Génère de l'audio à partir de texte, avec ou sans *prefix* (audio de référence).\n"
            "Paramètres avancés: *num_steps*, *CFG*, *température*, *max_seq_len*, *seed*."
        )

        with gr.Row():
            text = gr.Textbox(label="Texte à synthétiser", lines=4, placeholder="Tape ton texte ici…")
        with gr.Accordion("Prefix (optionnel)", open=False):
            ref_audio = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Audio de référence")
            ref_text  = gr.Textbox(label="Texte du prefix (si connu)", placeholder="Transcription du prefix (optionnel)")
        with gr.Accordion("Options avancées", open=False):
            with gr.Row():
                steps = gr.Slider(1, 50, value=10, step=1, label="num_steps")
                cfg = gr.Slider(0.5, 3.0, value=1.4, step=0.05, label="CFG (guidance)")
                cfg_ref = gr.Slider(0.5, 3.0, value=1.0, step=0.05, label="CFG (réf.)")
            with gr.Row():
                temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="Température")
                max_seq_len = gr.Slider(50, 1200, value=300, step=10, label="max_seq_len (tokens audio)")
                seed = gr.Number(value=0, precision=0, label="Seed (reproductibilité)")
            lang_hint = gr.Dropdown(choices=["fr", "en"], value="fr", label="Langue (normalisation)")

        btn = gr.Button("Synthétiser")
        out_audio = gr.Audio(label="Sortie audio", type="numpy")

        demo.queue(default_concurrency_limit=1, max_size=32)

        btn.click(
            fn=synthesize,
            inputs=[text, ref_audio, ref_text, steps, cfg, cfg_ref, temperature, max_seq_len, seed, lang_hint],
            outputs=[out_audio]
        )
    return demo

if __name__ == "__main__":
    demo = build_demo()
    demo.launch()
# retrigger 2025-10-31T16:46:57+01:00
# retrigger 2025-10-31T17:27:54+01:00
# retrigger 2025-10-31T17:29:41+01:00