voicebot / core /speechbrain_vad.py
datbkpro's picture
Create speechbrain_vad.py
b5e51ac verified
import torch
import torchaudio
import numpy as np
from typing import Optional, Callable
from config.settings import settings
class SpeechBrainVAD:
def __init__(self):
self.model = None
self.sample_rate = settings.SAMPLE_RATE
self.is_streaming = False
self.speech_callback = None
self.audio_buffer = []
self._initialize_model()
def _initialize_model(self):
"""Khởi tạo VAD model từ SpeechBrain"""
try:
from speechbrain.pretrained import VAD
print("🔄 Đang tải VAD model từ SpeechBrain...")
self.model = VAD.from_hparams(
source=settings.VAD_MODEL,
savedir=f"/tmp/{settings.VAD_MODEL.replace('/', '_')}"
)
print("✅ Đã tải VAD model thành công")
except Exception as e:
print(f"❌ Lỗi tải VAD model: {e}")
self.model = None
def start_stream(self, speech_callback: Callable):
"""Bắt đầu stream với VAD"""
if self.model is None:
print("❌ VAD model chưa được khởi tạo")
return False
self.is_streaming = True
self.speech_callback = speech_callback
self.audio_buffer = []
print("🎙️ Bắt đầu VAD streaming...")
return True
def stop_stream(self):
"""Dừng stream"""
self.is_streaming = False
self.speech_callback = None
self.audio_buffer = []
print("🛑 Đã dừng VAD streaming")
def process_stream(self, audio_chunk: np.ndarray, sample_rate: int):
"""Xử lý audio chunk với VAD"""
if not self.is_streaming or self.model is None:
return
try:
# Resample nếu cần
if sample_rate != self.sample_rate:
audio_chunk = self._resample_audio(audio_chunk, sample_rate, self.sample_rate)
# Thêm vào buffer
self.audio_buffer.extend(audio_chunk)
# Xử lý khi buffer đủ lớn (2 giây)
buffer_duration = len(self.audio_buffer) / self.sample_rate
if buffer_duration >= 2.0:
self._process_buffer()
except Exception as e:
print(f"❌ Lỗi xử lý VAD: {e}")
def _process_buffer(self):
"""Xử lý buffer audio với VAD"""
try:
# Chuyển buffer thành tensor
audio_tensor = torch.FloatTensor(self.audio_buffer).unsqueeze(0)
# Phát hiện speech với VAD
boundaries = self.model.get_speech_segments(
audio_tensor,
# Điều chỉnh parameters để nhạy hơn
threshold=settings.VAD_THRESHOLD - 0.1, # Giảm threshold
min_silence_duration=settings.VAD_MIN_SILENCE_DURATION + 0.3, # Tăng silence duration
speech_pad_duration=settings.VAD_SPEECH_PAD_DURATION
)
# Xử lý speech segments
if len(boundaries) > 0:
for start, end in boundaries:
start_sample = int(start * self.sample_rate)
end_sample = int(end * self.sample_rate)
# Trích xuất speech segment
speech_audio = np.array(self.audio_buffer[start_sample:end_sample])
if len(speech_audio) > self.sample_rate * 0.5: # Ít nhất 0.5 giây
print(f"🎯 VAD phát hiện speech: {len(speech_audio)/self.sample_rate:.2f}s")
# Gọi callback với speech segment
if self.speech_callback:
self.speech_callback(speech_audio, self.sample_rate)
# Giữ lại 0.5 giây cuối để overlap
keep_samples = int(self.sample_rate * 0.5)
if len(self.audio_buffer) > keep_samples:
self.audio_buffer = self.audio_buffer[-keep_samples:]
else:
self.audio_buffer = []
except Exception as e:
print(f"❌ Lỗi xử lý VAD buffer: {e}")
self.audio_buffer = []
def _resample_audio(self, audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
"""Resample audio nếu cần"""
if orig_sr == target_sr:
return audio
try:
audio_tensor = torch.FloatTensor(audio).unsqueeze(0)
resampler = torchaudio.transforms.Resample(orig_sr, target_sr)
resampled = resampler(audio_tensor)
return resampled.squeeze(0).numpy()
except Exception as e:
print(f"⚠️ Lỗi resample: {e}")
return audio
def is_speech(self, audio_chunk: np.ndarray, sample_rate: int) -> bool:
"""Kiểm tra xem audio chunk có phải là speech không"""
if self.model is None:
return True # Fallback: luôn coi là speech
try:
# Resample nếu cần
if sample_rate != self.sample_rate:
audio_chunk = self._resample_audio(audio_chunk, sample_rate, self.sample_rate)
# Chuyển thành tensor
audio_tensor = torch.FloatTensor(audio_chunk).unsqueeze(0)
# Phát hiện speech
prob_speech = self.model.get_speech_prob_chunk(audio_tensor)
# Kiểm tra ngưỡng
return prob_speech.mean().item() > (settings.VAD_THRESHOLD - 0.1)
except Exception as e:
print(f"❌ Lỗi kiểm tra speech: {e}")
return True