File size: 5,738 Bytes
dbf2148
 
 
b5e51ac
dbf2148
 
 
 
b5e51ac
dbf2148
b5e51ac
 
 
dbf2148
b5e51ac
dbf2148
b5e51ac
dbf2148
b5e51ac
 
 
dbf2148
b5e51ac
dbf2148
b5e51ac
dbf2148
b5e51ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbf2148
b5e51ac
 
 
 
 
 
 
dbf2148
b5e51ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbf2148
 
b5e51ac
 
 
 
 
 
 
 
 
 
 
 
dbf2148
b5e51ac
 
 
 
 
 
 
dbf2148
b5e51ac
 
 
 
 
dbf2148
 
b5e51ac
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
import torchaudio
import numpy as np
from typing import Optional, Callable
from config.settings import settings

class SpeechBrainVAD:
    def __init__(self):
        self.model = None
        self.sample_rate = settings.SAMPLE_RATE
        self.is_streaming = False
        self.speech_callback = None
        self.audio_buffer = []
        self._initialize_model()

    def _initialize_model(self):
        """Khởi tạo VAD model từ SpeechBrain"""
        try:
            from speechbrain.pretrained import VAD
            print("🔄 Đang tải VAD model từ SpeechBrain...")
            self.model = VAD.from_hparams(
                source=settings.VAD_MODEL,
                savedir=f"/tmp/{settings.VAD_MODEL.replace('/', '_')}"
            )
            print("✅ Đã tải VAD model thành công")
        except Exception as e:
            print(f"❌ Lỗi tải VAD model: {e}")
            self.model = None

    def start_stream(self, speech_callback: Callable):
        """Bắt đầu stream với VAD"""
        if self.model is None:
            print("❌ VAD model chưa được khởi tạo")
            return False
        
        self.is_streaming = True
        self.speech_callback = speech_callback
        self.audio_buffer = []
        print("🎙️ Bắt đầu VAD streaming...")
        return True

    def stop_stream(self):
        """Dừng stream"""
        self.is_streaming = False
        self.speech_callback = None
        self.audio_buffer = []
        print("🛑 Đã dừng VAD streaming")

    def process_stream(self, audio_chunk: np.ndarray, sample_rate: int):
        """Xử lý audio chunk với VAD"""
        if not self.is_streaming or self.model is None:
            return

        try:
            # Resample nếu cần
            if sample_rate != self.sample_rate:
                audio_chunk = self._resample_audio(audio_chunk, sample_rate, self.sample_rate)

            # Thêm vào buffer
            self.audio_buffer.extend(audio_chunk)

            # Xử lý khi buffer đủ lớn (2 giây)
            buffer_duration = len(self.audio_buffer) / self.sample_rate
            if buffer_duration >= 2.0:
                self._process_buffer()

        except Exception as e:
            print(f"❌ Lỗi xử lý VAD: {e}")

    def _process_buffer(self):
        """Xử lý buffer audio với VAD"""
        try:
            # Chuyển buffer thành tensor
            audio_tensor = torch.FloatTensor(self.audio_buffer).unsqueeze(0)
            
            # Phát hiện speech với VAD
            boundaries = self.model.get_speech_segments(
                audio_tensor,
                # Điều chỉnh parameters để nhạy hơn
                threshold=settings.VAD_THRESHOLD - 0.1,  # Giảm threshold
                min_silence_duration=settings.VAD_MIN_SILENCE_DURATION + 0.3,  # Tăng silence duration
                speech_pad_duration=settings.VAD_SPEECH_PAD_DURATION
            )

            # Xử lý speech segments
            if len(boundaries) > 0:
                for start, end in boundaries:
                    start_sample = int(start * self.sample_rate)
                    end_sample = int(end * self.sample_rate)
                    
                    # Trích xuất speech segment
                    speech_audio = np.array(self.audio_buffer[start_sample:end_sample])
                    
                    if len(speech_audio) > self.sample_rate * 0.5:  # Ít nhất 0.5 giây
                        print(f"🎯 VAD phát hiện speech: {len(speech_audio)/self.sample_rate:.2f}s")
                        
                        # Gọi callback với speech segment
                        if self.speech_callback:
                            self.speech_callback(speech_audio, self.sample_rate)

            # Giữ lại 0.5 giây cuối để overlap
            keep_samples = int(self.sample_rate * 0.5)
            if len(self.audio_buffer) > keep_samples:
                self.audio_buffer = self.audio_buffer[-keep_samples:]
            else:
                self.audio_buffer = []

        except Exception as e:
            print(f"❌ Lỗi xử lý VAD buffer: {e}")
            self.audio_buffer = []

    def _resample_audio(self, audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
        """Resample audio nếu cần"""
        if orig_sr == target_sr:
            return audio
        
        try:
            audio_tensor = torch.FloatTensor(audio).unsqueeze(0)
            resampler = torchaudio.transforms.Resample(orig_sr, target_sr)
            resampled = resampler(audio_tensor)
            return resampled.squeeze(0).numpy()
        except Exception as e:
            print(f"⚠️ Lỗi resample: {e}")
            return audio

    def is_speech(self, audio_chunk: np.ndarray, sample_rate: int) -> bool:
        """Kiểm tra xem audio chunk có phải là speech không"""
        if self.model is None:
            return True  # Fallback: luôn coi là speech
            
        try:
            # Resample nếu cần
            if sample_rate != self.sample_rate:
                audio_chunk = self._resample_audio(audio_chunk, sample_rate, self.sample_rate)

            # Chuyển thành tensor
            audio_tensor = torch.FloatTensor(audio_chunk).unsqueeze(0)
            
            # Phát hiện speech
            prob_speech = self.model.get_speech_prob_chunk(audio_tensor)
            
            # Kiểm tra ngưỡng
            return prob_speech.mean().item() > (settings.VAD_THRESHOLD - 0.1)
            
        except Exception as e:
            print(f"❌ Lỗi kiểm tra speech: {e}")
            return True