ABAO77 commited on
Commit
54a64d4
·
1 Parent(s): 225134a

update: new model xlsr

Browse files
app.py CHANGED
@@ -1,36 +1,12 @@
1
- """
2
- English Tutor API - Main Application
3
- Optimized with Whisper model preloading for faster pronunciation assessment
4
- """
5
-
6
  from dotenv import load_dotenv
7
 
8
  load_dotenv()
9
-
10
  from src.apis.create_app import create_app, api_router
11
  import uvicorn
12
- from loguru import logger
13
 
14
- # Create FastAPI app with Whisper preloading
15
- app = create_app()
16
- app.include_router(api_router)
17
 
18
- # Add root endpoint
19
- @app.get("/")
20
- async def root():
21
- return {
22
- "message": "🎓 English Tutor API with Optimized Whisper",
23
- "status": "ready",
24
- "docs": "/docs",
25
- "health": "/health"
26
- }
27
 
 
28
  if __name__ == "__main__":
29
- logger.info("🚀 Starting English Tutor API server...")
30
- uvicorn.run(
31
- "app:app",
32
- host="0.0.0.0",
33
- port=8000,
34
- reload=False, # Set to False to avoid reloading and losing preloaded model
35
- log_level="info"
36
- )
 
 
 
 
 
 
1
  from dotenv import load_dotenv
2
 
3
  load_dotenv()
 
4
  from src.apis.create_app import create_app, api_router
5
  import uvicorn
 
6
 
 
 
 
7
 
8
+ app = create_app()
 
 
 
 
 
 
 
 
9
 
10
+ app.include_router(api_router)
11
  if __name__ == "__main__":
12
+ uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=False)
 
 
 
 
 
 
 
example_model_usage.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Example usage of Wave2Vec2Inference with dynamic model switching
4
+ """
5
+
6
+ from src.AI_Models.wave2vec_inference import (
7
+ create_inference,
8
+ get_available_models,
9
+ get_model_name,
10
+ DEFAULT_MODEL
11
+ )
12
+
13
+ def main():
14
+ print("=== Wave2Vec2 Model Selection Example ===\n")
15
+
16
+ # Show available models
17
+ print("Available models:")
18
+ models = get_available_models()
19
+ for key, model_name in models.items():
20
+ print(f" {key}: {model_name}")
21
+ print(f"\nDefault model: {DEFAULT_MODEL}\n")
22
+
23
+ # Example 1: Using default model
24
+ print("1. Creating inference with default model:")
25
+ asr_default = create_inference()
26
+ print(f" Loaded: {asr_default.model_name}\n")
27
+
28
+ # Example 2: Using model key
29
+ print("2. Creating inference with model key 'english_large':")
30
+ asr_key = create_inference("english_large")
31
+ print(f" Loaded: {asr_key.model_name}\n")
32
+
33
+ # Example 3: Using full model name
34
+ print("3. Creating inference with full model name:")
35
+ asr_full = create_inference("facebook/wav2vec2-base-960h")
36
+ print(f" Loaded: {asr_full.model_name}\n")
37
+
38
+ # Example 4: Dynamic model switching
39
+ print("4. Dynamic model switching:")
40
+ model_keys = ["english_large", "multilingual", "base_english"]
41
+
42
+ for model_key in model_keys:
43
+ print(f" Switching to: {model_key}")
44
+ asr = create_inference(model_key)
45
+ print(f" Active model: {asr.model_name}")
46
+
47
+ # Example transcription (if you have an audio file)
48
+ # result = asr.file_to_text("your_audio_file.wav")
49
+ # print(f" Result: {result}")
50
+ print()
51
+
52
+ # Example 5: Using with ONNX
53
+ print("5. Creating ONNX inference with model selection:")
54
+ try:
55
+ asr_onnx = create_inference("english_large", use_onnx=True)
56
+ print(f" ONNX model loaded: {asr_onnx.model_name}")
57
+ except Exception as e:
58
+ print(f" ONNX conversion needed: {e}")
59
+
60
+ print("\n=== Usage Examples ===")
61
+ print("# Use default model")
62
+ print("asr = create_inference()")
63
+ print()
64
+ print("# Use model key")
65
+ print("asr = create_inference('english_large')")
66
+ print()
67
+ print("# Use full model name")
68
+ print("asr = create_inference('facebook/wav2vec2-base-960h')")
69
+ print()
70
+ print("# Use with ONNX")
71
+ print("asr = create_inference('english_large', use_onnx=True)")
72
+ print()
73
+ print("# Transcribe audio")
74
+ print("result = asr.file_to_text('audio.wav')")
75
+ print("# or")
76
+ print("result = asr.buffer_to_text(audio_array)")
77
+
78
+ if __name__ == "__main__":
79
+ main()
src/AI_Models/wave2vec_inference.py CHANGED
@@ -1,15 +1,63 @@
1
  import torch
2
- from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, Wav2Vec2ForCTC
 
 
 
 
 
3
  import onnxruntime as rt
4
  import numpy as np
5
  import librosa
6
  import warnings
7
  import os
 
8
  warnings.filterwarnings("ignore")
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  class Wave2Vec2Inference:
12
- def __init__(self, model_name, use_gpu=True):
 
 
 
13
  # Auto-detect device
14
  if use_gpu:
15
  if torch.backends.mps.is_available():
@@ -20,15 +68,25 @@ class Wave2Vec2Inference:
20
  self.device = "cpu"
21
  else:
22
  self.device = "cpu"
23
-
24
  print(f"Using device: {self.device}")
 
 
 
 
25
 
26
- # Load model and processor
27
- self.processor = AutoProcessor.from_pretrained(model_name)
28
- self.model = AutoModelForCTC.from_pretrained(model_name)
 
 
 
 
 
 
29
  self.model.to(self.device)
30
  self.model.eval()
31
-
32
  # Disable gradients for inference
33
  torch.set_grad_enabled(False)
34
 
@@ -52,7 +110,11 @@ class Wave2Vec2Inference:
52
 
53
  # Move to device
54
  input_values = inputs.input_values.to(self.device)
55
- attention_mask = inputs.attention_mask.to(self.device) if "attention_mask" in inputs else None
 
 
 
 
56
 
57
  # Inference
58
  with torch.no_grad():
@@ -65,7 +127,7 @@ class Wave2Vec2Inference:
65
  predicted_ids = torch.argmax(logits, dim=-1)
66
  if self.device != "cpu":
67
  predicted_ids = predicted_ids.cpu()
68
-
69
  transcription = self.processor.batch_decode(predicted_ids)[0]
70
  return transcription.lower().strip()
71
 
@@ -79,20 +141,25 @@ class Wave2Vec2Inference:
79
 
80
 
81
  class Wave2Vec2ONNXInference:
82
- def __init__(self, model_name, onnx_path, use_gpu=True):
83
- self.processor = Wav2Vec2Processor.from_pretrained(model_name)
 
 
84
 
 
 
 
85
  # Setup ONNX Runtime
86
  options = rt.SessionOptions()
87
  options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
88
-
89
  # Choose providers based on GPU availability
90
  providers = []
91
  if use_gpu and rt.get_available_providers():
92
- if 'CUDAExecutionProvider' in rt.get_available_providers():
93
- providers.append('CUDAExecutionProvider')
94
- providers.append('CPUExecutionProvider')
95
-
96
  self.model = rt.InferenceSession(onnx_path, options, providers=providers)
97
  self.input_name = self.model.get_inputs()[0].name
98
  print(f"ONNX model loaded with providers: {self.model.get_providers()}")
@@ -118,7 +185,7 @@ class Wave2Vec2ONNXInference:
118
  # ONNX inference
119
  input_values = inputs.input_values.astype(np.float32)
120
  onnx_outputs = self.model.run(None, {self.input_name: input_values})[0]
121
-
122
  # Decode
123
  prediction = np.argmax(onnx_outputs, axis=-1)
124
  transcription = self.processor.decode(prediction.squeeze().tolist())
@@ -138,7 +205,7 @@ def convert_to_onnx(model_id_or_path, onnx_model_name):
138
  print(f"Converting {model_id_or_path} to ONNX...")
139
  model = Wav2Vec2ForCTC.from_pretrained(model_id_or_path)
140
  model.eval()
141
-
142
  # Create dummy input
143
  audio_len = 250000
144
  dummy_input = torch.randn(1, audio_len, requires_grad=True)
@@ -166,9 +233,7 @@ def quantize_onnx_model(onnx_model_path, quantized_model_path):
166
  from onnxruntime.quantization import quantize_dynamic, QuantType
167
 
168
  quantize_dynamic(
169
- onnx_model_path,
170
- quantized_model_path,
171
- weight_type=QuantType.QUInt8
172
  )
173
  print(f"Quantized model saved to: {quantized_model_path}")
174
 
@@ -176,52 +241,57 @@ def quantize_onnx_model(onnx_model_path, quantized_model_path):
176
  def export_to_onnx(model_name, quantize=False):
177
  """
178
  Export model to ONNX format with optional quantization
179
-
180
  Args:
181
  model_name: HuggingFace model name
182
  quantize: Whether to also create quantized version
183
-
184
  Returns:
185
  tuple: (onnx_path, quantized_path or None)
186
  """
187
  onnx_filename = f"{model_name.split('/')[-1]}.onnx"
188
  convert_to_onnx(model_name, onnx_filename)
189
-
190
  quantized_path = None
191
  if quantize:
192
- quantized_path = onnx_filename.replace('.onnx', '.quantized.onnx')
193
  quantize_onnx_model(onnx_filename, quantized_path)
194
-
195
  return onnx_filename, quantized_path
196
 
197
 
198
- def create_inference(model_name, use_onnx=False, onnx_path=None, use_gpu=True, use_onnx_quantize=False):
 
 
199
  """
200
  Create optimized inference instance
201
-
202
  Args:
203
- model_name: HuggingFace model name
204
  use_onnx: Whether to use ONNX runtime
205
  onnx_path: Path to ONNX model file
206
  use_gpu: Whether to use GPU if available
207
  use_onnx_quantize: Whether to use quantized ONNX model
208
-
209
  Returns:
210
  Inference instance
211
  """
 
 
 
212
  if use_onnx:
213
  if not onnx_path or not os.path.exists(onnx_path):
214
  # Convert to ONNX if path not provided or doesn't exist
215
- onnx_filename = f"{model_name.split('/')[-1]}.onnx"
216
- convert_to_onnx(model_name, onnx_filename)
217
  onnx_path = onnx_filename
218
-
219
  if use_onnx_quantize:
220
- quantized_path = onnx_path.replace('.onnx', '.quantized.onnx')
221
  if not os.path.exists(quantized_path):
222
  quantize_onnx_model(onnx_path, quantized_path)
223
  onnx_path = quantized_path
224
-
225
  print(f"Using ONNX model: {onnx_path}")
226
  return Wave2Vec2ONNXInference(model_name, onnx_path, use_gpu)
227
  else:
@@ -231,39 +301,70 @@ def create_inference(model_name, use_onnx=False, onnx_path=None, use_gpu=True, u
231
 
232
  if __name__ == "__main__":
233
  import time
234
-
235
- model_name = "facebook/wav2vec2-large-960h-lv60-self"
 
 
 
 
 
 
 
 
236
  test_file = "test.wav"
237
-
238
  if not os.path.exists(test_file):
239
  print(f"Test file {test_file} not found. Please provide a valid audio file.")
240
- exit(1)
241
-
242
- # Test different configurations
243
- configs = [
244
- {"use_onnx": False, "use_gpu": True},
245
- {"use_onnx": True, "use_gpu": True, "use_onnx_quantize": False},
246
- {"use_onnx": True, "use_gpu": True, "use_onnx_quantize": True},
247
- ]
248
-
249
- for config in configs:
250
- print(f"\n=== Testing config: {config} ===")
251
 
252
- # Create inference instance
253
- asr = create_inference(model_name, **config)
254
 
255
- # Warm up
256
- asr.file_to_text(test_file)
 
 
257
 
258
- # Test performance
259
- times = []
260
- for i in range(5):
261
- start_time = time.time()
262
- text = asr.file_to_text(test_file)
263
- end_time = time.time()
264
- execution_time = end_time - start_time
265
- times.append(execution_time)
266
- print(f"Run {i+1}: {execution_time:.3f}s - {text[:50]}...")
267
 
268
- avg_time = sum(times) / len(times)
269
- print(f"Average time: {avg_time:.3f}s")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from transformers import (
3
+ AutoModelForCTC,
4
+ AutoProcessor,
5
+ Wav2Vec2Processor,
6
+ Wav2Vec2ForCTC,
7
+ )
8
  import onnxruntime as rt
9
  import numpy as np
10
  import librosa
11
  import warnings
12
  import os
13
+
14
  warnings.filterwarnings("ignore")
15
 
16
+ # Available Wave2Vec2 models
17
+ WAVE2VEC2_MODELS = {
18
+ "english_large": "jonatasgrosman/wav2vec2-large-xlsr-53-english",
19
+ "multilingual": "facebook/wav2vec2-large-xlsr-53",
20
+ "english_960h": "facebook/wav2vec2-large-960h-lv60-self",
21
+ "base_english": "facebook/wav2vec2-base-960h",
22
+ "large_english": "facebook/wav2vec2-large-960h",
23
+ "xlsr_english": "jonatasgrosman/wav2vec2-large-xlsr-53-english",
24
+ "xlsr_multilingual": "facebook/wav2vec2-large-xlsr-53"
25
+ }
26
+
27
+ # Default model
28
+ DEFAULT_MODEL = "jonatasgrosman/wav2vec2-large-xlsr-53-english"
29
+
30
+
31
+ def get_available_models():
32
+ """Return dictionary of available Wave2Vec2 models"""
33
+ return WAVE2VEC2_MODELS.copy()
34
+
35
+
36
+ def get_model_name(model_key=None):
37
+ """
38
+ Get model name from key or return default
39
+
40
+ Args:
41
+ model_key: Key from WAVE2VEC2_MODELS or full model name
42
+
43
+ Returns:
44
+ str: Full model name
45
+ """
46
+ if model_key is None:
47
+ return DEFAULT_MODEL
48
+
49
+ if model_key in WAVE2VEC2_MODELS:
50
+ return WAVE2VEC2_MODELS[model_key]
51
+
52
+ # If it's already a full model name, return as is
53
+ return model_key
54
+
55
 
56
  class Wave2Vec2Inference:
57
+ def __init__(self, model_name=None, use_gpu=True):
58
+ # Get the actual model name using helper function
59
+ self.model_name = get_model_name(model_name)
60
+
61
  # Auto-detect device
62
  if use_gpu:
63
  if torch.backends.mps.is_available():
 
68
  self.device = "cpu"
69
  else:
70
  self.device = "cpu"
71
+
72
  print(f"Using device: {self.device}")
73
+ print(f"Loading model: {self.model_name}")
74
+
75
+ # Check if model is XLSR and use appropriate processor/model
76
+ is_xlsr = "xlsr" in self.model_name.lower()
77
 
78
+ if is_xlsr:
79
+ print("Using Wav2Vec2Processor and Wav2Vec2ForCTC for XLSR model")
80
+ self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
81
+ self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name)
82
+ else:
83
+ print("Using AutoProcessor and AutoModelForCTC")
84
+ self.processor = AutoProcessor.from_pretrained(self.model_name)
85
+ self.model = AutoModelForCTC.from_pretrained(self.model_name)
86
+
87
  self.model.to(self.device)
88
  self.model.eval()
89
+
90
  # Disable gradients for inference
91
  torch.set_grad_enabled(False)
92
 
 
110
 
111
  # Move to device
112
  input_values = inputs.input_values.to(self.device)
113
+ attention_mask = (
114
+ inputs.attention_mask.to(self.device)
115
+ if "attention_mask" in inputs
116
+ else None
117
+ )
118
 
119
  # Inference
120
  with torch.no_grad():
 
127
  predicted_ids = torch.argmax(logits, dim=-1)
128
  if self.device != "cpu":
129
  predicted_ids = predicted_ids.cpu()
130
+
131
  transcription = self.processor.batch_decode(predicted_ids)[0]
132
  return transcription.lower().strip()
133
 
 
141
 
142
 
143
  class Wave2Vec2ONNXInference:
144
+ def __init__(self, model_name=None, onnx_path=None, use_gpu=True):
145
+ # Get the actual model name using helper function
146
+ self.model_name = get_model_name(model_name)
147
+ print(f"Loading ONNX model: {self.model_name}")
148
 
149
+ # Always use Wav2Vec2Processor for ONNX (works for all models)
150
+ self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
151
+
152
  # Setup ONNX Runtime
153
  options = rt.SessionOptions()
154
  options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
155
+
156
  # Choose providers based on GPU availability
157
  providers = []
158
  if use_gpu and rt.get_available_providers():
159
+ if "CUDAExecutionProvider" in rt.get_available_providers():
160
+ providers.append("CUDAExecutionProvider")
161
+ providers.append("CPUExecutionProvider")
162
+
163
  self.model = rt.InferenceSession(onnx_path, options, providers=providers)
164
  self.input_name = self.model.get_inputs()[0].name
165
  print(f"ONNX model loaded with providers: {self.model.get_providers()}")
 
185
  # ONNX inference
186
  input_values = inputs.input_values.astype(np.float32)
187
  onnx_outputs = self.model.run(None, {self.input_name: input_values})[0]
188
+
189
  # Decode
190
  prediction = np.argmax(onnx_outputs, axis=-1)
191
  transcription = self.processor.decode(prediction.squeeze().tolist())
 
205
  print(f"Converting {model_id_or_path} to ONNX...")
206
  model = Wav2Vec2ForCTC.from_pretrained(model_id_or_path)
207
  model.eval()
208
+
209
  # Create dummy input
210
  audio_len = 250000
211
  dummy_input = torch.randn(1, audio_len, requires_grad=True)
 
233
  from onnxruntime.quantization import quantize_dynamic, QuantType
234
 
235
  quantize_dynamic(
236
+ onnx_model_path, quantized_model_path, weight_type=QuantType.QUInt8
 
 
237
  )
238
  print(f"Quantized model saved to: {quantized_model_path}")
239
 
 
241
  def export_to_onnx(model_name, quantize=False):
242
  """
243
  Export model to ONNX format with optional quantization
244
+
245
  Args:
246
  model_name: HuggingFace model name
247
  quantize: Whether to also create quantized version
248
+
249
  Returns:
250
  tuple: (onnx_path, quantized_path or None)
251
  """
252
  onnx_filename = f"{model_name.split('/')[-1]}.onnx"
253
  convert_to_onnx(model_name, onnx_filename)
254
+
255
  quantized_path = None
256
  if quantize:
257
+ quantized_path = onnx_filename.replace(".onnx", ".quantized.onnx")
258
  quantize_onnx_model(onnx_filename, quantized_path)
259
+
260
  return onnx_filename, quantized_path
261
 
262
 
263
+ def create_inference(
264
+ model_name=None, use_onnx=False, onnx_path=None, use_gpu=True, use_onnx_quantize=False
265
+ ):
266
  """
267
  Create optimized inference instance
268
+
269
  Args:
270
+ model_name: Model key from WAVE2VEC2_MODELS or full HuggingFace model name (default: uses DEFAULT_MODEL)
271
  use_onnx: Whether to use ONNX runtime
272
  onnx_path: Path to ONNX model file
273
  use_gpu: Whether to use GPU if available
274
  use_onnx_quantize: Whether to use quantized ONNX model
275
+
276
  Returns:
277
  Inference instance
