datbkpro commited on
Commit
1129e66
·
verified ·
1 Parent(s): 474ccb5

Update core/silero_vad.py

Browse files
Files changed (1) hide show
  1. core/silero_vad.py +70 -24
core/silero_vad.py CHANGED
@@ -2,29 +2,63 @@ import torch
2
  import numpy as np
3
  from typing import Optional, Callable
4
  from config.settings import settings
 
5
 
6
  class SileroVAD:
7
  def __init__(self):
8
  self.model = None
9
- self.sample_rate = 16000 # Silero VAD yêu cầu 16kHz
10
  self.is_streaming = False
11
  self.speech_callback = None
12
  self.audio_buffer = []
13
  self._initialize_model()
14
 
15
  def _initialize_model(self):
16
- """Khởi tạo Silero VAD model"""
17
  try:
18
- print("🔄 Đang tải Silero VAD model...")
19
- torch.hub.download_url_to_file(
20
- 'https://raw.githubusercontent.com/snakers4/silero-vad/master/files/model.jit',
21
- 'silero_vad.jit'
 
 
 
 
22
  )
23
- self.model = torch.jit.load('silero_vad.jit')
24
- self.model.eval()
25
  print("✅ Đã tải Silero VAD model thành công")
 
26
  except Exception as e:
27
  print(f"❌ Lỗi tải Silero VAD model: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  self.model = None
29
 
30
  def start_stream(self, speech_callback: Callable):
@@ -52,16 +86,16 @@ class SileroVAD:
52
  return
53
 
54
  try:
55
- # Resample nếu cần (Silero yêu cầu 16kHz)
56
  if sample_rate != self.sample_rate:
57
  audio_chunk = self._resample_audio(audio_chunk, sample_rate, self.sample_rate)
58
 
59
  # Thêm vào buffer
60
  self.audio_buffer.extend(audio_chunk)
61
 
62
- # Xử lý khi buffer đủ lớn (1 giây - Silero làm việc tốt với chunk nhỏ)
63
  buffer_duration = len(self.audio_buffer) / self.sample_rate
64
- if buffer_duration >= 1.0: # Giảm từ 2.0 xuống 1.0 giây
65
  self._process_buffer()
66
 
67
  except Exception as e:
@@ -70,7 +104,6 @@ class SileroVAD:
70
  def _process_buffer(self):
71
  """Xử lý buffer audio với Silero VAD"""
72
  try:
73
- # Silero VAD làm việc tốt với chunk 1 giây
74
  chunk_size = self.sample_rate # 1 giây
75
  if len(self.audio_buffer) < chunk_size:
76
  return
@@ -80,10 +113,15 @@ class SileroVAD:
80
 
81
  # Chuẩn hóa audio cho Silero
82
  if audio_chunk.dtype != np.float32:
83
- audio_chunk = audio_chunk.astype(np.float32) / 32768.0 # Normalize to [-1, 1]
 
 
 
 
 
84
 
85
  # Chuyển thành tensor
86
- audio_tensor = torch.from_numpy(audio_chunk).unsqueeze(0)
87
 
88
  # Phát hiện speech với Silero VAD
89
  with torch.no_grad():
@@ -91,7 +129,7 @@ class SileroVAD:
91
 
92
  print(f"🎯 Silero VAD speech probability: {speech_prob:.3f}")
93
 
94
- # Ngưỡng phát hiện speech (có thể điều chỉnh)
95
  if speech_prob > settings.VAD_THRESHOLD:
96
  print(f"🎯 Silero VAD phát hiện speech: {speech_prob:.3f}")
97
 
@@ -99,7 +137,7 @@ class SileroVAD:
99
  if self.speech_callback:
100
  self.speech_callback(audio_chunk, self.sample_rate)
101
 
102
- # Giữ lại 0.3 giây cuối để overlap (Silero nhạy hơn)
103
  keep_samples = int(self.sample_rate * 0.3)
104
  if len(self.audio_buffer) > keep_samples:
105
  self.audio_buffer = self.audio_buffer[-keep_samples:]
@@ -142,15 +180,19 @@ class SileroVAD:
142
 
143
  # Chuẩn hóa audio
144
  if audio_chunk.dtype != np.float32:
145
- audio_chunk = audio_chunk.astype(np.float32) / 32768.0
 
 
 
 
146
 
147
  # Đảm bảo độ dài phù hợp
148
- if len(audio_chunk) < 512: # Silero cần ít nhất 512 samples
149
- padding = np.zeros(512 - len(audio_chunk))
150
  audio_chunk = np.concatenate([audio_chunk, padding])
151
 
152
  # Chuyển thành tensor
153
- audio_tensor = torch.from_numpy(audio_chunk).unsqueeze(0)
154
 
155
  # Phát hiện speech
156
  with torch.no_grad():
@@ -164,7 +206,7 @@ class SileroVAD:
164
  return True
165
 
166
  def get_speech_probability(self, audio_chunk: np.ndarray, sample_rate: int) -> float:
167
- """Lấy xác suất speech (dùng cho debugging)"""
168
  if self.model is None:
169
  return 0.0
170
 
@@ -175,15 +217,19 @@ class SileroVAD:
175
 
176
  # Chuẩn hóa audio
177
  if audio_chunk.dtype != np.float32:
178
- audio_chunk = audio_chunk.astype(np.float32) / 32768.0
 
 
 
 
179
 
180
  # Đảm bảo độ dài phù hợp
181
  if len(audio_chunk) < 512:
182
- padding = np.zeros(512 - len(audio_chunk))
183
  audio_chunk = np.concatenate([audio_chunk, padding])
184
 
185
  # Chuyển thành tensor
186
- audio_tensor = torch.from_numpy(audio_chunk).unsqueeze(0)
187
 
188
  # Phát hiện speech
189
  with torch.no_grad():
 
2
  import numpy as np
3
  from typing import Optional, Callable
4
  from config.settings import settings
5
+ import os
6
 
7
  class SileroVAD:
8
  def __init__(self):
9
  self.model = None
10
+ self.sample_rate = 16000
11
  self.is_streaming = False
12
  self.speech_callback = None
13
  self.audio_buffer = []
14
  self._initialize_model()
15
 
16
  def _initialize_model(self):
17
+ """Khởi tạo Silero VAD model sử dụng torch.hub"""
18
  try:
19
+ print("🔄 Đang tải Silero VAD model từ torch.hub...")
20
+
21
+ # Sử dụng torch.hub để load model (cách chính thức)
22
+ self.model = torch.hub.load(
23
+ repo_or_dir=settings.VAD_MODEL,
24
+ model='silero_vad',
25
+ force_reload=False, # Sử dụng cache nếu có
26
+ trust_repo=True
27
  )
28
+
 
29
  print("✅ Đã tải Silero VAD model thành công")
30
+
31
  except Exception as e:
32
  print(f"❌ Lỗi tải Silero VAD model: {e}")
33
+ print("🔄 Đang thử cách tải thay thế...")
34
+ self._initialize_model_fallback()
35
+
36
+ def _initialize_model_fallback(self):
37
+ """Fallback method nếu cách chính thức không hoạt động"""
38
+ try:
39
+ # Cách 2: Sử dụng direct download
40
+ model_urls = {
41
+ 'silero_vad.jit': 'https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.jit'
42
+ }
43
+
44
+ # Tạo thư mục cache
45
+ os.makedirs('./models', exist_ok=True)
46
+ model_path = './models/silero_vad.jit'
47
+
48
+ if not os.path.exists(model_path):
49
+ print("📥 Đang download Silero VAD model...")
50
+ torch.hub.download_url_to_file(
51
+ model_urls['silero_vad.jit'],
52
+ model_path
53
+ )
54
+
55
+ # Load model
56
+ self.model = torch.jit.load(model_path)
57
+ self.model.eval()
58
+ print("✅ Đã tải Silero VAD model thành công (fallback)")
59
+
60
+ except Exception as e:
61
+ print(f"❌ Lỗi tải Silero VAD model fallback: {e}")
62
  self.model = None
63
 
64
  def start_stream(self, speech_callback: Callable):
 
86
  return
87
 
88
  try:
89
+ # Resample nếu cần
90
  if sample_rate != self.sample_rate:
91
  audio_chunk = self._resample_audio(audio_chunk, sample_rate, self.sample_rate)
92
 
93
  # Thêm vào buffer
94
  self.audio_buffer.extend(audio_chunk)
95
 
96
+ # Xử lý khi buffer đủ lớn (1 giây)
97
  buffer_duration = len(self.audio_buffer) / self.sample_rate
98
+ if buffer_duration >= 1.0:
99
  self._process_buffer()
100
 
101
  except Exception as e:
 
104
  def _process_buffer(self):
105
  """Xử lý buffer audio với Silero VAD"""
106
  try:
 
107
  chunk_size = self.sample_rate # 1 giây
108
  if len(self.audio_buffer) < chunk_size:
109
  return
 
113
 
114
  # Chuẩn hóa audio cho Silero
115
  if audio_chunk.dtype != np.float32:
116
+ audio_chunk = audio_chunk.astype(np.float32)
117
+ if np.max(np.abs(audio_chunk)) > 1.0:
118
+ audio_chunk = audio_chunk / 32768.0 # Normalize từ int16
119
+
120
+ # Đảm bảo audio trong range [-1, 1]
121
+ audio_chunk = np.clip(audio_chunk, -1.0, 1.0)
122
 
123
  # Chuyển thành tensor
124
+ audio_tensor = torch.from_numpy(audio_chunk).float().unsqueeze(0)
125
 
126
  # Phát hiện speech với Silero VAD
127
  with torch.no_grad():
 
129
 
130
  print(f"🎯 Silero VAD speech probability: {speech_prob:.3f}")
131
 
132
+ # Ngưỡng phát hiện speech
133
  if speech_prob > settings.VAD_THRESHOLD:
134
  print(f"🎯 Silero VAD phát hiện speech: {speech_prob:.3f}")
135
 
 
137
  if self.speech_callback:
138
  self.speech_callback(audio_chunk, self.sample_rate)
139
 
140
+ # Giữ lại 0.3 giây cuối để overlap
141
  keep_samples = int(self.sample_rate * 0.3)
142
  if len(self.audio_buffer) > keep_samples:
143
  self.audio_buffer = self.audio_buffer[-keep_samples:]
 
180
 
181
  # Chuẩn hóa audio
182
  if audio_chunk.dtype != np.float32:
183
+ audio_chunk = audio_chunk.astype(np.float32)
184
+ if np.max(np.abs(audio_chunk)) > 1.0:
185
+ audio_chunk = audio_chunk / 32768.0
186
+
187
+ audio_chunk = np.clip(audio_chunk, -1.0, 1.0)
188
 
189
  # Đảm bảo độ dài phù hợp
190
+ if len(audio_chunk) < 512:
191
+ padding = np.zeros(512 - len(audio_chunk), dtype=np.float32)
192
  audio_chunk = np.concatenate([audio_chunk, padding])
193
 
194
  # Chuyển thành tensor
195
+ audio_tensor = torch.from_numpy(audio_chunk).float().unsqueeze(0)
196
 
197
  # Phát hiện speech
198
  with torch.no_grad():
 
206
  return True
207
 
208
  def get_speech_probability(self, audio_chunk: np.ndarray, sample_rate: int) -> float:
209
+ """Lấy xác suất speech"""
210
  if self.model is None:
211
  return 0.0
212
 
 
217
 
218
  # Chuẩn hóa audio
219
  if audio_chunk.dtype != np.float32:
220
+ audio_chunk = audio_chunk.astype(np.float32)
221
+ if np.max(np.abs(audio_chunk)) > 1.0:
222
+ audio_chunk = audio_chunk / 32768.0
223
+
224
+ audio_chunk = np.clip(audio_chunk, -1.0, 1.0)
225
 
226
  # Đảm bảo độ dài phù hợp
227
  if len(audio_chunk) < 512:
228
+ padding = np.zeros(512 - len(audio_chunk), dtype=np.float32)
229
  audio_chunk = np.concatenate([audio_chunk, padding])
230
 
231
  # Chuyển thành tensor
232
+ audio_tensor = torch.from_numpy(audio_chunk).float().unsqueeze(0)
233
 
234
  # Phát hiện speech
235
  with torch.no_grad():