Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import queue | |
| import torch | |
| import time | |
| import threading | |
| import os | |
| import urllib.request | |
| import torchaudio | |
| from scipy.spatial.distance import cosine | |
| from RealtimeSTT import AudioToTextRecorder | |
| from fastapi import FastAPI, APIRouter | |
| from fastrtc import Stream, AsyncStreamHandler, ReplyOnPause, get_cloudflare_turn_credentials_async, get_cloudflare_turn_credentials | |
| import json | |
| import io | |
| import wave | |
| import asyncio | |
| import uvicorn | |
| import socket | |
| # Simplified configuration parameters | |
| SILENCE_THRESHS = [0, 0.4] | |
| FINAL_TRANSCRIPTION_MODEL = "distil-large-v3" | |
| FINAL_BEAM_SIZE = 5 | |
| REALTIME_TRANSCRIPTION_MODEL = "distil-small.en" | |
| REALTIME_BEAM_SIZE = 5 | |
| TRANSCRIPTION_LANGUAGE = "en" | |
| SILERO_SENSITIVITY = 0.4 | |
| WEBRTC_SENSITIVITY = 3 | |
| MIN_LENGTH_OF_RECORDING = 0.7 | |
| PRE_RECORDING_BUFFER_DURATION = 0.35 | |
| # Speaker change detection parameters | |
| DEFAULT_CHANGE_THRESHOLD = 0.7 | |
| EMBEDDING_HISTORY_SIZE = 5 | |
| MIN_SEGMENT_DURATION = 1.0 | |
| DEFAULT_MAX_SPEAKERS = 4 | |
| ABSOLUTE_MAX_SPEAKERS = 10 | |
| # Global variables | |
| FAST_SENTENCE_END = True | |
| SAMPLE_RATE = 16000 | |
| BUFFER_SIZE = 512 | |
| CHANNELS = 1 | |
| # Speaker colors | |
| SPEAKER_COLORS = [ | |
| "#FFFF00", # Yellow | |
| "#FF0000", # Red | |
| "#00FF00", # Green | |
| "#00FFFF", # Cyan | |
| "#FF00FF", # Magenta | |
| "#0000FF", # Blue | |
| "#FF8000", # Orange | |
| "#00FF80", # Spring Green | |
| "#8000FF", # Purple | |
| "#FFFFFF", # White | |
| ] | |
| SPEAKER_COLOR_NAMES = [ | |
| "Yellow", "Red", "Green", "Cyan", "Magenta", | |
| "Blue", "Orange", "Spring Green", "Purple", "White" | |
| ] | |
| class SpeechBrainEncoder: | |
| """ECAPA-TDNN encoder from SpeechBrain for speaker embeddings""" | |
| def __init__(self, device="cpu"): | |
| self.device = device | |
| self.model = None | |
| self.embedding_dim = 192 | |
| self.model_loaded = False | |
| self.cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "speechbrain") | |
| os.makedirs(self.cache_dir, exist_ok=True) | |
| def _download_model(self): | |
| """Download pre-trained SpeechBrain ECAPA-TDNN model if not present""" | |
| model_url = "https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb/resolve/main/embedding_model.ckpt" | |
| model_path = os.path.join(self.cache_dir, "embedding_model.ckpt") | |
| if not os.path.exists(model_path): | |
| print(f"Downloading ECAPA-TDNN model to {model_path}...") | |
| urllib.request.urlretrieve(model_url, model_path) | |
| return model_path | |
| def load_model(self): | |
| """Load the ECAPA-TDNN model""" | |
| try: | |
| from speechbrain.pretrained import EncoderClassifier | |
| model_path = self._download_model() | |
| self.model = EncoderClassifier.from_hparams( | |
| source="speechbrain/spkrec-ecapa-voxceleb", | |
| savedir=self.cache_dir, | |
| run_opts={"device": self.device} | |
| ) | |
| self.model_loaded = True | |
| return True | |
| except Exception as e: | |
| print(f"Error loading ECAPA-TDNN model: {e}") | |
| return False | |
| def embed_utterance(self, audio, sr=16000): | |
| """Extract speaker embedding from audio""" | |
| if not self.model_loaded: | |
| raise ValueError("Model not loaded. Call load_model() first.") | |
| try: | |
| if isinstance(audio, np.ndarray): | |
| waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0) | |
| else: | |
| waveform = audio.unsqueeze(0) | |
| if sr != 16000: | |
| waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) | |
| with torch.no_grad(): | |
| embedding = self.model.encode_batch(waveform) | |
| return embedding.squeeze().cpu().numpy() | |
| except Exception as e: | |
| print(f"Error extracting embedding: {e}") | |
| return np.zeros(self.embedding_dim) | |
| class AudioProcessor: | |
| """Processes audio data to extract speaker embeddings""" | |
| def __init__(self, encoder): | |
| self.encoder = encoder | |
| def extract_embedding(self, audio_int16): | |
| try: | |
| float_audio = audio_int16.astype(np.float32) / 32768.0 | |
| if np.abs(float_audio).max() > 1.0: | |
| float_audio = float_audio / np.abs(float_audio).max() | |
| embedding = self.encoder.embed_utterance(float_audio) | |
| return embedding | |
| except Exception as e: | |
| print(f"Embedding extraction error: {e}") | |
| return np.zeros(self.encoder.embedding_dim) | |
| class SpeakerChangeDetector: | |
| """Speaker change detector that supports a configurable number of speakers""" | |
| def __init__(self, embedding_dim=192, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS): | |
| self.embedding_dim = embedding_dim | |
| self.change_threshold = change_threshold | |
| self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) | |
| self.current_speaker = 0 | |
| self.previous_embeddings = [] | |
| self.last_change_time = time.time() | |
| self.mean_embeddings = [None] * self.max_speakers | |
| self.speaker_embeddings = [[] for _ in range(self.max_speakers)] | |
| self.last_similarity = 0.0 | |
| self.active_speakers = set([0]) | |
| def set_max_speakers(self, max_speakers): | |
| """Update the maximum number of speakers""" | |
| new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) | |
| if new_max < self.max_speakers: | |
| for speaker_id in list(self.active_speakers): | |
| if speaker_id >= new_max: | |
| self.active_speakers.discard(speaker_id) | |
| if self.current_speaker >= new_max: | |
| self.current_speaker = 0 | |
| if new_max > self.max_speakers: | |
| self.mean_embeddings.extend([None] * (new_max - self.max_speakers)) | |
| self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)]) | |
| else: | |
| self.mean_embeddings = self.mean_embeddings[:new_max] | |
| self.speaker_embeddings = self.speaker_embeddings[:new_max] | |
| self.max_speakers = new_max | |
| def set_change_threshold(self, threshold): | |
| """Update the threshold for detecting speaker changes""" | |
| self.change_threshold = max(0.1, min(threshold, 0.99)) | |
| def add_embedding(self, embedding, timestamp=None): | |
| """Add a new embedding and check if there's a speaker change""" | |
| current_time = timestamp or time.time() | |
| if not self.previous_embeddings: | |
| self.previous_embeddings.append(embedding) | |
| self.speaker_embeddings[self.current_speaker].append(embedding) | |
| if self.mean_embeddings[self.current_speaker] is None: | |
| self.mean_embeddings[self.current_speaker] = embedding.copy() | |
| return self.current_speaker, 1.0 | |
| current_mean = self.mean_embeddings[self.current_speaker] | |
| if current_mean is not None: | |
| similarity = 1.0 - cosine(embedding, current_mean) | |
| else: | |
| similarity = 1.0 - cosine(embedding, self.previous_embeddings[-1]) | |
| self.last_similarity = similarity | |
| time_since_last_change = current_time - self.last_change_time | |
| is_speaker_change = False | |
| if time_since_last_change >= MIN_SEGMENT_DURATION: | |
| if similarity < self.change_threshold: | |
| best_speaker = self.current_speaker | |
| best_similarity = similarity | |
| for speaker_id in range(self.max_speakers): | |
| if speaker_id == self.current_speaker: | |
| continue | |
| speaker_mean = self.mean_embeddings[speaker_id] | |
| if speaker_mean is not None: | |
| speaker_similarity = 1.0 - cosine(embedding, speaker_mean) | |
| if speaker_similarity > best_similarity: | |
| best_similarity = speaker_similarity | |
| best_speaker = speaker_id | |
| if best_speaker != self.current_speaker: | |
| is_speaker_change = True | |
| self.current_speaker = best_speaker | |
| elif len(self.active_speakers) < self.max_speakers: | |
| for new_id in range(self.max_speakers): | |
| if new_id not in self.active_speakers: | |
| is_speaker_change = True | |
| self.current_speaker = new_id | |
| self.active_speakers.add(new_id) | |
| break | |
| if is_speaker_change: | |
| self.last_change_time = current_time | |
| self.previous_embeddings.append(embedding) | |
| if len(self.previous_embeddings) > EMBEDDING_HISTORY_SIZE: | |
| self.previous_embeddings.pop(0) | |
| self.speaker_embeddings[self.current_speaker].append(embedding) | |
| self.active_speakers.add(self.current_speaker) | |
| if len(self.speaker_embeddings[self.current_speaker]) > 30: | |
| self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-30:] | |
| if self.speaker_embeddings[self.current_speaker]: | |
| self.mean_embeddings[self.current_speaker] = np.mean( | |
| self.speaker_embeddings[self.current_speaker], axis=0 | |
| ) | |
| return self.current_speaker, similarity | |
| def get_color_for_speaker(self, speaker_id): | |
| """Return color for speaker ID""" | |
| if 0 <= speaker_id < len(SPEAKER_COLORS): | |
| return SPEAKER_COLORS[speaker_id] | |
| return "#FFFFFF" | |
| def get_status_info(self): | |
| """Return status information about the speaker change detector""" | |
| speaker_counts = [len(self.speaker_embeddings[i]) for i in range(self.max_speakers)] | |
| return { | |
| "current_speaker": self.current_speaker, | |
| "speaker_counts": speaker_counts, | |
| "active_speakers": len(self.active_speakers), | |
| "max_speakers": self.max_speakers, | |
| "last_similarity": self.last_similarity, | |
| "threshold": self.change_threshold | |
| } | |
| class RealtimeSpeakerDiarization: | |
| def __init__(self): | |
| self.encoder = None | |
| self.audio_processor = None | |
| self.speaker_detector = None | |
| self.recorder = None | |
| self.sentence_queue = queue.Queue() | |
| self.full_sentences = [] | |
| self.sentence_speakers = [] | |
| self.pending_sentences = [] | |
| self.displayed_text = "" | |
| self.last_realtime_text = "" | |
| self.is_running = False | |
| self.change_threshold = DEFAULT_CHANGE_THRESHOLD | |
| self.max_speakers = DEFAULT_MAX_SPEAKERS | |
| self.current_conversation = "" | |
| self.audio_buffer = [] | |
| def initialize_models(self): | |
| """Initialize the speaker encoder model""" | |
| try: | |
| device_str = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device_str}") | |
| self.encoder = SpeechBrainEncoder(device=device_str) | |
| success = self.encoder.load_model() | |
| if success: | |
| self.audio_processor = AudioProcessor(self.encoder) | |
| self.speaker_detector = SpeakerChangeDetector( | |
| embedding_dim=self.encoder.embedding_dim, | |
| change_threshold=self.change_threshold, | |
| max_speakers=self.max_speakers | |
| ) | |
| print("ECAPA-TDNN model loaded successfully!") | |
| return True | |
| else: | |
| print("Failed to load ECAPA-TDNN model") | |
| return False | |
| except Exception as e: | |
| print(f"Model initialization error: {e}") | |
| return False | |
| def live_text_detected(self, text): | |
| """Callback for real-time transcription updates""" | |
| text = text.strip() | |
| if text: | |
| sentence_delimiters = '.?!。' | |
| prob_sentence_end = ( | |
| len(self.last_realtime_text) > 0 | |
| and text[-1] in sentence_delimiters | |
| and self.last_realtime_text[-1] in sentence_delimiters | |
| ) | |
| self.last_realtime_text = text | |
| if prob_sentence_end and FAST_SENTENCE_END: | |
| self.recorder.stop() | |
| elif prob_sentence_end: | |
| self.recorder.post_speech_silence_duration = SILENCE_THRESHS[0] | |
| else: | |
| self.recorder.post_speech_silence_duration = SILENCE_THRESHS[1] | |
| def process_final_text(self, text): | |
| """Process final transcribed text with speaker embedding""" | |
| text = text.strip() | |
| if text: | |
| try: | |
| bytes_data = self.recorder.last_transcription_bytes | |
| self.sentence_queue.put((text, bytes_data)) | |
| self.pending_sentences.append(text) | |
| except Exception as e: | |
| print(f"Error processing final text: {e}") | |
| def process_sentence_queue(self): | |
| """Process sentences in the queue for speaker detection""" | |
| while self.is_running: | |
| try: | |
| text, bytes_data = self.sentence_queue.get(timeout=1) | |
| # Convert audio data to int16 | |
| audio_int16 = np.frombuffer(bytes_data, dtype=np.int16) | |
| # Extract speaker embedding | |
| speaker_embedding = self.audio_processor.extract_embedding(audio_int16) | |
| # Store sentence and embedding | |
| self.full_sentences.append((text, speaker_embedding)) | |
| # Fill in missing speaker assignments | |
| while len(self.sentence_speakers) < len(self.full_sentences) - 1: | |
| self.sentence_speakers.append(0) | |
| # Detect speaker changes | |
| speaker_id, similarity = self.speaker_detector.add_embedding(speaker_embedding) | |
| self.sentence_speakers.append(speaker_id) | |
| # Remove from pending | |
| if text in self.pending_sentences: | |
| self.pending_sentences.remove(text) | |
| # Update conversation display | |
| self.current_conversation = self.get_formatted_conversation() | |
| except queue.Empty: | |
| continue | |
| except Exception as e: | |
| print(f"Error processing sentence: {e}") | |
| def start_recording(self): | |
| """Start the recording and transcription process""" | |
| if self.encoder is None: | |
| return "Please initialize models first!" | |
| try: | |
| # Setup recorder configuration for manual audio input | |
| recorder_config = { | |
| 'spinner': False, | |
| 'use_microphone': False, # We'll feed audio manually | |
| 'model': FINAL_TRANSCRIPTION_MODEL, | |
| 'language': TRANSCRIPTION_LANGUAGE, | |
| 'silero_sensitivity': SILERO_SENSITIVITY, | |
| 'webrtc_sensitivity': WEBRTC_SENSITIVITY, | |
| 'post_speech_silence_duration': SILENCE_THRESHS[1], | |
| 'min_length_of_recording': MIN_LENGTH_OF_RECORDING, | |
| 'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION, | |
| 'min_gap_between_recordings': 0, | |
| 'enable_realtime_transcription': True, | |
| 'realtime_processing_pause': 0, | |
| 'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL, | |
| 'on_realtime_transcription_update': self.live_text_detected, | |
| 'beam_size': FINAL_BEAM_SIZE, | |
| 'beam_size_realtime': REALTIME_BEAM_SIZE, | |
| 'buffer_size': BUFFER_SIZE, | |
| 'sample_rate': SAMPLE_RATE, | |
| } | |
| self.recorder = AudioToTextRecorder(**recorder_config) | |
| # Start sentence processing thread | |
| self.is_running = True | |
| self.sentence_thread = threading.Thread(target=self.process_sentence_queue, daemon=True) | |
| self.sentence_thread.start() | |
| # Start transcription thread | |
| self.transcription_thread = threading.Thread(target=self.run_transcription, daemon=True) | |
| self.transcription_thread.start() | |
| return "Recording started successfully! FastRTC audio input ready." | |
| except Exception as e: | |
| return f"Error starting recording: {e}" | |
| def run_transcription(self): | |
| """Run the transcription loop""" | |
| try: | |
| while self.is_running: | |
| self.recorder.text(self.process_final_text) | |
| except Exception as e: | |
| print(f"Transcription error: {e}") | |
| def stop_recording(self): | |
| """Stop the recording process""" | |
| self.is_running = False | |
| if self.recorder: | |
| self.recorder.stop() | |
| return "Recording stopped!" | |
| def clear_conversation(self): | |
| """Clear all conversation data""" | |
| self.full_sentences = [] | |
| self.sentence_speakers = [] | |
| self.pending_sentences = [] | |
| self.displayed_text = "" | |
| self.last_realtime_text = "" | |
| self.current_conversation = "Conversation cleared!" | |
| if self.speaker_detector: | |
| self.speaker_detector = SpeakerChangeDetector( | |
| embedding_dim=self.encoder.embedding_dim, | |
| change_threshold=self.change_threshold, | |
| max_speakers=self.max_speakers | |
| ) | |
| return "Conversation cleared!" | |
| def update_settings(self, threshold, max_speakers): | |
| """Update speaker detection settings""" | |
| self.change_threshold = threshold | |
| self.max_speakers = max_speakers | |
| if self.speaker_detector: | |
| self.speaker_detector.set_change_threshold(threshold) | |
| self.speaker_detector.set_max_speakers(max_speakers) | |
| return f"Settings updated: Threshold={threshold:.2f}, Max Speakers={max_speakers}" | |
| def get_formatted_conversation(self): | |
| """Get the formatted conversation with speaker colors""" | |
| try: | |
| sentences_with_style = [] | |
| # Process completed sentences | |
| for i, sentence in enumerate(self.full_sentences): | |
| sentence_text, _ = sentence | |
| if i >= len(self.sentence_speakers): | |
| color = "#FFFFFF" | |
| speaker_name = "Unknown" | |
| else: | |
| speaker_id = self.sentence_speakers[i] | |
| color = self.speaker_detector.get_color_for_speaker(speaker_id) | |
| speaker_name = f"Speaker {speaker_id + 1}" | |
| sentences_with_style.append( | |
| f'<span style="color:{color};"><b>{speaker_name}:</b> {sentence_text}</span>') | |
| # Add pending sentences | |
| for pending_sentence in self.pending_sentences: | |
| sentences_with_style.append( | |
| f'<span style="color:#60FFFF;"><b>Processing:</b> {pending_sentence}</span>') | |
| if sentences_with_style: | |
| return "<br><br>".join(sentences_with_style) | |
| else: | |
| return "Waiting for speech input..." | |
| except Exception as e: | |
| return f"Error formatting conversation: {e}" | |
| def get_status_info(self): | |
| """Get current status information""" | |
| if not self.speaker_detector: | |
| return "Speaker detector not initialized" | |
| try: | |
| status = self.speaker_detector.get_status_info() | |
| status_lines = [ | |
| f"**Current Speaker:** {status['current_speaker'] + 1}", | |
| f"**Active Speakers:** {status['active_speakers']} of {status['max_speakers']}", | |
| f"**Last Similarity:** {status['last_similarity']:.3f}", | |
| f"**Change Threshold:** {status['threshold']:.2f}", | |
| f"**Total Sentences:** {len(self.full_sentences)}", | |
| "", | |
| "**Speaker Segment Counts:**" | |
| ] | |
| for i in range(status['max_speakers']): | |
| color_name = SPEAKER_COLOR_NAMES[i] if i < len(SPEAKER_COLOR_NAMES) else f"Speaker {i+1}" | |
| status_lines.append(f"Speaker {i+1} ({color_name}): {status['speaker_counts'][i]}") | |
| return "\n".join(status_lines) | |
| except Exception as e: | |
| return f"Error getting status: {e}" | |
| def feed_audio_data(self, audio_data): | |
| """Feed audio data to the recorder""" | |
| if not self.is_running or not self.recorder: | |
| return | |
| try: | |
| # Ensure audio is in the correct format (16-bit PCM) | |
| if isinstance(audio_data, np.ndarray): | |
| if audio_data.dtype != np.int16: | |
| # Convert float to int16 | |
| if audio_data.dtype == np.float32 or audio_data.dtype == np.float64: | |
| audio_data = (audio_data * 32767).astype(np.int16) | |
| else: | |
| audio_data = audio_data.astype(np.int16) | |
| # Convert to bytes | |
| audio_bytes = audio_data.tobytes() | |
| else: | |
| audio_bytes = audio_data | |
| # Feed to recorder | |
| self.recorder.feed_audio(audio_bytes) | |
| except Exception as e: | |
| print(f"Error feeding audio data: {e}") | |
| # FastRTC Audio Handler | |
| class DiarizationHandler(AsyncStreamHandler): | |
| def __init__(self, diarization_system): | |
| super().__init__() | |
| self.diarization_system = diarization_system | |
| def copy(self): | |
| # Return a fresh handler for each new stream connection | |
| return DiarizationHandler(self.diarization_system) | |
| async def emit(self): | |
| """Not used in this implementation""" | |
| return None | |
| async def receive(self, frame): | |
| """Receive audio data from FastRTC and process it""" | |
| try: | |
| if self.diarization_system.is_running: | |
| # Frame should be a numpy array of audio data | |
| if hasattr(frame, 'data'): | |
| audio_data = frame.data | |
| else: | |
| audio_data = frame | |
| # Feed audio data to the diarization system | |
| self.diarization_system.feed_audio_data(audio_data) | |
| except Exception as e: | |
| print(f"Error in FastRTC handler: {e}") | |
| # Global instance | |
| diarization_system = RealtimeSpeakerDiarization() | |
| def initialize_system(): | |
| """Initialize the diarization system""" | |
| success = diarization_system.initialize_models() | |
| if success: | |
| return "✅ System initialized successfully! Models loaded." | |
| else: | |
| return "❌ Failed to initialize system. Please check the logs." | |
| def start_recording(): | |
| """Start recording and transcription""" | |
| return diarization_system.start_recording() | |
| def stop_recording(): | |
| """Stop recording and transcription""" | |
| return diarization_system.stop_recording() | |
| def clear_conversation(): | |
| """Clear the conversation""" | |
| return diarization_system.clear_conversation() | |
| def update_settings(threshold, max_speakers): | |
| """Update system settings""" | |
| return diarization_system.update_settings(threshold, max_speakers) | |
| def get_conversation(): | |
| """Get the current conversation""" | |
| return diarization_system.get_formatted_conversation() | |
| def get_status(): | |
| """Get system status""" | |
| return diarization_system.get_status_info() | |
| # Create Gradio interface | |
| def create_interface(): | |
| with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Monochrome()) as interface: | |
| gr.Markdown("# 🎤 Real-time Speech Recognition with Speaker Diarization") | |
| gr.Markdown("This app performs real-time speech recognition with automatic speaker identification and color-coding.") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Main conversation display | |
| conversation_output = gr.HTML( | |
| value="<i>Click 'Initialize System' to start...</i>", | |
| label="Live Conversation" | |
| ) | |
| # Control buttons | |
| with gr.Row(): | |
| init_btn = gr.Button("🔧 Initialize System", variant="secondary") | |
| start_btn = gr.Button("🎙️ Start Recording", variant="primary", interactive=False) | |
| stop_btn = gr.Button("⏹️ Stop Recording", variant="stop", interactive=False) | |
| clear_btn = gr.Button("🗑️ Clear Conversation", interactive=False) | |
| # Status display | |
| status_output = gr.Textbox( | |
| label="System Status", | |
| value="System not initialized", | |
| lines=8, | |
| interactive=False | |
| ) | |
| with gr.Column(scale=1): | |
| # Settings panel | |
| gr.Markdown("## ⚙️ Settings") | |
| threshold_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.95, | |
| step=0.05, | |
| value=DEFAULT_CHANGE_THRESHOLD, | |
| label="Speaker Change Sensitivity", | |
| info="Lower values = more sensitive to speaker changes" | |
| ) | |
| max_speakers_slider = gr.Slider( | |
| minimum=2, | |
| maximum=ABSOLUTE_MAX_SPEAKERS, | |
| step=1, | |
| value=DEFAULT_MAX_SPEAKERS, | |
| label="Maximum Number of Speakers" | |
| ) | |
| update_settings_btn = gr.Button("Update Settings") | |
| # Instructions | |
| gr.Markdown("## 📝 Instructions") | |
| gr.Markdown(""" | |
| 1. Click **Initialize System** to load models | |
| 2. Click **Start Recording** to begin processing | |
| 3. Use the FastRTC interface below to connect your microphone | |
| 4. Allow microphone access when prompted | |
| 5. Speak into your microphone | |
| 6. Watch real-time transcription with speaker labels | |
| 7. Adjust settings as needed | |
| """) | |
| # Speaker color legend | |
| gr.Markdown("## 🎨 Speaker Colors") | |
| color_info = [] | |
| for i, (color, name) in enumerate(zip(SPEAKER_COLORS, SPEAKER_COLOR_NAMES)): | |
| color_info.append(f'<span style="color:{color};">■</span> Speaker {i+1} ({name})') | |
| gr.HTML("<br>".join(color_info[:DEFAULT_MAX_SPEAKERS])) | |
| # FastRTC Integration Notice | |
| gr.Markdown(""" | |
| ## ℹ️ About FastRTC | |
| This app uses FastRTC for low-latency audio streaming. | |
| For optimal performance, use a modern browser and allow microphone access when prompted. | |
| """) | |
| # Auto-refresh conversation and status | |
| def refresh_display(): | |
| return diarization_system.get_formatted_conversation(), diarization_system.get_status_info() | |
| # Event handlers | |
| def on_initialize(): | |
| result = initialize_system() | |
| if "successfully" in result: | |
| return ( | |
| result, | |
| gr.update(interactive=True), # start_btn | |
| gr.update(interactive=True), # clear_btn | |
| get_conversation(), | |
| get_status() | |
| ) | |
| else: | |
| return ( | |
| result, | |
| gr.update(interactive=False), # start_btn | |
| gr.update(interactive=False), # clear_btn | |
| get_conversation(), | |
| get_status() | |
| ) | |
| def on_start(): | |
| result = start_recording() | |
| return ( | |
| result, | |
| gr.update(interactive=False), # start_btn | |
| gr.update(interactive=True), # stop_btn | |
| ) | |
| def on_stop(): | |
| result = stop_recording() | |
| return ( | |
| result, | |
| gr.update(interactive=True), # start_btn | |
| gr.update(interactive=False), # stop_btn | |
| ) | |
| # Connect event handlers | |
| init_btn.click( | |
| on_initialize, | |
| outputs=[status_output, start_btn, clear_btn, conversation_output, status_output] | |
| ) | |
| start_btn.click( | |
| on_start, | |
| outputs=[status_output, start_btn, stop_btn] | |
| ) | |
| stop_btn.click( | |
| on_stop, | |
| outputs=[status_output, start_btn, stop_btn] | |
| ) | |
| clear_btn.click( | |
| clear_conversation, | |
| outputs=[status_output] | |
| ) | |
| update_settings_btn.click( | |
| update_settings, | |
| inputs=[threshold_slider, max_speakers_slider], | |
| outputs=[status_output] | |
| ) | |
| # Auto-refresh every 2 seconds when recording | |
| refresh_timer = gr.Timer(2.0) | |
| refresh_timer.tick( | |
| refresh_display, | |
| outputs=[conversation_output, status_output] | |
| ) | |
| return interface | |
| # Create API router for endpoints | |
| router = APIRouter() | |
| # Health check endpoint | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "timestamp": time.time(), | |
| "system_initialized": diarization_system.encoder is not None, | |
| "recording_active": diarization_system.is_running | |
| } | |
| # API endpoint to get conversation | |
| async def get_conversation_api(): | |
| """API endpoint to get current conversation""" | |
| return { | |
| "conversation": diarization_system.get_formatted_conversation(), | |
| "status": diarization_system.get_status_info(), | |
| "is_recording": diarization_system.is_running | |
| } | |
| # API endpoint to control recording | |
| async def control_recording(action: str): | |
| """API endpoint to control recording (start/stop/clear/initialize)""" | |
| if action == "start": | |
| result = diarization_system.start_recording() | |
| elif action == "stop": | |
| result = diarization_system.stop_recording() | |
| elif action == "clear": | |
| result = diarization_system.clear_conversation() | |
| elif action == "initialize": | |
| result = initialize_system() | |
| else: | |
| return {"error": "Invalid action. Use: start, stop, clear, or initialize"} | |
| return {"result": result, "is_recording": diarization_system.is_running} | |
| # Main application setup | |
| def create_app(): | |
| """Create and configure the FastAPI app with Gradio and FastRTC""" | |
| # Create FastAPI app | |
| app = FastAPI( | |
| title="Real-time Speaker Diarization", | |
| description="Real-time speech recognition with speaker diarization using FastRTC", | |
| version="1.0.0" | |
| ) | |
| # Include API routes | |
| app.include_router(router) | |
| # Create Gradio interface | |
| gradio_interface = create_interface() | |
| # Mount Gradio interface | |
| app = gr.mount_gradio_app(app, gradio_interface, path="/") | |
| # Setup FastRTC stream | |
| try: | |
| # Create the handler | |
| handler = DiarizationHandler(diarization_system) | |
| # Get TURN credentials | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if not hf_token: | |
| print("Warning: HF_TOKEN not set. Audio streaming may not work properly.") | |
| # Use basic STUN server as fallback | |
| rtc_config = { | |
| "iceServers": [{"urls": "stun:stun.l.google.com:19302"}] | |
| } | |
| else: | |
| # Get Cloudflare TURN credentials | |
| try: | |
| turn_credentials = get_cloudflare_turn_credentials(hf_token) | |
| # Safely extract credentials from the response | |
| ice_servers = [] | |
| # Always add STUN server | |
| ice_servers.append({"urls": "stun:stun.l.google.com:19302"}) | |
| # Check for and add TURN server if available | |
| if turn_credentials and isinstance(turn_credentials, dict): | |
| # Handle different possible structures | |
| if 'iceServers' in turn_credentials: | |
| # If credentials already have iceServers, use them directly | |
| rtc_config = turn_credentials | |
| elif 'urls' in turn_credentials and isinstance(turn_credentials['urls'], list) and turn_credentials['urls']: | |
| # Structure: {urls: [...], username: "...", credential: "..."} | |
| ice_servers.append({ | |
| "urls": [f"turn:{url}" for url in turn_credentials["urls"]], | |
| "username": turn_credentials.get("username", ""), | |
| "credential": turn_credentials.get("credential", "") | |
| }) | |
| rtc_config = {"iceServers": ice_servers} | |
| elif 'url' in turn_credentials: | |
| # Structure with single URL | |
| ice_servers.append({ | |
| "urls": f"turn:{turn_credentials['url']}", | |
| "username": turn_credentials.get("username", ""), | |
| "credential": turn_credentials.get("credential", "") | |
| }) | |
| rtc_config = {"iceServers": ice_servers} | |
| else: | |
| print("Warning: Unexpected TURN credentials format. Using STUN only.") | |
| rtc_config = {"iceServers": ice_servers} | |
| else: | |
| print("Warning: Could not get TURN credentials. Using STUN only.") | |
| rtc_config = {"iceServers": ice_servers} | |
| except Exception as e: | |
| print(f"Warning: Error getting TURN credentials: {e}. Using STUN only.") | |
| rtc_config = { | |
| "iceServers": [{"urls": "stun:stun.l.google.com:19302"}] | |
| } | |
| # Create FastRTC stream | |
| stream = Stream( | |
| handler=handler, | |
| rtc_configuration=rtc_config, | |
| modality="audio", | |
| mode="send-receive" | |
| ) | |
| # Mount the FastRTC stream to the FastAPI app | |
| stream.mount(app) | |
| print("FastRTC stream configured successfully!") | |
| except Exception as e: | |
| print(f"Warning: Failed to setup FastRTC stream: {e}") | |
| print("Audio streaming will not be available.") | |
| return app | |
| # Main entry point | |
| if __name__ == "__main__": | |
| # Create the app | |
| app = create_app() | |
| # Configuration | |
| host = os.environ.get("HOST", "0.0.0.0") | |
| port = int(os.environ.get("PORT", 7860)) | |
| # Find available port if specified port is in use | |
| def find_available_port(start_port=7860, max_tries=10): | |
| """Find an available port starting from start_port""" | |
| for port_offset in range(max_tries): | |
| port = start_port + port_offset | |
| try: | |
| sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
| sock.bind(('0.0.0.0', port)) | |
| sock.close() | |
| return port | |
| except OSError: | |
| continue | |
| # If no ports are available, return a default and let the server handle the error | |
| return start_port | |
| available_port = find_available_port(port) | |
| if available_port != port: | |
| print(f"Port {port} is in use, using port {available_port} instead.") | |
| port = available_port | |
| print(f""" | |
| 🎤 Real-time Speaker Diarization Server | |
| ===================================== | |
| Starting server on: http://{host}:{port} | |
| Features: | |
| - Real-time speech recognition | |
| - Speaker diarization with color coding | |
| - FastRTC low-latency audio streaming | |
| - Web interface for easy interaction | |
| Make sure to: | |
| 1. Set HF_TOKEN environment variable for TURN server access | |
| 2. Allow microphone access in your browser | |
| 3. Use a modern browser for best performance | |
| API Endpoints: | |
| - GET /health - Health check | |
| - GET /api/conversation - Get current conversation | |
| - POST /api/control/{{action}} - Control recording (start/stop/clear/initialize) | |
| - WS /stream/webrtc - FastRTC WebRTC endpoint | |
| """) | |
| # Run the server | |
| uvicorn.run( | |
| app, | |
| host=host, | |
| port=port, | |
| log_level="info", | |
| access_log=True | |
| ) |