mimo_audio_chat / webrtc_vad.py
yanyihan-xiaomi's picture
Refactor VAD initialization and processing logic
457ae0e
import threading
from dataclasses import dataclass
from typing import Callable, Generator, override
import fastrtc
import librosa
import numpy as np
from ten_vad import TenVad
@dataclass
class VADEvent:
interrupt_signal: bool | None = None
full_audio: tuple[int, np.ndarray] | None = None
global_ten_vad: TenVad | None = None
global_vad_lock = threading.Lock()
def global_vad_process(audio_data: np.ndarray) -> float:
"""
Process audio data (hop_size=256) with global TenVad instance.
Returns:
speech probability.
"""
global global_ten_vad
with global_vad_lock:
if global_ten_vad is None:
global_ten_vad = TenVad()
prob, _ = global_ten_vad.process(audio_data)
return prob
class RealtimeVAD:
def __init__(
self,
src_sr: int = 24000,
start_threshold: float = 0.8,
end_threshold: float = 0.7,
pad_start_s: float = 0.6,
min_positive_s: float = 0.4,
min_silence_s: float = 1.2,
):
self.src_sr = src_sr
self.vad_sr = 16000
self.hop_size = 256
self.start_threshold = start_threshold
self.end_threshold = end_threshold
self.pad_start_s = pad_start_s
self.min_positive_s = min_positive_s
self.min_silence_s = min_silence_s
self.vad_buffer = np.array([], dtype=np.int16)
"""
VAD Buffer to store audio data for VAD processing
Stores 16kHz int16 PCM. Process and cut for each `hop_size` samples.
"""
self.src_buffer = np.array([], dtype=np.int16)
"""
Source Buffer to store original audio data
Stores original sampling rate (24kHz) int16 PCM.
Cut when pause detected (after `min_silence_s`).
Sliding window `pad_start_s` when inactive.
"""
self.vad_buffer_offset = 0
self.src_buffer_offset = 0
self.active = False
self.interrupt_signal = False
self.sum_positive_s = 0.0
self.silence_start_s: float | None = None
def process(self, audio_data: np.ndarray):
if audio_data.ndim == 2:
# FastRTC style [channels, samples]
audio_data = audio_data[0]
# Append to buffers
self.src_buffer = np.concatenate((self.src_buffer, audio_data))
vad_audio_data = librosa.resample(
audio_data.astype(np.float32) / 32768.0,
orig_sr=self.src_sr,
target_sr=self.vad_sr,
)
vad_audio_data = (vad_audio_data * 32767.0).round().astype(np.int16)
self.vad_buffer = np.concatenate((self.vad_buffer, vad_audio_data))
vad_buffer_size = self.vad_buffer.shape[0]
def process_chunk(chunk_offset_s: float, vad_chunk: np.ndarray):
speech_prob = global_vad_process(vad_chunk)
hop_s = self.hop_size / self.vad_sr
if not self.active:
if speech_prob >= self.start_threshold:
self.active = True
self.sum_positive_s = hop_s
print(f"[VAD] Active at {chunk_offset_s:.2f}s, {speech_prob=:.3f}")
else:
new_src_offset = int(
(chunk_offset_s - self.pad_start_s) * self.src_sr
)
cut_pos = new_src_offset - self.src_buffer_offset
if cut_pos > 0:
self.src_buffer = self.src_buffer[cut_pos:]
self.src_buffer_offset = new_src_offset
return
chunk_src_pos = int(chunk_offset_s * self.src_sr)
if speech_prob >= self.end_threshold:
self.silence_start_s = None
self.sum_positive_s += hop_s
if (
not self.interrupt_signal
and self.sum_positive_s >= self.min_positive_s
):
self.interrupt_signal = True
yield VADEvent(interrupt_signal=True)
print(
f"[VAD] Interrupt signal at {chunk_offset_s:.2f}s, {speech_prob=:.3f}"
)
elif self.silence_start_s is None:
self.silence_start_s = chunk_offset_s
if (
self.silence_start_s is not None
and chunk_offset_s - self.silence_start_s >= self.min_silence_s
):
# Inactive now
cut_pos = chunk_src_pos - self.src_buffer_offset
if self.interrupt_signal:
webrtc_audio = self.src_buffer[np.newaxis, :cut_pos]
yield VADEvent(full_audio=(self.src_sr, webrtc_audio))
print(
f"[VAD] Full audio at {chunk_offset_s:.2f}s, {webrtc_audio.shape=}"
)
self.src_buffer = self.src_buffer[cut_pos:]
self.src_buffer_offset = chunk_src_pos
self.active = False
self.interrupt_signal = False
self.sum_positive_s = 0.0
self.silence_start_s = None
processed_samples = 0
for chunk_pos in range(0, vad_buffer_size - self.hop_size, self.hop_size):
processed_samples = chunk_pos + self.hop_size
chunk_offset_s = (self.vad_buffer_offset + chunk_pos) / self.vad_sr
vad_chunk = self.vad_buffer[chunk_pos : chunk_pos + self.hop_size]
yield from process_chunk(chunk_offset_s, vad_chunk)
self.vad_buffer = self.vad_buffer[processed_samples:]
self.vad_buffer_offset += processed_samples
def init_global_ten_vad(input_sample_rate: int = 24000):
"""
Call this once at the start of the program to avoid latency on first use.
No-op if already initialized.
"""
global global_ten_vad
require_warmup = False
with global_vad_lock:
if global_ten_vad is None:
global_ten_vad = TenVad()
require_warmup = True
if require_warmup:
print("[VAD] Initializing global TenVad...")
realtime_vad = RealtimeVAD(src_sr=input_sample_rate)
for _ in range(25): # Warmup with 1 second of silence
for _ in realtime_vad.process(np.zeros(960, dtype=np.int16)):
pass
print("[VAD] Global VAD initialized")
type StreamerGenerator = Generator[fastrtc.tracks.EmitType, None, None]
type StreamerFn = Callable[[tuple[int, np.ndarray], str], StreamerGenerator]
class VADStreamHandler(fastrtc.StreamHandler):
def __init__(
self,
streamer_fn: StreamerFn,
input_sample_rate: int = 24000,
):
super().__init__(
"mono",
24000,
None,
input_sample_rate,
30,
)
self.streamer_fn = streamer_fn
self.realtime_vad = RealtimeVAD(src_sr=input_sample_rate)
self.generator: StreamerGenerator | None = None
init_global_ten_vad()
@override
def emit(self) -> fastrtc.tracks.EmitType:
if self.generator is None:
return None
try:
return next(self.generator)
except StopIteration:
self.generator = None
return None
@override
def receive(self, frame: tuple[int, np.ndarray]):
_, audio_data = frame
for event in self.realtime_vad.process(audio_data):
if event.interrupt_signal:
self.generator = None
self.clear_queue()
if event.full_audio is not None:
self.wait_for_args_sync()
self.latest_args[0] = event.full_audio
self.generator = self.streamer_fn(*self.latest_args)
@override
def copy(self):
return VADStreamHandler(
self.streamer_fn,
input_sample_rate=self.input_sample_rate,
)