Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import soundfile as sf | |
| import torch | |
| from transformers import Wav2Vec2ForCTC, AutoProcessor | |
| # Assuming 'transcribe' was defined in a previous cell. | |
| # If not, define it here or import it from the correct module. | |
| # Create a placeholder for ASR_LANGUAGES if it's not defined elsewhere. | |
| ASR_LANGUAGES = {"eng": "English", "swh": "Swahili"} # Replace with your actual languages | |
| # ✅ Define or Re-define the `transcribe` function within this cell | |
| MODEL_ID = "facebook/mms-1b-all" # Make sure this is the same model ID used for training | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID) | |
| def transcribe(audio_path, language): | |
| """Transcribes an audio file using the fine-tuned model.""" | |
| # Set the target language based on user selection | |
| if language: | |
| target_lang = language.split(" ")[0] # Extract language code | |
| processor.tokenizer.set_target_lang(target_lang) | |
| if target_lang != "eng": # Load adapter if not English | |
| model.load_adapter(target_lang) | |
| audio, samplerate = sf.read(audio_path) | |
| inputs = processor(audio, sampling_rate=samplerate, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs).logits | |
| ids = torch.argmax(outputs, dim=-1)[0] | |
| return processor.decode(ids) | |
| mms_transcribe = gr.Interface( | |
| fn=transcribe, | |
| inputs=[ | |
| gr.Audio(), | |
| gr.Dropdown( | |
| [f"{k} ({v})" for k, v in ASR_LANGUAGES.items()], | |
| label="Language", | |
| value="eng English", | |
| ), | |
| ], | |
| outputs="text", | |
| title="Speech-to-Text Transcription", | |
| description="Transcribe audio input into text.", | |
| allow_flagging="never", | |
| ) | |
| with gr.Blocks() as demo: | |
| mms_transcribe.render() | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch() |