|
|
import torch |
|
|
import torchaudio |
|
|
import numpy as np |
|
|
from typing import Optional, Callable |
|
|
from config.settings import settings |
|
|
|
|
|
class SpeechBrainVAD: |
|
|
def __init__(self): |
|
|
self.model = None |
|
|
self.sample_rate = settings.SAMPLE_RATE |
|
|
self.is_streaming = False |
|
|
self.speech_callback = None |
|
|
self.audio_buffer = [] |
|
|
self._initialize_model() |
|
|
|
|
|
def _initialize_model(self): |
|
|
"""Khởi tạo VAD model từ SpeechBrain""" |
|
|
try: |
|
|
from speechbrain.pretrained import VAD |
|
|
print("🔄 Đang tải VAD model từ SpeechBrain...") |
|
|
self.model = VAD.from_hparams( |
|
|
source=settings.VAD_MODEL, |
|
|
savedir=f"/tmp/{settings.VAD_MODEL.replace('/', '_')}" |
|
|
) |
|
|
print("✅ Đã tải VAD model thành công") |
|
|
except Exception as e: |
|
|
print(f"❌ Lỗi tải VAD model: {e}") |
|
|
self.model = None |
|
|
|
|
|
def start_stream(self, speech_callback: Callable): |
|
|
"""Bắt đầu stream với VAD""" |
|
|
if self.model is None: |
|
|
print("❌ VAD model chưa được khởi tạo") |
|
|
return False |
|
|
|
|
|
self.is_streaming = True |
|
|
self.speech_callback = speech_callback |
|
|
self.audio_buffer = [] |
|
|
print("🎙️ Bắt đầu VAD streaming...") |
|
|
return True |
|
|
|
|
|
def stop_stream(self): |
|
|
"""Dừng stream""" |
|
|
self.is_streaming = False |
|
|
self.speech_callback = None |
|
|
self.audio_buffer = [] |
|
|
print("🛑 Đã dừng VAD streaming") |
|
|
|
|
|
def process_stream(self, audio_chunk: np.ndarray, sample_rate: int): |
|
|
"""Xử lý audio chunk với VAD""" |
|
|
if not self.is_streaming or self.model is None: |
|
|
return |
|
|
|
|
|
try: |
|
|
|
|
|
if sample_rate != self.sample_rate: |
|
|
audio_chunk = self._resample_audio(audio_chunk, sample_rate, self.sample_rate) |
|
|
|
|
|
|
|
|
self.audio_buffer.extend(audio_chunk) |
|
|
|
|
|
|
|
|
buffer_duration = len(self.audio_buffer) / self.sample_rate |
|
|
if buffer_duration >= 2.0: |
|
|
self._process_buffer() |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Lỗi xử lý VAD: {e}") |
|
|
|
|
|
def _process_buffer(self): |
|
|
"""Xử lý buffer audio với VAD""" |
|
|
try: |
|
|
|
|
|
audio_tensor = torch.FloatTensor(self.audio_buffer).unsqueeze(0) |
|
|
|
|
|
|
|
|
boundaries = self.model.get_speech_segments( |
|
|
audio_tensor, |
|
|
|
|
|
threshold=settings.VAD_THRESHOLD - 0.1, |
|
|
min_silence_duration=settings.VAD_MIN_SILENCE_DURATION + 0.3, |
|
|
speech_pad_duration=settings.VAD_SPEECH_PAD_DURATION |
|
|
) |
|
|
|
|
|
|
|
|
if len(boundaries) > 0: |
|
|
for start, end in boundaries: |
|
|
start_sample = int(start * self.sample_rate) |
|
|
end_sample = int(end * self.sample_rate) |
|
|
|
|
|
|
|
|
speech_audio = np.array(self.audio_buffer[start_sample:end_sample]) |
|
|
|
|
|
if len(speech_audio) > self.sample_rate * 0.5: |
|
|
print(f"🎯 VAD phát hiện speech: {len(speech_audio)/self.sample_rate:.2f}s") |
|
|
|
|
|
|
|
|
if self.speech_callback: |
|
|
self.speech_callback(speech_audio, self.sample_rate) |
|
|
|
|
|
|
|
|
keep_samples = int(self.sample_rate * 0.5) |
|
|
if len(self.audio_buffer) > keep_samples: |
|
|
self.audio_buffer = self.audio_buffer[-keep_samples:] |
|
|
else: |
|
|
self.audio_buffer = [] |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Lỗi xử lý VAD buffer: {e}") |
|
|
self.audio_buffer = [] |
|
|
|
|
|
def _resample_audio(self, audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray: |
|
|
"""Resample audio nếu cần""" |
|
|
if orig_sr == target_sr: |
|
|
return audio |
|
|
|
|
|
try: |
|
|
audio_tensor = torch.FloatTensor(audio).unsqueeze(0) |
|
|
resampler = torchaudio.transforms.Resample(orig_sr, target_sr) |
|
|
resampled = resampler(audio_tensor) |
|
|
return resampled.squeeze(0).numpy() |
|
|
except Exception as e: |
|
|
print(f"⚠️ Lỗi resample: {e}") |
|
|
return audio |
|
|
|
|
|
def is_speech(self, audio_chunk: np.ndarray, sample_rate: int) -> bool: |
|
|
"""Kiểm tra xem audio chunk có phải là speech không""" |
|
|
if self.model is None: |
|
|
return True |
|
|
|
|
|
try: |
|
|
|
|
|
if sample_rate != self.sample_rate: |
|
|
audio_chunk = self._resample_audio(audio_chunk, sample_rate, self.sample_rate) |
|
|
|
|
|
|
|
|
audio_tensor = torch.FloatTensor(audio_chunk).unsqueeze(0) |
|
|
|
|
|
|
|
|
prob_speech = self.model.get_speech_prob_chunk(audio_tensor) |
|
|
|
|
|
|
|
|
return prob_speech.mean().item() > (settings.VAD_THRESHOLD - 0.1) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Lỗi kiểm tra speech: {e}") |
|
|
return True |