import torch import torchaudio import numpy as np from speechbrain.inference import VAD from typing import List, Tuple, Optional import queue import threading import time from config.settings import settings class SpeechBrainVAD: def __init__(self): self.vad_model = None self.sample_rate = settings.SAMPLE_RATE self.threshold = settings.VAD_THRESHOLD self.min_silence_duration = settings.VAD_MIN_SILENCE_DURATION self.speech_pad_duration = settings.VAD_SPEECH_PAD_DURATION self.is_running = False self.audio_queue = queue.Queue() self.speech_buffer = [] self.silence_start_time = None self.callback = None self._initialize_model() def _initialize_model(self): """Khởi tạo mô hình VAD từ SpeechBrain""" try: print("🔄 Đang tải mô hình SpeechBrain VAD...") self.vad_model = VAD.from_hparams( source=settings.VAD_MODEL, savedir=f"pretrained_models/{settings.VAD_MODEL}" ) print("✅ Đã tải mô hình VAD thành công") except Exception as e: print(f"❌ Lỗi tải mô hình VAD: {e}") self.vad_model = None def preprocess_audio(self, audio_data: np.ndarray, original_sr: int) -> np.ndarray: """Tiền xử lý audio cho VAD""" if original_sr != self.sample_rate: # Resample audio to VAD sample rate audio_tensor = torch.from_numpy(audio_data).float() if len(audio_tensor.shape) > 1: audio_tensor = audio_tensor.mean(dim=0) # Convert to mono resampler = torchaudio.transforms.Resample( orig_freq=original_sr, new_freq=self.sample_rate ) audio_tensor = resampler(audio_tensor) audio_data = audio_tensor.numpy() # Normalize audio if np.max(np.abs(audio_data)) > 0: audio_data = audio_data / np.max(np.abs(audio_data)) return audio_data def detect_voice_activity(self, audio_chunk: np.ndarray) -> bool: """Phát hiện hoạt động giọng nói trong audio chunk""" if self.vad_model is None: # Fallback: simple energy-based VAD return self._energy_based_vad(audio_chunk) try: # Convert to tensor and add batch dimension audio_tensor = torch.from_numpy(audio_chunk).float().unsqueeze(0) # Get VAD probabilities with torch.no_grad(): prob = self.vad_model.get_speech_prob_chunk(audio_tensor) return prob.item() > self.threshold except Exception as e: print(f"❌ Lỗi VAD detection: {e}") return self._energy_based_vad(audio_chunk) def _energy_based_vad(self, audio_chunk: np.ndarray) -> bool: """Fallback VAD dựa trên năng lượng âm thanh""" energy = np.mean(audio_chunk ** 2) return energy > 0.01 # Simple threshold def process_stream(self, audio_chunk: np.ndarray, original_sr: int): """Xử lý audio stream real-time""" if not self.is_running: return # Preprocess audio processed_audio = self.preprocess_audio(audio_chunk, original_sr) # Detect voice activity is_speech = self.detect_voice_activity(processed_audio) if is_speech: self.silence_start_time = None self.speech_buffer.extend(processed_audio) print("🎤 Đang nói...") else: # Silence detected if self.silence_start_time is None: self.silence_start_time = time.time() elif len(self.speech_buffer) > 0: silence_duration = time.time() - self.silence_start_time if silence_duration >= self.min_silence_duration: # End of speech segment self._process_speech_segment() return is_speech def _process_speech_segment(self): """Xử lý segment giọng nói khi kết thúc""" if len(self.speech_buffer) == 0: return # Convert buffer to numpy array speech_audio = np.array(self.speech_buffer) # Call callback with speech segment if self.callback and callable(self.callback): self.callback(speech_audio, self.sample_rate) # Clear buffer self.speech_buffer = [] self.silence_start_time = None print("✅ Đã xử lý segment giọng nói") def start_stream(self, callback: callable): """Bắt đầu xử lý stream""" self.is_running = True self.callback = callback self.speech_buffer = [] self.silence_start_time = None print("🎙️ Bắt đầu stream VAD...") def stop_stream(self): """Dừng xử lý stream""" self.is_running = False # Process any remaining speech if len(self.speech_buffer) > 0: self._process_speech_segment() print("🛑 Đã dừng stream VAD") def get_audio_chunk_from_stream(self, stream, chunk_size: int = 1024): """Lấy audio chunk từ stream (for microphone input)""" try: data = stream.read(chunk_size, exception_on_overflow=False) audio_data = np.frombuffer(data, dtype=np.int16) return audio_data.astype(np.float32) / 32768.0 # Normalize to [-1, 1] except Exception as e: print(f"❌ Lỗi đọc audio stream: {e}") return None