|
|
import torch |
|
|
import torchaudio |
|
|
import numpy as np |
|
|
from speechbrain.inference import VAD |
|
|
from typing import List, Tuple, Optional |
|
|
import queue |
|
|
import threading |
|
|
import time |
|
|
from config.settings import settings |
|
|
|
|
|
class SpeechBrainVAD: |
|
|
def __init__(self): |
|
|
self.vad_model = None |
|
|
self.sample_rate = settings.SAMPLE_RATE |
|
|
self.threshold = settings.VAD_THRESHOLD |
|
|
self.min_silence_duration = settings.VAD_MIN_SILENCE_DURATION |
|
|
self.speech_pad_duration = settings.VAD_SPEECH_PAD_DURATION |
|
|
self.is_running = False |
|
|
self.audio_queue = queue.Queue() |
|
|
self.speech_buffer = [] |
|
|
self.silence_start_time = None |
|
|
self.callback = None |
|
|
|
|
|
self._initialize_model() |
|
|
|
|
|
def _initialize_model(self): |
|
|
"""Khởi tạo mô hình VAD từ SpeechBrain""" |
|
|
try: |
|
|
print("🔄 Đang tải mô hình SpeechBrain VAD...") |
|
|
self.vad_model = VAD.from_hparams( |
|
|
source=settings.VAD_MODEL, |
|
|
savedir=f"pretrained_models/{settings.VAD_MODEL}" |
|
|
) |
|
|
print("✅ Đã tải mô hình VAD thành công") |
|
|
except Exception as e: |
|
|
print(f"❌ Lỗi tải mô hình VAD: {e}") |
|
|
self.vad_model = None |
|
|
|
|
|
def preprocess_audio(self, audio_data: np.ndarray, original_sr: int) -> np.ndarray: |
|
|
"""Tiền xử lý audio cho VAD""" |
|
|
if original_sr != self.sample_rate: |
|
|
|
|
|
audio_tensor = torch.from_numpy(audio_data).float() |
|
|
if len(audio_tensor.shape) > 1: |
|
|
audio_tensor = audio_tensor.mean(dim=0) |
|
|
|
|
|
resampler = torchaudio.transforms.Resample( |
|
|
orig_freq=original_sr, |
|
|
new_freq=self.sample_rate |
|
|
) |
|
|
audio_tensor = resampler(audio_tensor) |
|
|
audio_data = audio_tensor.numpy() |
|
|
|
|
|
|
|
|
if np.max(np.abs(audio_data)) > 0: |
|
|
audio_data = audio_data / np.max(np.abs(audio_data)) |
|
|
|
|
|
return audio_data |
|
|
|
|
|
def detect_voice_activity(self, audio_chunk: np.ndarray) -> bool: |
|
|
"""Phát hiện hoạt động giọng nói trong audio chunk""" |
|
|
if self.vad_model is None: |
|
|
|
|
|
return self._energy_based_vad(audio_chunk) |
|
|
|
|
|
try: |
|
|
|
|
|
audio_tensor = torch.from_numpy(audio_chunk).float().unsqueeze(0) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
prob = self.vad_model.get_speech_prob_chunk(audio_tensor) |
|
|
|
|
|
return prob.item() > self.threshold |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Lỗi VAD detection: {e}") |
|
|
return self._energy_based_vad(audio_chunk) |
|
|
|
|
|
def _energy_based_vad(self, audio_chunk: np.ndarray) -> bool: |
|
|
"""Fallback VAD dựa trên năng lượng âm thanh""" |
|
|
energy = np.mean(audio_chunk ** 2) |
|
|
return energy > 0.01 |
|
|
|
|
|
def process_stream(self, audio_chunk: np.ndarray, original_sr: int): |
|
|
"""Xử lý audio stream real-time""" |
|
|
if not self.is_running: |
|
|
return |
|
|
|
|
|
|
|
|
processed_audio = self.preprocess_audio(audio_chunk, original_sr) |
|
|
|
|
|
|
|
|
is_speech = self.detect_voice_activity(processed_audio) |
|
|
|
|
|
if is_speech: |
|
|
self.silence_start_time = None |
|
|
self.speech_buffer.extend(processed_audio) |
|
|
print("🎤 Đang nói...") |
|
|
else: |
|
|
|
|
|
if self.silence_start_time is None: |
|
|
self.silence_start_time = time.time() |
|
|
elif len(self.speech_buffer) > 0: |
|
|
silence_duration = time.time() - self.silence_start_time |
|
|
if silence_duration >= self.min_silence_duration: |
|
|
|
|
|
self._process_speech_segment() |
|
|
|
|
|
return is_speech |
|
|
|
|
|
def _process_speech_segment(self): |
|
|
"""Xử lý segment giọng nói khi kết thúc""" |
|
|
if len(self.speech_buffer) == 0: |
|
|
return |
|
|
|
|
|
|
|
|
speech_audio = np.array(self.speech_buffer) |
|
|
|
|
|
|
|
|
if self.callback and callable(self.callback): |
|
|
self.callback(speech_audio, self.sample_rate) |
|
|
|
|
|
|
|
|
self.speech_buffer = [] |
|
|
self.silence_start_time = None |
|
|
|
|
|
print("✅ Đã xử lý segment giọng nói") |
|
|
|
|
|
def start_stream(self, callback: callable): |
|
|
"""Bắt đầu xử lý stream""" |
|
|
self.is_running = True |
|
|
self.callback = callback |
|
|
self.speech_buffer = [] |
|
|
self.silence_start_time = None |
|
|
print("🎙️ Bắt đầu stream VAD...") |
|
|
|
|
|
def stop_stream(self): |
|
|
"""Dừng xử lý stream""" |
|
|
self.is_running = False |
|
|
|
|
|
if len(self.speech_buffer) > 0: |
|
|
self._process_speech_segment() |
|
|
print("🛑 Đã dừng stream VAD") |
|
|
|
|
|
def get_audio_chunk_from_stream(self, stream, chunk_size: int = 1024): |
|
|
"""Lấy audio chunk từ stream (for microphone input)""" |
|
|
try: |
|
|
data = stream.read(chunk_size, exception_on_overflow=False) |
|
|
audio_data = np.frombuffer(data, dtype=np.int16) |
|
|
return audio_data.astype(np.float32) / 32768.0 |
|
|
except Exception as e: |
|
|
print(f"❌ Lỗi đọc audio stream: {e}") |
|
|
return None |