Minte commited on
Commit
d191a12
Β·
1 Parent(s): 8f055e9

Enhance multilingual ASR functionality with improved language configuration and model loading

Browse files
Files changed (1) hide show
  1. app.py +248 -41
app.py CHANGED
@@ -2,89 +2,296 @@ import traceback
2
  import soundfile as sf
3
  import torch
4
  import numpy as np
5
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
6
  import gradio as gr
7
  import resampy
8
 
9
- # Language code mapping
10
- LANGUAGE_CODES = {
11
- "Amharic": "amh",
12
- "Swahili": "swh",
13
- "Somali": "som",
14
- "Afan Oromo": "orm",
15
- "Tigrinya": "tir",
16
- "Chichewa": "nya"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  }
18
 
19
- # --- Load ASR model ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  try:
21
- model_id = "facebook/seamless-m4t-v2-large"
22
- processor = AutoProcessor.from_pretrained(model_id)
23
- asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id).to("cpu")
24
- print("[INFO] ASR model loaded successfully.")
 
25
  except Exception as e:
26
- print("[ERROR] Failed to load ASR model:", e)
27
  traceback.print_exc()
28
- asr_model = None
29
- processor = None
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # --- Helper: ASR ---
32
  def transcribe_audio(audio_file, language):
33
- if asr_model is None or processor is None:
34
- return "ASR Model loading failed"
 
 
 
 
 
35
 
36
  try:
37
- # Get language code
38
- lang_code = LANGUAGE_CODES.get(language)
39
- if not lang_code:
40
- return f"Unsupported language: {language}"
41
-
42
  # Read and preprocess audio
43
  audio, sr = sf.read(audio_file)
44
  if audio.ndim > 1:
45
  audio = audio.mean(axis=1)
46
  audio = resampy.resample(audio, sr, 16000)
47
 
48
- # Process with model
49
- inputs = processor(audios=audio, sampling_rate=16000, return_tensors="pt")
50
 
51
- with torch.no_grad():
52
- generated_ids = asr_model.generate(**inputs, tgt_lang=lang_code)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- # Decode the transcription
55
- transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
56
  return transcription.strip()
57
 
58
  except Exception as e:
59
  print(f"[ERROR] ASR transcription failed for {language}:", e)
60
  traceback.print_exc()
61
- return f"ASR failed: {str(e)[:50]}..."
62
 
63
- # --- Gradio UI ---
64
- with gr.Blocks(title="🌍 Multilingual ASR") as demo:
65
- gr.Markdown("# 🌍 Multilingual Speech Recognition")
66
- gr.Markdown("Transcribe audio in Amharic, Swahili, Somali, Afan Oromo, Tigrinya, or Chichewa")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
68
  with gr.Row():
69
  with gr.Column():
70
- audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Record or upload audio")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  language_select = gr.Dropdown(
72
- choices=list(LANGUAGE_CODES.keys()),
73
  value="Swahili",
74
- label="Select Language"
 
 
 
 
 
 
 