278
  """
279
+ # Get the actual model name
280
+ actual_model_name = get_model_name(model_name)
281
+
282
  if use_onnx:
283
  if not onnx_path or not os.path.exists(onnx_path):
284
  # Convert to ONNX if path not provided or doesn't exist
285
+ onnx_filename = f"{actual_model_name.split('/')[-1]}.onnx"
286
+ convert_to_onnx(actual_model_name, onnx_filename)
287
  onnx_path = onnx_filename
288
+
289
  if use_onnx_quantize:
290
+ quantized_path = onnx_path.replace(".onnx", ".quantized.onnx")
291
  if not os.path.exists(quantized_path):
292
  quantize_onnx_model(onnx_path, quantized_path)
293
  onnx_path = quantized_path
294
+
295
  print(f"Using ONNX model: {onnx_path}")
296
  return Wave2Vec2ONNXInference(model_name, onnx_path, use_gpu)
297
  else:
 
301
 
302
  if __name__ == "__main__":
303
  import time
304
+
305
+ # Display available models
306
+ print("Available Wave2Vec2 models:")
307
+ for key, model_name in get_available_models().items():
308
+ print(f" {key}: {model_name}")
309
+ print(f"\nDefault model: {DEFAULT_MODEL}")
310
+ print()
311
+
312
+ # Test with different models
313
+ test_models = ["english_large", "multilingual", "english_960h"]
314
  test_file = "test.wav"
315
+
316
  if not os.path.exists(test_file):
317
  print(f"Test file {test_file} not found. Please provide a valid audio file.")
318
+ print("Creating example usage without actual file...")
 
 
 
 
 
 
 
 
 
 
319
 
320
+ # Example usage without file
321
+ print("\n=== Example Usage ===")
322
 
323
+ # Using default model
324
+ print("1. Using default model:")
325
+ asr_default = create_inference()
326
+ print(f" Model loaded: {asr_default.model_name}")
327
 
328
+ # Using model key
329
+ print("\n2. Using model key 'english_large':")
330
+ asr_key = create_inference("english_large")
331
+ print(f" Model loaded: {asr_key.model_name}")
 
 
 
 
 
332
 
333
+ # Using full model name
334
+ print("\n3. Using full model name:")
335
+ asr_full = create_inference("facebook/wav2vec2-base-960h")
336
+ print(f" Model loaded: {asr_full.model_name}")
337
+
338
+ exit(0)
339
+
340
+ # Test different model configurations
341
+ for model_key in test_models:
342
+ print(f"\n=== Testing model: {model_key} ===")
343
+
344
+ # Test different configurations
345
+ configs = [
346
+ {"use_onnx": False, "use_gpu": True},
347
+ {"use_onnx": True, "use_gpu": True, "use_onnx_quantize": False},
348
+ ]
349
+
350
+ for config in configs:
351
+ print(f"\nConfig: {config}")
352
+
353
+ # Create inference instance with model selection
354
+ asr = create_inference(model_key, **config)
355
+
356
+ # Warm up
357
+ asr.file_to_text(test_file)
358
+
359
+ # Test performance
360
+ times = []
361
+ for i in range(3):
362
+ start_time = time.time()
363
+ text = asr.file_to_text(test_file)
364
+ end_time = time.time()
365
+ execution_time = end_time - start_time
366
+ times.append(execution_time)
367
+ print(f"Run {i+1}: {execution_time:.3f}s - {text[:50]}...")
368
+
369
+ avg_time = sum(times) / len(times)
370
+ print(f"Average time: {avg_time:.3f}s")
src/apis/__pycache__/create_app.cpython-311.pyc CHANGED
Binary files a/src/apis/__pycache__/create_app.cpython-311.pyc and b/src/apis/__pycache__/create_app.cpython-311.pyc differ
 
src/apis/controllers/speaking_controller.py CHANGED
@@ -13,8 +13,10 @@ from loguru import logger
13
  import Levenshtein
14
  from dataclasses import dataclass
15
  from enum import Enum
16
- import whisper
17
- import os
 
 
18
 
19
  # Download required NLTK data
20
  try:
@@ -23,168 +25,6 @@ try:
23
  except:
24
  print("Warning: NLTK data not available")
25
 
26
- # Pre-computed phoneme mappings for instant lookup (Top 1000 English words)
27
- COMMON_WORD_PHONEMES = {
28
- "the": ["ð", "ə"],
29
- "be": ["b", "i"],
30
- "to": ["t", "u"],
31
- "of": ["ʌ", "v"],
32
- "and": ["æ", "n", "d"],
33
- "a": ["ə"],
34
- "in": ["ɪ", "n"],
35
- "that": ["ð", "æ", "t"],
36
- "have": ["h", "æ", "v"],
37
- "i": ["aɪ"],
38
- "it": ["ɪ", "t"],
39
- "for": ["f", "ɔr"],
40
- "not": ["n", "ɑ", "t"],
41
- "on": ["ɑ", "n"],
42
- "with": ["w", "ɪ", "θ"],
43
- "he": ["h", "i"],
44
- "as": ["æ", "z"],
45
- "you": ["j", "u"],
46
- "do": ["d", "u"],
47
- "at": ["æ", "t"],
48
- "this": ["ð", "ɪ", "s"],
49
- "but": ["b", "ʌ", "t"],
50
- "his": ["h", "ɪ", "z"],
51
- "by": ["b", "aɪ"],
52
- "from": ["f", "r", "ʌ", "m"],
53
- "they": ["ð", "eɪ"],
54
- "we": ["w", "i"],
55
- "say": ["s", "eɪ"],
56
- "her": ["h", "ɝ"],
57
- "she": ["ʃ", "i"],
58
- "or": ["ɔr"],
59
- "an": ["æ", "n"],
60
- "will": ["w", "ɪ", "l"],
61
- "my": ["m", "aɪ"],
62
- "one": ["w", "ʌ", "n"],
63
- "all": ["ɔ", "l"],
64
- "would": ["w", "ʊ", "d"],
65
- "there": ["ð", "ɛr"],
66
- "their": ["ð", "ɛr"],
67
- "what": ["w", "ʌ", "t"],
68
- "so": ["s", "oʊ"],
69
- "up": ["ʌ", "p"],
70
- "out": ["aʊ", "t"],
71
- "if": ["ɪ", "f"],
72
- "about": ["ə", "b", "aʊ", "t"],
73
- "who": ["h", "u"],
74
- "get": ["ɡ", "ɛ", "t"],
75
- "which": ["w", "ɪ", "tʃ"],
76
- "go": ["ɡ", "oʊ"],
77
- "me": ["m", "i"],
78
- "when": ["w", "ɛ", "n"],
79
- "make": ["m", "eɪ", "k"],
80
- "can": ["k", "æ", "n"],
81
- "like": ["l", "aɪ", "k"],
82
- "time": ["t", "aɪ", "m"],
83
- "no": ["n", "oʊ"],
84
- "just": ["dʒ", "ʌ", "s", "t"],
85
- "him": ["h", "ɪ", "m"],
86
- "know": ["n", "oʊ"],
87
- "take": ["t", "eɪ", "k"],
88
- "people": ["p", "i", "p", "ə", "l"],
89
- "into": ["ɪ", "n", "t", "u"],
90
- "year": ["j", "ɪr"],
91
- "your": ["j", "ʊr"],
92
- "good": ["ɡ", "ʊ", "d"],
93
- "some": ["s", "ʌ", "m"],
94
- "could": ["k", "ʊ", "d"],
95
- "them": ["ð", "ɛ", "m"],
96
- "see": ["s", "i"],
97
- "other": ["ʌ", "ð", "ər"],
98
- "than": ["ð", "æ", "n"],
99
- "then": ["ð", "ɛ", "n"],
100
- "now": ["n", "aʊ"],
101
- "look": ["l", "ʊ", "k"],
102
- "only": ["oʊ", "n", "l", "i"],
103
- "come": ["k", "ʌ", "m"],
104
- "its": ["ɪ", "t", "s"],
105
- "over": ["oʊ", "v", "ər"],
106
- "think": ["θ", "ɪ", "ŋ", "k"],
107
- "also": ["ɔ", "l", "s", "oʊ"],
108
- "your": ["j", "ʊr"],
109
- "work": ["w", "ɝ", "k"],
110
- "life": ["l", "aɪ", "f"],
111
- "only": ["oʊ", "n", "l", "i"],
112
- "new": ["n", "u"],
113
- "way": ["w", "eɪ"],
114
- "may": ["m", "eɪ"],
115
- "say": ["s", "eɪ"],
116
- "first": ["f", "ɝ", "s", "t"],
117
- "well": ["w", "ɛ", "l"],
118
- "great": ["ɡ", "r", "eɪ", "t"],
119
- "little": ["l", "ɪ", "t", "ə", "l"],
120
- "own": ["oʊ", "n"],
121
- "old": ["oʊ", "l", "d"],
122
- "right": ["r", "aɪ", "t"],
123
- "big": ["b", "ɪ", "ɡ"],
124
- "high": ["h", "aɪ"],
125
- "different": ["d", "ɪ", "f", "ər", "ə", "n", "t"],
126
- "small": ["s", "m", "ɔ", "l"],
127
- "large": ["l", "ɑr", "dʒ"],
128
- "next": ["n", "ɛ", "k", "s", "t"],
129
- "early": ["ɝ", "l", "i"],
130
- "young": ["j", "ʌ", "ŋ"],
131
- "important": ["ɪ", "m", "p", "ɔr", "t", "ə", "n", "t"],
132
- "few": ["f", "j", "u"],
133
- "public": ["p", "ʌ", "b", "l", "ɪ", "k"],
134
- "bad": ["b", "æ", "d"],
135
- "same": ["s", "eɪ", "m"],
136
- "able": ["eɪ", "b", "ə", "l"],
137
- "hello": ["h", "ə", "l", "oʊ"],
138
- "world": ["w", "ɝ", "l", "d"],
139
- "how": ["h", "aʊ"],
140
- "are": ["ɑr"],
141
- "today": ["t", "ə", "d", "eɪ"],
142
- "pronunciation": ["p", "r", "ə", "n", "ʌ", "n", "s", "i", "eɪ", "ʃ", "ə", "n"]
143
- }
144
-
145
- class LazyImports:
146
- """Lazy load heavy dependencies only when needed"""
147
-
148
- @property
149
- def psutil(self):
150
- if not hasattr(self, '_psutil'):
151
- try:
152
- import psutil
153
- self._psutil = psutil
154
- except ImportError:
155
- # Create a mock psutil if not available
156
- class MockPsutil:
157
- def cpu_count(self): return 4
158
- def cpu_percent(self, interval=0.1): return 50
159
- self._psutil = MockPsutil()
160
- return self._psutil
161
-
162
- @property
163
- def librosa(self):
164
- if not hasattr(self, '_librosa'):
165
- import librosa
166
- self._librosa = librosa
167
- return self._librosa
168
-
169
- class ObjectPool:
170
- """Object pool to avoid creating/destroying objects continuously"""
171
- def __init__(self):
172
- self.g2p_pool = []
173
- self.comparator_pool = []
174
-
175
- def get_g2p(self):
176
- if self.g2p_pool:
177
- return self.g2p_pool.pop()
178
- return None # Will create new if needed
179
-
180
- def return_g2p(self, obj):
181
- if len(self.g2p_pool) < 5: # Limit pool size
182
- self.g2p_pool.append(obj)
183
-
184
- # Global instances for optimization
185
- lazy_imports = LazyImports()
186
- object_pool = ObjectPool()
187
-
188
 
189
  class AssessmentMode(Enum):
190
  WORD = "word"
@@ -213,119 +53,56 @@ class CharacterError:
213
  color: str
214
 
215
 
216
- class EnhancedWhisperASR:
217
- """Enhanced Whisper ASR with prosody analysis support"""
218
 
219
- def __init__(self, whisper_model: str = "base.en"):
 
 
 
 
 
 
 
220
  self.sample_rate = 16000
221
- self.whisper_model_name = whisper_model
222
-
223
- # Load Whisper model
224
- logger.info(f"Loading Whisper model: {whisper_model}")
225
- self.whisper_model = whisper.load_model(whisper_model, in_memory=True)
226
- logger.info("Whisper model loaded successfully")
227
-
228
- # Initialize G2P once and reuse (optimization fix)
229
- self.g2p = EnhancedG2P()
230
- logger.info("G2P converter initialized and ready for reuse")
231
 
232
- def _characters_to_phoneme_representation(self, text: str) -> str:
233
- """Convert character-based transcript to phoneme representation - Optimized reuse"""
234
- if not text:
235
- return ""
236
-
237
- # Reuse the initialized G2P converter instead of creating new instances
238
- return self.g2p.get_phoneme_string(text)
239
-
240
- @lru_cache(maxsize=100)
241
- def _cached_audio_features(self, audio_path: str, file_mtime: float) -> Dict:
242
- """Cache audio features based on file modification time"""
243
- return self._extract_basic_audio_features_uncached(audio_path)
244
 
245
- def _extract_basic_audio_features(self, audio_path: str) -> Dict:
246
- """Extract audio features with caching optimization"""
247
- import os
248
- try:
249
- file_mtime = os.path.getmtime(audio_path)
250
- return self._cached_audio_features(audio_path, file_mtime)
251
- except:
252
- # Fallback to uncached version
253
- return self._extract_basic_audio_features_uncached(audio_path)
254
 
255
- def _extract_basic_audio_features_uncached(self, audio_path: str) -> Dict:
256
- """Ultra-fast basic features using minimal librosa"""
257
- try:
258
- # Load with aggressive downsampling
259
- y, sr = lazy_imports.librosa.load(audio_path, sr=8000) # Very low sample rate
260
- duration = len(y) / sr
261
-
262
- if duration < 0.1:
263
- return {"duration": duration, "error": "Audio too short"}
264
-
265
- # Simple energy-based features
266
- energy = y ** 2
267
-
268
- # Basic "pitch" using zero-crossing rate as proxy
269
- zcr = lazy_imports.librosa.feature.zero_crossing_rate(y, frame_length=1024,
270
- hop_length=512)[0]
271
- pseudo_pitch = sr / (2 * np.mean(zcr)) if np.mean(zcr) > 0 else 0
272
-
273
- # Simple rhythm from energy peaks
274
- frame_length = int(0.1 * sr) # 100ms frames
275
- energy_frames = [np.mean(energy[i:i+frame_length])
276
- for i in range(0, len(energy)-frame_length, frame_length)]
277
-
278
- # Count energy peaks as beats
279
- if len(energy_frames) > 2:
280
- threshold = np.mean(energy_frames) + 0.5 * np.std(energy_frames)
281
- beats = sum(1 for e in energy_frames if e > threshold)
282
- tempo = (beats / duration) * 60 if duration > 0 else 120
283
- else:
284
- tempo = 120
285
- beats = 2
286
-
287
- # RMS from energy
288
- rms = np.sqrt(np.mean(energy))
289
-
290
- return {
291
- "duration": duration,
292
- "pseudo_pitch": pseudo_pitch,
293
- "tempo": tempo,
294
- "rms": rms,
295
- "beats": beats,
296
- "frame_count": len(energy_frames),
297
- }
298
-
299
- except Exception as e:
300
- logger.warning(f"Audio feature extraction failed: {e}")
301
- return {"duration": 0, "error": str(e)}
302
 
303
- # Rest of the methods remain unchanged...
304
  def transcribe_with_features(self, audio_path: str) -> Dict:
305
- """Enhanced transcription with audio features for prosody analysis - Whisper only"""
306
  try:
307
  start_time = time.time()
308
 
309
- # Use Whisper for transcription
310
- logger.info("Using Whisper for transcription")
311
- result = self.whisper_model.transcribe(audio_path)
312
- character_transcript = result["text"]
313
- logger.info(f"transcript time: {time.time() - start_time:.2f}s")
314
-
315
- clean_character_time = time.time()
316
- character_transcript = self._clean_character_transcript(character_transcript)
317
- logger.info(f"clean_character_time: {time.time() - clean_character_time:.2f}s")
318
 
319
- phone_transform_time = time.time()
320
- phoneme_representation = self._characters_to_phoneme_representation(character_transcript)
321
- logger.info(f"phone_transform_time: {time.time() - phone_transform_time:.2f}s")
 
322
 
323
  # Basic audio features (simplified for speed)
324
- time_feature_start = time.time()
325
  audio_features = self._extract_basic_audio_features(audio_path)
326
- logger.info(f"time_feature_extraction: {time.time() - time_feature_start:.2f}s")
327
 
328
- logger.info(f"Optimized transcription time: {time.time() - start_time:.2f}s")
 
 
329
 
330
  return {
331
  "character_transcript": character_transcript,
@@ -338,82 +115,114 @@ class EnhancedWhisperASR:
338
  logger.error(f"Enhanced ASR error: {e}")
339
  return self._empty_result()
340
 
341
- # All other methods remain exactly the same...
342
- def _extract_basic_audio_features_uncached(self, audio_path: str) -> Dict:
343
- """Ultra-fast basic features using minimal librosa"""
344
  try:
345
- # Load with aggressive downsampling
346
- y, sr = librosa.load(audio_path, sr=8000) # Very low sample rate
347
  duration = len(y) / sr
348
-
349
- if duration < 0.1:
350
- return {"duration": duration, "error": "Audio too short"}
351
-
352
- # Simple energy-based features
353
- energy = y ** 2
354
-
355
- # Basic "pitch" using zero-crossing rate as proxy
356
- zcr = librosa.feature.zero_crossing_rate(y, frame_length=1024,
357
- hop_length=512)[0]
358
- pseudo_pitch = sr / (2 * np.mean(zcr)) if np.mean(zcr) > 0 else 0
359
-
360
- # Simple rhythm from energy peaks
361
- frame_length = int(0.1 * sr) # 100ms frames
362
- energy_frames = [np.mean(energy[i:i+frame_length])
363
- for i in range(0, len(energy)-frame_length, frame_length)]
364
-
365
- # Count energy peaks as beats
366
- if len(energy_frames) > 2:
367
- threshold = np.mean(energy_frames) + 0.5 * np.std(energy_frames)
368
- beats = sum(1 for e in energy_frames if e > threshold)
369
- tempo = (beats / duration) * 60 if duration > 0 else 120
370
- else:
371
- tempo = 120
372
- beats = 2
373
-
374
- # RMS from energy
375
- rms_mean = np.sqrt(np.mean(energy))
376
- rms_std = np.sqrt(np.std(energy))
377
-
378
  return {
379
  "duration": duration,
380
  "pitch": {
381
- "values": [pseudo_pitch] if pseudo_pitch > 0 else [],
382
- "mean": pseudo_pitch,
383
- "std": 0,
384
- "range": 0,
385
- "cv": 0,
 
 
 
 
 
 
 
 
386
  },
387
  "rhythm": {
388
  "tempo": tempo,
389
- "beats_per_second": beats / duration if duration > 0 else 0,
390
  },
391
  "intensity": {
392
- "rms_mean": rms_mean,
393
- "rms_std": rms_std,
394
- }
395
  }
396
-
397
  except Exception as e:
398
- logger.error(f"Ultra-fast audio feature extraction error: {e}")
399
  return {"duration": 0, "error": str(e)}
400
 
401
  def _clean_character_transcript(self, transcript: str) -> str:
402
- """Clean and standardize character transcript - Remove punctuation for better scoring"""
403
  logger.info(f"Raw transcript before cleaning: {transcript}")
404
- # Remove punctuation marks that can affect scoring
405
- cleaned = re.sub(r'[.,!?;:"()[\]{}]', '', transcript)
406
- # Normalize whitespace
407
- cleaned = re.sub(r"\s+", " ", cleaned)
408
  return cleaned.strip().lower()
409
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  def _simple_letter_to_phoneme(self, word: str) -> List[str]:
411
  """Fallback letter-to-phoneme conversion"""
412
  letter_to_phoneme = {
413
- "a": "æ", "b": "b", "c": "k", "d": "d", "e": "ɛ", "f": "f", "g": "ɡ",
414
- "h": "h", "i": "ɪ", "j": "dʒ", "k": "k", "l": "l", "m": "m", "n": "n",
415
- "o": "ʌ", "p": "p", "q": "k", "r": "r", "s": "s", "t": "t", "u": "ʌ",
416
- "v": "v", "w": "w", "x": "ks", "y": "j", "z": "z",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
  }
418
 
419
  return [
@@ -439,8 +248,9 @@ class EnhancedWhisperASR:
439
  "confidence": 0.0,
440
  }
441
 
 
442
  class EnhancedG2P:
443
- """Enhanced Grapheme-to-Phoneme converter with visualization support - Hybrid Optimized"""
444
 
445
  def __init__(self):
446
  try:
@@ -449,240 +259,70 @@ class EnhancedG2P:
449
  self.cmu_dict = {}
450
  logger.warning("CMU dictionary not available")
451
 
452
- # Pre-build CMU to IPA mapping for faster access
453
- self.cmu_to_ipa_map = {
454
- "AA": "ɑ", "AE": "æ", "AH": "ʌ", "AO": "ɔ", "AW": "aʊ", "AY": "aɪ",
455
- "EH": "ɛ", "ER": "ɝ", "EY": "eɪ", "IH": "ɪ", "IY": "i", "OW": "oʊ",
456
- "OY": "ɔɪ", "UH": "ʊ", "UW": "u", "B": "b", "CH": "tʃ", "D": "d",
457
- "DH": "ð", "F": "f", "G": "ɡ", "HH": "h", "JH": "dʒ", "K": "k",
458
- "L": "l", "M": "m", "N": "n", "NG": "ŋ", "P": "p", "R": "r",
459
- "S": "s", "SH": "ʃ", "T": "t", "TH": "θ", "V": "v", "W": "w",
460
- "Y": "j", "Z": "z", "ZH": "ʒ",
461
- }
462
-
463
- # Fast pattern mapping for common combinations
464
- self.fast_patterns = {
465
- 'th': 'θ', 'sh': 'ʃ', 'ch': 'tʃ', 'ng': 'ŋ', 'ck': 'k',
466
- 'ph': 'f', 'qu': 'kw', 'tion': 'ʃən', 'ing': 'ɪŋ', 'ed': 'd',
467
- 'er': 'ɝ', 'ar': 'ɑr', 'or': 'ɔr', 'oo': 'u', 'ee': 'i',
468
- 'oa': 'oʊ', 'ai': 'eɪ', 'ay': 'eɪ', 'ow': 'aʊ', 'oy': 'ɔɪ'
469
- }
470
-
471
- # Fast character mapping
472
- self.char_to_phoneme_map = {
473
- 'a': 'æ', 'e': 'ɛ', 'i': 'ɪ', 'o': 'ʌ', 'u': 'ʌ',
474
- 'b': 'b', 'c': 'k', 'd': 'd', 'f': 'f', 'g': 'ɡ',
475
- 'h': 'h', 'j': 'dʒ', 'k': 'k', 'l': 'l', 'm': 'm',
476
- 'n': 'n', 'p': 'p', 'r': 'r', 's': 's', 't': 't',
477
- 'v': 'v', 'w': 'w', 'x': 'ks', 'y': 'j', 'z': 'z'
478
- }
479
-
480
- # Vietnamese speaker substitution patterns (unchanged)
481
  self.vn_substitutions = {
482
- "θ": ["f", "s", "t", "d"], "ð": ["d", "z", "v", "t"],
483
- "v": ["w", "f", "b"], "w": ["v", "b"], "r": ["l", "n"],
484
- "l": ["r", "n"], "z": ["s", "j"], "ʒ": ["ʃ", "z", "s"],
485
- "ʃ": ["s", "ʒ"], "ŋ": ["n", "m"], "tʃ": ["ʃ", "s", "k"],
486
- "": ["ʒ", "j", "g"], "æ": ["ɛ", "a"], "ɪ": ["i"], "ʊ": ["u"],
 
 
 
 
 
 
 
 
 
 
487
  }
488
 
489
- # Difficulty scores (unchanged)
490
  self.difficulty_scores = {
491
- "θ": 0.9, "ð": 0.9, "v": 0.8, "z": 0.8, "ʒ": 0.9, "r": 0.7,
492
- "l": 0.6, "w": 0.5, "æ": 0.7, "ɪ": 0.6, "ʊ": 0.6, "ŋ": 0.3,
493
- "f": 0.2, "s": 0.2, "ʃ": 0.5, "tʃ": 0.4, "dʒ": 0.5,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
  }
495
 
496
- @lru_cache(maxsize=5000) # Increased from 1000 for common words
497
  def word_to_phonemes(self, word: str) -> List[str]:
498
- """Convert word to phoneme list - Optimized with hybrid approach"""
499
  word_lower = word.lower().strip()
500
 
501
- # Check pre-computed dictionary first (instant lookup)
502
- if word_lower in COMMON_WORD_PHONEMES:
503
- return COMMON_WORD_PHONEMES[word_lower]
504
-
505
  if word_lower in self.cmu_dict:
506
  cmu_phonemes = self.cmu_dict[word_lower][0]
507
- return self._convert_cmu_to_ipa_fast(cmu_phonemes)
508
  else:
509
- return self._fast_estimate_phonemes(word_lower)
510
 
511
- @lru_cache(maxsize=1000) # Decreased from 2000 for text-level operations
512
  def get_phoneme_string(self, text: str) -> str:
513
- """Get space-separated phoneme string - Hybrid optimized"""
514
- return self._characters_to_phoneme_representation_optimized(text)
515
-
516
- def _characters_to_phoneme_representation_optimized(self, text: str) -> str:
517
- """Optimized phoneme conversion - Smart threading strategy"""
518
- if not text:
519
- return ""
520
-
521
  words = self._clean_text(text).split()
522
- if not words:
523
- return ""
524
-
525
- # Smart threading strategy - avoid overhead for small texts
526
- return self._smart_parallel_processing(words)
527
 
528
- def _smart_parallel_processing(self, words: List[str]) -> str:
529
- """Intelligent parallel processing based on system resources and text length"""
530
- try:
531
- # Only use parallel processing if:
532
- # 1. Text is long enough (>10 words, increased threshold)
533
- # 2. System has enough resources
534
- try:
535
- cpu_count = lazy_imports.psutil.cpu_count()
536
- cpu_usage = lazy_imports.psutil.cpu_percent(interval=0.1)
537
- except:
538
- # Fallback if psutil not available
539
- cpu_count = 4
540
- cpu_usage = 50
541
-
542
- if (len(words) > 10 and # Increased threshold from 5
543
- cpu_count >= 4 and
544
- cpu_usage < 70):
545
- return self._parallel_phoneme_processing(words)
546
- else:
547
- return self._batch_cmu_lookup(words)
548
- except:
549
- # Fallback to batch processing if anything fails
550
- if len(words) > 10:
551
- return self._parallel_phoneme_processing(words)
552
- else:
553
- return self._batch_cmu_lookup(words)
554
-
555
- def _fast_short_text_phonemes(self, words: List[str]) -> str:
556
- """Ultra-fast processing for 1-2 words"""
557
- phonemes = []
558
  for word in words:
559
- word_lower = word.lower()
560
- if word_lower in self.cmu_dict:
561
- # Direct CMU conversion
562
- cmu_phonemes = self.cmu_dict[word_lower][0]
563
- for phone in cmu_phonemes:
564
- clean_phone = re.sub(r"[0-9]", "", phone)
565
- ipa_phone = self.cmu_to_ipa_map.get(clean_phone, clean_phone.lower())
566
- phonemes.append(ipa_phone)
567
- else:
568
- phonemes.extend(self._ultra_fast_estimate(word_lower))
569
-
570
- return " ".join(phonemes)
571
 
572
- def _batch_cmu_lookup(self, words: List[str]) -> str:
573
- """Batch CMU dictionary lookup with pre-computed optimization - 5x faster"""
574
- phonemes = []
575
-
576
- for word in words:
577
- word_lower = word.lower()
578
-
579
- # Check pre-computed dictionary first (instant lookup)
580
- if word_lower in COMMON_WORD_PHONEMES:
581
- phonemes.extend(COMMON_WORD_PHONEMES[word_lower])
582
- elif word_lower in self.cmu_dict:
583
- # Direct conversion without method overhead
584
- cmu_phones = self.cmu_dict[word_lower][0]
585
- for phone in cmu_phones:
586
- clean_phone = re.sub(r"[0-9]", "", phone)
587
- ipa_phone = self.cmu_to_ipa_map.get(clean_phone, clean_phone.lower())
588
- phonemes.append(ipa_phone)
589
- else:
590
- # Fast fallback
591
- phonemes.extend(self._ultra_fast_estimate(word_lower))
592
-
593
- return " ".join(phonemes)
594
-
595
- def _parallel_phoneme_processing(self, words: List[str]) -> str:
596
- """Parallel processing for longer texts - Optimized with larger chunks"""
597
- # Use 3 chunks instead of 2 for better load balancing
598
- chunk_size = max(5, len(words) // 3) # Minimum 5 words per chunk
599
- chunks = [words[i:i + chunk_size] for i in range(0, len(words), chunk_size)]
600
-
601
- # Process chunks in parallel using thread pool
602
- import concurrent.futures
603
- with concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(chunks))) as executor:
604
- futures = [executor.submit(self._process_word_chunk, chunk) for chunk in chunks]
605
-
606
- all_phonemes = []
607
- for future in concurrent.futures.as_completed(futures):
608
- all_phonemes.extend(future.result())
609
-
610
  return " ".join(all_phonemes)
611
 
612
- def _process_word_chunk(self, words: List[str]) -> List[str]:
613
- """Process a chunk of words with pre-computed dictionary optimization"""
614
- phonemes = []
615
- for word in words:
616
- word_lower = word.lower()
617
-
618
- # Check pre-computed dictionary first (instant lookup)
619
- if word_lower in COMMON_WORD_PHONEMES:
620
- phonemes.extend(COMMON_WORD_PHONEMES[word_lower])
621
- elif word_lower in self.cmu_dict:
622
- cmu_phones = self.cmu_dict[word_lower][0]
623
- for phone in cmu_phones:
624
- clean_phone = re.sub(r"[0-9]", "", phone)
625
- ipa_phone = self.cmu_to_ipa_map.get(clean_phone, clean_phone.lower())
626
- phonemes.append(ipa_phone)
627
- else:
628
- phonemes.extend(self._ultra_fast_estimate(word_lower))
629
- return phonemes
630
-
631
- def _ultra_fast_estimate(self, word: str) -> List[str]:
632
- """Ultra-fast phoneme estimation using pattern matching"""
633
- if not word:
634
- return []
635
-
636
- phonemes = []
637
- i = 0
638
-
639
- while i < len(word):
640
- # Check for 4-char patterns first
641
- if i <= len(word) - 4:
642
- four_char = word[i:i+4]
643
- if four_char in self.fast_patterns:
644
- phonemes.append(self.fast_patterns[four_char])
645
- i += 4
646
- continue
647
-
648
- # Check for 3-char patterns
649
- if i <= len(word) - 3:
650
- three_char = word[i:i+3]
651
- if three_char in self.fast_patterns:
652
- phonemes.append(self.fast_patterns[three_char])
653
- i += 3
654
- continue
655
-
656
- # Check for 2-char patterns
657
- if i <= len(word) - 2:
658
- two_char = word[i:i+2]
659
- if two_char in self.fast_patterns:
660
- phonemes.append(self.fast_patterns[two_char])
661
- i += 2
662
- continue
663
-
664
- # Single character mapping
665
- char = word[i]
666
- if char in self.char_to_phoneme_map:
667
- phonemes.append(self.char_to_phoneme_map[char])
668
- i += 1
669
-
670
- return phonemes
671
-
672
- def _convert_cmu_to_ipa_fast(self, cmu_phonemes: List[str]) -> List[str]:
673
- """Fast CMU to IPA conversion using pre-built mapping"""
674
- ipa_phonemes = []
675
- for phoneme in cmu_phonemes:
676
- clean_phoneme = re.sub(r"[0-9]", "", phoneme)
677
- ipa_phoneme = self.cmu_to_ipa_map.get(clean_phoneme, clean_phoneme.lower())
678
- ipa_phonemes.append(ipa_phoneme)
679
- return ipa_phonemes
680
-
681
- def _fast_estimate_phonemes(self, word: str) -> List[str]:
682
- """Optimized phoneme estimation - kept for backward compatibility"""
683
- return self._ultra_fast_estimate(word)
684
-
685
- # Rest of the methods remain unchanged for backward compatibility
686
  def text_to_phonemes(self, text: str) -> List[Dict]:
687
  """Convert text to phoneme sequence with visualization data"""
688
  words = self._clean_text(text).split()
@@ -703,12 +343,110 @@ class EnhancedG2P:
703
  return phoneme_sequence
704
 
705
  def _convert_cmu_to_ipa(self, cmu_phonemes: List[str]) -> List[str]:
706
- """Original method - kept for backward compatibility"""
707
- return self._convert_cmu_to_ipa_fast(cmu_phonemes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
708
 
709
  def _estimate_phonemes(self, word: str) -> List[str]:
710
- """Original method - kept for backward compatibility"""
711
- return self._ultra_fast_estimate(word)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
 
713
  def _clean_text(self, text: str) -> str:
714
  """Clean text for processing"""
@@ -741,7 +479,21 @@ class EnhancedG2P:
741
  def _get_phoneme_color_category(self, phoneme: str) -> str:
742
  """Categorize phonemes by color for visualization"""
743
  vowel_phonemes = {
744
- "ɑ", "æ", "ʌ", "ɔ", "aʊ", "aɪ", "ɛ", "ɝ", "eɪ", "ɪ", "i", "oʊ", "ɔɪ", "ʊ", "u",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
745
  }
746
  difficult_consonants = {"θ", "ð", "v", "z", "ʒ", "r", "w"}
747
 
@@ -778,7 +530,6 @@ class EnhancedG2P:
778
  return self.difficulty_scores.get(phoneme, 0.3)
779
 
780
 
781
-
782
  class AdvancedPhonemeComparator:
783
  """Enhanced phoneme comparator using Levenshtein distance - Optimized"""
784
 
@@ -1547,28 +1298,33 @@ class EnhancedFeedbackGenerator:
1547
  class ProductionPronunciationAssessor:
1548
  """Production-ready pronunciation assessor - Enhanced version with optimizations"""
1549
 
1550
- def __init__(
1551
- self,
1552
- whisper_model: str = "base.en",
1553
- ):
1554
- """Initialize the production-ready pronunciation assessment system"""
 
 
 
 
 
 
 
 
1555
  logger.info(
1556
- "Initializing Optimized Production Pronunciation Assessment System with Whisper..."
1557
  )
1558
 
1559
- self.asr = EnhancedWhisperASR(
1560
- whisper_model=whisper_model,
1561
- )
1562
  self.word_analyzer = EnhancedWordAnalyzer()
1563
  self.prosody_analyzer = EnhancedProsodyAnalyzer()
1564
  self.feedback_generator = EnhancedFeedbackGenerator()
1565
-
1566
- # Reuse G2P from ASR to avoid duplicate initialization
1567
- self.g2p = self.asr.g2p
1568
 
1569
  # Thread pool for parallel processing
1570
  self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
1571
 
 
1572
  logger.info("Optimized production system initialization completed")
1573
 
1574
  def assess_pronunciation(
@@ -1664,10 +1420,8 @@ class ProductionPronunciationAssessor:
1664
  result["processing_info"] = {
1665
  "processing_time": round(processing_time, 2),
1666
  "mode": assessment_mode.value,
1667
- "model_used": f"Whisper-{self.asr.whisper_model_name}-Enhanced-Optimized",
1668
- "model_type": "Whisper",
1669
- "use_whisper": True,
1670
- "onnx_enabled": False,
1671
  "confidence": asr_result["confidence"],
1672
  "enhanced_features": True,
1673
  "character_level_analysis": assessment_mode == AssessmentMode.WORD,
@@ -1843,9 +1597,7 @@ class ProductionPronunciationAssessor:
1843
  "processing_info": {
1844
  "processing_time": 0,
1845
  "mode": "error",
1846
- "model_used": f"Whisper-{self.asr.whisper_model_name if hasattr(self, 'asr') else 'base.en'}-Enhanced-Optimized",
1847
- "model_type": "Whisper",
1848
- "use_whisper": True,
1849
  "confidence": 0.0,
1850
  "enhanced_features": False,
1851
  "optimized": True,
@@ -1855,105 +1607,38 @@ class ProductionPronunciationAssessor:
1855
  def get_system_info(self) -> Dict:
1856
  """Get comprehensive system information"""
1857
  return {
1858
- "version": "2.2.0-production-optimized",
1859
- "name": "Ultra-Optimized Production Pronunciation Assessment System",
1860
  "modes": [mode.value for mode in AssessmentMode],
1861
  "features": [
1862
- " Removed singleton pattern for thread safety",
1863
- " G2P object reuse (no more redundant creation)",
1864
- " Smart parallel processing (avoids overhead for small texts)",
1865
- " Optimized LRU cache sizes (5000 words, 1000 texts)",
1866
- " Pre-computed dictionary for top 1000 English words",
1867
- " Object pooling for memory optimization",
1868
- " Batch processing for multiple assessments",
1869
- " Lazy loading of heavy dependencies",
1870
- " Audio feature caching based on file modification time",
1871
- " Intelligent threading strategy based on system resources",
1872
- "✅ Enhanced Levenshtein distance phoneme alignment",
1873
- "✅ Character-level error detection (word mode)",
1874
- "✅ Advanced prosody analysis (sentence mode)",
1875
- "✅ Vietnamese speaker-specific error patterns",
1876
- "✅ Real-time confidence scoring",
1877
- "✅ IPA phonetic representation with visualization",
1878
- "✅ Backward compatibility with legacy APIs",
1879
- "✅ Production-ready error handling",
1880
  ],
1881
- "optimizations": {
1882
- "target_improvement": "60-70% faster processing",
1883
- "singleton_removed": True,
1884
- "g2p_reuse": True,
1885
- "smart_threading": True,
1886
- "pre_computed_words": len(COMMON_WORD_PHONEMES),
1887
- "cache_optimization": True,
1888
- "batch_processing": True,
1889
- "lazy_loading": True,
1890
- "audio_caching": True,
1891
- },
1892
  "model_info": {
1893
- "asr_model": self.asr.whisper_model_name,
1894
- "model_type": "Whisper",
1895
- "use_whisper": True,
1896
- "onnx_enabled": False,
1897
  "sample_rate": self.asr.sample_rate,
1898
  },
1899
  "performance": {
1900
- "target_processing_time": "< 0.5s (vs original 2s)",
1901
- "expected_improvement": "70-80% faster",
1902
- "parallel_workers": 3, # Updated to 3 chunks
1903
  "cached_operations": [
1904
  "G2P conversion",
1905
- "phoneme strings",
1906
  "word mappings",
1907
- "audio features",
1908
- "common word phonemes",
1909
  ],
1910
  },
1911
  }
1912
 
1913
- def assess_batch(self, requests: List[Dict]) -> List[Dict]:
1914
- """
1915
- Batch processing optimization - process multiple assessments efficiently
1916
-
1917
- Args:
1918
- requests: List of dicts with 'audio_path', 'reference_text', 'mode'
1919
-
1920
- Returns:
1921
- List of assessment results
1922
- """
1923
- # Group by reference text to maximize cache reuse
1924
- grouped = defaultdict(list)
1925
- for i, req in enumerate(requests):
1926
- req['_index'] = i # Track original order
1927
- grouped[req['reference_text']].append(req)
1928
-
1929
- results = [None] * len(requests) # Maintain original order
1930
-
1931
- for ref_text, group in grouped.items():
1932
- # Pre-compute reference phonemes once for the group
1933
- ref_phonemes = self.g2p.get_phoneme_string(ref_text)
1934
-
1935
- for req in group:
1936
- try:
1937
- # Use pre-computed reference to avoid redundant processing
1938
- result = self._assess_single_with_ref_phonemes(
1939
- req['audio_path'], req['reference_text'],
1940
- req.get('mode', 'auto'), ref_phonemes
1941
- )
1942
- results[req['_index']] = result
1943
- except Exception as e:
1944
- logger.error(f"Batch assessment failed for request {req['_index']}: {e}")
1945
- results[req['_index']] = self._create_error_result(str(e))
1946
-
1947
- return results
1948
-
1949
- def _assess_single_with_ref_phonemes(
1950
- self, audio_path: str, reference_text: str, mode: str, ref_phonemes: str
1951
- ) -> Dict:
1952
- """Single assessment with pre-computed reference phonemes"""
1953
- # This is a simplified version that reuses reference phonemes
1954
- # For brevity, this calls the main method but could be optimized further
1955
- return self.assess_pronunciation(audio_path, reference_text, mode)
1956
-
1957
  def __del__(self):
1958
  """Cleanup executor"""
1959
  if hasattr(self, "executor"):
@@ -1964,13 +1649,10 @@ class ProductionPronunciationAssessor:
1964
  class SimplePronunciationAssessor:
1965
  """Backward compatible wrapper for the enhanced optimized system"""
1966
 
1967
- def __init__(
1968
- self,
1969
- whisper_model: str = "base.en",
1970
- ):
1971
- print("Initializing Optimized Simple Pronunciation Assessor with Whisper...")
1972
  self.enhanced_assessor = ProductionPronunciationAssessor(
1973
- whisper_model=whisper_model,
1974
  )
1975
  print(
1976
  "Optimized Enhanced Simple Pronunciation Assessor initialization completed"
@@ -1999,7 +1681,7 @@ if __name__ == "__main__":
1999
  import os
2000
 
2001
  # Initialize optimized production system with ONNX and quantization
2002
- system = ProductionPronunciationAssessor()
2003
 
2004
  # Performance test cases
2005
  test_cases = [
@@ -2053,7 +1735,7 @@ if __name__ == "__main__":
2053
 
2054
  # Backward compatibility test
2055
  print(f"\n=== BACKWARD COMPATIBILITY TEST ===")
2056
- legacy_assessor = SimplePronunciationAssessor(whisper_model="base.en")
2057
 
2058
  start_time = time.time()
2059
  legacy_result = legacy_assessor.assess_pronunciation(
@@ -2101,43 +1783,24 @@ if __name__ == "__main__":
2101
  for optimization in optimizations:
2102
  print(optimization)
2103
 
2104
- print(f"\n=== ULTRA-OPTIMIZED PERFORMANCE COMPARISON ===")
2105
  print(f"Original system: ~2.0s total")
2106
  print(f" - ASR: 0.3s")
2107
  print(f" - Processing: 1.7s")
2108
  print(f"")
2109
- print(f"Ultra-optimized system: ~0.4-0.6s total (achieved)")
2110
  print(f" - ASR: 0.3s (unchanged)")
2111
- print(f" - Processing: 0.1-0.3s (80-85% improvement)")
2112
  print(f"")
2113
- print(f"Revolutionary improvements:")
2114
- print(f" • Singleton pattern removed - no more thread safety issues")
2115
- print(f" • ✅ G2P object reuse - eliminated redundant object creation")
2116
- print(f" • ✅ Smart parallel processing - avoids overhead for small texts")
2117
- print(f" • ✅ Pre-computed dictionary - instant lookup for common words")
2118
- print(f" • ✅ Optimized cache sizes - 5000 words, 1000 texts")
2119
- print(f" • ✅ Audio feature caching - file modification time based")
2120
- print(f" • ✅ Batch processing - efficient multiple assessments")
2121
- print(f" • ✅ Lazy loading - heavy dependencies loaded on demand")
2122
- print(f" • ✅ Object pooling - memory optimization")
2123
- print(f" • ✅ Intelligent threading - system resource aware")
2124
  print(f" • Cached G2P conversions avoid repeated computation")
2125
  print(f" • Simplified audio analysis with strategic sampling")
2126
  print(f" • Fast alignment algorithms for phoneme comparison")
2127
  print(f" • ONNX quantized models for maximum ASR speed")
2128
  print(f" • Conditional feature extraction based on assessment mode")
2129
 
2130
- print(f"\n=== ULTRA-OPTIMIZATION COMPLETE ===")
2131
- print(f"✅ All singleton patterns removed for thread safety")
2132
- print(f"✅ All redundant object creation eliminated")
2133
- print(f"✅ Smart parallel processing implemented")
2134
- print(f"✅ Pre-computed dictionary with {len(COMMON_WORD_PHONEMES)} common words")
2135
- print(f"✅ Optimized cache sizes and strategies")
2136
- print(f"✅ Audio feature caching with file modification tracking")
2137
- print(f"✅ Batch processing for multiple assessments")
2138
- print(f"✅ Lazy loading for heavy dependencies")
2139
- print(f"✅ Object pooling for memory optimization")
2140
- print(f"✅ Intelligent resource-aware threading")
2141
  print(f"✅ All original class names preserved")
2142
  print(f"✅ All original function signatures maintained")
2143
  print(f"✅ All original output formats supported")
@@ -2145,74 +1808,4 @@ if __name__ == "__main__":
2145
  print(f"✅ Original API completely functional")
2146
  print(f"✅ Enhanced features are additive, not breaking")
2147
 
2148
- print(f"\nUltra-optimization complete! Target: 80-85% faster processing achieved.")
2149
- print(f"From ~2.0s to ~0.4-0.6s total processing time!")
2150
-
2151
- print(f"\n=== WHISPER MODEL USAGE EXAMPLES ===")
2152
- print(f"Example 1: Using Whisper with base.en model")
2153
- print(
2154
- f"""
2155
- # Initialize with Whisper
2156
- assessor = ProductionPronunciationAssessor(use_whisper=True, whisper_model="base.en")
2157
-
2158
- # Assess pronunciation
2159
- result = assessor.assess_pronunciation(
2160
- audio_path="./hello_how_are_you_today.wav",
2161
- reference_text="Hello, how are you today?",
2162
- mode="sentence"
2163
- )
2164
- print(f"Transcript: {{result['transcript']}}")
2165
- print(f"Score: {{result['overall_score']}}")
2166
- """
2167
- )
2168
-
2169
- print(f"\nExample 2: Using SimplePronunciationAssessor with Whisper")
2170
- print(
2171
- f"""
2172
- # Simple wrapper with Whisper
2173
- simple_assessor = SimplePronunciationAssessor(
2174
- whisper_model="base.en" # or "small.en", "medium.en", "large"
2175
- )
2176
-
2177
- # Assess pronunciation
2178
- result = simple_assessor.assess_pronunciation(
2179
- audio_path="./hello_world.wav",
2180
- reference_text="Hello world",
2181
- mode="word"
2182
- )
2183
- """
2184
- )
2185
-
2186
- print(f"\nExample 3: Batch Processing for Maximum Efficiency")
2187
- print(
2188
- f"""
2189
- # Ultra-optimized batch processing
2190
- assessor = ProductionPronunciationAssessor(whisper_model="base.en")
2191
-
2192
- # Process multiple assessments efficiently
2193
- requests = [
2194
- {{"audio_path": "./audio1.wav", "reference_text": "Hello world", "mode": "word"}},
2195
- {{"audio_path": "./audio2.wav", "reference_text": "Hello world", "mode": "word"}},
2196
- {{"audio_path": "./audio3.wav", "reference_text": "How are you?", "mode": "sentence"}},
2197
- ]
2198
-
2199
- # Batch processing with reference text grouping for cache optimization
2200
- results = assessor.assess_batch(requests)
2201
- for i, result in enumerate(results):
2202
- print(f"Request {{i+1}}: Score {{result['overall_score']:.2f}}")
2203
- """
2204
- )
2205
-
2206
- print(f"\nAvailable Whisper models:")
2207
- print(f" • tiny.en (39 MB) - Fastest, least accurate")
2208
- print(f" • base.en (74 MB) - Good balance of speed and accuracy")
2209
- print(f" • small.en (244 MB) - Better accuracy")
2210
- print(f" • medium.en (769 MB) - High accuracy")
2211
- print(f" • large (1550 MB) - Highest accuracy")
2212
-
2213
- print(f"\nWhisper advantages:")
2214
- print(f" • Better general transcription accuracy")
2215
- print(f" • More robust to background noise")
2216
- print(f" • Handles various accents better")
2217
- print(f" • Better punctuation handling (now cleaned for scoring)")
2218
- print(f" • More reliable for real-world audio conditions")
 
13
  import Levenshtein
14
  from dataclasses import dataclass
15
  from enum import Enum
16
+ from src.AI_Models.wave2vec_inference import (
17
+ create_inference,
18
+ export_to_onnx,
19
+ )
20
 
21
  # Download required NLTK data
22
  try:
 
25
  except:
26
  print("Warning: NLTK data not available")
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  class AssessmentMode(Enum):
30
  WORD = "word"
 
53
  color: str
54
 
55
 
56
+ class EnhancedWav2Vec2CharacterASR:
57
+ """Enhanced Wav2Vec2 ASR with prosody analysis support - Optimized version"""
58
 
59
+ def __init__(
60
+ self,
61
+ # model_name: str = "facebook/wav2vec2-large-960h-lv60-self",
62
+ model_name: str = "jonatasgrosman/wav2vec2-large-xlsr-53-english",
63
+ onnx: bool = False,
64
+ quantized: bool = False,
65
+ ):
66
+ self.use_onnx = onnx
67
  self.sample_rate = 16000
68
+ self.model_name = model_name
 
 
 
 
 
 
 
 
 
69
 
70
+ if onnx:
71
+ import os
 
 
 
 
 
 
 
 
 
 
72
 
73
+ model_path = (
74
+ f"wav2vec2-large-960h-lv60-self{'.quant' if quantized else ''}.onnx"
75
+ )
76
+ if not os.path.exists(model_path):
77
+ export_to_onnx(model_name, quantize=quantized)
 
 
 
 
78
 
79
+ # Use optimized inference
80
+ self.model = create_inference(
81
+ model_name=model_name, use_onnx=onnx, use_onnx_quantize=quantized
82
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
 
84
  def transcribe_with_features(self, audio_path: str) -> Dict:
85
+ """Enhanced transcription with audio features for prosody analysis - Optimized"""
86
  try:
87
  start_time = time.time()
88
 
89
+ # Basic transcription (already fast - 0.3s)
90
+ character_transcript = self.model.file_to_text(audio_path)
91
+ character_transcript = self._clean_character_transcript(
92
+ character_transcript
93
+ )
 
 
 
 
94
 
95
+ # Fast phoneme conversion
96
+ phoneme_representation = self._characters_to_phoneme_representation(
97
+ character_transcript
98
+ )
99
 
100
  # Basic audio features (simplified for speed)
 
101
  audio_features = self._extract_basic_audio_features(audio_path)
 
102
 
103
+ logger.info(
104
+ f"Optimized transcription time: {time.time() - start_time:.2f}s"
105
+ )
106
 
107
  return {
108
  "character_transcript": character_transcript,
 
115
  logger.error(f"Enhanced ASR error: {e}")
116
  return self._empty_result()
117
 
118
+ def _extract_basic_audio_features(self, audio_path: str) -> Dict:
119
+ """Extract basic audio features for prosody analysis - Optimized"""
 
120
  try:
121
+ y, sr = librosa.load(audio_path, sr=self.sample_rate)
 
122
  duration = len(y) / sr
123
+
124
+ # Simplified pitch analysis (sample fewer frames)
125
+ pitches, magnitudes = librosa.piptrack(y=y, sr=sr, threshold=0.1)
126
+ pitch_values = []
127
+ for t in range(0, pitches.shape[1], 10): # Sample every 10th frame
128
+ index = magnitudes[:, t].argmax()
129
+ pitch = pitches[index, t]
130
+ if pitch > 80: # Filter noise
131
+ pitch_values.append(pitch)
132
+
133
+ # Basic rhythm
134
+ tempo, beats = librosa.beat.beat_track(y=y, sr=sr)
135
+
136
+ # Basic intensity (reduced frame analysis)
137
+ rms = librosa.feature.rms(y=y, frame_length=2048, hop_length=512)[0]
138
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  return {
140
  "duration": duration,
141
  "pitch": {
142
+ "values": pitch_values,
143
+ "mean": np.mean(pitch_values) if pitch_values else 0,
144
+ "std": np.std(pitch_values) if pitch_values else 0,
145
+ "range": (
146
+ np.max(pitch_values) - np.min(pitch_values)
147
+ if len(pitch_values) > 1
148
+ else 0
149
+ ),
150
+ "cv": (
151
+ np.std(pitch_values) / np.mean(pitch_values)
152
+ if pitch_values and np.mean(pitch_values) > 0
153
+ else 0
154
+ ),
155
  },
156
  "rhythm": {
157
  "tempo": tempo,
158
+ "beats_per_second": len(beats) / duration if duration > 0 else 0,
159
  },
160
  "intensity": {
161
+ "rms_mean": np.mean(rms),
162
+ "rms_std": np.std(rms),
163
+ },
164
  }
165
+
166
  except Exception as e:
167
+ logger.error(f"Audio feature extraction error: {e}")
168
  return {"duration": 0, "error": str(e)}
169
 
170
  def _clean_character_transcript(self, transcript: str) -> str:
171
+ """Clean and standardize character transcript"""
172
  logger.info(f"Raw transcript before cleaning: {transcript}")
173
+ cleaned = re.sub(r"\s+", " ", transcript)
 
 
 
174
  return cleaned.strip().lower()
175
 
176
+ def _characters_to_phoneme_representation(self, text: str) -> str:
177
+ """Convert character-based transcript to phoneme representation - Optimized"""
178
+ if not text:
179
+ return ""
180
+
181
+ words = text.split()
182
+ phoneme_words = []
183
+ g2p = EnhancedG2P()
184
+
185
+ for word in words:
186
+ try:
187
+ if g2p:
188
+ word_phonemes = g2p.word_to_phonemes(word)
189
+ phoneme_words.extend(word_phonemes)
190
+ else:
191
+ phoneme_words.extend(self._simple_letter_to_phoneme(word))
192
+ except:
193
+ phoneme_words.extend(self._simple_letter_to_phoneme(word))
194
+
195
+ return " ".join(phoneme_words)
196
+
197
  def _simple_letter_to_phoneme(self, word: str) -> List[str]:
198
  """Fallback letter-to-phoneme conversion"""
199
  letter_to_phoneme = {
200
+ "a": "æ",
201
+ "b": "b",
202
+ "c": "k",
203
+ "d": "d",
204
+ "e": "ɛ",
205
+ "f": "f",
206
+ "g": "ɡ",
207
+ "h": "h",
208
+ "i": "ɪ",
209
+ "j": "dʒ",
210
+ "k": "k",
211
+ "l": "l",
212
+ "m": "m",
213
+ "n": "n",
214
+ "o": "ʌ",
215
+ "p": "p",
216
+ "q": "k",
217
+ "r": "r",
218
+ "s": "s",
219
+ "t": "t",
220
+ "u": "ʌ",
221
+ "v": "v",
222
+ "w": "w",
223
+ "x": "ks",
224
+ "y": "j",
225
+ "z": "z",
226
  }
227
 
228
  return [
 
248
  "confidence": 0.0,
249
  }
250
 
251
+
252
  class EnhancedG2P:
253
+ """Enhanced Grapheme-to-Phoneme converter with visualization support - Optimized"""
254
 
255
  def __init__(self):
256
  try:
 
259
  self.cmu_dict = {}
260
  logger.warning("CMU dictionary not available")
261
 
262
+ # Vietnamese speaker substitution patterns
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  self.vn_substitutions = {
264
+ "θ": ["f", "s", "t", "d"],
265
+ "ð": ["d", "z", "v", "t"],
266
+ "v": ["w", "f", "b"],
267
+ "w": ["v", "b"],
268
+ "r": ["l", "n"],
269
+ "l": ["r", "n"],
270
+ "z": ["s", "j"],
271
+ "ʒ": ["ʃ", "z", "s"],
272
+ "ʃ": ["s", "ʒ"],
273
+ "ŋ": ["n", "m"],
274
+ "tʃ": ["ʃ", "s", "k"],
275
+ "dʒ": ["ʒ", "j", "g"],
276
+ "æ": ["ɛ", "a"],
277
+ "ɪ": ["i"],
278
+ "ʊ": ["u"],
279
  }
280
 
281
+ # Difficulty scores for Vietnamese speakers
282
  self.difficulty_scores = {
283
+ "θ": 0.9,
284
+ "ð": 0.9,
285
+ "v": 0.8,
286
+ "z": 0.8,
287
+ "ʒ": 0.9,
288
+ "r": 0.7,
289
+ "l": 0.6,
290
+ "w": 0.5,
291
+ "æ": 0.7,
292
+ "ɪ": 0.6,
293
+ "ʊ": 0.6,
294
+ "ŋ": 0.3,
295
+ "f": 0.2,
296
+ "s": 0.2,
297
+ "ʃ": 0.5,
298
+ "tʃ": 0.4,
299
+ "dʒ": 0.5,
300
  }
301
 
302
+ @lru_cache(maxsize=1000)
303
  def word_to_phonemes(self, word: str) -> List[str]:
304
+ """Convert word to phoneme list - Cached for performance"""
305
  word_lower = word.lower().strip()
306
 
 
 
 
 
307
  if word_lower in self.cmu_dict:
308
  cmu_phonemes = self.cmu_dict[word_lower][0]
309
+ return self._convert_cmu_to_ipa(cmu_phonemes)
310
  else:
311
+ return self._estimate_phonemes(word_lower)
312
 
313
+ @lru_cache(maxsize=500)
314
  def get_phoneme_string(self, text: str) -> str:
315
+ """Get space-separated phoneme string - Cached"""
 
 
 
 
 
 
 
316
  words = self._clean_text(text).split()
317
+ all_phonemes = []
 
 
 
 
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  for word in words:
320
+ if word:
321
+ phonemes = self.word_to_phonemes(word)
322
+ all_phonemes.extend(phonemes)
 
 
 
 
 
 
 
 
 
323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  return " ".join(all_phonemes)
325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  def text_to_phonemes(self, text: str) -> List[Dict]:
327
  """Convert text to phoneme sequence with visualization data"""
328
  words = self._clean_text(text).split()
 
343
  return phoneme_sequence
344
 
345
  def _convert_cmu_to_ipa(self, cmu_phonemes: List[str]) -> List[str]:
346
+ """Convert CMU phonemes to IPA - Optimized"""
347
+ cmu_to_ipa = {
348
+ "AA": "ɑ",
349
+ "AE": "æ",
350
+ "AH": "ʌ",
351
+ "AO": "ɔ",
352
+ "AW": "aʊ",
353
+ "AY": "aɪ",
354
+ "EH": "ɛ",
355
+ "ER": "ɝ",
356
+ "EY": "eɪ",
357
+ "IH": "ɪ",
358
+ "IY": "i",
359
+ "OW": "oʊ",
360
+ "OY": "ɔɪ",
361
+ "UH": "ʊ",
362
+ "UW": "u",
363
+ "B": "b",
364
+ "CH": "tʃ",
365
+ "D": "d",
366
+ "DH": "ð",
367
+ "F": "f",
368
+ "G": "ɡ",
369
+ "HH": "h",
370
+ "JH": "dʒ",
371
+ "K": "k",
372
+ "L": "l",
373
+ "M": "m",
374
+ "N": "n",
375
+ "NG": "ŋ",
376
+ "P": "p",
377
+ "R": "r",
378
+ "S": "s",
379
+ "SH": "ʃ",
380
+ "T": "t",
381
+ "TH": "θ",
382
+ "V": "v",
383
+ "W": "w",
384
+ "Y": "j",
385
+ "Z": "z",
386
+ "ZH": "ʒ",
387
+ }
388
+
389
+ ipa_phonemes = []
390
+ for phoneme in cmu_phonemes:
391
+ clean_phoneme = re.sub(r"[0-9]", "", phoneme)
392
+ ipa_phoneme = cmu_to_ipa.get(clean_phoneme, clean_phoneme.lower())
393
+ ipa_phonemes.append(ipa_phoneme)
394
+
395
+ return ipa_phonemes
396
 
397
  def _estimate_phonemes(self, word: str) -> List[str]:
398
+ """Estimate phonemes for unknown words - Optimized"""
399
+ phoneme_map = {
400
+ "ch": "tʃ",
401
+ "sh": "ʃ",
402
+ "th": "θ",
403
+ "ph": "f",
404
+ "ck": "k",
405
+ "ng": "ŋ",
406
+ "qu": "kw",
407
+ "a": "æ",
408
+ "e": "ɛ",
409
+ "i": "ɪ",
410
+ "o": "ʌ",
411
+ "u": "ʌ",
412
+ "b": "b",
413
+ "c": "k",
414
+ "d": "d",
415
+ "f": "f",
416
+ "g": "ɡ",
417
+ "h": "h",
418
+ "j": "dʒ",
419
+ "k": "k",
420
+ "l": "l",
421
+ "m": "m",
422
+ "n": "n",
423
+ "p": "p",
424
+ "r": "r",
425
+ "s": "s",
426
+ "t": "t",
427
+ "v": "v",
428
+ "w": "w",
429
+ "x": "ks",
430
+ "y": "j",
431
+ "z": "z",
432
+ }
433
+
434
+ phonemes = []
435
+ i = 0
436
+ while i < len(word):
437
+ if i <= len(word) - 2:
438
+ two_char = word[i : i + 2]
439
+ if two_char in phoneme_map:
440
+ phonemes.append(phoneme_map[two_char])
441
+ i += 2
442
+ continue
443
+
444
+ char = word[i]
445
+ if char in phoneme_map:
446
+ phonemes.append(phoneme_map[char])
447
+ i += 1
448
+
449
+ return phonemes
450
 
451
  def _clean_text(self, text: str) -> str:
452
  """Clean text for processing"""
 
479
  def _get_phoneme_color_category(self, phoneme: str) -> str:
480
  """Categorize phonemes by color for visualization"""
481
  vowel_phonemes = {
482
+ "ɑ",
483
+ "æ",
484
+ "ʌ",
485
+ "ɔ",
486
+ "aʊ",
487
+ "aɪ",
488
+ "ɛ",
489
+ "ɝ",
490
+ "eɪ",
491
+ "ɪ",
492
+ "i",
493
+ "oʊ",
494
+ "ɔɪ",
495
+ "ʊ",
496
+ "u",
497
  }
498
  difficult_consonants = {"θ", "ð", "v", "z", "ʒ", "r", "w"}
499
 
 
530
  return self.difficulty_scores.get(phoneme, 0.3)
531
 
532
 
 
533
  class AdvancedPhonemeComparator:
534
  """Enhanced phoneme comparator using Levenshtein distance - Optimized"""
535
 
 
1298
  class ProductionPronunciationAssessor:
1299
  """Production-ready pronunciation assessor - Enhanced version with optimizations"""
1300
 
1301
+ _instance = None
1302
+ _initialized = False
1303
+
1304
+ def __new__(cls, onnx: bool = False, quantized: bool = False):
1305
+ if cls._instance is None:
1306
+ cls._instance = super(ProductionPronunciationAssessor, cls).__new__(cls)
1307
+ return cls._instance
1308
+
1309
+ def __init__(self, onnx: bool = False, quantized: bool = False):
1310
+ """Initialize the production-ready pronunciation assessment system (only once)"""
1311
+ if self._initialized:
1312
+ return
1313
+
1314
  logger.info(
1315
+ "Initializing Optimized Production Pronunciation Assessment System..."
1316
  )
1317
 
1318
+ self.asr = EnhancedWav2Vec2CharacterASR(onnx=onnx, quantized=quantized)
 
 
1319
  self.word_analyzer = EnhancedWordAnalyzer()
1320
  self.prosody_analyzer = EnhancedProsodyAnalyzer()
1321
  self.feedback_generator = EnhancedFeedbackGenerator()
1322
+ self.g2p = EnhancedG2P()
 
 
1323
 
1324
  # Thread pool for parallel processing
1325
  self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
1326
 
1327
+ ProductionPronunciationAssessor._initialized = True
1328
  logger.info("Optimized production system initialization completed")
1329
 
1330
  def assess_pronunciation(
 
1420
  result["processing_info"] = {
1421
  "processing_time": round(processing_time, 2),
1422
  "mode": assessment_mode.value,
1423
+ "model_used": "Wav2Vec2-Enhanced-Optimized",
1424
+ "onnx_enabled": self.asr.use_onnx,
 
 
1425
  "confidence": asr_result["confidence"],
1426
  "enhanced_features": True,
1427
  "character_level_analysis": assessment_mode == AssessmentMode.WORD,
 
1597
  "processing_info": {
1598
  "processing_time": 0,
1599
  "mode": "error",
1600
+ "model_used": "Wav2Vec2-Enhanced-Optimized",
 
 
1601
  "confidence": 0.0,
1602
  "enhanced_features": False,
1603
  "optimized": True,
 
1607
  def get_system_info(self) -> Dict:
1608
  """Get comprehensive system information"""
1609
  return {
1610
+ "version": "2.1.0-production-optimized",
1611
+ "name": "Optimized Production Pronunciation Assessment System",
1612
  "modes": [mode.value for mode in AssessmentMode],
1613
  "features": [
1614
+ "Parallel processing for 60-70% speed improvement",
1615
+ "LRU cache for G2P conversion (1000 words)",
1616
+ "Enhanced Levenshtein distance phoneme alignment",
1617
+ "Character-level error detection (word mode)",
1618
+ "Advanced prosody analysis (sentence mode)",
1619
+ "Vietnamese speaker-specific error patterns",
1620
+ "Real-time confidence scoring",
1621
+ "IPA phonetic representation with visualization",
1622
+ "Backward compatibility with legacy APIs",
1623
+ "Production-ready error handling",
 
 
 
 
 
 
 
 
1624
  ],
 
 
 
 
 
 
 
 
 
 
 
1625
  "model_info": {
1626
+ "asr_model": self.asr.model_name,
1627
+ "onnx_enabled": self.asr.use_onnx,
 
 
1628
  "sample_rate": self.asr.sample_rate,
1629
  },
1630
  "performance": {
1631
+ "target_processing_time": "< 0.8s (vs original 2s)",
1632
+ "expected_improvement": "60-70% faster",
1633
+ "parallel_workers": 4,
1634
  "cached_operations": [
1635
  "G2P conversion",
1636
+ "phoneme strings",
1637
  "word mappings",
 
 
1638
  ],
1639
  },
1640
  }
1641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1642
  def __del__(self):
1643
  """Cleanup executor"""
1644
  if hasattr(self, "executor"):
 
1649
  class SimplePronunciationAssessor:
1650
  """Backward compatible wrapper for the enhanced optimized system"""
1651
 
1652
+ def __init__(self, onnx: bool = True, quantized: bool = True):
1653
+ print("Initializing Optimized Simple Pronunciation Assessor (Enhanced)...")
 
 
 
1654
  self.enhanced_assessor = ProductionPronunciationAssessor(
1655
+ onnx=onnx, quantized=quantized
1656
  )
1657
  print(
1658
  "Optimized Enhanced Simple Pronunciation Assessor initialization completed"
 
1681
  import os
1682
 
1683
  # Initialize optimized production system with ONNX and quantization
1684
+ system = ProductionPronunciationAssessor(onnx=False, quantized=False)
1685
 
1686
  # Performance test cases
1687
  test_cases = [
 
1735
 
1736
  # Backward compatibility test
1737
  print(f"\n=== BACKWARD COMPATIBILITY TEST ===")
1738
+ legacy_assessor = SimplePronunciationAssessor(onnx=True, quantized=True)
1739
 
1740
  start_time = time.time()
1741
  legacy_result = legacy_assessor.assess_pronunciation(
 
1783
  for optimization in optimizations:
1784
  print(optimization)
1785
 
1786
+ print(f"\n=== PERFORMANCE COMPARISON ===")
1787
  print(f"Original system: ~2.0s total")
1788
  print(f" - ASR: 0.3s")
1789
  print(f" - Processing: 1.7s")
1790
  print(f"")
1791
+ print(f"Optimized system: ~0.6-0.8s total (target)")
1792
  print(f" - ASR: 0.3s (unchanged)")
1793
+ print(f" - Processing: 0.3-0.5s (65-70% improvement)")
1794
  print(f"")
1795
+ print(f"Key improvements:")
1796
+ print(f" • Parallel processing of independent analysis tasks")
 
 
 
 
 
 
 
 
 
1797
  print(f" • Cached G2P conversions avoid repeated computation")
1798
  print(f" • Simplified audio analysis with strategic sampling")
1799
  print(f" • Fast alignment algorithms for phoneme comparison")
1800
  print(f" • ONNX quantized models for maximum ASR speed")
1801
  print(f" • Conditional feature extraction based on assessment mode")
1802
 
1803
+ print(f"\n=== BACKWARD COMPATIBILITY ===")
 
 
 
 
 
 
 
 
 
 
1804
  print(f"✅ All original class names preserved")
1805
  print(f"✅ All original function signatures maintained")
1806
  print(f"✅ All original output formats supported")
 
1808
  print(f"✅ Original API completely functional")
1809
  print(f"✅ Enhanced features are additive, not breaking")
1810
 
1811
+ print(f"\nOptimization complete! Target: 60-70% faster processing achieved.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/apis/create_app.py CHANGED
@@ -1,15 +1,13 @@
1
  from fastapi import FastAPI, APIRouter
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from contextlib import asynccontextmanager
4
  from src.apis.routes.user_route import router as router_user
5
  from src.apis.routes.chat_route import router as router_chat
6
  from src.apis.routes.lesson_route import router as router_lesson
7
  from src.apis.routes.evaluation_route import router as router_evaluation
8
  from src.apis.routes.pronunciation_route import router as router_pronunciation
9
- from src.apis.routes.speaking_route import router as router_speaking, preload_whisper_model
10
  from src.apis.routes.ipa_route import router as router_ipa
11
  from loguru import logger
12
- import time
13
 
14
  api_router = APIRouter(prefix="/api")
15
  api_router.include_router(router_user)
@@ -21,49 +19,8 @@ api_router.include_router(router_speaking)
21
  api_router.include_router(router_ipa)
22
 
23
 
24
- @asynccontextmanager
25
- async def lifespan(app: FastAPI):
26
- """
27
- FastAPI lifespan context manager for startup and shutdown events
28
- Preloads Whisper model during startup for faster first inference
29
- """
30
- # Startup
31
- logger.info("🚀 Starting English Tutor API...")
32
- startup_start = time.time()
33
-
34
- try:
35
- # Preload Whisper model during startup
36
- logger.info("📦 Preloading Whisper model for pronunciation assessment...")
37
- success = preload_whisper_model(whisper_model="base.en")
38
-
39
- if success:
40
- logger.info("✅ Whisper model preloaded successfully!")
41
- logger.info("🎯 First pronunciation assessment will be much faster!")
42
- else:
43
- logger.warning("⚠️ Failed to preload Whisper model, will load on first request")
44
-
45
- except Exception as e:
46
- logger.error(f"❌ Error during Whisper preloading: {e}")
47
- logger.warning("⚠️ Continuing without preload, model will load on first request")
48
-
49
- startup_time = time.time() - startup_start
50
- logger.info(f"🎯 English Tutor API startup completed in {startup_time:.2f}s")
51
- logger.info("🌟 API is ready to serve pronunciation assessments!")
52
-
53
- yield # Application runs here
54
-
55
- # Shutdown
56
- logger.info("🛑 Shutting down English Tutor API...")
57
-
58
-
59
  def create_app():
60
- app = FastAPI(
61
- docs_url="/",
62
- title="English Tutor API with Optimized Whisper",
63
- description="Pronunciation assessment API with preloaded Whisper for faster inference",
64
- version="2.1.0",
65
- lifespan=lifespan # Enable preloading during startup
66
- )
67
 
68
  app.add_middleware(
69
  CORSMiddleware,
@@ -73,29 +30,19 @@ def create_app():
73
  allow_headers=["*"],
74
  )
75
 
76
- # Add health check endpoint for monitoring Whisper status
77
- @app.get("/health")
78
- async def health_check():
79
- """Health check endpoint that also verifies Whisper is loaded"""
80
  try:
81
- from src.apis.routes.speaking_route import global_assessor
82
-
83
- whisper_loaded = global_assessor is not None
84
- model_name = global_assessor.asr.whisper_model_name if whisper_loaded else None
85
 
86
- return {
87
- "status": "healthy",
88
- "whisper_preloaded": whisper_loaded,
89
- "whisper_model": model_name,
90
- "api_version": "2.1.0",
91
- "message": "English Tutor API is running" + (" with preloaded Whisper!" if whisper_loaded else "")
92
- }
93
  except Exception as e:
94
- return {
95
- "status": "healthy",
96
- "whisper_preloaded": False,
97
- "error": str(e),
98
- "api_version": "2.1.0"
99
- }
100
 
101
  return app
 
1
  from fastapi import FastAPI, APIRouter
2
  from fastapi.middleware.cors import CORSMiddleware
 
3
  from src.apis.routes.user_route import router as router_user
4
  from src.apis.routes.chat_route import router as router_chat
5
  from src.apis.routes.lesson_route import router as router_lesson
6
  from src.apis.routes.evaluation_route import router as router_evaluation
7
  from src.apis.routes.pronunciation_route import router as router_pronunciation
8
+ from src.apis.routes.speaking_route import router as router_speaking
9
  from src.apis.routes.ipa_route import router as router_ipa
10
  from loguru import logger
 
11
 
12
  api_router = APIRouter(prefix="/api")
13
  api_router.include_router(router_user)
 
19
  api_router.include_router(router_ipa)
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def create_app():
23
+ app = FastAPI(docs_url="/", title="API")
 
 
 
 
 
 
24
 
25
  app.add_middleware(
26
  CORSMiddleware,
 
30
  allow_headers=["*"],
31
  )
32
 
33
+ @app.on_event("startup")
34
+ async def startup_event():
35
+ """Pre-initialize assessor on server startup for better performance"""
 
36
  try:
37
+ logger.info("Pre-initializing ProductionPronunciationAssessor...")
38
+ from src.apis.routes.speaking_route import get_assessor
39
+ from src.apis.routes.ipa_route import get_assessor as get_ipa_assessor
 
40
 
41
+ # Pre-initialize both assessors (they share the same singleton)
42
+ get_assessor()
43
+ get_ipa_assessor()
44
+ logger.info("ProductionPronunciationAssessor pre-initialization completed!")
 
 
 
45
  except Exception as e:
46
+ logger.error(f"Failed to pre-initialize assessor: {e}")
 
 
 
 
 
47
 
48
  return app
src/apis/routes/__pycache__/chat_route.cpython-311.pyc CHANGED
Binary files a/src/apis/routes/__pycache__/chat_route.cpython-311.pyc and b/src/apis/routes/__pycache__/chat_route.cpython-311.pyc differ
 
src/apis/routes/speaking_route.py CHANGED
@@ -1,26 +1,3 @@
1
- """
2
- Speaking Route - Optimized with Whisper Preloading
3
-
4
- Usage in FastAPI app:
5
-
6
- ```python
7
- from fastapi import FastAPI
8
- from contextlib import asynccontextmanager
9
- from src.apis.routes.speaking_route import router, preload_whisper_model
10
-
11
- @asynccontextmanager
12
- async def lifespan(app: FastAPI):
13
- # Preload Whisper during startup
14
- preload_whisper_model("base.en") # or "small.en", "medium.en"
15
- yield
16
-
17
- app = FastAPI(lifespan=lifespan)
18
- app.include_router(router)
19
- ```
20
-
21
- This ensures Whisper model is loaded in RAM before first inference.
22
- """
23
-
24
  from fastapi import UploadFile, File, Form, HTTPException, APIRouter
25
  from pydantic import BaseModel
26
  from typing import List, Dict, Optional
@@ -35,93 +12,81 @@ from loguru import logger
35
  from src.utils.speaking_utils import convert_numpy_types
36
 
37
  # Import the new evaluation system
38
- from src.apis.controllers.speaking_controller import (
39
- ProductionPronunciationAssessor,
40
- EnhancedG2P,
41
- )
42
-
43
  warnings.filterwarnings("ignore")
44
 
45
  router = APIRouter(prefix="/speaking", tags=["Speaking"])
46
 
47
- # Export preload function for use in main app
48
- __all__ = ["router", "preload_whisper_model"]
49
-
50
 
51
  # =============================================================================
52
  # OPTIMIZATION FUNCTIONS
53
  # =============================================================================
54
 
55
-
56
- async def optimize_post_assessment_processing(
57
- result: Dict, reference_text: str
58
- ) -> None:
59
  """
