Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import src.generate as generate | |
| import src.process as process | |
| import src.tts as tts | |
| # ------------------- UI printing functions ------------------- | |
| def clear_all(): | |
| # target, user_transcript, score_html, diff_html, result_html, | |
| # tts_text, clone_status, tts_audio | |
| return "", "", "", "", "", "", "", None | |
| def make_result_html(pass_threshold, passed, ratio): | |
| """Returns summary and score label.""" | |
| summary = ( | |
| f"β Correct (β₯ {int(pass_threshold * 100)}%)" | |
| if passed else | |
| f"β Not a match (need β₯ {int(pass_threshold * 100)}%)" | |
| ) | |
| score = f"Similarity: {ratio * 100:.1f}%" | |
| return summary, score | |
| def make_alignment_html(ref_tokens, hyp_tokens, alignments): | |
| """Returns HTML showing alignment between target and recognized user audio.""" | |
| out = [] | |
| no_match_html = ' <span style="background:#ffe0e0;text-decoration:line-through;">' | |
| match_html = ' <span style="background:#e0ffe0;">' | |
| for span in alignments: | |
| op, i1, i2, j1, j2 = span | |
| ref_string = " ".join(ref_tokens[i1:i2]) | |
| hyp_string = " ".join(hyp_tokens[j1:j2]) | |
| if op == "equal": | |
| out.append(" " + ref_string) | |
| elif op == "delete": | |
| out.append(no_match_html + ref_string + "</span>") | |
| elif op == "insert": | |
| out.append(match_html + hyp_string + "</span>") | |
| elif op == "replace": | |
| out.append(no_match_html + ref_string + "</span>") | |
| out.append(match_html + hyp_string + "</span>") | |
| html = '<div style="line-height:1.6;font-size:1rem;">' + "".join(out).strip() + "</div>" | |
| return html | |
| def make_html(sentence_match): | |
| """Build diff + results HTML.""" | |
| diff_html = make_alignment_html(sentence_match.target_tokens, | |
| sentence_match.user_tokens, | |
| sentence_match.alignments) | |
| result_html, score_html = make_result_html(sentence_match.pass_threshold, | |
| sentence_match.passed, | |
| sentence_match.ratio) | |
| return score_html, result_html, diff_html | |
| # ------------------- Core Check (English-only) ------------------- | |
| def get_user_transcript(audio_path: gr.Audio, target_sentence: str, model_id: str, device_pref: str) -> (str, str): | |
| """ASR for the input audio and basic validation.""" | |
| if not target_sentence: | |
| return "Please generate a sentence first.", "" | |
| if audio_path is None: | |
| return "Please start, record, then stop the audio recording before trying to transcribe.", "" | |
| user_transcript = process.run_asr(audio_path, model_id, device_pref) | |
| if isinstance(user_transcript, Exception): | |
| return f"Transcription failed: {user_transcript}", "" | |
| return "", user_transcript | |
| def transcribe_check(audio_path, target_sentence, model_id, device_pref, pass_threshold): | |
| """Transcribe user audio, compute match, and render results.""" | |
| error_msg, user_transcript = get_user_transcript(audio_path, target_sentence, model_id, device_pref) | |
| if error_msg: | |
| score_html = "" | |
| diff_html = "" | |
| result_html = error_msg | |
| else: | |
| sentence_match = process.SentenceMatcher(target_sentence, user_transcript, pass_threshold) | |
| score_html, result_html, diff_html = make_html(sentence_match) | |
| return user_transcript, score_html, result_html, diff_html | |
| # ------------------- Voice cloning gate ------------------- | |
| def clone_if_pass( | |
| audio_path, # ref voice (the same recorded clip) | |
| target_sentence, # sentence user was supposed to say | |
| user_transcript, # what ASR heard | |
| tts_text, # what we want to synthesize (in cloned voice) | |
| pass_threshold, # must meet or exceed this | |
| tts_model_id, # e.g., "coqui/XTTS-v2" | |
| tts_language, # e.g., "en" | |
| ): | |
| """ | |
| If user correctly read the target (>= threshold), clone their voice from the | |
| recorded audio and speak 'tts_text'. Otherwise, refuse. | |
| """ | |
| # Basic validations | |
| if audio_path is None: | |
| return None, "Record audio first (reference voice is required)." | |
| if not target_sentence: | |
| return None, "Generate a target sentence first." | |
| if not user_transcript: | |
| return None, "Transcribe first to verify the sentence." | |
| if not tts_text: | |
| return None, "Enter the sentence to synthesize." | |
| # Recompute pass/fail to avoid relying on UI state | |
| sm = process.SentenceMatcher(target_sentence, user_transcript, pass_threshold) | |
| if not sm.passed: | |
| return None, ( | |
| f"β Cloning blocked: your reading did not reach the threshold " | |
| f"({sm.ratio*100:.1f}% < {int(pass_threshold*100)}%)." | |
| ) | |
| # Run zero-shot cloning | |
| out = tts.run_tts_clone(audio_path, tts_text, model_id=tts_model_id, language=tts_language) | |
| if isinstance(out, Exception): | |
| return None, f"Voice cloning failed: {out}" | |
| sr, wav = out | |
| # Gradio Audio can take a tuple (sr, np.array) | |
| return (sr, wav), f"β Cloned and synthesized with {tts_model_id} ({tts_language})." | |
| # ------------------- UI ------------------- | |
| with gr.Blocks(title="Say the Sentence (English)") as demo: | |
| gr.Markdown( | |
| """ | |
| # π€ Say the Sentence (English) | |
| 1) Generate a sentence. | |
| 2) Record yourself reading it. | |
| 3) Transcribe & check your accuracy. | |
| 4) If matched, clone your voice to speak any sentence you enter. | |
| """ | |
| ) | |
| with gr.Row(): | |
| target = gr.Textbox(label="Target sentence", interactive=False, | |
| placeholder="Click 'Generate sentence'") | |
| with gr.Row(): | |
| btn_gen = gr.Button("π² Generate sentence", variant="primary") | |
| btn_clear = gr.Button("π§Ή Clear") | |
| with gr.Row(): | |
| audio = gr.Audio(sources=["microphone"], type="filepath", | |
| label="Record your voice") | |
| with gr.Accordion("Advanced settings", open=False): | |
| model_id = gr.Dropdown( | |
| choices=[ | |
| "openai/whisper-tiny.en", | |
| "openai/whisper-base.en", | |
| "distil-whisper/distil-small.en", | |
| ], | |
| value="openai/whisper-tiny.en", | |
| label="ASR model (English only)", | |
| ) | |
| device_pref = gr.Radio( | |
| choices=["auto", "cpu", "cuda"], | |
| value="auto", | |
| label="Device preference" | |
| ) | |
| pass_threshold = gr.Slider(0.50, 1.00, value=0.85, step=0.01, | |
| label="Match threshold") | |
| with gr.Row(): | |
| btn_check = gr.Button("β Transcribe & Check", variant="primary") | |
| with gr.Row(): | |
| user_transcript = gr.Textbox(label="Transcription", interactive=False) | |
| with gr.Row(): | |
| score_html = gr.Label(label="Score") | |
| result_html = gr.Label(label="Result") | |
| diff_html = gr.HTML( | |
| label="Word-level diff (red = expected but missing / green = extra or replacement)") | |
| gr.Markdown("## π Voice cloning (gated)") | |
| with gr.Row(): | |
| tts_text = gr.Textbox( | |
| label="Text to synthesize (voice clone)", | |
| placeholder="Type the sentence you want the cloned voice to say", | |
| ) | |
| with gr.Row(): | |
| tts_model_id = gr.Dropdown( | |
| choices=[ | |
| "coqui/XTTS-v2", | |
| # add others if you like, e.g., "myshell-ai/MeloTTS" | |
| ], | |
| value="coqui/XTTS-v2", | |
| label="TTS (voice cloning) model", | |
| ) | |
| tts_language = gr.Dropdown( | |
| choices=["en", "de", "fr", "es", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh"], | |
| value="en", | |
| label="Language", | |
| ) | |
| with gr.Row(): | |
| btn_clone = gr.Button("π Clone voice (if passed)", variant="secondary") | |
| with gr.Row(): | |
| tts_audio = gr.Audio(label="Cloned speech output", interactive=False) | |
| clone_status = gr.Label(label="Cloning status") | |
| # -------- Events -------- | |
| btn_gen.click(fn=generate.gen_sentence_set, outputs=target) | |
| btn_clear.click( | |
| fn=clear_all, | |
| outputs=[target, user_transcript, score_html, result_html, diff_html, tts_text, clone_status, tts_audio] | |
| ) | |
| btn_check.click( | |
| fn=transcribe_check, | |
| inputs=[audio, target, model_id, device_pref, pass_threshold], | |
| outputs=[user_transcript, score_html, result_html, diff_html] | |
| ) | |
| btn_clone.click( | |
| fn=clone_if_pass, | |
| inputs=[audio, target, user_transcript, tts_text, pass_threshold, tts_model_id, tts_language], | |
| outputs=[tts_audio, clone_status], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |