Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import queue | |
| import torch | |
| import time | |
| import threading | |
| import os | |
| import urllib.request | |
| import torchaudio | |
| from scipy.spatial.distance import cosine | |
| from RealtimeSTT import AudioToTextRecorder | |
| from fastapi import FastAPI, APIRouter | |
| from fastrtc import Stream, AsyncStreamHandler | |
| import json | |
| import asyncio | |
| import uvicorn | |
| from queue import Queue | |
| import logging | |
| from gradio_webrtc import WebRTC | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Simplified configuration parameters | |
| SILENCE_THRESHS = [0, 0.4] | |
| FINAL_TRANSCRIPTION_MODEL = "distil-large-v3" | |
| FINAL_BEAM_SIZE = 5 | |
| REALTIME_TRANSCRIPTION_MODEL = "distil-small.en" | |
| REALTIME_BEAM_SIZE = 5 | |
| TRANSCRIPTION_LANGUAGE = "en" | |
| SILERO_SENSITIVITY = 0.4 | |
| WEBRTC_SENSITIVITY = 3 | |
| MIN_LENGTH_OF_RECORDING = 0.7 | |
| PRE_RECORDING_BUFFER_DURATION = 0.35 | |
| # Speaker change detection parameters | |
| DEFAULT_CHANGE_THRESHOLD = 0.65 | |
| EMBEDDING_HISTORY_SIZE = 5 | |
| MIN_SEGMENT_DURATION = 1.5 | |
| DEFAULT_MAX_SPEAKERS = 4 | |
| ABSOLUTE_MAX_SPEAKERS = 8 | |
| # Global variables | |
| SAMPLE_RATE = 16000 | |
| BUFFER_SIZE = 1024 | |
| CHANNELS = 1 | |
| # Speaker colors - more distinguishable colors | |
| SPEAKER_COLORS = [ | |
| "#FF6B6B", # Red | |
| "#4ECDC4", # Teal | |
| "#45B7D1", # Blue | |
| "#96CEB4", # Green | |
| "#FFEAA7", # Yellow | |
| "#DDA0DD", # Plum | |
| "#98D8C8", # Mint | |
| "#F7DC6F", # Gold | |
| ] | |
| SPEAKER_COLOR_NAMES = [ | |
| "Red", "Teal", "Blue", "Green", "Yellow", "Plum", "Mint", "Gold" | |
| ] | |
| class SpeechBrainEncoder: | |
| """ECAPA-TDNN encoder from SpeechBrain for speaker embeddings""" | |
| def __init__(self, device="cpu"): | |
| self.device = device | |
| self.model = None | |
| self.embedding_dim = 192 | |
| self.model_loaded = False | |
| self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain") | |
| os.makedirs(self.cache_dir, exist_ok=True) | |
| def load_model(self): | |
| """Load the ECAPA-TDNN model""" | |
| try: | |
| from speechbrain.pretrained import EncoderClassifier | |
| self.model = EncoderClassifier.from_hparams( | |
| source="speechbrain/spkrec-ecapa-voxceleb", | |
| savedir=self.cache_dir, | |
| run_opts={"device": self.device} | |
| ) | |
| self.model_loaded = True | |
| logger.info("ECAPA-TDNN model loaded successfully!") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error loading ECAPA-TDNN model: {e}") | |
| return False | |
| def embed_utterance(self, audio, sr=16000): | |
| """Extract speaker embedding from audio""" | |
| if not self.model_loaded: | |
| raise ValueError("Model not loaded. Call load_model() first.") | |
| try: | |
| if isinstance(audio, np.ndarray): | |
| # Ensure audio is float32 and properly normalized | |
| audio = audio.astype(np.float32) | |
| if np.max(np.abs(audio)) > 1.0: | |
| audio = audio / np.max(np.abs(audio)) | |
| waveform = torch.tensor(audio).unsqueeze(0) | |
| else: | |
| waveform = audio.unsqueeze(0) | |
| # Resample if necessary | |
| if sr != 16000: | |
| waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) | |
| with torch.no_grad(): | |
| embedding = self.model.encode_batch(waveform) | |
| return embedding.squeeze().cpu().numpy() | |
| except Exception as e: | |
| logger.error(f"Error extracting embedding: {e}") | |
| return np.zeros(self.embedding_dim) | |
| class AudioProcessor: | |
| """Processes audio data to extract speaker embeddings""" | |
| def __init__(self, encoder): | |
| self.encoder = encoder | |
| self.audio_buffer = [] | |
| self.min_audio_length = int(SAMPLE_RATE * 1.0) # Minimum 1 second of audio | |
| def add_audio_chunk(self, audio_chunk): | |
| """Add audio chunk to buffer""" | |
| self.audio_buffer.extend(audio_chunk) | |
| # Keep buffer from getting too large | |
| max_buffer_size = int(SAMPLE_RATE * 10) # 10 seconds max | |
| if len(self.audio_buffer) > max_buffer_size: | |
| self.audio_buffer = self.audio_buffer[-max_buffer_size:] | |
| def extract_embedding_from_buffer(self): | |
| """Extract embedding from current audio buffer""" | |
| if len(self.audio_buffer) < self.min_audio_length: | |
| return None | |
| try: | |
| # Use the last portion of the buffer for embedding | |
| audio_segment = np.array(self.audio_buffer[-self.min_audio_length:], dtype=np.float32) | |
| # Normalize audio | |
| if np.max(np.abs(audio_segment)) > 0: | |
| audio_segment = audio_segment / np.max(np.abs(audio_segment)) | |
| else: | |
| return None | |
| embedding = self.encoder.embed_utterance(audio_segment) | |
| return embedding | |
| except Exception as e: | |
| logger.error(f"Embedding extraction error: {e}") | |
| return None | |
| class SpeakerChangeDetector: | |
| """Improved speaker change detector""" | |
| def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS): | |
| self.embedding_dim = embedding_dim | |
| self.change_threshold = change_threshold | |
| self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) | |
| self.current_speaker = 0 | |
| self.speaker_embeddings = [[] for _ in range(self.max_speakers)] | |
| self.speaker_centroids = [None] * self.max_speakers | |
| self.last_change_time = time.time() | |
| self.last_similarity = 1.0 | |
| self.active_speakers = set([0]) | |
| self.segment_counter = 0 | |
| def set_max_speakers(self, max_speakers): | |
| """Update the maximum number of speakers""" | |
| new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) | |
| if new_max < self.max_speakers: | |
| # Remove speakers beyond the new limit | |
| for speaker_id in list(self.active_speakers): | |
| if speaker_id >= new_max: | |
| self.active_speakers.discard(speaker_id) | |
| if self.current_speaker >= new_max: | |
| self.current_speaker = 0 | |
| # Resize arrays | |
| if new_max > self.max_speakers: | |
| self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)]) | |
| self.speaker_centroids.extend([None] * (new_max - self.max_speakers)) | |
| else: | |
| self.speaker_embeddings = self.speaker_embeddings[:new_max] | |
| self.speaker_centroids = self.speaker_centroids[:new_max] | |
| self.max_speakers = new_max | |
| def set_change_threshold(self, threshold): | |
| """Update the threshold for detecting speaker changes""" | |
| self.change_threshold = max(0.1, min(threshold, 0.95)) | |
| def add_embedding(self, embedding, timestamp=None): | |
| """Add a new embedding and detect speaker changes""" | |
| current_time = timestamp or time.time() | |
| self.segment_counter += 1 | |
| # Initialize first speaker | |
| if not self.speaker_embeddings[0]: | |
| self.speaker_embeddings[0].append(embedding) | |
| self.speaker_centroids[0] = embedding.copy() | |
| self.active_speakers.add(0) | |
| return 0, 1.0 | |
| # Calculate similarity with current speaker | |
| current_centroid = self.speaker_centroids[self.current_speaker] | |
| if current_centroid is not None: | |
| similarity = 1.0 - cosine(embedding, current_centroid) | |
| else: | |
| similarity = 0.5 | |
| self.last_similarity = similarity | |
| # Check for speaker change | |
| time_since_last_change = current_time - self.last_change_time | |
| speaker_changed = False | |
| if time_since_last_change >= MIN_SEGMENT_DURATION and similarity < self.change_threshold: | |
| # Find best matching speaker | |
| best_speaker = self.current_speaker | |
| best_similarity = similarity | |
| for speaker_id in self.active_speakers: | |
| if speaker_id == self.current_speaker: | |
| continue | |
| centroid = self.speaker_centroids[speaker_id] | |
| if centroid is not None: | |
| speaker_similarity = 1.0 - cosine(embedding, centroid) | |
| if speaker_similarity > best_similarity and speaker_similarity > self.change_threshold: | |
| best_similarity = speaker_similarity | |
| best_speaker = speaker_id | |
| # If no good match found and we can add a new speaker | |
| if best_speaker == self.current_speaker and len(self.active_speakers) < self.max_speakers: | |
| for new_id in range(self.max_speakers): | |
| if new_id not in self.active_speakers: | |
| best_speaker = new_id | |
| self.active_speakers.add(new_id) | |
| break | |
| if best_speaker != self.current_speaker: | |
| self.current_speaker = best_speaker | |
| self.last_change_time = current_time | |
| speaker_changed = True | |
| # Update speaker embeddings and centroids | |
| self.speaker_embeddings[self.current_speaker].append(embedding) | |
| # Keep only recent embeddings (sliding window) | |
| max_embeddings = 20 | |
| if len(self.speaker_embeddings[self.current_speaker]) > max_embeddings: | |
| self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-max_embeddings:] | |
| # Update centroid | |
| if self.speaker_embeddings[self.current_speaker]: | |
| self.speaker_centroids[self.current_speaker] = np.mean( | |
| self.speaker_embeddings[self.current_speaker], axis=0 | |
| ) | |
| return self.current_speaker, similarity | |
| def get_color_for_speaker(self, speaker_id): | |
| """Return color for speaker ID""" | |
| if 0 <= speaker_id < len(SPEAKER_COLORS): | |
| return SPEAKER_COLORS[speaker_id] | |
| return "#FFFFFF" | |
| def get_status_info(self): | |
| """Return status information""" | |
| speaker_counts = [len(self.speaker_embeddings[i]) for i in range(self.max_speakers)] | |
| return { | |
| "current_speaker": self.current_speaker, | |
| "speaker_counts": speaker_counts, | |
| "active_speakers": len(self.active_speakers), | |
| "max_speakers": self.max_speakers, | |
| "last_similarity": self.last_similarity, | |
| "threshold": self.change_threshold, | |
| "segment_counter": self.segment_counter | |
| } | |
| class RealtimeSpeakerDiarization: | |
| def __init__(self): | |
| self.encoder = None | |
| self.audio_processor = None | |
| self.speaker_detector = None | |
| self.recorder = None | |
| self.sentence_queue = queue.Queue() | |
| self.full_sentences = [] | |
| self.sentence_speakers = [] | |
| self.pending_sentences = [] | |
| self.current_conversation = "" | |
| self.is_running = False | |
| self.change_threshold = DEFAULT_CHANGE_THRESHOLD | |
| self.max_speakers = DEFAULT_MAX_SPEAKERS | |
| self.last_transcription = "" | |
| self.transcription_lock = threading.Lock() | |
| def initialize_models(self): | |
| """Initialize the speaker encoder model""" | |
| try: | |
| device_str = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {device_str}") | |
| self.encoder = SpeechBrainEncoder(device=device_str) | |
| success = self.encoder.load_model() | |
| if success: | |
| self.audio_processor = AudioProcessor(self.encoder) | |
| self.speaker_detector = SpeakerChangeDetector( | |
| embedding_dim=self.encoder.embedding_dim, | |
| change_threshold=self.change_threshold, | |
| max_speakers=self.max_speakers | |
| ) | |
| logger.info("Models initialized successfully!") | |
| return True | |
| else: | |
| logger.error("Failed to load models") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Model initialization error: {e}") | |
| return False | |
| def live_text_detected(self, text): | |
| """Callback for real-time transcription updates""" | |
| with self.transcription_lock: | |
| self.last_transcription = text.strip() | |
| def process_final_text(self, text): | |
| """Process final transcribed text with speaker embedding""" | |
| text = text.strip() | |
| if text: | |
| try: | |
| # Get audio data for this transcription | |
| audio_bytes = getattr(self.recorder, 'last_transcription_bytes', None) | |
| if audio_bytes: | |
| self.sentence_queue.put((text, audio_bytes)) | |
| else: | |
| # If no audio bytes, use current speaker | |
| self.sentence_queue.put((text, None)) | |
| except Exception as e: | |
| logger.error(f"Error processing final text: {e}") | |
| def process_sentence_queue(self): | |
| """Process sentences in the queue for speaker detection""" | |
| while self.is_running: | |
| try: | |
| text, audio_bytes = self.sentence_queue.get(timeout=1) | |
| current_speaker = self.speaker_detector.current_speaker | |
| if audio_bytes: | |
| # Convert audio data and extract embedding | |
| audio_int16 = np.frombuffer(audio_bytes, dtype=np.int16) | |
| audio_float = audio_int16.astype(np.float32) / 32768.0 | |
| # Extract embedding | |
| embedding = self.audio_processor.encoder.embed_utterance(audio_float) | |
| if embedding is not None: | |
| current_speaker, similarity = self.speaker_detector.add_embedding(embedding) | |
| # Store sentence with speaker | |
| with self.transcription_lock: | |
| self.full_sentences.append((text, current_speaker)) | |
| self.update_conversation_display() | |
| except queue.Empty: | |
| continue | |
| except Exception as e: | |
| logger.error(f"Error processing sentence: {e}") | |
| def update_conversation_display(self): | |
| """Update the conversation display""" | |
| try: | |
| sentences_with_style = [] | |
| for sentence_text, speaker_id in self.full_sentences: | |
| color = self.speaker_detector.get_color_for_speaker(speaker_id) | |
| speaker_name = f"Speaker {speaker_id + 1}" | |
| sentences_with_style.append( | |
| f'<span style="color:{color}; font-weight: bold;">{speaker_name}:</span> ' | |
| f'<span style="color:#333333;">{sentence_text}</span>' | |
| ) | |
| # Add current transcription if available | |
| if self.last_transcription: | |
| current_color = self.speaker_detector.get_color_for_speaker(self.speaker_detector.current_speaker) | |
| current_speaker = f"Speaker {self.speaker_detector.current_speaker + 1}" | |
| sentences_with_style.append( | |
| f'<span style="color:{current_color}; font-weight: bold; opacity: 0.7;">{current_speaker}:</span> ' | |
| f'<span style="color:#666666; font-style: italic;">{self.last_transcription}...</span>' | |
| ) | |
| if sentences_with_style: | |
| self.current_conversation = "<br><br>".join(sentences_with_style) | |
| else: | |
| self.current_conversation = "<i>Waiting for speech input...</i>" | |
| except Exception as e: | |
| logger.error(f"Error updating conversation display: {e}") | |
| self.current_conversation = f"<i>Error: {str(e)}</i>" | |
| def start_recording(self): | |
| """Start the recording and transcription process""" | |
| if self.encoder is None: | |
| return "Please initialize models first!" | |
| try: | |
| # Setup recorder configuration | |
| recorder_config = { | |
| 'spinner': False, | |
| 'use_microphone': False, # Change to False for Hugging Face Spaces | |
| 'model': FINAL_TRANSCRIPTION_MODEL, | |
| 'language': TRANSCRIPTION_LANGUAGE, | |
| 'silero_sensitivity': SILERO_SENSITIVITY, | |
| 'webrtc_sensitivity': WEBRTC_SENSITIVITY, | |
| 'post_speech_silence_duration': SILENCE_THRESHS[1], | |
| 'min_length_of_recording': MIN_LENGTH_OF_RECORDING, | |
| 'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION, | |
| 'min_gap_between_recordings': 0, | |
| 'enable_realtime_transcription': True, | |
| 'realtime_processing_pause': 0.1, | |
| 'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL, | |
| 'on_realtime_transcription_update': self.live_text_detected, | |
| 'beam_size': FINAL_BEAM_SIZE, | |
| 'beam_size_realtime': REALTIME_BEAM_SIZE, | |
| 'sample_rate': SAMPLE_RATE, | |
| } | |
| self.recorder = AudioToTextRecorder(**recorder_config) | |
| # Start processing threads | |
| self.is_running = True | |
| self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True) | |
| self.sentence_thread.start() | |
| self.transcription_thread = threading.Thread(target=self.run_transcription, daemon=True) | |
| self.transcription_thread.start() | |
| return "Recording started successfully!" | |
| except Exception as e: | |
| logger.error(f"Error starting recording: {e}") | |
| return f"Error starting recording: {e}" | |
| def run_transcription(self): | |
| """Run the transcription loop""" | |
| try: | |
| while self.is_running: | |
| self.recorder.text(self.process_final_text) | |
| except Exception as e: | |
| logger.error(f"Transcription error: {e}") | |
| def stop_recording(self): | |
| """Stop the recording process""" | |
| self.is_running = False | |
| if self.recorder: | |
| self.recorder.stop() | |
| return "Recording stopped!" | |
| def clear_conversation(self): | |
| """Clear all conversation data""" | |
| with self.transcription_lock: | |
| self.full_sentences = [] | |
| self.last_transcription = "" | |
| self.current_conversation = "Conversation cleared!" | |
| if self.speaker_detector: | |
| self.speaker_detector = SpeakerChangeDetector( | |
| embedding_dim=self.encoder.embedding_dim, | |
| change_threshold=self.change_threshold, | |
| max_speakers=self.max_speakers | |
| ) | |
| return "Conversation cleared!" | |
| def update_settings(self, threshold, max_speakers): | |
| """Update speaker detection settings""" | |
| self.change_threshold = threshold | |
| self.max_speakers = max_speakers | |
| if self.speaker_detector: | |
| self.speaker_detector.set_change_threshold(threshold) | |
| self.speaker_detector.set_max_speakers(max_speakers) | |
| return f"Settings updated: Threshold={threshold:.2f}, Max Speakers={max_speakers}" | |
| def get_formatted_conversation(self): | |
| """Get the formatted conversation""" | |
| return self.current_conversation | |
| def get_status_info(self): | |
| """Get current status information""" | |
| if not self.speaker_detector: | |
| return "Speaker detector not initialized" | |
| try: | |
| status = self.speaker_detector.get_status_info() | |
| status_lines = [ | |
| f"**Current Speaker:** {status['current_speaker'] + 1}", | |
| f"**Active Speakers:** {status['active_speakers']} of {status['max_speakers']}", | |
| f"**Last Similarity:** {status['last_similarity']:.3f}", | |
| f"**Change Threshold:** {status['threshold']:.2f}", | |
| f"**Total Sentences:** {len(self.full_sentences)}", | |
| f"**Segments Processed:** {status['segment_counter']}", | |
| "", | |
| "**Speaker Activity:**" | |
| ] | |
| for i in range(status['max_speakers']): | |
| color_name = SPEAKER_COLOR_NAMES[i] if i < len(SPEAKER_COLOR_NAMES) else f"Speaker {i+1}" | |
| count = status['speaker_counts'][i] | |
| active = "🟢" if count > 0 else "⚫" | |
| status_lines.append(f"{active} Speaker {i+1} ({color_name}): {count} segments") | |
| return "\n".join(status_lines) | |
| except Exception as e: | |
| return f"Error getting status: {e}" | |
| def process_audio_chunk(self, audio_data, sample_rate=16000): | |
| """Process audio chunk from FastRTC input""" | |
| if not self.is_running or self.audio_processor is None: | |
| return | |
| try: | |
| # Ensure audio is float32 | |
| if isinstance(audio_data, np.ndarray): | |
| if audio_data.dtype != np.float32: | |
| audio_data = audio_data.astype(np.float32) | |
| else: | |
| audio_data = np.array(audio_data, dtype=np.float32) | |
| # Ensure mono | |
| if len(audio_data.shape) > 1: | |
| audio_data = np.mean(audio_data, axis=1) if audio_data.shape[1] > 1 else audio_data.flatten() | |
| # Normalize if needed | |
| if np.max(np.abs(audio_data)) > 1.0: | |
| audio_data = audio_data / np.max(np.abs(audio_data)) | |
| # Add to audio processor buffer for speaker detection | |
| self.audio_processor.add_audio_chunk(audio_data) | |
| # Periodically extract embeddings for speaker detection | |
| if len(self.audio_processor.audio_buffer) % (SAMPLE_RATE // 2) == 0: # Every 0.5 seconds | |
| embedding = self.audio_processor.extract_embedding_from_buffer() | |
| if embedding is not None: | |
| self.speaker_detector.add_embedding(embedding) | |
| except Exception as e: | |
| logger.error(f"Error processing audio chunk: {e}") | |
| # FastRTC Audio Handler | |
| class DiarizationHandler(AsyncStreamHandler): | |
| def __init__(self, diarization_system): | |
| super().__init__() | |
| self.diarization_system = diarization_system | |
| self.audio_buffer = [] | |
| self.buffer_size = BUFFER_SIZE | |
| def copy(self): | |
| """Return a fresh handler for each new stream connection""" | |
| return DiarizationHandler(self.diarization_system) | |
| async def emit(self): | |
| """Not used - we only receive audio""" | |
| return None | |
| async def receive(self, frame): | |
| """Receive audio data from FastRTC""" | |
| try: | |
| if not self.diarization_system.is_running: | |
| return | |
| # Extract audio data | |
| audio_data = getattr(frame, 'data', frame) | |
| # Convert to numpy array | |
| if isinstance(audio_data, bytes): | |
| audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0 | |
| elif isinstance(audio_data, (list, tuple)): | |
| audio_array = np.array(audio_data, dtype=np.float32) | |
| else: | |
| audio_array = np.array(audio_data, dtype=np.float32) | |
| # Ensure 1D | |
| if len(audio_array.shape) > 1: | |
| audio_array = audio_array.flatten() | |
| # Buffer audio chunks | |
| self.audio_buffer.extend(audio_array) | |
| # Process in chunks | |
| while len(self.audio_buffer) >= self.buffer_size: | |
| chunk = np.array(self.audio_buffer[:self.buffer_size]) | |
| self.audio_buffer = self.audio_buffer[self.buffer_size:] | |
| # Process asynchronously | |
| await self.process_audio_async(chunk) | |
| except Exception as e: | |
| logger.error(f"Error in FastRTC receive: {e}") | |
| async def process_audio_async(self, audio_data): | |
| """Process audio data asynchronously""" | |
| try: | |
| loop = asyncio.get_event_loop() | |
| await loop.run_in_executor( | |
| None, | |
| self.diarization_system.process_audio_chunk, | |
| audio_data, | |
| SAMPLE_RATE | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in async audio processing: {e}") | |
| # Global instances | |
| diarization_system = RealtimeSpeakerDiarization() | |
| audio_handler = None | |
| def initialize_system(): | |
| """Initialize the diarization system""" | |
| global audio_handler | |
| try: | |
| success = diarization_system.initialize_models() | |
| if success: | |
| audio_handler = DiarizationHandler(diarization_system) | |
| return "✅ System initialized successfully!" | |
| else: | |
| return "❌ Failed to initialize system. Check logs for details." | |
| except Exception as e: | |
| logger.error(f"Initialization error: {e}") | |
| return f"❌ Initialization error: {str(e)}" | |
| def start_recording(): | |
| """Start recording and transcription""" | |
| try: | |
| result = diarization_system.start_recording() | |
| # Connect WebRTC to server stream | |
| audio_webrtc.stream_url = "/stream" # This is your FastRTC endpoint | |
| return result | |
| except Exception as e: | |
| return f"❌ Failed to start recording: {str(e)}" | |
| def stop_recording(): | |
| """Stop recording and transcription""" | |
| try: | |
| result = diarization_system.stop_recording() | |
| return f"⏹️ {result}" | |
| except Exception as e: | |
| return f"❌ Failed to stop recording: {str(e)}" | |
| def clear_conversation(): | |
| """Clear the conversation""" | |
| try: | |
| result = diarization_system.clear_conversation() | |
| return f"🗑️ {result}" | |
| except Exception as e: | |
| return f"❌ Failed to clear conversation: {str(e)}" | |
| def update_settings(threshold, max_speakers): | |
| """Update system settings""" | |
| try: | |
| result = diarization_system.update_settings(threshold, max_speakers) | |
| return f"⚙️ {result}" | |
| except Exception as e: | |
| return f"❌ Failed to update settings: {str(e)}" | |
| def get_conversation(): | |
| """Get the current conversation""" | |
| try: | |
| return diarization_system.get_formatted_conversation() | |
| except Exception as e: | |
| return f"<i>Error getting conversation: {str(e)}</i>" | |
| def get_status(): | |
| """Get system status""" | |
| try: | |
| return diarization_system.get_status_info() | |
| except Exception as e: | |
| return f"Error getting status: {str(e)}" | |
| # Create Gradio interface | |
| def create_interface(): | |
| with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as interface: | |
| gr.Markdown("# 🎤 Real-time Speech Recognition with Speaker Diarization") | |
| gr.Markdown("Live transcription with automatic speaker identification using FastRTC audio streaming.") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Add WebRTC component for audio streaming | |
| audio_webrtc = WebRTC( | |
| label="Audio Input", | |
| sources=["microphone"], | |
| streaming=True, | |
| rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]} | |
| ) | |
| # Conversation display | |
| conversation_output = gr.HTML( | |
| value="<div style='padding: 20px; background: #f8f9fa; border-radius: 10px; min-height: 300px;'><i>Click 'Initialize System' to start...</i></div>", | |
| label="Live Conversation" | |
| ) | |
| # Control buttons | |
| with gr.Row(): | |
| init_btn = gr.Button("🔧 Initialize System", variant="secondary", size="lg") | |
| start_btn = gr.Button("🎙️ Start", variant="primary", size="lg", interactive=False) | |
| stop_btn = gr.Button("⏹️ Stop", variant="stop", size="lg", interactive=False) | |
| clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="lg", interactive=False) | |
| # Status display | |
| status_output = gr.Textbox( | |
| label="System Status", | |
| value="Ready to initialize...", | |
| lines=8, | |
| interactive=False | |
| ) | |
| with gr.Column(scale=1): | |
| # Settings | |
| gr.Markdown("## ⚙️ Settings") | |
| threshold_slider = gr.Slider( | |
| minimum=0.3, | |
| maximum=0.9, | |
| step=0.05, | |
| value=DEFAULT_CHANGE_THRESHOLD, | |
| label="Speaker Change Sensitivity", | |
| info="Lower = more sensitive" | |
| ) | |
| max_speakers_slider = gr.Slider( | |
| minimum=2, | |
| maximum=ABSOLUTE_MAX_SPEAKERS, | |
| step=1, | |
| value=DEFAULT_MAX_SPEAKERS, | |
| label="Maximum Speakers" | |
| ) | |
| update_btn = gr.Button("Update Settings", variant="secondary") | |
| # Instructions | |
| gr.Markdown(""" | |
| ## 📋 Instructions | |
| 1. **Initialize** the system (loads AI models) | |
| 2. **Start** recording | |
| 3. **Speak** - system will transcribe and identify speakers | |
| 4. **Monitor** real-time results below | |
| ## 🎨 Speaker Colors | |
| - 🔴 Speaker 1 (Red) | |
| - 🟢 Speaker 2 (Teal) | |
| - 🔵 Speaker 3 (Blue) | |
| - 🟡 Speaker 4 (Green) | |
| - 🟣 Speaker 5 (Yellow) | |
| - 🟤 Speaker 6 (Plum) | |
| - 🟫 Speaker 7 (Mint) | |
| - 🟨 Speaker 8 (Gold) | |
| """) | |
| # Event handlers | |
| def on_initialize(): | |
| result = initialize_system() | |
| if "✅" in result: | |
| return result, gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True) | |
| else: | |
| return result, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False) | |
| def on_start(): | |
| result = start_recording() | |
| # Connect WebRTC to server stream | |
| audio_webrtc.stream_url = "/stream" # This is your FastRTC endpoint | |
| return result, gr.update(interactive=False), gr.update(interactive=True) | |
| def on_stop(): | |
| result = stop_recording() | |
| return result, gr.update(interactive=True), gr.update(interactive=False) | |
| def on_clear(): | |
| result = clear_conversation() | |
| return result | |
| def on_update_settings(threshold, max_speakers): | |
| result = update_settings(threshold, int(max_speakers)) | |
| return result | |
| def refresh_conversation(): | |
| return get_conversation() | |
| def refresh_status(): | |
| return get_status() | |
| # Button click handlers | |
| init_btn.click( | |
| fn=on_initialize, | |
| outputs=[status_output, start_btn, stop_btn, clear_btn] | |
| ) | |
| start_btn.click( | |
| fn=on_start, | |
| outputs=[status_output, start_btn, stop_btn] | |
| ) | |
| stop_btn.click( | |
| fn=on_stop, | |
| outputs=[status_output, start_btn, stop_btn] | |
| ) | |
| clear_btn.click( | |
| fn=on_clear, | |
| outputs=[status_output] | |
| ) | |
| update_btn.click( | |
| fn=on_update_settings, | |
| inputs=[threshold_slider, max_speakers_slider], | |
| outputs=[status_output] | |
| ) | |
| # Auto-refresh conversation display every 1 second | |
| conversation_timer = gr.Timer(1) | |
| conversation_timer.tick(refresh_conversation, outputs=[conversation_output]) | |
| # Auto-refresh status every 2 seconds | |
| status_timer = gr.Timer(2) | |
| status_timer.tick(refresh_status, outputs=[status_output]) | |
| return interface | |
| # FastAPI setup for FastRTC integration | |
| app = FastAPI() | |
| async def root(): | |
| return {"message": "Real-time Speaker Diarization API"} | |
| async def health_check(): | |
| return {"status": "healthy", "system_running": diarization_system.is_running} | |
| async def api_initialize(): | |
| result = initialize_system() | |
| return {"result": result, "success": "✅" in result} | |
| async def api_start(): | |
| result = start_recording() | |
| return {"result": result, "success": "🎙️" in result} | |
| async def api_stop(): | |
| result = stop_recording() | |
| return {"result": result, "success": "⏹️" in result} | |
| async def api_clear(): | |
| result = clear_conversation() | |
| return {"result": result} | |
| async def api_get_conversation(): | |
| return {"conversation": get_conversation()} | |
| async def api_get_status(): | |
| return {"status": get_status()} | |
| async def api_update_settings(threshold: float, max_speakers: int): | |
| result = update_settings(threshold, max_speakers) | |
| return {"result": result} | |
| # FastRTC Stream setup | |
| if audio_handler: | |
| stream = Stream(handler=audio_handler) | |
| app.include_router(stream.router, prefix="/stream") | |
| # Main execution | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Real-time Speaker Diarization System") | |
| parser.add_argument("--mode", choices=["gradio", "api", "both"], default="gradio", | |
| help="Run mode: gradio interface, API only, or both") | |
| parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") | |
| parser.add_argument("--port", type=int, default=7860, help="Port to bind to") | |
| parser.add_argument("--api-port", type=int, default=8000, help="API port (when running both)") | |
| args = parser.parse_args() | |
| if args.mode == "gradio": | |
| # Run Gradio interface only | |
| interface = create_interface() | |
| interface.launch( | |
| server_name=args.host, | |
| server_port=args.port, | |
| share=True, | |
| show_error=True | |
| ) | |
| elif args.mode == "api": | |
| # Run FastAPI only | |
| uvicorn.run( | |
| app, | |
| host=args.host, | |
| port=args.port, | |
| log_level="info" | |
| ) | |
| elif args.mode == "both": | |
| # Run both Gradio and FastAPI | |
| import multiprocessing | |
| import threading | |
| def run_gradio(): | |
| interface = create_interface() | |
| interface.launch( | |
| server_name=args.host, | |
| server_port=args.port, | |
| share=True, | |
| show_error=True | |
| ) | |
| def run_fastapi(): | |
| uvicorn.run( | |
| app, | |
| host=args.host, | |
| port=args.api_port, | |
| log_level="info" | |
| ) | |
| # Start FastAPI in a separate thread | |
| api_thread = threading.Thread(target=run_fastapi, daemon=True) | |
| api_thread.start() | |
| # Start Gradio in main thread | |
| run_gradio() |