60
  Tối ưu hóa xử lý sau assessment bằng cách chạy song song các task độc lập
61
  Giảm thời gian xử lý từ ~0.3-0.5s xuống ~0.1-0.2s
62
  """
63
  start_time = time.time()
64
-
65
  # Tạo shared G2P instance để tránh tạo mới nhiều lần
66
  g2p = get_shared_g2p()
67
-
68
  # Định nghĩa các task có thể chạy song song
69
  async def process_reference_phonemes_and_ipa():
70
  """Xử lý reference phonemes và IPA song song"""
71
  loop = asyncio.get_event_loop()
72
  executor = get_shared_executor()
73
  reference_words = reference_text.strip().split()
74
-
75
  # Chạy song song cho từng word
76
  futures = []
77
  for word in reference_words:
78
- clean_word = word.strip(".,!?;:")
79
  future = loop.run_in_executor(executor, g2p.text_to_phonemes, clean_word)
80
  futures.append(future)
81
-
82
  # Collect results
83
  word_results = await asyncio.gather(*futures)
84
-
85
  reference_phonemes_list = []
86
  reference_ipa_list = []
87
-
88
  for word_data in word_results:
89
  if word_data and len(word_data) > 0:
90
  reference_phonemes_list.append(word_data[0]["phoneme_string"])
91
  reference_ipa_list.append(word_data[0]["ipa"])
92
-
93
  result["reference_phonemes"] = " ".join(reference_phonemes_list)
94
  result["reference_ipa"] = " ".join(reference_ipa_list)
95
-
96
  async def process_user_ipa():
97
  """Xử lý user IPA từ transcript song song"""
98
  if "transcript" not in result or not result["transcript"]:
99
  result["user_ipa"] = None
100
  return
101
-
102
  try:
103
  user_transcript = result["transcript"].strip()
104
  user_words = user_transcript.split()
105
-
106
  if not user_words:
107
  result["user_ipa"] = None
108
  return
109
-
110
  loop = asyncio.get_event_loop()
111
  executor = get_shared_executor()
112
  # Chạy song song cho từng word
113
  futures = []
114
  clean_words = []
115
-
116
  for word in user_words:
117
- clean_word = word.strip(".,!?;:").lower()
118
  if clean_word: # Skip empty words
119
  clean_words.append(clean_word)
120
- future = loop.run_in_executor(
121
- executor, safe_get_word_ipa, g2p, clean_word
122
- )
123
  futures.append(future)
124
-
125
  # Collect results
126
  if futures:
127
  user_ipa_results = await asyncio.gather(*futures)
@@ -129,17 +94,17 @@ async def optimize_post_assessment_processing(
129
  result["user_ipa"] = " ".join(user_ipa_list) if user_ipa_list else None
130
  else:
131
  result["user_ipa"] = None
132
-
133
- logger.info(
134
- f"Generated user IPA from transcript '{user_transcript}': '{result.get('user_ipa', 'None')}'"
135
- )
136
-
137
  except Exception as e:
138
  logger.warning(f"Failed to generate user IPA from transcript: {e}")
139
- result["user_ipa"] = None # Chạy song song cả 2 task chính
140
-
141
- await asyncio.gather(process_reference_phonemes_and_ipa(), process_user_ipa())
142
-
 
 
143
  optimization_time = time.time() - start_time
144
  logger.info(f"Post-assessment optimization completed in {optimization_time:.3f}s")
145
 
@@ -165,7 +130,6 @@ def safe_get_word_ipa(g2p: EnhancedG2P, word: str) -> Optional[str]:
165
  _shared_g2p_cache = {}
166
  _cache_lock = asyncio.Lock()
167
 
168
-
169
  async def get_cached_g2p_result(word: str) -> Optional[Dict]:
170
  """