75
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- submit_btn = gr.Button("Transcribe", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
 
79
  with gr.Row():
80
  with gr.Column():
81
- transcription_output = gr.Textbox(label="Transcription")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  submit_btn.click(
84
  fn=transcribe_audio,
85
  inputs=[audio_input, language_select],
86
  outputs=transcription_output
 
 
 
87
  )
88
 
89
  if __name__ == "__main__":
90
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
2
  import soundfile as sf
3
  import torch
4
  import numpy as np
5
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, Wav2Vec2ForCTC, Wav2Vec2Processor
6
  import gradio as gr
7
  import resampy
8
 
9
+ # Language configuration
10
+ LANGUAGE_CONFIG = {
11
+ "Amharic": {
12
+ "code": "amh",
13
+ "model": "facebook/seamless-m4t-v2-large",
14
+ "available": True
15
+ },
16
+ "Swahili": {
17
+ "code": "swh",
18
+ "model": "facebook/seamless-m4t-v2-large",
19
+ "available": True
20
+ },
21
+ "Somali": {
22
+ "code": "som",
23
+ "model": "facebook/seamless-m4t-v2-large",
24
+ "available": True
25
+ },
26
+ "Afan Oromo": {
27
+ "code": "orm",
28
+ "model": "osanseviero/seamless-copy",
29
+ "available": True
30
+ },
31
+ "Tigrinya": {
32
+ "code": "tir",
33
+ "model": "facebook/seamless-m4t-v2-large",
34
+ "available": False,
35
+ "message": "Tigrinya transcription is not currently available"
36
+ },
37
+ "Chichewa": {
38
+ "code": "nya",
39
+ "model": "dmatekenya/wav2vec2-large-xls-r-300m-chichewa",
40
+ "available": True
41
+ }
42
  }
43
 
44
+ # Initialize models
45
+ models = {}
46
+ processors = {}
47
+
48
+ print("[INFO] Loading transcription models...")
49
+
50
+ # Load SeamlessM4T model for Amharic, Swahili, Somali
51
+ try:
52
+ seamless_model_id = "facebook/seamless-m4t-v2-large"
53
+ seamless_processor = AutoProcessor.from_pretrained(seamless_model_id)
54
+ seamless_model = AutoModelForSpeechSeq2Seq.from_pretrained(seamless_model_id).to("cpu")
55
+
56
+ for lang, config in LANGUAGE_CONFIG.items():
57
+ if config["available"] and config["model"] == seamless_model_id:
58
+ models[lang] = seamless_model
59
+ processors[lang] = seamless_processor
60
+
61
+ print("[SUCCESS] SeamlessM4T model loaded for Amharic, Swahili, Somali")
62
+ except Exception as e:
63
+ print("[ERROR] Failed to load SeamlessM4T model:", e)
64
+ traceback.print_exc()
65
+
66
+ # Load Afan Oromo model
67
  try:
68
+ oromo_processor = AutoProcessor.from_pretrained("osanseviero/seamless-copy")
69
+ oromo_model = AutoModelForSpeechSeq2Seq.from_pretrained("osanseviero/seamless-copy").to("cpu")
70
+ models["Afan Oromo"] = oromo_model
71
+ processors["Afan Oromo"] = oromo_processor
72
+ print("[SUCCESS] Afan Oromo model loaded successfully")
73
  except Exception as e:
74
+ print("[ERROR] Failed to load Afan Oromo model:", e)
75
  traceback.print_exc()
76
+ LANGUAGE_CONFIG["Afan Oromo"]["available"] = False
77
+
78
+ # Load Chichewa model
79
+ try:
80
+ chichewa_processor = Wav2Vec2Processor.from_pretrained("dmatekenya/wav2vec2-large-xls-r-300m-chichewa")
81
+ chichewa_model = Wav2Vec2ForCTC.from_pretrained("dmatekenya/wav2vec2-large-xls-r-300m-chichewa").to("cpu")
82
+ models["Chichewa"] = chichewa_model
83
+ processors["Chichewa"] = chichewa_processor
84
+ print("[SUCCESS] Chichewa model loaded successfully")
85
+ except Exception as e:
86
+ print("[ERROR] Failed to load Chichewa model:", e)
87
+ traceback.print_exc()
88
+ LANGUAGE_CONFIG["Chichewa"]["available"] = False
89
 
90
  # --- Helper: ASR ---
91
  def transcribe_audio(audio_file, language):
92
+ if language not in models or language not in processors:
93
+ return f"Model for {language} is not available"
94
+
95
+ if not LANGUAGE_CONFIG[language]["available"]:
96
+ if language == "Tigrinya":
97
+ return LANGUAGE_CONFIG[language]["message"]
98
+ return f"{language} transcription is currently unavailable"
99
 
100
  try:
 
 
 
 
 
101
  # Read and preprocess audio
102
  audio, sr = sf.read(audio_file)
103
  if audio.ndim > 1:
104
  audio = audio.mean(axis=1)
105
  audio = resampy.resample(audio, sr, 16000)
106
 
107
+ model = models[language]
108
+ processor = processors[language]
109
 
110
+ # Handle different model types
111
+ if language == "Chichewa":
112
+ # Wav2Vec2 processing
113
+ inputs = processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
114
+ with torch.no_grad():
115
+ logits = model(**inputs).logits
116
+ predicted_ids = torch.argmax(logits, dim=-1)
117
+ transcription = processor.batch_decode(predicted_ids)[0]
118
+
119
+ elif language == "Afan Oromo":
120
+ # Seamless-copy processing
121
+ inputs = processor(audios=audio, sampling_rate=16000, return_tensors="pt")
122
+ with torch.no_grad():
123
+ generated_ids = model.generate(**inputs, tgt_lang=LANGUAGE_CONFIG[language]["code"])
124
+ transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
125
+
126
+ else:
127
+ # Standard SeamlessM4T processing
128
+ inputs = processor(audios=audio, sampling_rate=16000, return_tensors="pt")
129
+ with torch.no_grad():
130
+ generated_ids = model.generate(**inputs, tgt_lang=LANGUAGE_CONFIG[language]["code"])
131
+ transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
132
 
 
 
133
  return transcription.strip()
134
 
135
  except Exception as e:
136
  print(f"[ERROR] ASR transcription failed for {language}:", e)
137
  traceback.print_exc()
138
+ return f"Transcription failed: {str(e)[:100]}..."
139
 
140
+ # --- Beautiful Gradio UI ---
141
+ with gr.Blocks(
142
+ theme=gr.themes.Soft(
143
+ primary_hue="blue",
144
+ secondary_hue="green",
145
+ ),
146
+ title="🌍 GihonTech - Multilingual Speech Recognition",
147
+ css="""
148
+ .gradio-container {
149
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
150
+ }
151
+ .header {
152
+ text-align: center;
153
+ padding: 20px;
154
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
155
+ border-radius: 15px;
156
+ margin-bottom: 20px;
157
+ color: white;
158
+ }
159
+ .language-card {
160
+ background: white;
161
+ padding: 15px;
162
+ border-radius: 10px;
163
+ margin: 10px 0;
164
+ border-left: 4px solid #667eea;
165
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
166
+ }
167
+ .unavailable {
168
+ background: #ffebee;
169
+ border-left: 4px solid #f44336;
170
+ }
171
+ .available {
172
+ background: #e8f5e8;
173
+ border-left: 4px solid #4caf50;
174
+ }
175
+ """
176
+ ) as demo:
177
 
178
+ # Header Section
179
  with gr.Row():
180
  with gr.Column():
181
+ gr.HTML("""
182
+ <div class="header">
183
+ <h1>🌍 GihonTech Multilingual Speech Recognition</h1>
184
+ <p>Transcribe audio in multiple African languages with state-of-the-art AI models</p>
185
+ </div>
186
+ """)
187
+
188
+ # Main Content
189
+ with gr.Row():
190
+ # Input Section
191
+ with gr.Column(scale=1):
192
+ gr.Markdown("### 🎀 Upload Audio")
193
+
194
+ audio_input = gr.Audio(
195
+ sources=["microphone", "upload"],
196
+ type="filepath",
197
+ label="Record or Upload Audio",
198
+ elem_classes="audio-input"
199
+ )
200
+
201
  language_select = gr.Dropdown(
202
+ choices=list(LANGUAGE_CONFIG.keys()),
203
  value="Swahili",
204
+ label="Select Language",
205
+ info="Choose the language of your audio"
206
+ )
207
+
208
+ submit_btn = gr.Button(
209
+ "🎯 Transcribe Audio",
210
+ variant="primary",
211
+ size="lg"
212
  )
213
+
214
+ # Output Section
215
+ with gr.Column(scale=1):
216
+ gr.Markdown("### πŸ“ Transcription Result")
217
+ transcription_output = gr.Textbox(
218
+ label="Transcribed Text",
219
+ placeholder="Your transcription will appear here...",
220
+ lines=5,
221
+ show_copy_button=True
222
+ )
223
+
224
+ # Status indicator
225
+ status_indicator = gr.HTML("""
226
+ <div style="text-align: center; padding: 10px;">
227
+ <span style="color: #4caf50;">βœ… Ready to transcribe</span>
228
+ </div>
229
+ """)
230
 
231
+ # Language Information Section
232
+ with gr.Row():
233
+ with gr.Column():
234
+ gr.Markdown("### 🌐 Supported Languages")
235
+
236
+ for lang, config in LANGUAGE_CONFIG.items():
237
+ status_class = "unavailable" if not config["available"] else "available"
238
+ status_text = "πŸ”΄ Not Available" if not config["available"] else "🟒 Available"
239
+ model_info = config["model"] if config["available"] else config.get("message", "Not available")
240
+
241
+ gr.HTML(f"""
242
+ <div class="language-card {status_class}">
243
+ <h4>{lang} {status_text}</h4>
244
+ <p><strong>Model:</strong> {model_info}</p>
245
+ </div>
246
+ """)
247
 
248
+ # Footer
249
  with gr.Row():
250
  with gr.Column():
251
+ gr.Markdown("""
252
+ ---
253
+ ### ℹ️ About This Service
254
+
255
+ **Powered by:**
256
+ - Facebook SeamlessM4T
257
+ - Hugging Face Transformers
258
+ - Specialized African Language Models
259
+
260
+ **Supported Formats:** WAV, MP3, M4A, FLAC
261
+ **Maximum Duration:** 30 seconds per audio
262
+
263
+ *For best results, use clear audio with minimal background noise*
264
+ """)
265
+
266
+ # Event handlers
267
+ def update_status(language):
268
+ config = LANGUAGE_CONFIG[language]
269
+ if not config["available"]:
270
+ if language == "Tigrinya":
271
+ return f'<div style="text-align: center; padding: 10px; background: #ffebee; border-radius: 5px;"><span style="color: #f44336;">β›” {config["message"]}</span></div>'
272
+ return f'<div style="text-align: center; padding: 10px; background: #ffebee; border-radius: 5px;"><span style="color: #f44336;">β›” {language} transcription is currently unavailable</span></div>'
273
+ return '<div style="text-align: center; padding: 10px; background: #e8f5e8; border-radius: 5px;"><span style="color: #4caf50;">βœ… Ready to transcribe</span></div>'
274
+
275
+ # Connect events
276
+ language_select.change(
277
+ fn=update_status,
278
+ inputs=[language_select],
279
+ outputs=status_indicator
280
+ )
281
 
282
  submit_btn.click(
283
  fn=transcribe_audio,
284
  inputs=[audio_input, language_select],
285
  outputs=transcription_output
286
+ ).then(
287
+ fn=lambda: '<div style="text-align: center; padding: 10px; background: #e8f5e8; border-radius: 5px;"><span style="color: #4caf50;">βœ… Ready to transcribe</span></div>',
288
+ outputs=status_indicator
289
  )
290
 
291
  if __name__ == "__main__":
292
+ demo.launch(
293
+ server_name="0.0.0.0",
294
+ server_port=7860,
295
+ share=False,
296
+ show_error=True
297
+ )