import gradio as gr import torch import numpy as np import librosa import soundfile as sf # likely needed by the pipeline or local saving from transformers import pipeline, VitsModel, AutoTokenizer from datasets import load_dataset # ------------------------------------------------------ # 1. ASR Pipeline (English) - Wav2Vec2 # ------------------------------------------------------ asr = pipeline( "automatic-speech-recognition", model="facebook/wav2vec2-base-960h" ) # ------------------------------------------------------ # 2. Translation Models (3 languages) # ------------------------------------------------------ translation_models = { "Spanish": "Helsinki-NLP/opus-mt-en-es", "Chinese": "Helsinki-NLP/opus-mt-en-zh", "Japanese": "Helsinki-NLP/opus-mt-en-ja" } translation_tasks = { "Spanish": "translation_en_to_es", "Chinese": "translation_en_to_zh", "Japanese": "translation_en_to_ja" } # ------------------------------------------------------ # 3. TTS Configuration # - Spanish: VITS-based MMS TTS # - Chinese & Japanese: Microsoft SpeechT5 # ------------------------------------------------------ # We'll store them as keys for convenience SPANISH_KEY = "Spanish" CHINESE_KEY = "Chinese" JAPANESE_KEY = "Japanese" # VITS config for Spanish only mms_spanish_config = { "model_id": "facebook/mms-tts-spa", "architecture": "vits" } # ------------------------------------------------------ # 4. Create TTS Pipelines / Models Once (Caching) # ------------------------------------------------------ translator_cache = {} vits_model_cache = None # for Spanish speech_t5_pipeline_cache = None # for Chinese/Japanese speech_t5_speaker_embedding = None def get_translator(lang): """ Return a cached MarianMT translator for the specified language. """ if lang in translator_cache: return translator_cache[lang] model_name = translation_models[lang] task_name = translation_tasks[lang] translator = pipeline(task_name, model=model_name) translator_cache[lang] = translator return translator def load_spanish_vits(): """ Load and cache the Spanish VITS model + tokenizer (facebook/mms-tts-spa). """ global vits_model_cache if vits_model_cache is not None: return vits_model_cache try: model_id = mms_spanish_config["model_id"] model = VitsModel.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) vits_model_cache = (model, tokenizer) except Exception as e: raise RuntimeError(f"Failed to load Spanish TTS model {mms_spanish_config['model_id']}: {e}") return vits_model_cache def load_speech_t5_pipeline(): """ Load and cache the Microsoft SpeechT5 text-to-speech pipeline and a default speaker embedding. """ global speech_t5_pipeline_cache, speech_t5_speaker_embedding if speech_t5_pipeline_cache is not None and speech_t5_speaker_embedding is not None: return speech_t5_pipeline_cache, speech_t5_speaker_embedding try: # Create the pipeline # The pipeline is named "text-to-speech" in Transformers >= 4.29 t5_pipe = pipeline("text-to-speech", model="microsoft/speecht5_tts") except Exception as e: raise RuntimeError(f"Failed to load Microsoft SpeechT5 pipeline: {e}") # Load a default speaker embedding try: embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") # Just pick an arbitrary index for speaker embedding speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0) except Exception as e: raise RuntimeError(f"Failed to load default speaker embedding: {e}") speech_t5_pipeline_cache = t5_pipe speech_t5_speaker_embedding = speaker_embedding return t5_pipe, speaker_embedding # ------------------------------------------------------ # 5. TTS Inference Helpers # ------------------------------------------------------ def run_vits_inference(text): """ For Spanish TTS using MMS (facebook/mms-tts-spa). """ model, tokenizer = load_spanish_vits() inputs = tokenizer(text, return_tensors="pt") with torch.no_grad(): output = model(**inputs) if not hasattr(output, "waveform"): raise RuntimeError("VITS output does not contain 'waveform'.") waveform = output.waveform.squeeze().cpu().numpy() sample_rate = 16000 return sample_rate, waveform def run_speecht5_inference(text): """ For Chinese & Japanese TTS using Microsoft SpeechT5 pipeline. """ t5_pipe, speaker_embedding = load_speech_t5_pipeline() # The pipeline returns a dict with 'audio' (numpy) and 'sampling_rate' result = t5_pipe( text, forward_params={"speaker_embeddings": speaker_embedding} ) waveform = result["audio"] sample_rate = result["sampling_rate"] return sample_rate, waveform # ------------------------------------------------------ # 6. Main Prediction Function # ------------------------------------------------------ def predict(audio, text, target_language): """ 1. Get English text (ASR if audio provided, else text). 2. Translate to target_language. 3. TTS with the chosen approach (VITS for Spanish, SpeechT5 for Chinese/Japanese). """ # Step 1: English text if text.strip(): english_text = text.strip() elif audio is not None: sample_rate, audio_data = audio # Convert to float32 if needed if audio_data.dtype not in [np.float32, np.float64]: audio_data = audio_data.astype(np.float32) # Stereo -> mono if len(audio_data.shape) > 1 and audio_data.shape[1] > 1: audio_data = np.mean(audio_data, axis=1) # Resample to 16k if needed if sample_rate != 16000: audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000) asr_input = {"array": audio_data, "sampling_rate": 16000} asr_result = asr(asr_input) english_text = asr_result["text"] else: return "No input provided.", "", None # Step 2: Translate translator = get_translator(target_language) try: translation_result = translator(english_text) translated_text = translation_result[0]["translation_text"] except Exception as e: return english_text, f"Translation error: {e}", None # Step 3: TTS try: if target_language == SPANISH_KEY: sr, waveform = run_vits_inference(translated_text) else: # Chinese or Japanese -> SpeechT5 sr, waveform = run_speecht5_inference(translated_text) except Exception as e: return english_text, translated_text, f"TTS error: {e}" return english_text, translated_text, (sr, waveform) # ------------------------------------------------------ # 7. Gradio Interface # ------------------------------------------------------ iface = gr.Interface( fn=predict, inputs=[ gr.Audio(type="numpy", label="Record/Upload English Audio (optional)"), gr.Textbox(lines=4, placeholder="Or enter English text here", label="English Text Input (optional)"), gr.Dropdown(choices=["Spanish", "Chinese", "Japanese"], value="Spanish", label="Target Language") ], outputs=[ gr.Textbox(label="English Transcription"), gr.Textbox(label="Translation (Target Language)"), gr.Audio(label="Synthesized Speech") ], title="Multimodal Language Learning Aid", description=( "1. Transcribes English speech using Wav2Vec2 (or takes English text).\n" "2. Translates to Spanish, Chinese, or Japanese (via Helsinki-NLP models).\n" "3. Synthesizes speech:\n" " - Spanish -> facebook/mms-tts-spa (VITS)\n" " - Chinese & Japanese -> microsoft/speecht5_tts (SpeechT5)\n\n" "Note: SpeechT5 is not officially trained for Japanese, so results may vary.\n" "You can also try inputting short, clear audio for best ASR results." ), allow_flagging="never" ) if __name__ == "__main__": iface.launch(server_name="0.0.0.0", server_port=7860)