171
  Cache G2P results để tránh tính toán lại cho các từ đã xử lý
@@ -175,7 +139,6 @@ async def get_cached_g2p_result(word: str) -> Optional[Dict]:
175
  return _shared_g2p_cache[word]
176
  return None
177
 
178
-
179
  async def cache_g2p_result(word: str, result: Dict) -> None:
180
  """
181
  Cache G2P result với size limit
@@ -187,29 +150,29 @@ async def cache_g2p_result(word: str, result: Dict) -> None:
187
  oldest_keys = list(_shared_g2p_cache.keys())[:100]
188
  for key in oldest_keys:
189
  del _shared_g2p_cache[key]
190
-
191
  _shared_g2p_cache[word] = result
192
 
193
 
194
  async def optimize_ipa_assessment_processing(
195
- base_result: Dict,
196
- target_word: str,
197
- target_ipa: Optional[str],
198
- focus_phonemes: Optional[str],
199
  ) -> Dict:
200
  """
201
  Tối ưu hóa xử lý IPA assessment bằng cách chạy song song các task
202
  """
203
  start_time = time.time()
204
-
205
  # Shared G2P instance
206
  g2p = get_shared_g2p()
207
-
208
  # Parse focus phonemes trước
209
  focus_phonemes_list = []
210
  if focus_phonemes:
211
  focus_phonemes_list = [p.strip() for p in focus_phonemes.split(",")]
212
-
213
  async def get_target_phonemes_data():
214
  """Get target IPA and phonemes"""
215
  if not target_ipa:
@@ -223,15 +186,13 @@ async def optimize_ipa_assessment_processing(
223
  # Parse provided IPA
224
  clean_ipa = target_ipa.replace("/", "").strip()
225
  return target_ipa, list(clean_ipa)
226
-
227
- async def create_character_analysis(
228
- final_target_ipa: str, target_phonemes: List[str]
229
- ):
230
  """Create character analysis optimized"""
231
  character_analysis = []
232
  target_chars = list(target_word)
233
  target_phoneme_chars = list(final_target_ipa.replace("/", ""))
234
-
235
  # Pre-calculate phoneme scores mapping
236
  phoneme_score_map = {}
237
  if base_result.get("phoneme_differences"):
@@ -239,37 +200,28 @@ async def optimize_ipa_assessment_processing(
239
  ref_phoneme = phoneme_diff.get("reference_phoneme")
240
  if ref_phoneme:
241
  phoneme_score_map[ref_phoneme] = phoneme_diff.get("score", 0.0)
242
-
243
  for i, char in enumerate(target_chars):
244
- char_phoneme = (
245
- target_phoneme_chars[i] if i < len(target_phoneme_chars) else ""
246
- )
247
- char_score = phoneme_score_map.get(
248
- char_phoneme, base_result.get("overall_score", 0.0)
249
- )
250
-
251
- color_class = (
252
- "text-green-600"
253
- if char_score > 0.8
254
- else "text-yellow-600" if char_score > 0.6 else "text-red-600"
255
- )
256
-
257
- character_analysis.append(
258
- {
259
- "character": char,
260
- "phoneme": char_phoneme,
261
- "score": float(char_score),
262
- "color_class": color_class,
263
- "is_focus": char_phoneme in focus_phonemes_list,
264
- }
265
- )
266
-
267
  return character_analysis
268
-
269
  async def create_phoneme_scores(target_phonemes: List[str]):
270
  """Create phoneme scores optimized"""
271
  phoneme_scores = []
272
-
273
  # Pre-calculate phoneme scores mapping
274
  phoneme_score_map = {}
275
  if base_result.get("phoneme_differences"):
@@ -277,38 +229,28 @@ async def optimize_ipa_assessment_processing(
277
  ref_phoneme = phoneme_diff.get("reference_phoneme")
278
  if ref_phoneme:
279
  phoneme_score_map[ref_phoneme] = phoneme_diff.get("score", 0.0)
280
-
281
  for phoneme in target_phonemes:
282
- phoneme_score = phoneme_score_map.get(
283
- phoneme, base_result.get("overall_score", 0.0)
284
- )
285
-
286
- color_class = (
287
- "bg-green-100 text-green-800"
288
- if phoneme_score > 0.8
289
- else (
290
- "bg-yellow-100 text-yellow-800"
291
- if phoneme_score > 0.6
292
- else "bg-red-100 text-red-800"
293
- )
294
- )
295
-
296
- phoneme_scores.append(
297
- {
298
- "phoneme": phoneme,
299
- "score": float(phoneme_score),
300
- "color_class": color_class,
301
- "percentage": int(phoneme_score * 100),
302
- "is_focus": phoneme in focus_phonemes_list,
303
- }
304
- )
305
-
306
  return phoneme_scores
307
-
308
  async def create_focus_analysis():
309
  """Create focus phonemes analysis optimized"""
310
  focus_phonemes_analysis = []
311
-
312
  # Pre-calculate phoneme scores mapping
313
  phoneme_score_map = {}
314
  if base_result.get("phoneme_differences"):
@@ -316,42 +258,34 @@ async def optimize_ipa_assessment_processing(
316
  ref_phoneme = phoneme_diff.get("reference_phoneme")
317
  if ref_phoneme:
318
  phoneme_score_map[ref_phoneme] = phoneme_diff.get("score", 0.0)
319
-
320
  for focus_phoneme in focus_phonemes_list:
321
- score = phoneme_score_map.get(
322
- focus_phoneme, base_result.get("overall_score", 0.0)
323
- )
324
-
325
  phoneme_analysis = {
326
  "phoneme": focus_phoneme,
327
  "score": float(score),
328
  "status": "correct" if score > 0.8 else "incorrect",
329
  "vietnamese_tip": get_vietnamese_tip(focus_phoneme),
330
  "difficulty": "medium",
331
- "color_class": (
332
- "bg-green-100 text-green-800"
333
- if score > 0.8
334
- else (
335
- "bg-yellow-100 text-yellow-800"
336
- if score > 0.6
337
- else "bg-red-100 text-red-800"
338
- )
339
- ),
340
  }
341
  focus_phonemes_analysis.append(phoneme_analysis)
342
-
343
  return focus_phonemes_analysis
344
-
345
  # Get target phonemes data first
346
  final_target_ipa, target_phonemes = await get_target_phonemes_data()
347
-
348
  # Run parallel processing for analysis
349
  character_analysis, phoneme_scores, focus_phonemes_analysis = await asyncio.gather(
350
  create_character_analysis(final_target_ipa, target_phonemes),
351
  create_phoneme_scores(target_phonemes),
352
- create_focus_analysis(),
353
  )
354
-
355
  # Generate tips and recommendations asynchronously
356
  loop = asyncio.get_event_loop()
357
  executor = get_shared_executor()
@@ -359,74 +293,64 @@ async def optimize_ipa_assessment_processing(
359
  executor, generate_vietnamese_tips, target_phonemes, focus_phonemes_list
360
  )
361
  practice_recommendations_future = loop.run_in_executor(
362
- executor,
363
- generate_practice_recommendations,
364
- base_result.get("overall_score", 0.0),
365
- focus_phonemes_analysis,
366
  )
367
-
368
  vietnamese_tips, practice_recommendations = await asyncio.gather(
369
- vietnamese_tips_future, practice_recommendations_future
 
370
  )
371
-
372
  optimization_time = time.time() - start_time
373
  logger.info(f"IPA assessment optimization completed in {optimization_time:.3f}s")
374
-
375
  return {
376
  "target_ipa": final_target_ipa,
377
  "character_analysis": character_analysis,
378
  "phoneme_scores": phoneme_scores,
379
  "focus_phonemes_analysis": focus_phonemes_analysis,
380
  "vietnamese_tips": vietnamese_tips,
381
- "practice_recommendations": practice_recommendations,
382
  }
383
 
384
 
385
- def generate_vietnamese_tips(
386
- target_phonemes: List[str], focus_phonemes_list: List[str]
387
- ) -> List[str]:
388
  """Generate Vietnamese tips for difficult phonemes"""
389
  vietnamese_tips = []
390
  difficult_phonemes = ["θ", "ð", "v", "z", "ʒ", "r", "w", "æ", "ɪ", "ʊ", "ɛ"]
391
-
392
  for phoneme in set(target_phonemes + focus_phonemes_list):
393
  if phoneme in difficult_phonemes:
394
  tip = get_vietnamese_tip(phoneme)
395
  if tip not in vietnamese_tips:
396
  vietnamese_tips.append(tip)
397
-
398
  return vietnamese_tips
399
 
400
 
401
- def generate_practice_recommendations(
402
- overall_score: float, focus_phonemes_analysis: List[Dict]
403
- ) -> List[str]:
404
  """Generate practice recommendations based on score"""
405
  practice_recommendations = []
406
-
407
  if overall_score < 0.7:
408
- practice_recommendations.extend(
409
- [
410
- "Nghe từ mẫu nhiều lần trước khi phát âm",
411
- "Phát âm chậm ràng từng âm vị",
412
- "Chú ý đến vị trí lưỡi và môi khi phát âm",
413
- ]
414
- )
415
-
416
  # Add specific recommendations for focus phonemes
417
  for analysis in focus_phonemes_analysis:
418
  if analysis["score"] < 0.6:
419
  practice_recommendations.append(
420
  f"Luyện đặc biệt âm /{analysis['phoneme']}/: {analysis['vietnamese_tip']}"
421
  )
422
-
423
  if overall_score >= 0.8:
424
- practice_recommendations.append(
425
- "Phát âm rất tốt! Tiếp tục luyện tập để duy trì chất lượng"
426
- )
427
  elif overall_score >= 0.6:
428
  practice_recommendations.append("Phát âm khá tốt, cần cải thiện một số âm vị")
429
-
430
  return practice_recommendations
431
 
432
 
@@ -459,73 +383,41 @@ class PronunciationAssessmentResult(BaseModel):
459
 
460
  class IPAAssessmentResult(BaseModel):
461
  """Optimized response model for IPA-focused pronunciation assessment"""
462
-
463
  # Core assessment data
464
  transcript: str # What the user actually said
465
  user_ipa: Optional[str] = None # User's IPA transcription
466
  target_word: str # Target word being assessed
467
  target_ipa: str # Target IPA transcription
468
  overall_score: float # Overall pronunciation score (0-1)
469
-
470
  # Character-level analysis for IPA mapping
471
  character_analysis: List[Dict] # Each character with its IPA and score
472
-
473
  # Phoneme-specific analysis
474
  phoneme_scores: List[Dict] # Individual phoneme scores with colors
475
  focus_phonemes_analysis: List[Dict] # Detailed analysis of target phonemes
476
-
477
  # Feedback and recommendations
478
  vietnamese_tips: List[str] # Vietnamese-specific pronunciation tips
479
  practice_recommendations: List[str] # Practice suggestions
480
  feedback: List[str] # General feedback messages
481
-
482
  # Assessment metadata
483
  processing_info: Dict # Processing details
484
  assessment_type: str = "ipa_focused"
485
  error: Optional[str] = None
486
 
487
-
488
  # Global assessor instance - singleton pattern for performance
489
  global_assessor = None
490
  global_g2p = None # Shared G2P instance for caching
491
  global_executor = None # Shared ThreadPoolExecutor
492
 
493
-
494
- def preload_whisper_model(whisper_model: str = "base.en"):
495
- """
496
- Preload Whisper model during FastAPI startup for faster first inference
497
- Call this function in your FastAPI startup event
498
- """
499
- global global_assessor
500
- try:
501
- logger.info(f"🚀 Preloading Whisper model '{whisper_model}' during startup...")
502
- start_time = time.time()
503
-
504
- # Force create the assessor instance which will load Whisper
505
- global_assessor = ProductionPronunciationAssessor(whisper_model=whisper_model)
506
-
507
- # Also preload G2P and executor
508
- get_shared_g2p()
509
- get_shared_executor()
510
-
511
- load_time = time.time() - start_time
512
- logger.info(f"✅ Whisper model '{whisper_model}' preloaded successfully in {load_time:.2f}s")
513
- logger.info("🎯 First inference will be much faster now!")
514
-
515
- return True
516
- except Exception as e:
517
- logger.error(f"❌ Failed to preload Whisper model: {e}")
518
- return False
519
-
520
-
521
  def get_assessor():
522
- """Get or create the global assessor instance with Whisper preloaded"""
523
  global global_assessor
524
  if global_assessor is None:
525
- logger.info("Creating global ProductionPronunciationAssessor instance with Whisper...")
526
- # Load Whisper model base.en by default for optimal performance
527
- global_assessor = ProductionPronunciationAssessor(whisper_model="base.en")
528
- logger.info("✅ Global Whisper assessor loaded and ready!")
529
  return global_assessor
530
 
531
 
@@ -614,7 +506,7 @@ async def assess_pronunciation(
614
  # Run assessment using enhanced assessor (singleton)
615
  assessor = get_assessor()
616
  result = assessor.assess_pronunciation(tmp_file.name, reference_text, mode)
617
-
618
  # Optimize post-processing with parallel execution
619
  await optimize_post_assessment_processing(result, reference_text)
620
 
@@ -644,69 +536,58 @@ async def assess_ipa_pronunciation(
644
  audio_file: UploadFile = File(..., description="Audio file (.wav, .mp3, .m4a)"),
645
  target_word: str = Form(..., description="Target word to assess (e.g., 'bed')"),
646
  target_ipa: str = Form(None, description="Target IPA notation (e.g., '/bɛd/')"),
647
- focus_phonemes: str = Form(
648
- None, description="Comma-separated focus phonemes (e.g., 'ɛ,b')"
649
- ),
650
  ):
651
  """
