Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import threading | |
| from dataclasses import dataclass | |
| from typing import Callable, Generator, override | |
| import fastrtc | |
| import librosa | |
| import numpy as np | |
| from ten_vad import TenVad | |
| class VADEvent: | |
| interrupt_signal: bool | None = None | |
| full_audio: tuple[int, np.ndarray] | None = None | |
| global_ten_vad: TenVad | None = None | |
| global_vad_lock = threading.Lock() | |
| def global_vad_process(audio_data: np.ndarray) -> float: | |
| """ | |
| Process audio data (hop_size=256) with global TenVad instance. | |
| Returns: | |
| speech probability. | |
| """ | |
| global global_ten_vad | |
| with global_vad_lock: | |
| if global_ten_vad is None: | |
| global_ten_vad = TenVad() | |
| prob, _ = global_ten_vad.process(audio_data) | |
| return prob | |
| class RealtimeVAD: | |
| def __init__( | |
| self, | |
| src_sr: int = 24000, | |
| start_threshold: float = 0.8, | |
| end_threshold: float = 0.7, | |
| pad_start_s: float = 0.6, | |
| min_positive_s: float = 0.4, | |
| min_silence_s: float = 1.2, | |
| ): | |
| self.src_sr = src_sr | |
| self.vad_sr = 16000 | |
| self.hop_size = 256 | |
| self.start_threshold = start_threshold | |
| self.end_threshold = end_threshold | |
| self.pad_start_s = pad_start_s | |
| self.min_positive_s = min_positive_s | |
| self.min_silence_s = min_silence_s | |
| self.vad_buffer = np.array([], dtype=np.int16) | |
| """ | |
| VAD Buffer to store audio data for VAD processing | |
| Stores 16kHz int16 PCM. Process and cut for each `hop_size` samples. | |
| """ | |
| self.src_buffer = np.array([], dtype=np.int16) | |
| """ | |
| Source Buffer to store original audio data | |
| Stores original sampling rate (24kHz) int16 PCM. | |
| Cut when pause detected (after `min_silence_s`). | |
| Sliding window `pad_start_s` when inactive. | |
| """ | |
| self.vad_buffer_offset = 0 | |
| self.src_buffer_offset = 0 | |
| self.active = False | |
| self.interrupt_signal = False | |
| self.sum_positive_s = 0.0 | |
| self.silence_start_s: float | None = None | |
| def process(self, audio_data: np.ndarray): | |
| if audio_data.ndim == 2: | |
| # FastRTC style [channels, samples] | |
| audio_data = audio_data[0] | |
| # Append to buffers | |
| self.src_buffer = np.concatenate((self.src_buffer, audio_data)) | |
| vad_audio_data = librosa.resample( | |
| audio_data.astype(np.float32) / 32768.0, | |
| orig_sr=self.src_sr, | |
| target_sr=self.vad_sr, | |
| ) | |
| vad_audio_data = (vad_audio_data * 32767.0).round().astype(np.int16) | |
| self.vad_buffer = np.concatenate((self.vad_buffer, vad_audio_data)) | |
| vad_buffer_size = self.vad_buffer.shape[0] | |
| def process_chunk(chunk_offset_s: float, vad_chunk: np.ndarray): | |
| speech_prob = global_vad_process(vad_chunk) | |
| hop_s = self.hop_size / self.vad_sr | |
| if not self.active: | |
| if speech_prob >= self.start_threshold: | |
| self.active = True | |
| self.sum_positive_s = hop_s | |
| print(f"[VAD] Active at {chunk_offset_s:.2f}s, {speech_prob=:.3f}") | |
| else: | |
| new_src_offset = int( | |
| (chunk_offset_s - self.pad_start_s) * self.src_sr | |
| ) | |
| cut_pos = new_src_offset - self.src_buffer_offset | |
| if cut_pos > 0: | |
| self.src_buffer = self.src_buffer[cut_pos:] | |
| self.src_buffer_offset = new_src_offset | |
| return | |
| chunk_src_pos = int(chunk_offset_s * self.src_sr) | |
| if speech_prob >= self.end_threshold: | |
| self.silence_start_s = None | |
| self.sum_positive_s += hop_s | |
| if ( | |
| not self.interrupt_signal | |
| and self.sum_positive_s >= self.min_positive_s | |
| ): | |
| self.interrupt_signal = True | |
| yield VADEvent(interrupt_signal=True) | |
| print( | |
| f"[VAD] Interrupt signal at {chunk_offset_s:.2f}s, {speech_prob=:.3f}" | |
| ) | |
| elif self.silence_start_s is None: | |
| self.silence_start_s = chunk_offset_s | |
| if ( | |
| self.silence_start_s is not None | |
| and chunk_offset_s - self.silence_start_s >= self.min_silence_s | |
| ): | |
| # Inactive now | |
| cut_pos = chunk_src_pos - self.src_buffer_offset | |
| if self.interrupt_signal: | |
| webrtc_audio = self.src_buffer[np.newaxis, :cut_pos] | |
| yield VADEvent(full_audio=(self.src_sr, webrtc_audio)) | |
| print( | |
| f"[VAD] Full audio at {chunk_offset_s:.2f}s, {webrtc_audio.shape=}" | |
| ) | |
| self.src_buffer = self.src_buffer[cut_pos:] | |
| self.src_buffer_offset = chunk_src_pos | |
| self.active = False | |
| self.interrupt_signal = False | |
| self.sum_positive_s = 0.0 | |
| self.silence_start_s = None | |
| processed_samples = 0 | |
| for chunk_pos in range(0, vad_buffer_size - self.hop_size, self.hop_size): | |
| processed_samples = chunk_pos + self.hop_size | |
| chunk_offset_s = (self.vad_buffer_offset + chunk_pos) / self.vad_sr | |
| vad_chunk = self.vad_buffer[chunk_pos : chunk_pos + self.hop_size] | |
| yield from process_chunk(chunk_offset_s, vad_chunk) | |
| self.vad_buffer = self.vad_buffer[processed_samples:] | |
| self.vad_buffer_offset += processed_samples | |
| def init_global_ten_vad(input_sample_rate: int = 24000): | |
| """ | |
| Call this once at the start of the program to avoid latency on first use. | |
| No-op if already initialized. | |
| """ | |
| global global_ten_vad | |
| require_warmup = False | |
| with global_vad_lock: | |
| if global_ten_vad is None: | |
| global_ten_vad = TenVad() | |
| require_warmup = True | |
| if require_warmup: | |
| print("[VAD] Initializing global TenVad...") | |
| realtime_vad = RealtimeVAD(src_sr=input_sample_rate) | |
| for _ in range(25): # Warmup with 1 second of silence | |
| for _ in realtime_vad.process(np.zeros(960, dtype=np.int16)): | |
| pass | |
| print("[VAD] Global VAD initialized") | |
| type StreamerGenerator = Generator[fastrtc.tracks.EmitType, None, None] | |
| type StreamerFn = Callable[[tuple[int, np.ndarray], str], StreamerGenerator] | |
| class VADStreamHandler(fastrtc.StreamHandler): | |
| def __init__( | |
| self, | |
| streamer_fn: StreamerFn, | |
| input_sample_rate: int = 24000, | |
| ): | |
| super().__init__( | |
| "mono", | |
| 24000, | |
| None, | |
| input_sample_rate, | |
| 30, | |
| ) | |
| self.streamer_fn = streamer_fn | |
| self.realtime_vad = RealtimeVAD(src_sr=input_sample_rate) | |
| self.generator: StreamerGenerator | None = None | |
| init_global_ten_vad() | |
| def emit(self) -> fastrtc.tracks.EmitType: | |
| if self.generator is None: | |
| return None | |
| try: | |
| return next(self.generator) | |
| except StopIteration: | |
| self.generator = None | |
| return None | |
| def receive(self, frame: tuple[int, np.ndarray]): | |
| _, audio_data = frame | |
| for event in self.realtime_vad.process(audio_data): | |
| if event.interrupt_signal: | |
| self.generator = None | |
| self.clear_queue() | |
| if event.full_audio is not None: | |
| self.wait_for_args_sync() | |
| self.latest_args[0] = event.full_audio | |
| self.generator = self.streamer_fn(*self.latest_args) | |
| def copy(self): | |
| return VADStreamHandler( | |
| self.streamer_fn, | |
| input_sample_rate=self.input_sample_rate, | |
| ) | |