Minte
ASR for Local Languages
cd5cc96
raw
history blame
2.88 kB
import traceback
import soundfile as sf
import torch
import numpy as np
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
import gradio as gr
import resampy
# Language code mapping
LANGUAGE_CODES = {
"Amharic": "amh",
"Swahili": "swh",
"Somali": "som",
"Afan Oromo": "orm",
"Tigrinya": "tir",
"Chichewa": "nya"
}
# --- Load ASR model ---
try:
model_id = "facebook/seamless-m4t-v2-large"
processor = AutoProcessor.from_pretrained(model_id)
asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id).to("cpu")
print("[INFO] ASR model loaded successfully.")
except Exception as e:
print("[ERROR] Failed to load ASR model:", e)
traceback.print_exc()
asr_model = None
processor = None
# --- Helper: ASR ---
def transcribe_audio(audio_file, language):
if asr_model is None or processor is None:
return "ASR Model loading failed"
try:
# Get language code
lang_code = LANGUAGE_CODES.get(language)
if not lang_code:
return f"Unsupported language: {language}"
# Read and preprocess audio
audio, sr = sf.read(audio_file)
if audio.ndim > 1:
audio = audio.mean(axis=1)
audio = resampy.resample(audio, sr, 16000)
# Process with model
inputs = processor(audios=audio, sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
generated_ids = asr_model.generate(**inputs, tgt_lang=lang_code)
# Decode the transcription
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return transcription.strip()
except Exception as e:
print(f"[ERROR] ASR transcription failed for {language}:", e)
traceback.print_exc()
return f"ASR failed: {str(e)[:50]}..."
# --- Gradio UI ---
with gr.Blocks(title="🌍 Multilingual ASR") as demo:
gr.Markdown("# 🌍 Multilingual Speech Recognition")
gr.Markdown("Transcribe audio in Amharic, Swahili, Somali, Afan Oromo, Tigrinya, or Chichewa")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Record or upload audio")
language_select = gr.Dropdown(
choices=list(LANGUAGE_CODES.keys()),
value="Swahili",
label="Select Language"
)
submit_btn = gr.Button("Transcribe", variant="primary")
with gr.Row():
with gr.Column():
transcription_output = gr.Textbox(label="Transcription")
submit_btn.click(
fn=transcribe_audio,
inputs=[audio_input, language_select],
outputs=transcription_output
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)