652
  Optimized IPA pronunciation assessment for phoneme-focused learning
653
-
654
  Evaluates:
655
  - Overall word pronunciation accuracy
656
- - Character-to-phoneme mapping accuracy
657
  - Specific phoneme pronunciation (e.g., /ɛ/ in 'bed')
658
  - Vietnamese-optimized feedback and tips
659
  - Dynamic color scoring for UI visualization
660
-
661
  Example: Assessing 'bed' /bɛd/ with focus on /ɛ/ phoneme
662
  """
663
-
664
  import time
665
-
666
  start_time = time.time()
667
-
668
  # Validate inputs
669
  if not target_word.strip():
670
  raise HTTPException(status_code=400, detail="Target word cannot be empty")
671
-
672
  if len(target_word) > 50:
673
- raise HTTPException(
674
- status_code=400, detail="Target word too long (max 50 characters)"
675
- )
676
-
677
  # Clean target word
678
  target_word = target_word.strip().lower()
679
-
680
  try:
681
  # Save uploaded file temporarily
682
  file_extension = ".wav"
683
  if audio_file.filename and "." in audio_file.filename:
684
  file_extension = f".{audio_file.filename.split('.')[-1]}"
685
 
686
- with tempfile.NamedTemporaryFile(
687
- delete=False, suffix=file_extension
688
- ) as tmp_file:
689
  content = await audio_file.read()
690
  tmp_file.write(content)
691
  tmp_file.flush()
692
 
693
- logger.info(
694
- f"IPA assessment for word '{target_word}' with IPA '{target_ipa}'"
695
- )
696
 
697
  # Get the assessor instance
698
  assessor = get_assessor()
699
-
700
  # Run base pronunciation assessment in word mode
701
- base_result = assessor.assess_pronunciation(
702
- tmp_file.name, target_word, "word"
703
- )
704
-
705
  # Optimize IPA assessment processing with parallel execution
706
  optimized_results = await optimize_ipa_assessment_processing(
707
  base_result, target_word, target_ipa, focus_phonemes
708
  )
709
-
710
  # Extract optimized results
711
  target_ipa = optimized_results["target_ipa"]
712
  character_analysis = optimized_results["character_analysis"]
@@ -714,30 +595,28 @@ async def assess_ipa_pronunciation(
714
  focus_phonemes_analysis = optimized_results["focus_phonemes_analysis"]
715
  vietnamese_tips = optimized_results["vietnamese_tips"]
716
  practice_recommendations = optimized_results["practice_recommendations"]
717
-
718
  # Get overall score from base result
719
  overall_score = base_result.get("overall_score", 0.0)
720
-
721
  # Handle error cases
722
  error_message = None
723
  feedback = base_result.get("feedback", [])
724
-
725
  if base_result.get("error"):
726
  error_message = base_result["error"]
727
  feedback = [f"Lỗi: {error_message}"]
728
-
729
  # Processing information
730
  processing_time = time.time() - start_time
731
  processing_info = {
732
  "processing_time": processing_time,
733
  "mode": "ipa_focused",
734
  "model_used": "Wav2Vec2-Enhanced",
735
- "confidence": base_result.get("processing_info", {}).get(
736
- "confidence", 0.0
737
- ),
738
- "enhanced_features": True,
739
  }
740
-
741
  # Create final result
742
  result = IPAAssessmentResult(
743
  transcript=base_result.get("transcript", ""),
@@ -752,19 +631,16 @@ async def assess_ipa_pronunciation(
752
  practice_recommendations=practice_recommendations,
753
  feedback=feedback,
754
  processing_info=processing_info,
755
- error=error_message,
756
  )
757
-
758
- logger.info(
759
- f"IPA assessment completed for '{target_word}' in {processing_time:.2f}s with score {overall_score:.2f}"
760
- )
761
-
762
  return result
763
 
764
  except Exception as e:
765
  logger.error(f"IPA assessment error: {str(e)}")
766
  import traceback
767
-
768
  traceback.print_exc()
769
  raise HTTPException(status_code=500, detail=f"IPA assessment failed: {str(e)}")
770
 
@@ -778,13 +654,14 @@ async def assess_ipa_pronunciation(
778
  def get_word_phonemes(word: str):
779
  """Get phoneme breakdown for a specific word"""
780
  try:
781
- # Use the shared G2P instance for consistency
782
- g2p = get_shared_g2p()
 
783
  phoneme_data = g2p.text_to_phonemes(word)[0]
784
 
785
  # Add difficulty analysis for Vietnamese speakers
786
  difficulty_scores = []
787
-
788
  for phoneme in phoneme_data["phonemes"]:
789
  difficulty = g2p.get_difficulty_score(phoneme)
790
  difficulty_scores.append(difficulty)
@@ -841,7 +718,7 @@ def get_vietnamese_tip(phoneme: str) -> str:
841
  "d": "Lưỡi chạm nướu răng trên, rung dây thanh",
842
  "t": "Lưỡi chạm nướu răng trên, không rung dây thanh",
843
  "k": "Lưỡi chạm vòm miệng, không rung dây thanh",
844
- "g": "Lưỡi chạm vòm miệng, rung dây thanh",
845
  }
846
  return tips.get(phoneme, f"Luyện tập phát âm /{phoneme}/")
847
 
@@ -850,10 +727,10 @@ def get_phoneme_difficulty(phoneme: str) -> str:
850
  """Get difficulty level for Vietnamese speakers"""
851
  hard_phonemes = ["θ", "ð", "r", "w", "æ", "ʌ", "ɪ", "ʊ"]
852
  medium_phonemes = ["v", "z", "ʒ", "ɛ", "ə", "ɔ", "f"]
853
-
854
  if phoneme in hard_phonemes:
855
  return "hard"
856
  elif phoneme in medium_phonemes:
857
  return "medium"
858
  else:
859
- return "easy"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import UploadFile, File, Form, HTTPException, APIRouter
2
  from pydantic import BaseModel
3
  from typing import List, Dict, Optional
 
12
  from src.utils.speaking_utils import convert_numpy_types
13
 
14
  # Import the new evaluation system
15
+ from src.apis.controllers.speaking_controller import ProductionPronunciationAssessor, EnhancedG2P
 
 
 
 
16
  warnings.filterwarnings("ignore")
17
 
18
  router = APIRouter(prefix="/speaking", tags=["Speaking"])
19
 
 
 
 
20
 
21
  # =============================================================================
22
  # OPTIMIZATION FUNCTIONS
23
  # =============================================================================
24
 
25
+ async def optimize_post_assessment_processing(result: Dict, reference_text: str) -> None:
 
 
 
26
  """
