Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| import time | |
| import os | |
| import urllib.request | |
| from scipy.spatial.distance import cosine | |
| import threading | |
| import queue | |
| from collections import deque | |
| import asyncio | |
| from typing import Generator, Tuple, List, Optional | |
| import whisper | |
| from transformers import pipeline | |
| # Configuration parameters (keeping original models) | |
| FINAL_TRANSCRIPTION_MODEL = "distil-large-v3" | |
| FINAL_BEAM_SIZE = 5 | |
| REALTIME_TRANSCRIPTION_MODEL = "distil-small.en" | |
| REALTIME_BEAM_SIZE = 5 | |
| TRANSCRIPTION_LANGUAGE = "en" | |
| SILERO_SENSITIVITY = 0.4 | |
| WEBRTC_SENSITIVITY = 3 | |
| MIN_LENGTH_OF_RECORDING = 0.7 | |
| PRE_RECORDING_BUFFER_DURATION = 0.35 | |
| # Speaker change detection parameters | |
| DEFAULT_CHANGE_THRESHOLD = 0.7 | |
| EMBEDDING_HISTORY_SIZE = 5 | |
| MIN_SEGMENT_DURATION = 1.0 | |
| DEFAULT_MAX_SPEAKERS = 4 | |
| ABSOLUTE_MAX_SPEAKERS = 10 | |
| SAMPLE_RATE = 16000 | |
| CHUNK_DURATION = 2.0 # Process audio in 2-second chunks | |
| # Speaker labels | |
| SPEAKER_LABELS = [f"Speaker {i+1}" for i in range(ABSOLUTE_MAX_SPEAKERS)] | |
| class SpeechBrainEncoder: | |
| """ECAPA-TDNN encoder from SpeechBrain for speaker embeddings""" | |
| def __init__(self, device="cpu"): | |
| self.device = device | |
| self.model = None | |
| self.embedding_dim = 192 | |
| self.model_loaded = False | |
| self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain") | |
| os.makedirs(self.cache_dir, exist_ok=True) | |
| def load_model(self): | |
| """Load the ECAPA-TDNN model""" | |
| try: | |
| from speechbrain.pretrained import EncoderClassifier | |
| self.model = EncoderClassifier.from_hparams( | |
| source="speechbrain/spkrec-ecapa-voxceleb", | |
| savedir=self.cache_dir, | |
| run_opts={"device": self.device} | |
| ) | |
| self.model_loaded = True | |
| return True | |
| except Exception as e: | |
| print(f"Error loading ECAPA-TDNN model: {e}") | |
| return False | |
| def embed_utterance(self, audio, sr=16000): | |
| """Extract speaker embedding from audio""" | |
| if not self.model_loaded: | |
| raise ValueError("Model not loaded. Call load_model() first.") | |
| try: | |
| if isinstance(audio, np.ndarray): | |
| waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0) | |
| else: | |
| waveform = audio.unsqueeze(0) | |
| if sr != 16000: | |
| waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) | |
| with torch.no_grad(): | |
| embedding = self.model.encode_batch(waveform) | |
| return embedding.squeeze().cpu().numpy() | |
| except Exception as e: | |
| print(f"Error extracting embedding: {e}") | |
| return np.zeros(self.embedding_dim) | |
| class SpeakerChangeDetector: | |
| """Speaker change detector that supports configurable number of speakers""" | |
| def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS): | |
| self.embedding_dim = embedding_dim | |
| self.change_threshold = change_threshold | |
| self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) | |
| self.current_speaker = 0 | |
| self.previous_embeddings = [] | |
| self.last_change_time = time.time() | |
| self.mean_embeddings = [None] * self.max_speakers | |
| self.speaker_embeddings = [[] for _ in range(self.max_speakers)] | |
| self.last_similarity = 0.0 | |
| self.active_speakers = set([0]) | |
| def set_max_speakers(self, max_speakers): | |
| """Update the maximum number of speakers""" | |
| new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) | |
| if new_max < self.max_speakers: | |
| for speaker_id in list(self.active_speakers): | |
| if speaker_id >= new_max: | |
| self.active_speakers.discard(speaker_id) | |
| if self.current_speaker >= new_max: | |
| self.current_speaker = 0 | |
| if new_max > self.max_speakers: | |
| self.mean_embeddings.extend([None] * (new_max - self.max_speakers)) | |
| self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)]) | |
| else: | |
| self.mean_embeddings = self.mean_embeddings[:new_max] | |
| self.speaker_embeddings = self.speaker_embeddings[:new_max] | |
| self.max_speakers = new_max | |
| def set_change_threshold(self, threshold): | |
| """Update the threshold for detecting speaker changes""" | |
| self.change_threshold = max(0.1, min(threshold, 0.99)) | |
| def add_embedding(self, embedding, timestamp=None): | |
| """Add a new embedding and check if there's a speaker change""" | |
| current_time = timestamp or time.time() | |
| if not self.previous_embeddings: | |
| self.previous_embeddings.append(embedding) | |
| self.speaker_embeddings[self.current_speaker].append(embedding) | |
| if self.mean_embeddings[self.current_speaker] is None: | |
| self.mean_embeddings[self.current_speaker] = embedding.copy() | |
| return self.current_speaker, 1.0 | |
| current_mean = self.mean_embeddings[self.current_speaker] | |
| if current_mean is not None: | |
| similarity = 1.0 - cosine(embedding, current_mean) | |
| else: | |
| similarity = 1.0 - cosine(embedding, self.previous_embeddings[-1]) | |
| self.last_similarity = similarity | |
| time_since_last_change = current_time - self.last_change_time | |
| is_speaker_change = False | |
| if time_since_last_change >= MIN_SEGMENT_DURATION: | |
| if similarity < self.change_threshold: | |
| best_speaker = self.current_speaker | |
| best_similarity = similarity | |
| for speaker_id in range(self.max_speakers): | |
| if speaker_id == self.current_speaker: | |
| continue | |
| speaker_mean = self.mean_embeddings[speaker_id] | |
| if speaker_mean is not None: | |
| speaker_similarity = 1.0 - cosine(embedding, speaker_mean) | |
| if speaker_similarity > best_similarity: | |
| best_similarity = speaker_similarity | |
| best_speaker = speaker_id | |
| if best_speaker != self.current_speaker: | |
| is_speaker_change = True | |
| self.current_speaker = best_speaker | |
| elif len(self.active_speakers) < self.max_speakers: | |
| for new_id in range(self.max_speakers): | |
| if new_id not in self.active_speakers: | |
| is_speaker_change = True | |
| self.current_speaker = new_id | |
| self.active_speakers.add(new_id) | |
| break | |
| if is_speaker_change: | |
| self.last_change_time = current_time | |
| self.previous_embeddings.append(embedding) | |
| if len(self.previous_embeddings) > EMBEDDING_HISTORY_SIZE: | |
| self.previous_embeddings.pop(0) | |
| self.speaker_embeddings[self.current_speaker].append(embedding) | |
| self.active_speakers.add(self.current_speaker) | |
| if len(self.speaker_embeddings[self.current_speaker]) > 30: | |
| self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-30:] | |
| if self.speaker_embeddings[self.current_speaker]: | |
| self.mean_embeddings[self.current_speaker] = np.mean( | |
| self.speaker_embeddings[self.current_speaker], axis=0 | |
| ) | |
| return self.current_speaker, similarity | |
| class AudioProcessor: | |
| """Processes audio data to extract speaker embeddings""" | |
| def __init__(self, encoder): | |
| self.encoder = encoder | |
| def extract_embedding(self, audio_data): | |
| try: | |
| # Ensure audio is float32 and normalized | |
| if audio_data.dtype != np.float32: | |
| audio_data = audio_data.astype(np.float32) | |
| # Normalize if needed | |
| if np.abs(audio_data).max() > 1.0: | |
| audio_data = audio_data / np.abs(audio_data).max() | |
| # Extract embedding using the loaded encoder | |
| embedding = self.encoder.embed_utterance(audio_data) | |
| return embedding | |
| except Exception as e: | |
| print(f"Embedding extraction error: {e}") | |
| return np.zeros(self.encoder.embedding_dim) | |
| class RealTimeSpeakerDiarization: | |
| """Main class for real-time speaker diarization with FastRTC""" | |
| def __init__(self, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS): | |
| self.encoder = None | |
| self.audio_processor = None | |
| self.speaker_detector = None | |
| self.transcription_pipeline = None | |
| self.change_threshold = change_threshold | |
| self.max_speakers = max_speakers | |
| self.transcript_history = [] | |
| self.is_initialized = False | |
| # Audio processing | |
| self.audio_buffer = deque(maxlen=int(SAMPLE_RATE * 10)) # 10 second buffer | |
| self.processing_queue = queue.Queue() | |
| self.last_processed_time = 0 | |
| self.current_transcript = "" | |
| def initialize(self): | |
| """Initialize the speaker diarization system""" | |
| if self.is_initialized: | |
| return True | |
| try: | |
| device_str = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Initializing models on {device_str}...") | |
| # Initialize speaker encoder | |
| self.encoder = SpeechBrainEncoder(device=device_str) | |
| success = self.encoder.load_model() | |
| if not success: | |
| return False | |
| # Initialize transcription pipeline | |
| self.transcription_pipeline = pipeline( | |
| "automatic-speech-recognition", | |
| model=f"openai/whisper-{REALTIME_TRANSCRIPTION_MODEL}", | |
| device=0 if torch.cuda.is_available() else -1, | |
| return_timestamps=True | |
| ) | |
| self.audio_processor = AudioProcessor(self.encoder) | |
| self.speaker_detector = SpeakerChangeDetector( | |
| embedding_dim=self.encoder.embedding_dim, | |
| change_threshold=self.change_threshold, | |
| max_speakers=self.max_speakers | |
| ) | |
| self.is_initialized = True | |
| print("Speaker diarization system initialized successfully!") | |
| return True | |
| except Exception as e: | |
| print(f"Initialization error: {e}") | |
| return False | |
| def update_settings(self, change_threshold, max_speakers): | |
| """Update diarization settings""" | |
| self.change_threshold = change_threshold | |
| self.max_speakers = max_speakers | |
| if self.speaker_detector: | |
| self.speaker_detector.set_change_threshold(change_threshold) | |
| self.speaker_detector.set_max_speakers(max_speakers) | |
| def process_audio_stream(self, audio_chunk, sample_rate): | |
| """Process real-time audio stream from FastRTC""" | |
| if not self.is_initialized: | |
| return self.get_current_transcript(), "System not initialized" | |
| try: | |
| # Convert to numpy array if needed | |
| if hasattr(audio_chunk, 'numpy'): | |
| audio_data = audio_chunk.numpy() | |
| else: | |
| audio_data = np.array(audio_chunk) | |
| # Handle different audio formats | |
| if len(audio_data.shape) > 1: | |
| audio_data = audio_data.mean(axis=1) # Convert to mono | |
| # Resample if needed | |
| if sample_rate != SAMPLE_RATE: | |
| audio_data = torchaudio.functional.resample( | |
| torch.tensor(audio_data), sample_rate, SAMPLE_RATE | |
| ).numpy() | |
| # Add to buffer | |
| self.audio_buffer.extend(audio_data) | |
| # Process if we have enough audio | |
| current_time = time.time() | |
| if (current_time - self.last_processed_time) >= CHUNK_DURATION: | |
| self.process_buffered_audio() | |
| self.last_processed_time = current_time | |
| return self.get_current_transcript(), f"Processing... Buffer: {len(self.audio_buffer)} samples" | |
| except Exception as e: | |
| error_msg = f"Error processing audio stream: {str(e)}" | |
| print(error_msg) | |
| return self.get_current_transcript(), error_msg | |
| def process_buffered_audio(self): | |
| """Process buffered audio for transcription and speaker diarization""" | |
| if len(self.audio_buffer) < int(SAMPLE_RATE * MIN_LENGTH_OF_RECORDING): | |
| return | |
| try: | |
| # Get audio data from buffer | |
| audio_data = np.array(list(self.audio_buffer)) | |
| # Transcribe audio | |
| if len(audio_data) > 0: | |
| result = self.transcription_pipeline( | |
| audio_data, | |
| return_timestamps=True, | |
| generate_kwargs={"language": TRANSCRIPTION_LANGUAGE} | |
| ) | |
| transcription = result["text"].strip() | |
| if transcription and len(transcription) > 0: | |
| # Extract speaker embedding | |
| embedding = self.audio_processor.extract_embedding(audio_data) | |
| # Detect speaker | |
| speaker_id, similarity = self.speaker_detector.add_embedding(embedding) | |
| # Format text with speaker label | |
| speaker_label = SPEAKER_LABELS[speaker_id] | |
| formatted_text = f"{speaker_label}: {transcription}" | |
| # Add to transcript | |
| self.add_to_transcript(formatted_text) | |
| print(f"Transcribed: {formatted_text} (Similarity: {similarity:.3f})") | |
| # Clear part of the buffer to prevent memory issues | |
| if len(self.audio_buffer) > SAMPLE_RATE * 5: # Keep last 5 seconds | |
| self.audio_buffer = deque(list(self.audio_buffer)[-SAMPLE_RATE * 3:], maxlen=int(SAMPLE_RATE * 10)) | |
| except Exception as e: | |
| print(f"Error in process_buffered_audio: {e}") | |
| def get_current_transcript(self): | |
| """Get the current transcript""" | |
| return "\n".join(self.transcript_history) if self.transcript_history else "Listening..." | |
| def add_to_transcript(self, formatted_text: str): | |
| """Add formatted text to transcript history""" | |
| self.transcript_history.append(formatted_text) | |
| # Keep only last 50 entries to prevent memory issues | |
| if len(self.transcript_history) > 50: | |
| self.transcript_history = self.transcript_history[-50:] | |
| def clear_transcript(self): | |
| """Clear transcript history and reset speaker detector""" | |
| self.transcript_history = [] | |
| self.audio_buffer.clear() | |
| if self.speaker_detector: | |
| self.speaker_detector = SpeakerChangeDetector( | |
| embedding_dim=self.encoder.embedding_dim, | |
| change_threshold=self.change_threshold, | |
| max_speakers=self.max_speakers | |
| ) | |
| def get_status(self): | |
| """Get current system status""" | |
| if not self.is_initialized: | |
| return "System not initialized" | |
| if self.speaker_detector: | |
| active_speakers = len(self.speaker_detector.active_speakers) | |
| current_speaker = self.speaker_detector.current_speaker + 1 | |
| similarity = self.speaker_detector.last_similarity | |
| return f"Active: {active_speakers} speakers | Current: Speaker {current_speaker} | Similarity: {similarity:.3f}" | |
| return "Ready" | |
| # Global instance | |
| diarization_system = RealTimeSpeakerDiarization() | |
| def initialize_system(): | |
| """Initialize the diarization system""" | |
| success = diarization_system.initialize() | |
| if success: | |
| return "β Speaker diarization system initialized successfully!" | |
| else: | |
| return "β Failed to initialize speaker diarization system. Please check your setup." | |
| def process_realtime_audio(audio_stream, change_threshold, max_speakers): | |
| """Process real-time audio stream from FastRTC""" | |
| if not diarization_system.is_initialized: | |
| return "Please initialize the system first.", "System not ready" | |
| # Update settings | |
| diarization_system.update_settings(change_threshold, max_speakers) | |
| if audio_stream is None: | |
| return diarization_system.get_current_transcript(), diarization_system.get_status() | |
| # Process the audio stream | |
| transcript, status = diarization_system.process_audio_stream(audio_stream, SAMPLE_RATE) | |
| return transcript, diarization_system.get_status() | |
| def clear_conversation(): | |
| """Clear the conversation transcript""" | |
| diarization_system.clear_transcript() | |
| return "Conversation cleared. Listening...", "Ready" | |
| def create_gradio_interface(): | |
| """Create and return the Gradio interface with FastRTC""" | |
| with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# ποΈ Real-time Speaker Diarization with FastRTC") | |
| gr.Markdown("Speak into your microphone for real-time speaker diarization and transcription.") | |
| # Initialization section | |
| with gr.Row(): | |
| init_btn = gr.Button("π Initialize System", variant="primary", scale=1) | |
| init_status = gr.Textbox(label="System Status", interactive=False, scale=2) | |
| # Settings section | |
| with gr.Row(): | |
| with gr.Column(): | |
| change_threshold = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| value=DEFAULT_CHANGE_THRESHOLD, | |
| step=0.05, | |
| label="Speaker Change Threshold", | |
| info="Lower values = more sensitive to speaker changes" | |
| ) | |
| with gr.Column(): | |
| max_speakers = gr.Slider( | |
| minimum=2, | |
| maximum=ABSOLUTE_MAX_SPEAKERS, | |
| value=DEFAULT_MAX_SPEAKERS, | |
| step=1, | |
| label="Maximum Number of Speakers", | |
| info="Maximum number of speakers to detect" | |
| ) | |
| # FastRTC Audio Input | |
| with gr.Row(): | |
| with gr.Column(): | |
| # FastRTC component for real-time audio | |
| audio_input = gr.FastRTC( | |
| audio=True, | |
| video=False, | |
| label="π€ Real-time Audio Input", | |
| audio_sample_rate=SAMPLE_RATE, | |
| audio_channels=1 | |
| ) | |
| clear_btn = gr.Button("ποΈ Clear Conversation", variant="stop") | |
| with gr.Column(): | |
| current_status = gr.Textbox( | |
| label="Current Status", | |
| interactive=False, | |
| value="Click Initialize to start" | |
| ) | |
| # Output section | |
| transcript_output = gr.Textbox( | |
| label="π΄ Live Transcript with Speaker Labels", | |
| lines=15, | |
| max_lines=25, | |
| interactive=False, | |
| value="Click Initialize, then start speaking...", | |
| autoscroll=True | |
| ) | |
| # Event handlers | |
| init_btn.click( | |
| fn=initialize_system, | |
| outputs=[init_status] | |
| ) | |
| # FastRTC stream processing | |
| audio_input.stream( | |
| fn=process_realtime_audio, | |
| inputs=[audio_input, change_threshold, max_speakers], | |
| outputs=[transcript_output, current_status], | |
| time_limit=30 # Process in 30-second chunks | |
| ) | |
| clear_btn.click( | |
| fn=clear_conversation, | |
| outputs=[transcript_output, current_status] | |
| ) | |
| # Instructions | |
| with gr.Accordion("π Instructions", open=False): | |
| gr.Markdown(""" | |
| ## How to Use: | |
| 1. **Initialize**: Click "π Initialize System" to load the AI models (this may take a moment) | |
| 2. **Allow Microphone**: Your browser will ask for microphone permission - please allow it | |
| 3. **Adjust Settings**: | |
| - **Speaker Change Threshold**: | |
| - Lower (0.3-0.5) for speakers with different voices | |
| - Higher (0.6-0.8) for speakers with similar voices | |
| - **Max Speakers**: Set expected number of speakers (2-10) | |
| 4. **Start Speaking**: The system will automatically transcribe and identify speakers | |
| 5. **View Results**: See real-time transcript with speaker labels (Speaker 1, Speaker 2, etc.) | |
| 6. **Clear**: Use "Clear Conversation" to reset and start fresh | |
| ## Features: | |
| - β Real-time audio processing via FastRTC | |
| - β Automatic speech recognition with Whisper | |
| - β Speaker diarization with ECAPA-TDNN | |
| - β Live transcript with speaker labels | |
| - β Configurable sensitivity settings | |
| - β Support for up to 10 speakers | |
| ## Tips: | |
| - Speak clearly and allow brief pauses between speakers | |
| - The system learns speaker characteristics over time | |
| - Better results with distinct speaker voices | |
| - Ensure good microphone quality for best performance | |
| """) | |
| return demo | |
| if __name__ == "__main__": | |
| # Create and launch the Gradio interface | |
| demo = create_gradio_interface() | |
| demo.launch( | |
| share=True, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) | |