Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| from scipy.spatial.distance import cosine | |
| import tempfile | |
| import os | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| try: | |
| from transformers import pipeline | |
| except ImportError: | |
| print("transformers not found. Install with: pip install transformers") | |
| # Configuration | |
| class Config: | |
| # Audio settings | |
| SAMPLE_RATE = 16000 | |
| # Speaker detection | |
| CHANGE_THRESHOLD = 0.65 | |
| MAX_SPEAKERS = 4 | |
| MIN_SEGMENT_DURATION = 1.0 | |
| EMBEDDING_HISTORY_SIZE = 3 | |
| SPEAKER_MEMORY_SIZE = 20 | |
| # Console colors for speakers (HTML version) | |
| SPEAKER_COLORS = [ | |
| "#FFD700", # Gold | |
| "#FF6B6B", # Red | |
| "#4ECDC4", # Teal | |
| "#45B7D1", # Blue | |
| "#96CEB4", # Mint | |
| "#FFEAA7", # Light Yellow | |
| "#DDA0DD", # Plum | |
| "#98D8C8", # Mint Green | |
| ] | |
| class SpeakerEncoder: | |
| """Simplified speaker encoder using torchaudio transforms""" | |
| def __init__(self, device="cpu"): | |
| self.device = device | |
| self.embedding_dim = 128 | |
| self.model_loaded = False | |
| self._setup_model() | |
| def _setup_model(self): | |
| """Setup a simple MFCC-based feature extractor""" | |
| try: | |
| self.mfcc_transform = torchaudio.transforms.MFCC( | |
| sample_rate=Config.SAMPLE_RATE, | |
| n_mfcc=13, | |
| melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 23} | |
| ).to(self.device) | |
| self.model_loaded = True | |
| print("Simple MFCC-based encoder initialized") | |
| except Exception as e: | |
| print(f"Error setting up encoder: {e}") | |
| self.model_loaded = False | |
| def extract_embedding(self, audio): | |
| """Extract speaker embedding from audio""" | |
| if not self.model_loaded: | |
| return np.zeros(self.embedding_dim) | |
| try: | |
| # Ensure audio is float32 and normalized | |
| if isinstance(audio, np.ndarray): | |
| audio = torch.from_numpy(audio).float() | |
| # Normalize audio | |
| if audio.abs().max() > 0: | |
| audio = audio / audio.abs().max() | |
| # Add batch dimension if needed | |
| if audio.dim() == 1: | |
| audio = audio.unsqueeze(0) | |
| # Extract MFCC features | |
| with torch.no_grad(): | |
| mfcc = self.mfcc_transform(audio) | |
| # Simple statistics-based embedding | |
| embedding = torch.cat([ | |
| mfcc.mean(dim=2).flatten(), | |
| mfcc.std(dim=2).flatten(), | |
| mfcc.max(dim=2)[0].flatten(), | |
| mfcc.min(dim=2)[0].flatten() | |
| ]) | |
| # Pad or truncate to fixed size | |
| if embedding.size(0) > self.embedding_dim: | |
| embedding = embedding[:self.embedding_dim] | |
| elif embedding.size(0) < self.embedding_dim: | |
| padding = torch.zeros(self.embedding_dim - embedding.size(0)) | |
| embedding = torch.cat([embedding, padding]) | |
| return embedding.cpu().numpy() | |
| except Exception as e: | |
| print(f"Error extracting embedding: {e}") | |
| return np.zeros(self.embedding_dim) | |
| class SpeakerDetector: | |
| """Speaker change detection using embeddings""" | |
| def __init__(self, threshold=Config.CHANGE_THRESHOLD, max_speakers=Config.MAX_SPEAKERS): | |
| self.threshold = threshold | |
| self.max_speakers = max_speakers | |
| self.current_speaker = 0 | |
| self.speaker_embeddings = [[] for _ in range(max_speakers)] | |
| self.speaker_centroids = [None] * max_speakers | |
| self.active_speakers = {0} | |
| def reset(self): | |
| """Reset speaker detection state""" | |
| self.current_speaker = 0 | |
| self.speaker_embeddings = [[] for _ in range(self.max_speakers)] | |
| self.speaker_centroids = [None] * self.max_speakers | |
| self.active_speakers = {0} | |
| def detect_speaker(self, embedding): | |
| """Detect current speaker from embedding""" | |
| # Initialize first speaker | |
| if not self.speaker_embeddings[0]: | |
| self.speaker_embeddings[0].append(embedding) | |
| self.speaker_centroids[0] = embedding.copy() | |
| return 0, 1.0 | |
| # Calculate similarity with current speaker | |
| current_centroid = self.speaker_centroids[self.current_speaker] | |
| if current_centroid is not None: | |
| similarity = 1.0 - cosine(embedding, current_centroid) | |
| else: | |
| similarity = 0.0 | |
| # Check for speaker change | |
| if similarity < self.threshold: | |
| # Find best matching existing speaker | |
| best_speaker = self.current_speaker | |
| best_similarity = similarity | |
| for speaker_id in self.active_speakers: | |
| if speaker_id == self.current_speaker: | |
| continue | |
| centroid = self.speaker_centroids[speaker_id] | |
| if centroid is not None: | |
| sim = 1.0 - cosine(embedding, centroid) | |
| if sim > best_similarity and sim > self.threshold: | |
| best_similarity = sim | |
| best_speaker = speaker_id | |
| # Create new speaker if no good match and slots available | |
| if (best_speaker == self.current_speaker and | |
| len(self.active_speakers) < self.max_speakers): | |
| for new_id in range(self.max_speakers): | |
| if new_id not in self.active_speakers: | |
| best_speaker = new_id | |
| best_similarity = 0.0 | |
| self.active_speakers.add(new_id) | |
| break | |
| # Update current speaker if changed | |
| if best_speaker != self.current_speaker: | |
| self.current_speaker = best_speaker | |
| similarity = best_similarity | |
| # Update speaker model | |
| self._update_speaker_model(self.current_speaker, embedding) | |
| return self.current_speaker, similarity | |
| def _update_speaker_model(self, speaker_id, embedding): | |
| """Update speaker model with new embedding""" | |
| self.speaker_embeddings[speaker_id].append(embedding) | |
| # Keep only recent embeddings | |
| if len(self.speaker_embeddings[speaker_id]) > Config.SPEAKER_MEMORY_SIZE: | |
| self.speaker_embeddings[speaker_id] = \ | |
| self.speaker_embeddings[speaker_id][-Config.SPEAKER_MEMORY_SIZE:] | |
| # Update centroid | |
| if self.speaker_embeddings[speaker_id]: | |
| self.speaker_centroids[speaker_id] = np.mean( | |
| self.speaker_embeddings[speaker_id], axis=0 | |
| ) | |
| class AudioProcessor: | |
| """Handles audio processing and transcription""" | |
| def __init__(self): | |
| self.encoder = SpeakerEncoder() | |
| self.detector = SpeakerDetector() | |
| # Initialize Whisper model for transcription | |
| try: | |
| self.transcriber = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-base", | |
| chunk_length_s=30, | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| print("Whisper model loaded successfully") | |
| except Exception as e: | |
| print(f"Error loading Whisper model: {e}") | |
| self.transcriber = None | |
| def process_audio_file(self, audio_file): | |
| """Process uploaded audio file""" | |
| if audio_file is None: | |
| return "Please upload an audio file.", "" | |
| try: | |
| # Reset speaker detection for new file | |
| self.detector.reset() | |
| # Load audio file | |
| waveform, sample_rate = torchaudio.load(audio_file) | |
| # Convert to mono if stereo | |
| if waveform.shape[0] > 1: | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| # Resample to 16kHz if needed | |
| if sample_rate != Config.SAMPLE_RATE: | |
| resampler = torchaudio.transforms.Resample(sample_rate, Config.SAMPLE_RATE) | |
| waveform = resampler(waveform) | |
| # Convert to numpy | |
| audio_data = waveform.squeeze().numpy() | |
| # Transcribe entire audio | |
| if self.transcriber: | |
| transcription_result = self.transcriber(audio_file) | |
| full_transcription = transcription_result['text'] | |
| else: | |
| full_transcription = "Transcription service unavailable" | |
| # Process audio in chunks for speaker detection | |
| chunk_duration = 3.0 # 3 second chunks | |
| chunk_samples = int(chunk_duration * Config.SAMPLE_RATE) | |
| results = [] | |
| for i in range(0, len(audio_data), chunk_samples // 2): # 50% overlap | |
| chunk = audio_data[i:i + chunk_samples] | |
| if len(chunk) < Config.SAMPLE_RATE: # Skip chunks less than 1 second | |
| continue | |
| # Extract speaker embedding | |
| embedding = self.encoder.extract_embedding(chunk) | |
| speaker_id, similarity = self.detector.detect_speaker(embedding) | |
| # Get timestamp | |
| start_time = i / Config.SAMPLE_RATE | |
| end_time = (i + len(chunk)) / Config.SAMPLE_RATE | |
| # Transcribe chunk | |
| if self.transcriber and len(chunk) > Config.SAMPLE_RATE: | |
| # Save chunk temporarily for transcription | |
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file: | |
| torchaudio.save(tmp_file.name, torch.tensor(chunk).unsqueeze(0), Config.SAMPLE_RATE) | |
| chunk_result = self.transcriber(tmp_file.name) | |
| chunk_text = chunk_result['text'].strip() | |
| os.unlink(tmp_file.name) # Clean up temp file | |
| else: | |
| chunk_text = "" | |
| if chunk_text: # Only add if there's actual text | |
| results.append({ | |
| 'speaker_id': speaker_id, | |
| 'start_time': start_time, | |
| 'end_time': end_time, | |
| 'text': chunk_text, | |
| 'similarity': similarity | |
| }) | |
| # Format results | |
| formatted_output = self._format_results(results) | |
| return formatted_output, full_transcription | |
| except Exception as e: | |
| return f"Error processing audio: {str(e)}", "" | |
| def _format_results(self, results): | |
| """Format results with speaker colors""" | |
| if not results: | |
| return "No speech detected in the audio file." | |
| formatted_lines = [] | |
| formatted_lines.append("๐ค **Speaker Diarization Results**\n") | |
| for result in results: | |
| speaker_id = result['speaker_id'] | |
| start_time = result['start_time'] | |
| end_time = result['end_time'] | |
| text = result['text'] | |
| similarity = result['similarity'] | |
| color = SPEAKER_COLORS[speaker_id % len(SPEAKER_COLORS)] | |
| # Format timestamp | |
| start_min, start_sec = divmod(int(start_time), 60) | |
| end_min, end_sec = divmod(int(end_time), 60) | |
| timestamp = f"[{start_min:02d}:{start_sec:02d} - {end_min:02d}:{end_sec:02d}]" | |
| # Create colored HTML output | |
| formatted_lines.append( | |
| f'<div style="margin-bottom: 10px; padding: 8px; border-left: 4px solid {color}; background-color: {color}20;">' | |
| f'<strong style="color: {color};">Speaker {speaker_id + 1}</strong> ' | |
| f'<span style="color: #666; font-size: 0.9em;">{timestamp}</span><br>' | |
| f'<span style="color: #333;">{text}</span>' | |
| f'</div>' | |
| ) | |
| return "".join(formatted_lines) | |
| # Global processor instance | |
| processor = AudioProcessor() | |
| def process_audio(audio_file, sensitivity): | |
| """Process audio file with speaker detection""" | |
| if audio_file is None: | |
| return "Please upload an audio file.", "" | |
| # Update sensitivity | |
| processor.detector.threshold = sensitivity | |
| # Process the audio | |
| diarized_output, full_transcription = processor.process_audio_file(audio_file) | |
| return diarized_output, full_transcription | |
| # Create Gradio interface | |
| def create_interface(): | |
| """Create Gradio interface""" | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(), | |
| title="Speaker Diarization & Transcription", | |
| css=""" | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| } | |
| .speaker-output { | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| } | |
| """ | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # ๐๏ธ Speaker Diarization & Transcription | |
| Upload an audio file to automatically detect different speakers and transcribe their speech. | |
| The system will identify speaker changes and display each speaker's text in different colors. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| audio_input = gr.Audio( | |
| label="Upload Audio File", | |
| type="filepath", | |
| sources=["upload", "microphone"] | |
| ) | |
| sensitivity_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.65, | |
| step=0.05, | |
| label="Speaker Change Sensitivity", | |
| info="Lower values = more sensitive to speaker changes" | |
| ) | |
| process_btn = gr.Button("๐ฏ Process Audio", variant="primary", size="lg") | |
| gr.Markdown( | |
| """ | |
| ### Instructions: | |
| 1. Upload an audio file (WAV, MP3, etc.) | |
| 2. Adjust sensitivity if needed | |
| 3. Click "Process Audio" | |
| 4. View results with speaker colors | |
| ### Tips: | |
| - Works best with clear speech | |
| - Supports multiple file formats | |
| - Different speakers shown in different colors | |
| - Processing may take a moment for longer files | |
| """ | |
| ) | |
| with gr.Column(scale=2): | |
| with gr.Tabs(): | |
| with gr.TabItem("๐จ Speaker Diarization"): | |
| diarized_output = gr.HTML( | |
| label="Speaker Diarization Results", | |
| elem_classes=["speaker-output"] | |
| ) | |
| with gr.TabItem("๐ Full Transcription"): | |
| full_transcription = gr.Textbox( | |
| label="Complete Transcription", | |
| lines=15, | |
| max_lines=20, | |
| show_copy_button=True | |
| ) | |
| # Event handlers | |
| process_btn.click( | |
| fn=process_audio, | |
| inputs=[audio_input, sensitivity_slider], | |
| outputs=[diarized_output, full_transcription], | |
| show_progress=True | |
| ) | |
| # Auto-process when audio is uploaded | |
| audio_input.change( | |
| fn=process_audio, | |
| inputs=[audio_input, sensitivity_slider], | |
| outputs=[diarized_output, full_transcription], | |
| show_progress=True | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### About | |
| This application uses: | |
| - **MFCC features** for speaker embedding extraction | |
| - **Cosine similarity** for speaker change detection | |
| - **OpenAI Whisper** for speech-to-text transcription | |
| - **Gradio** for the web interface | |
| **Note**: This is a simplified speaker diarization system. For production use, | |
| consider more advanced speaker embedding models like speechbrain or pyannote.audio. | |
| """ | |
| ) | |
| return demo | |
| # Create and launch the interface | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True | |
| ) | |