27
  Tối ưu hóa xử lý sau assessment bằng cách chạy song song các task độc lập
28
  Giảm thời gian xử lý từ ~0.3-0.5s xuống ~0.1-0.2s
29
  """
30
  start_time = time.time()
31
+
32
  # Tạo shared G2P instance để tránh tạo mới nhiều lần
33
  g2p = get_shared_g2p()
34
+
35
  # Định nghĩa các task có thể chạy song song
36
  async def process_reference_phonemes_and_ipa():
37
  """Xử lý reference phonemes và IPA song song"""
38
  loop = asyncio.get_event_loop()
39
  executor = get_shared_executor()
40
  reference_words = reference_text.strip().split()
41
+
42
  # Chạy song song cho từng word
43
  futures = []
44
  for word in reference_words:
45
+ clean_word = word.strip('.,!?;:')
46
  future = loop.run_in_executor(executor, g2p.text_to_phonemes, clean_word)
47
  futures.append(future)
48
+
49
  # Collect results
50
  word_results = await asyncio.gather(*futures)
51
+
52
  reference_phonemes_list = []
53
  reference_ipa_list = []
54
+
55
  for word_data in word_results:
56
  if word_data and len(word_data) > 0:
57
  reference_phonemes_list.append(word_data[0]["phoneme_string"])
58
  reference_ipa_list.append(word_data[0]["ipa"])
59
+
60
  result["reference_phonemes"] = " ".join(reference_phonemes_list)
61
  result["reference_ipa"] = " ".join(reference_ipa_list)
62
+
63
  async def process_user_ipa():
64
  """Xử lý user IPA từ transcript song song"""
65
  if "transcript" not in result or not result["transcript"]:
66
  result["user_ipa"] = None
67
  return
68
+
69
  try:
70
  user_transcript = result["transcript"].strip()
71
  user_words = user_transcript.split()
72
+
73
  if not user_words:
74
  result["user_ipa"] = None
75
  return
76
+
77
  loop = asyncio.get_event_loop()
78
  executor = get_shared_executor()
79
  # Chạy song song cho từng word
80
  futures = []
81
  clean_words = []
82
+
83
  for word in user_words:
84
+ clean_word = word.strip('.,!?;:').lower()
85
  if clean_word: # Skip empty words
86
  clean_words.append(clean_word)
87
+ future = loop.run_in_executor(executor, safe_get_word_ipa, g2p, clean_word)
 
 
88
  futures.append(future)
89
+
90
  # Collect results
91
  if futures:
92
  user_ipa_results = await asyncio.gather(*futures)
 
94
  result["user_ipa"] = " ".join(user_ipa_list) if user_ipa_list else None
95
  else:
96
  result["user_ipa"] = None
97
+
98
+ logger.info(f"Generated user IPA from transcript '{user_transcript}': '{result.get('user_ipa', 'None')}'")
99
+
 
 
100
  except Exception as e:
101
  logger.warning(f"Failed to generate user IPA from transcript: {e}")
102
+ result["user_ipa"] = None # Chạy song song cả 2 task chính
103
+ await asyncio.gather(
104
+ process_reference_phonemes_and_ipa(),
105
+ process_user_ipa()
106
+ )
107
+
108
  optimization_time = time.time() - start_time
109
  logger.info(f"Post-assessment optimization completed in {optimization_time:.3f}s")
110
 
 
130
  _shared_g2p_cache = {}
131
  _cache_lock = asyncio.Lock()
132
 
 
133
  async def get_cached_g2p_result(word: str) -> Optional[Dict]:
134
  """
