RepeatAfterMe / app.py
frimelle's picture
frimelle HF Staff
try to recreate the code
0d9ff36
import gradio as gr
import src.generate as generate
import src.process as process
import src.tts as tts
# ------------------- UI printing functions -------------------
def clear_all():
# target, user_transcript, score_html, diff_html, result_html,
# tts_text, clone_status, tts_audio
return "", "", "", "", "", "", "", None
def make_result_html(pass_threshold, passed, ratio):
"""Returns summary and score label."""
summary = (
f"βœ… Correct (β‰₯ {int(pass_threshold * 100)}%)"
if passed else
f"❌ Not a match (need β‰₯ {int(pass_threshold * 100)}%)"
)
score = f"Similarity: {ratio * 100:.1f}%"
return summary, score
def make_alignment_html(ref_tokens, hyp_tokens, alignments):
"""Returns HTML showing alignment between target and recognized user audio."""
out = []
no_match_html = ' <span style="background:#ffe0e0;text-decoration:line-through;">'
match_html = ' <span style="background:#e0ffe0;">'
for span in alignments:
op, i1, i2, j1, j2 = span
ref_string = " ".join(ref_tokens[i1:i2])
hyp_string = " ".join(hyp_tokens[j1:j2])
if op == "equal":
out.append(" " + ref_string)
elif op == "delete":
out.append(no_match_html + ref_string + "</span>")
elif op == "insert":
out.append(match_html + hyp_string + "</span>")
elif op == "replace":
out.append(no_match_html + ref_string + "</span>")
out.append(match_html + hyp_string + "</span>")
html = '<div style="line-height:1.6;font-size:1rem;">' + "".join(out).strip() + "</div>"
return html
def make_html(sentence_match):
"""Build diff + results HTML."""
diff_html = make_alignment_html(sentence_match.target_tokens,
sentence_match.user_tokens,
sentence_match.alignments)
result_html, score_html = make_result_html(sentence_match.pass_threshold,
sentence_match.passed,
sentence_match.ratio)
return score_html, result_html, diff_html
# ------------------- Core Check (English-only) -------------------
def get_user_transcript(audio_path: gr.Audio, target_sentence: str, model_id: str, device_pref: str) -> (str, str):
"""ASR for the input audio and basic validation."""
if not target_sentence:
return "Please generate a sentence first.", ""
if audio_path is None:
return "Please start, record, then stop the audio recording before trying to transcribe.", ""
user_transcript = process.run_asr(audio_path, model_id, device_pref)
if isinstance(user_transcript, Exception):
return f"Transcription failed: {user_transcript}", ""
return "", user_transcript
def transcribe_check(audio_path, target_sentence, model_id, device_pref, pass_threshold):
"""Transcribe user audio, compute match, and render results."""
error_msg, user_transcript = get_user_transcript(audio_path, target_sentence, model_id, device_pref)
if error_msg:
score_html = ""
diff_html = ""
result_html = error_msg
else:
sentence_match = process.SentenceMatcher(target_sentence, user_transcript, pass_threshold)
score_html, result_html, diff_html = make_html(sentence_match)
return user_transcript, score_html, result_html, diff_html
# ------------------- Voice cloning gate -------------------
def clone_if_pass(
audio_path, # ref voice (the same recorded clip)
target_sentence, # sentence user was supposed to say
user_transcript, # what ASR heard
tts_text, # what we want to synthesize (in cloned voice)
pass_threshold, # must meet or exceed this
tts_model_id, # e.g., "coqui/XTTS-v2"
tts_language, # e.g., "en"
):
"""
If user correctly read the target (>= threshold), clone their voice from the
recorded audio and speak 'tts_text'. Otherwise, refuse.
"""
# Basic validations
if audio_path is None:
return None, "Record audio first (reference voice is required)."
if not target_sentence:
return None, "Generate a target sentence first."
if not user_transcript:
return None, "Transcribe first to verify the sentence."
if not tts_text:
return None, "Enter the sentence to synthesize."
# Recompute pass/fail to avoid relying on UI state
sm = process.SentenceMatcher(target_sentence, user_transcript, pass_threshold)
if not sm.passed:
return None, (
f"❌ Cloning blocked: your reading did not reach the threshold "
f"({sm.ratio*100:.1f}% < {int(pass_threshold*100)}%)."
)
# Run zero-shot cloning
out = tts.run_tts_clone(audio_path, tts_text, model_id=tts_model_id, language=tts_language)
if isinstance(out, Exception):
return None, f"Voice cloning failed: {out}"
sr, wav = out
# Gradio Audio can take a tuple (sr, np.array)
return (sr, wav), f"βœ… Cloned and synthesized with {tts_model_id} ({tts_language})."
# ------------------- UI -------------------
with gr.Blocks(title="Say the Sentence (English)") as demo:
gr.Markdown(
"""
# 🎀 Say the Sentence (English)
1) Generate a sentence.
2) Record yourself reading it.
3) Transcribe & check your accuracy.
4) If matched, clone your voice to speak any sentence you enter.
"""
)
with gr.Row():
target = gr.Textbox(label="Target sentence", interactive=False,
placeholder="Click 'Generate sentence'")
with gr.Row():
btn_gen = gr.Button("🎲 Generate sentence", variant="primary")
btn_clear = gr.Button("🧹 Clear")
with gr.Row():
audio = gr.Audio(sources=["microphone"], type="filepath",
label="Record your voice")
with gr.Accordion("Advanced settings", open=False):
model_id = gr.Dropdown(
choices=[
"openai/whisper-tiny.en",
"openai/whisper-base.en",
"distil-whisper/distil-small.en",
],
value="openai/whisper-tiny.en",
label="ASR model (English only)",
)
device_pref = gr.Radio(
choices=["auto", "cpu", "cuda"],
value="auto",
label="Device preference"
)
pass_threshold = gr.Slider(0.50, 1.00, value=0.85, step=0.01,
label="Match threshold")
with gr.Row():
btn_check = gr.Button("βœ… Transcribe & Check", variant="primary")
with gr.Row():
user_transcript = gr.Textbox(label="Transcription", interactive=False)
with gr.Row():
score_html = gr.Label(label="Score")
result_html = gr.Label(label="Result")
diff_html = gr.HTML(
label="Word-level diff (red = expected but missing / green = extra or replacement)")
gr.Markdown("## πŸ” Voice cloning (gated)")
with gr.Row():
tts_text = gr.Textbox(
label="Text to synthesize (voice clone)",
placeholder="Type the sentence you want the cloned voice to say",
)
with gr.Row():
tts_model_id = gr.Dropdown(
choices=[
"coqui/XTTS-v2",
# add others if you like, e.g., "myshell-ai/MeloTTS"
],
value="coqui/XTTS-v2",
label="TTS (voice cloning) model",
)
tts_language = gr.Dropdown(
choices=["en", "de", "fr", "es", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh"],
value="en",
label="Language",
)
with gr.Row():
btn_clone = gr.Button("πŸ” Clone voice (if passed)", variant="secondary")
with gr.Row():
tts_audio = gr.Audio(label="Cloned speech output", interactive=False)
clone_status = gr.Label(label="Cloning status")
# -------- Events --------
btn_gen.click(fn=generate.gen_sentence_set, outputs=target)
btn_clear.click(
fn=clear_all,
outputs=[target, user_transcript, score_html, result_html, diff_html, tts_text, clone_status, tts_audio]
)
btn_check.click(
fn=transcribe_check,
inputs=[audio, target, model_id, device_pref, pass_threshold],
outputs=[user_transcript, score_html, result_html, diff_html]
)
btn_clone.click(
fn=clone_if_pass,
inputs=[audio, target, user_transcript, tts_text, pass_threshold, tts_model_id, tts_language],
outputs=[tts_audio, clone_status],
)
if __name__ == "__main__":
demo.launch()