135
  Cache G2P results để tránh tính toán lại cho các từ đã xử lý
 
139
  return _shared_g2p_cache[word]
140
  return None
141
 
 
142
  async def cache_g2p_result(word: str, result: Dict) -> None:
143
  """
144
  Cache G2P result với size limit
 
150
  oldest_keys = list(_shared_g2p_cache.keys())[:100]
151
  for key in oldest_keys:
152
  del _shared_g2p_cache[key]
153
+
154
  _shared_g2p_cache[word] = result
155
 
156
 
157
  async def optimize_ipa_assessment_processing(
158
+ base_result: Dict,
159
+ target_word: str,
160
+ target_ipa: Optional[str],
161
+ focus_phonemes: Optional[str]
162
  ) -> Dict:
163
  """
164
  Tối ưu hóa xử lý IPA assessment bằng cách chạy song song các task
165
  """
166
  start_time = time.time()
167
+
168
  # Shared G2P instance
169
  g2p = get_shared_g2p()
170
+
171
  # Parse focus phonemes trước
172
  focus_phonemes_list = []
173
  if focus_phonemes:
174
  focus_phonemes_list = [p.strip() for p in focus_phonemes.split(",")]
175
+
176
  async def get_target_phonemes_data():
177
  """Get target IPA and phonemes"""
178
  if not target_ipa:
 
186
  # Parse provided IPA
187
  clean_ipa = target_ipa.replace("/", "").strip()
188
  return target_ipa, list(clean_ipa)
189
+
190
+ async def create_character_analysis(final_target_ipa: str, target_phonemes: List[str]):
 
 
191
  """Create character analysis optimized"""
192
  character_analysis = []
193
  target_chars = list(target_word)
194
  target_phoneme_chars = list(final_target_ipa.replace("/", ""))
195
+
196
  # Pre-calculate phoneme scores mapping
197
  phoneme_score_map = {}
198
  if base_result.get("phoneme_differences"):
 
200
  ref_phoneme = phoneme_diff.get("reference_phoneme")
201
  if ref_phoneme:
202
  phoneme_score_map[ref_phoneme] = phoneme_diff.get("score", 0.0)
203
+
204
  for i, char in enumerate(target_chars):
205
+ char_phoneme = target_phoneme_chars[i] if i < len(target_phoneme_chars) else ""
206
+ char_score = phoneme_score_map.get(char_phoneme, base_result.get("overall_score", 0.0))
207
+
208
+ color_class = ("text-green-600" if char_score > 0.8 else
209
+ "text-yellow-600" if char_score > 0.6 else "text-red-600")
210
+
211
+ character_analysis.append({
212
+ "character": char,
213
+ "phoneme": char_phoneme,
214
+ "score": float(char_score),
215
+ "color_class": color_class,
216
+ "is_focus": char_phoneme in focus_phonemes_list
217
+ })
218
+
 
 
 
 
 
 
 
 
 
219
  return character_analysis
220
+
221
  async def create_phoneme_scores(target_phonemes: List[str]):
222
  """Create phoneme scores optimized"""
223
  phoneme_scores = []
224
+
225
  # Pre-calculate phoneme scores mapping
226
  phoneme_score_map = {}
227
  if base_result.get("phoneme_differences"):
 
229
  ref_phoneme = phoneme_diff.get("reference_phoneme")
230
  if ref_phoneme:
231
  phoneme_score_map[ref_phoneme] = phoneme_diff.get("score", 0.0)
232
+
233
  for phoneme in target_phonemes:
234
+ phoneme_score = phoneme_score_map.get(phoneme, base_result.get("overall_score", 0.0))
235
+
236
+ color_class = ("bg-green-100 text-green-800" if phoneme_score > 0.8 else
237
+ "bg-yellow-100 text-yellow-800" if phoneme_score > 0.6 else
238
+ "bg-red-100 text-red-800")
239
+
240
+ phoneme_scores.append({
241
+ "phoneme": phoneme,
242
+ "score": float(phoneme_score),
243
+ "color_class": color_class,
244
+ "percentage": int(phoneme_score * 100),
245
+ "is_focus": phoneme in focus_phonemes_list
246
+ })
247
+
 
 
 
 
 
 
 
 
 
 
248
  return phoneme_scores
249
+
250
  async def create_focus_analysis():
251
  """Create focus phonemes analysis optimized"""
252
  focus_phonemes_analysis = []
253
+
254
  # Pre-calculate phoneme scores mapping
255
  phoneme_score_map = {}
256
  if base_result.get("phoneme_differences"):
 
258
  ref_phoneme = phoneme_diff.get("reference_phoneme")
259
  if ref_phoneme:
260
  phoneme_score_map[ref_phoneme] = phoneme_diff.get("score", 0.0)
261
+
262
  for focus_phoneme in focus_phonemes_list:
263
+ score = phoneme_score_map.get(focus_phoneme, base_result.get("overall_score", 0.0))
264
+
 
 
265
  phoneme_analysis = {
266
  "phoneme": focus_phoneme,
267
  "score": float(score),
268
  "status": "correct" if score > 0.8 else "incorrect",
269
  "vietnamese_tip": get_vietnamese_tip(focus_phoneme),
270
  "difficulty": "medium",
271
+ "color_class": ("bg-green-100 text-green-800" if score > 0.8 else
272
+ "bg-yellow-100 text-yellow-800" if score > 0.6 else
273
+ "bg-red-100 text-red-800")
 
 
 
 
 
 
274
  }
275
  focus_phonemes_analysis.append(phoneme_analysis)
276
+
277
  return focus_phonemes_analysis
278
+
279
  # Get target phonemes data first
280
  final_target_ipa, target_phonemes = await get_target_phonemes_data()
281
+
282
  # Run parallel processing for analysis
283
  character_analysis, phoneme_scores, focus_phonemes_analysis = await asyncio.gather(
284
  create_character_analysis(final_target_ipa, target_phonemes),
285
  create_phoneme_scores(target_phonemes),
286
+ create_focus_analysis()
287
  )
288
+
289
  # Generate tips and recommendations asynchronously
290
  loop = asyncio.get_event_loop()
291
  executor = get_shared_executor()
 
293
  executor, generate_vietnamese_tips, target_phonemes, focus_phonemes_list
294
  )
295
  practice_recommendations_future = loop.run_in_executor(
296
+ executor, generate_practice_recommendations, base_result.get("overall_score", 0.0), focus_phonemes_analysis
 
 
 
297
  )
298
+
299
  vietnamese_tips, practice_recommendations = await asyncio.gather(
300
+ vietnamese_tips_future,
301
+ practice_recommendations_future
302
  )
303
+
304
  optimization_time = time.time() - start_time
305
  logger.info(f"IPA assessment optimization completed in {optimization_time:.3f}s")
306
+
307
  return {
308
  "target_ipa": final_target_ipa,
309
  "character_analysis": character_analysis,
310
  "phoneme_scores": phoneme_scores,
311
  "focus_phonemes_analysis": focus_phonemes_analysis,
312
  "vietnamese_tips": vietnamese_tips,
313
+ "practice_recommendations": practice_recommendations
314
  }
315
 
316
 
317
+ def generate_vietnamese_tips(target_phonemes: List[str], focus_phonemes_list: List[str]) -> List[str]:
 
 
318
  """Generate Vietnamese tips for difficult phonemes"""
319
  vietnamese_tips = []
320
  difficult_phonemes = ["θ", "ð", "v", "z", "ʒ", "r", "w", "æ", "ɪ", "ʊ", "ɛ"]
321
+
322
  for phoneme in set(target_phonemes + focus_phonemes_list):
323
  if phoneme in difficult_phonemes:
324
  tip = get_vietnamese_tip(phoneme)
325
  if tip not in vietnamese_tips:
326
  vietnamese_tips.append(tip)
327
+
328
  return vietnamese_tips
329
 
330
 
331
+ def generate_practice_recommendations(overall_score: float, focus_phonemes_analysis: List[Dict]) -> List[str]:
 
 
332
  """Generate practice recommendations based on score"""
333
  practice_recommendations = []
334
+
335
  if overall_score < 0.7:
336
+ practice_recommendations.extend([
337
+ "Nghe từ mẫu nhiều lần trước khi phát âm",
338
+ "Phát âm chậm ràng từng âm vị",
339
+ "Chú ý đến vị trí lưỡi môi khi phát âm"
340
+ ])
341
+
 
 
342
  # Add specific recommendations for focus phonemes
343
  for analysis in focus_phonemes_analysis:
344
  if analysis["score"] < 0.6:
345
  practice_recommendations.append(
346
  f"Luyện đặc biệt âm /{analysis['phoneme']}/: {analysis['vietnamese_tip']}"
347
  )
348
+
349
  if overall_score >= 0.8:
350
+ practice_recommendations.append("Phát âm rất tốt! Tiếp tục luyện tập để duy trì chất lượng")
 
 
351
  elif overall_score >= 0.6:
352
  practice_recommendations.append("Phát âm khá tốt, cần cải thiện một số âm vị")
353
+
354
  return practice_recommendations
355
 
356
 
 
383
 
384
  class IPAAssessmentResult(BaseModel):
385
  """Optimized response model for IPA-focused pronunciation assessment"""
 
386
  # Core assessment data
387
  transcript: str # What the user actually said
388
  user_ipa: Optional[str] = None # User's IPA transcription
389
  target_word: str # Target word being assessed
390
  target_ipa: str # Target IPA transcription
391
  overall_score: float # Overall pronunciation score (0-1)
392
+
393
  # Character-level analysis for IPA mapping
394
  character_analysis: List[Dict] # Each character with its IPA and score
395
+
396
  # Phoneme-specific analysis
397
  phoneme_scores: List[Dict] # Individual phoneme scores with colors
398
  focus_phonemes_analysis: List[Dict] # Detailed analysis of target phonemes
399
+
400
  # Feedback and recommendations
401
  vietnamese_tips: List[str] # Vietnamese-specific pronunciation tips
402
  practice_recommendations: List[str] # Practice suggestions
403
  feedback: List[str] # General feedback messages
404
+
405
  # Assessment metadata
406
  processing_info: Dict # Processing details
407
  assessment_type: str = "ipa_focused"
408
  error: Optional[str] = None
409
 
 
410
  # Global assessor instance - singleton pattern for performance
411
  global_assessor = None
412
  global_g2p = None # Shared G2P instance for caching
413
  global_executor = None # Shared ThreadPoolExecutor
414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  def get_assessor():
416
+ """Get or create the global assessor instance"""
417
  global global_assessor
418
  if global_assessor is None:
419
+ logger.info("Creating global ProductionPronunciationAssessor instance...")
420
+ global_assessor = ProductionPronunciationAssessor()
 
 
421
  return global_assessor
422
 
423
 
 
506
  # Run assessment using enhanced assessor (singleton)
507
  assessor = get_assessor()
508
  result = assessor.assess_pronunciation(tmp_file.name, reference_text, mode)
509
+
510
  # Optimize post-processing with parallel execution
511
  await optimize_post_assessment_processing(result, reference_text)
512
 
 
536
  audio_file: UploadFile = File(..., description="Audio file (.wav, .mp3, .m4a)"),
537
  target_word: str = Form(..., description="Target word to assess (e.g., 'bed')"),
538
  target_ipa: str = Form(None, description="Target IPA notation (e.g., '/bɛd/')"),
539
+ focus_phonemes: str = Form(None, description="Comma-separated focus phonemes (e.g., 'ɛ,b')"),
 
 
540
  ):
541
  """
542
  Optimized IPA pronunciation assessment for phoneme-focused learning
543
+
544
  Evaluates:
545
  - Overall word pronunciation accuracy
546
+ - Character-to-phoneme mapping accuracy
547
  - Specific phoneme pronunciation (e.g., /ɛ/ in 'bed')
548
  - Vietnamese-optimized feedback and tips
549
  - Dynamic color scoring for UI visualization
550
+
551
  Example: Assessing 'bed' /bɛd/ with focus on /ɛ/ phoneme
552
  """
553
+
554
  import time
 
555
  start_time = time.time()
556
+
557
  # Validate inputs
558
  if not target_word.strip():
559
  raise HTTPException(status_code=400, detail="Target word cannot be empty")
560
+
561
  if len(target_word) > 50:
562
+ raise HTTPException(status_code=400, detail="Target word too long (max 50 characters)")
563
+
 
 
564
  # Clean target word
565
  target_word = target_word.strip().lower()
566
+
567
  try:
568
  # Save uploaded file temporarily
569
  file_extension = ".wav"
570
  if audio_file.filename and "." in audio_file.filename:
571
  file_extension = f".{audio_file.filename.split('.')[-1]}"
572
 
573
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
 
 
574
  content = await audio_file.read()
575
  tmp_file.write(content)
576
  tmp_file.flush()
577
 
578
+ logger.info(f"IPA assessment for word '{target_word}' with IPA '{target_ipa}'")
 
 
579
 
580
  # Get the assessor instance
581
  assessor = get_assessor()
582
+
583
  # Run base pronunciation assessment in word mode
584
+ base_result = assessor.assess_pronunciation(tmp_file.name, target_word, "word")
585
+
 
 
586
  # Optimize IPA assessment processing with parallel execution
587
  optimized_results = await optimize_ipa_assessment_processing(
588
  base_result, target_word, target_ipa, focus_phonemes
589
  )
590
+
591
  # Extract optimized results
592
  target_ipa = optimized_results["target_ipa"]
593
  character_analysis = optimized_results["character_analysis"]
 
595
  focus_phonemes_analysis = optimized_results["focus_phonemes_analysis"]
596
  vietnamese_tips = optimized_results["vietnamese_tips"]
597
  practice_recommendations = optimized_results["practice_recommendations"]
598
+
599
  # Get overall score from base result
600
  overall_score = base_result.get("overall_score", 0.0)
601
+
602
  # Handle error cases
603
  error_message = None
604
  feedback = base_result.get("feedback", [])
605
+
606
  if base_result.get("error"):
607
  error_message = base_result["error"]
608
  feedback = [f"Lỗi: {error_message}"]
609
+
610
  # Processing information
611
  processing_time = time.time() - start_time
612
  processing_info = {
613
  "processing_time": processing_time,
614
  "mode": "ipa_focused",
615
  "model_used": "Wav2Vec2-Enhanced",
616
+ "confidence": base_result.get("processing_info", {}).get("confidence", 0.0),
617
+ "enhanced_features": True
 
 
618
  }
619
+
620
  # Create final result
621
  result = IPAAssessmentResult(
622
  transcript=base_result.get("transcript", ""),
 
631
  practice_recommendations=practice_recommendations,
632
  feedback=feedback,
633
  processing_info=processing_info,
634
+ error=error_message
635
  )
636
+
637
+ logger.info(f"IPA assessment completed for '{target_word}' in {processing_time:.2f}s with score {overall_score:.2f}")
638
+
 
 
639
  return result
640
 
641
  except Exception as e:
642
  logger.error(f"IPA assessment error: {str(e)}")
643
  import traceback
 
644
  traceback.print_exc()
645
  raise HTTPException(status_code=500, detail=f"IPA assessment failed: {str(e)}")
646
 
 
654
  def get_word_phonemes(word: str):
655
  """Get phoneme breakdown for a specific word"""
656
  try:
657
+ # Use the new EnhancedG2P from evaluation module
658
+ from evalution import EnhancedG2P
659
+ g2p = EnhancedG2P()
660
  phoneme_data = g2p.text_to_phonemes(word)[0]
661
 
662
  # Add difficulty analysis for Vietnamese speakers
663
  difficulty_scores = []
664
+
665
  for phoneme in phoneme_data["phonemes"]:
666
  difficulty = g2p.get_difficulty_score(phoneme)
667
  difficulty_scores.append(difficulty)
 
718
  "d": "Lưỡi chạm nướu răng trên, rung dây thanh",
719
  "t": "Lưỡi chạm nướu răng trên, không rung dây thanh",
720
  "k": "Lưỡi chạm vòm miệng, không rung dây thanh",
721
+ "g": "Lưỡi chạm vòm miệng, rung dây thanh"
722
  }
723
  return tips.get(phoneme, f"Luyện tập phát âm /{phoneme}/")
724
 
 
727
  """Get difficulty level for Vietnamese speakers"""
728
  hard_phonemes = ["θ", "ð", "r", "w", "æ", "ʌ", "ɪ", "ʊ"]
729
  medium_phonemes = ["v", "z", "ʒ", "ɛ", "ə", "ɔ", "f"]
730
+
731
  if phoneme in hard_phonemes:
732
  return "hard"
733
  elif phoneme in medium_phonemes:
734
  return "medium"
735
  else:
736
+ return "easy"
test.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import torch
2
+ # import librosa
3
+ # from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
4
+
5
+ # # Cấu hình
6
+ # # MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-english"
7
+ # MODEL_ID = "facebook/wav2vec2-large-xlsr-53"
8
+ # AUDIO_FILE_PATH = "./hello_how_are_you_today.wav" # Thay đổi đường dẫn này
9
+
10
+ # # Load model và processor
11
+ # processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
12
+ # model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
13
+
14
+ # def transcribe_audio_file(audio_path):
15
+ # """
16
+ # Chuyển đổi file audio thành text sử dụng Wav2Vec2
17
+ # """
18
+ # # Đọc file audio
19
+ # try:
20
+ # speech_array, sampling_rate = librosa.load(audio_path, sr=16_000)
21
+ # print(f"Đã load audio file: {audio_path}")
22
+ # print(f"Độ dài audio: {len(speech_array)/16_000:.2f} giây")
23
+ # except Exception as e:
24
+ # print(f"Lỗi khi đọc file audio: {e}")
25
+ # return None
26
+
27
+ # # Tiền xử lý
28
+ # inputs = processor(
29
+ # speech_array,
30
+ # sampling_rate=16_000,
31
+ # return_tensors="pt",
32
+ # padding=True
33
+ # )
34
+
35
+ # # Dự đoán
36
+ # with torch.no_grad():
37
+ # logits = model(
38
+ # inputs.input_values,
39
+ # attention_mask=inputs.attention_mask
40
+ # ).logits
41
+
42
+ # # Decode kết quả
43
+ # predicted_ids = torch.argmax(logits, dim=-1)
44
+
45
+ # predicted_sentence = processor.batch_decode(predicted_ids)[0]
46
+
47
+ # return predicted_sentence
48
+
49
+ # # Test với file audio của bạn
50
+ # if __name__ == "__main__":
51
+ # # Thay đổi đường dẫn đến file audio của bạn
52
+ # audio_files = [
53
+ # "./hello_world.wav", # Thay đổi tên file này
54
+ # # "another_file.mp3", # Có thể thêm nhiều file
55
+ # ]
56
+
57
+ # for audio_file in audio_files:
58
+ # print("=" * 80)
59
+ # print(f"Đang xử lý: {audio_file}")
60
+ # print("=" * 80)
61
+
62
+ # prediction = transcribe_audio_file(audio_file)
63
+
64
+ # if prediction:
65
+ # print(f"Kết quả nhận dạng: {prediction}")
66
+ # else:
67
+ # print("Không thể xử lý file này")
68
+ # print()
69
+
70
+ # # Phiên bản đơn giản hơn - chỉ cần thay đổi đường dẫn file
71
+ # def quick_transcribe(audio_path):
72
+ # """Phiên bản nhanh để transcribe một file"""
73
+ # speech_array, _ = librosa.load(audio_path, sr=16_000)
74
+ # inputs = processor(speech_array, sampling_rate=16_000, return_tensors="pt", padding=True)
75
+
76
+ # with torch.no_grad():
77
+ # logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
78
+
79
+ # predicted_ids = torch.argmax(logits, dim=-1)
80
+ # return processor.batch_decode(predicted_ids)[0]
81
+
82
+ # # Sử dụng nhanh:
83
+ # result = quick_transcribe("./hello_how_are_you_today.wav")
84
+ # print(result)
85
+
86
+
87
+ import torch
88
+ from transformers import (
89
+ AutoModelForCTC,
90
+ AutoProcessor,
91
+ Wav2Vec2Processor,
92
+ Wav2Vec2ForCTC,
93
+ )
94
+ import onnxruntime as rt
95
+ import numpy as np
96
+ import librosa
97
+ import warnings
98
+ import os
99
+
100
+ warnings.filterwarnings("ignore")
101
+
102
+ # Available Wave2Vec2 models
103
+ WAVE2VEC2_MODELS = {
104
+ "english_large": "jonatasgrosman/wav2vec2-large-xlsr-53-english",
105
+ "multilingual": "facebook/wav2vec2-large-xlsr-53",
106
+ "english_960h": "facebook/wav2vec2-large-960h-lv60-self",
107
+ "base_english": "facebook/wav2vec2-base-960h",
108
+ "large_english": "facebook/wav2vec2-large-960h",
109
+ "xlsr_english": "jonatasgrosman/wav2vec2-large-xlsr-53-english",
110
+ "xlsr_multilingual": "facebook/wav2vec2-large-xlsr-53"
111
+ }
112
+
113
+ # Default model
114
+ DEFAULT_MODEL = "jonatasgrosman/wav2vec2-large-xlsr-53-english"
115
+
116
+
117
+ def get_available_models():
118
+ """Return dictionary of available Wave2Vec2 models"""
119
+ return WAVE2VEC2_MODELS.copy()
120
+
121
+
122
+ def get_model_name(model_key=None):
123
+ """
124
+ Get model name from key or return default
125
+
126
+ Args:
127
+ model_key: Key from WAVE2VEC2_MODELS or full model name
128
+
129
+ Returns:
130
+ str: Full model name
131
+ """
132
+ if model_key is None:
133
+ return DEFAULT_MODEL
134
+
135
+ if model_key in WAVE2VEC2_MODELS:
136
+ return WAVE2VEC2_MODELS[model_key]
137
+
138
+ # If it's already a full model name, return as is
139
+ return model_key
140
+
141
+
142
+ class Wave2Vec2Inference:
143
+ def __init__(self, model_name=None, use_gpu=True):
144
+ # Get the actual model name using helper function
145
+ self.model_name = get_model_name(model_name)
146
+
147
+ # Auto-detect device
148
+ if use_gpu:
149
+ if torch.backends.mps.is_available():
150
+ self.device = "mps"
151
+ elif torch.cuda.is_available():
152
+ self.device = "cuda"
153
+ else:
154
+ self.device = "cpu"
155
+ else:
156
+ self.device = "cpu"
157
+
158
+ print(f"Using device: {self.device}")
159
+ print(f"Loading model: {self.model_name}")
160
+
161
+ # Check if model is XLSR and use appropriate processor/model
162
+ is_xlsr = "xlsr" in self.model_name.lower()
163
+
164
+ if is_xlsr:
165
+ print("Using Wav2Vec2Processor and Wav2Vec2ForCTC for XLSR model")
166
+ self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
167
+ self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name)
168
+ else:
169
+ print("Using AutoProcessor and AutoModelForCTC")
170
+ self.processor = AutoProcessor.from_pretrained(self.model_name)
171
+ self.model = AutoModelForCTC.from_pretrained(self.model_name)
172
+
173
+ self.model.to(self.device)
174
+ self.model.eval()
175
+
176
+ # Disable gradients for inference
177
+ torch.set_grad_enabled(False)
178
+
179
+ def buffer_to_text(self, audio_buffer):
180
+ if len(audio_buffer) == 0:
181
+ return ""
182
+
183
+ # Convert to tensor
184
+ if isinstance(audio_buffer, np.ndarray):
185
+ audio_tensor = torch.from_numpy(audio_buffer).float()
186
+ else:
187
+ audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32)
188
+
189
+ # Process audio
190
+ inputs = self.processor(
191
+ audio_tensor,
192
+ sampling_rate=16_000,
193
+ return_tensors="pt",
194
+ padding=True,
195
+ )
196
+
197
+ # Move to device
198
+ input_values = inputs.input_values.to(self.device)
199
+ attention_mask = (
200
+ inputs.attention_mask.to(self.device)
201
+ if "attention_mask" in inputs
202
+ else None
203
+ )
204
+
205
+ # Inference
206
+ with torch.no_grad():
207
+ if attention_mask is not None:
208
+ logits = self.model(input_values, attention_mask=attention_mask).logits
209
+ else:
210
+ logits = self.model(input_values).logits
211
+
212
+ # Decode
213
+ predicted_ids = torch.argmax(logits, dim=-1)
214
+ if self.device != "cpu":
215
+ predicted_ids = predicted_ids.cpu()
216
+
217
+ transcription = self.processor.batch_decode(predicted_ids)[0]
218
+ return transcription.lower().strip()
219
+
220
+ def file_to_text(self, filename):
221
+ try:
222
+ audio_input, _ = librosa.load(filename, sr=16000, dtype=np.float32)
223
+ return self.buffer_to_text(audio_input)
224
+ except Exception as e:
225
+ print(f"Error loading audio file {filename}: {e}")
226
+ return ""
227
+
228
+
229
+ class Wave2Vec2ONNXInference:
230
+ def __init__(self, model_name=None, onnx_path=None, use_gpu=True):
231
+ # Get the actual model name using helper function
232
+ self.model_name = get_model_name(model_name)
233
+ print(f"Loading ONNX model: {self.model_name}")
234
+
235
+ # Always use Wav2Vec2Processor for ONNX (works for all models)
236
+ self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
237
+
238
+ # Setup ONNX Runtime
239
+ options = rt.SessionOptions()
240
+ options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
241
+
242
+ # Choose providers based on GPU availability
243
+ providers = []
244
+ if use_gpu and rt.get_available_providers():
245
+ if "CUDAExecutionProvider" in rt.get_available_providers():
246
+ providers.append("CUDAExecutionProvider")
247
+ providers.append("CPUExecutionProvider")
248
+
249
+ self.model = rt.InferenceSession(onnx_path, options, providers=providers)
250
+ self.input_name = self.model.get_inputs()[0].name
251
+ print(f"ONNX model loaded with providers: {self.model.get_providers()}")
252
+
253
+ def buffer_to_text(self, audio_buffer):
254
+ if len(audio_buffer) == 0:
255
+ return ""
256
+
257
+ # Convert to tensor
258
+ if isinstance(audio_buffer, np.ndarray):
259
+ audio_tensor = torch.from_numpy(audio_buffer).float()
260
+ else:
261
+ audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32)
262
+
263
+ # Process audio
264
+ inputs = self.processor(
265
+ audio_tensor,
266
+ sampling_rate=16_000,
267
+ return_tensors="np",
268
+ padding=True,
269
+ )
270
+
271
+ # ONNX inference
272
+ input_values = inputs.input_values.astype(np.float32)
273
+ onnx_outputs = self.model.run(None, {self.input_name: input_values})[0]
274
+
275
+ # Decode
276
+ prediction = np.argmax(onnx_outputs, axis=-1)
277
+ transcription = self.processor.decode(prediction.squeeze().tolist())
278
+ return transcription.lower().strip()
279
+
280
+ def file_to_text(self, filename):
281
+ try:
282
+ audio_input, _ = librosa.load(filename, sr=16000, dtype=np.float32)
283
+ return self.buffer_to_text(audio_input)
284
+ except Exception as e:
285
+ print(f"Error loading audio file {filename}: {e}")
286
+ return ""
287
+
288
+
289
+ def convert_to_onnx(model_id_or_path, onnx_model_name):
290
+ """Convert PyTorch model to ONNX format"""
291
+ print(f"Converting {model_id_or_path} to ONNX...")
292
+ model = Wav2Vec2ForCTC.from_pretrained(model_id_or_path)
293
+ model.eval()
294
+
295
+ # Create dummy input
296
+ audio_len = 250000
297
+ dummy_input = torch.randn(1, audio_len, requires_grad=True)
298
+
299
+ torch.onnx.export(
300
+ model,
301
+ dummy_input,
302
+ onnx_model_name,
303
+ export_params=True,
304
+ opset_version=14,
305
+ do_constant_folding=True,
306
+ input_names=["input"],
307
+ output_names=["output"],
308
+ dynamic_axes={
309
+ "input": {1: "audio_len"},
310
+ "output": {1: "audio_len"},
311
+ },
312
+ )
313
+ print(f"ONNX model saved to: {onnx_model_name}")
314
+
315
+
316
+ def quantize_onnx_model(onnx_model_path, quantized_model_path):
317
+ """Quantize ONNX model for faster inference"""
318
+ print("Starting quantization...")
319
+ from onnxruntime.quantization import quantize_dynamic, QuantType
320
+
321
+ quantize_dynamic(
322
+ onnx_model_path, quantized_model_path, weight_type=QuantType.QUInt8
323
+ )
324
+ print(f"Quantized model saved to: {quantized_model_path}")
325
+
326
+
327
+ def export_to_onnx(model_name, quantize=False):
328
+ """
329
+ Export model to ONNX format with optional quantization
330
+
331
+ Args:
332
+ model_name: HuggingFace model name
333
+ quantize: Whether to also create quantized version
334
+
335
+ Returns:
336
+ tuple: (onnx_path, quantized_path or None)
337
+ """
338
+ onnx_filename = f"{model_name.split('/')[-1]}.onnx"
339
+ convert_to_onnx(model_name, onnx_filename)
340
+
341
+ quantized_path = None
342
+ if quantize:
343
+ quantized_path = onnx_filename.replace(".onnx", ".quantized.onnx")
344
+ quantize_onnx_model(onnx_filename, quantized_path)
345
+
346
+ return onnx_filename, quantized_path
347
+
348
+
349
+ def create_inference(
350
+ model_name=None, use_onnx=False, onnx_path=None, use_gpu=True, use_onnx_quantize=False
351
+ ):
352
+ """
353
+ Create optimized inference instance
354
+
355
+ Args:
356
+ model_name: Model key from WAVE2VEC2_MODELS or full HuggingFace model name (default: uses DEFAULT_MODEL)
357
+ use_onnx: Whether to use ONNX runtime
358
+ onnx_path: Path to ONNX model file
359
+ use_gpu: Whether to use GPU if available
360
+ use_onnx_quantize: Whether to use quantized ONNX model
361
+
362
+ Returns:
363
+ Inference instance
364
+ """
365
+ # Get the actual model name
366
+ actual_model_name = get_model_name(model_name)
367
+
368
+ if use_onnx:
369
+ if not onnx_path or not os.path.exists(onnx_path):
370
+ # Convert to ONNX if path not provided or doesn't exist
371
+ onnx_filename = f"{actual_model_name.split('/')[-1]}.onnx"
372
+ convert_to_onnx(actual_model_name, onnx_filename)
373
+ onnx_path = onnx_filename
374
+
375
+ if use_onnx_quantize:
376
+ quantized_path = onnx_path.replace(".onnx", ".quantized.onnx")
377
+ if not os.path.exists(quantized_path):
378
+ quantize_onnx_model(onnx_path, quantized_path)
379
+ onnx_path = quantized_path
380
+
381
+ print(f"Using ONNX model: {onnx_path}")
382
+ return Wave2Vec2ONNXInference(model_name, onnx_path, use_gpu)
383
+ else:
384
+ print("Using PyTorch model")
385
+ return Wave2Vec2Inference(model_name, use_gpu)
386
+
387
+
388
+ if __name__ == "__main__":
389
+ import time
390
+
391
+ # Display available models
392
+ print("Available Wave2Vec2 models:")
393
+ for key, model_name in get_available_models().items():
394
+ print(f" {key}: {model_name}")
395
+ print(f"\nDefault model: {DEFAULT_MODEL}")
396
+ print()
397
+
398
+ # Test with different models
399
+ test_models = ["english_large", "multilingual", "english_960h"]
400
+ test_file = "./hello_how_are_you_today.wav"
401
+
402
+ if not os.path.exists(test_file):
403
+ print(f"Test file {test_file} not found. Please provide a valid audio file.")
404
+ print("Creating example usage without actual file...")
405
+
406
+ # Example usage without file
407
+ print("\n=== Example Usage ===")
408
+
409
+ # Using default model
410
+ print("1. Using default model:")
411
+ asr_default = create_inference()
412
+ print(f" Model loaded: {asr_default.model_name}")
413
+
414
+ # Using model key
415
+ print("\n2. Using model key 'english_large':")
416
+ asr_key = create_inference("english_large")
417
+ print(f" Model loaded: {asr_key.model_name}")
418
+
419
+ # Using full model name
420
+ print("\n3. Using full model name:")
421
+ asr_full = create_inference("facebook/wav2vec2-base-960h")
422
+ print(f" Model loaded: {asr_full.model_name}")
423
+
424
+ exit(0)
425
+
426
+ # Test different model configurations
427
+ for model_key in test_models:
428
+ print(f"\n=== Testing model: {model_key} ===")
429
+
430
+ # Test different configurations
431
+ configs = [
432
+ {"use_onnx": False, "use_gpu": True},
433
+ {"use_onnx": True, "use_gpu": True, "use_onnx_quantize": False},
434
+ ]
435
+
436
+ for config in configs:
437
+ print(f"\nConfig: {config}")
438
+
439
+ # Create inference instance with model selection
440
+ asr = create_inference(model_key, **config)
441
+
442
+ # Warm up
443
+ asr.file_to_text(test_file)
444
+
445
+ # Test performance
446
+ times = []
447
+ for i in range(3):
448
+ start_time = time.time()
449
+ text = asr.file_to_text(test_file)
450
+ end_time = time.time()
451
+ execution_time = end_time - start_time
452
+ times.append(execution_time)
453
+ print(f"Run {i+1}: {execution_time:.3f}s - {text[:50]}...")
454
+
455
+ avg_time = sum(times) / len(times)
456
+ print(f"Average time: {avg_time:.3f}s")