Spaces:
Running
Running
| import gradio as gr | |
| from faster_whisper import WhisperModel | |
| from pydub import AudioSegment | |
| import os | |
| import tempfile | |
| import time | |
| import torch | |
| from pathlib import Path | |
| import warnings | |
| import numpy as np | |
| import torchaudio | |
| import scipy.io.wavfile as wavfile | |
| from jiwer import wer, cer | |
| import re | |
| import string | |
| # Suppress warnings for cleaner output | |
| warnings.filterwarnings("ignore") | |
| # Global variables for models | |
| WHISPER_MODELS = {} | |
| DEVICE = None | |
| # Model configurations - Hebrew-focused models | |
| AVAILABLE_WHISPER_MODELS = { | |
| "ivrit-ai/faster-whisper-v2-d4": "Hebrew Faster-Whisper V2-D4 (Recommended)", | |
| "ivrit-ai/faster-whisper-v2-d3": "Hebrew Faster-Whisper V2-D3", | |
| "ivrit-ai/faster-whisper-v2-d2": "Hebrew Faster-Whisper V2-D2", | |
| "large-v3": "OpenAI Whisper Large V3 (Multilingual)", | |
| "large-v2": "OpenAI Whisper Large V2 (Multilingual)", | |
| "medium": "OpenAI Whisper Medium (Multilingual)", | |
| "small": "OpenAI Whisper Small (Multilingual)", | |
| } | |
| # Default audio and transcription | |
| DEFAULT_AUDIO = "heb.wav" | |
| DEFAULT_TRANSCRIPTION = "שלום! אנחנו נרגשים להציג לכם את יכולות הדיבור הטבעי שלנו. כאן תוכלו לביים קול, ליצור דיאלוגים מציאותיים ועוד הרבה יותר. ערכו את המקומות הללו כדי להתחיל." | |
| # Predefined audio files | |
| PREDEFINED_AUDIO_FILES = { | |
| "heb.wav": { | |
| "file": "heb.wav", | |
| "description": "Regular quality Hebrew audio", | |
| "transcription": "שלום! אנחנו נרגשים להציג לכם את יכולות הדיבור הטבעי שלנו. כאן תוכלו לביים קול, ליצור דיאלוגים מציאותיים ועוד הרבה יותר. ערכו את המקומות הללו כדי להתחיל." | |
| }, | |
| "noise.wav": { | |
| "file": "noise.wav", | |
| "description": "Noisy Hebrew audio", | |
| "transcription": "אז כך, קרנות החיסכון האלה כאילו מנסות לבנות מנדט לכל הסטארט-אפים הפרטיים.." | |
| } | |
| } | |
| def normalize_hebrew_text(text): | |
| """Normalize Hebrew text for WER calculation""" | |
| if not text: | |
| return "" | |
| # Remove diacritics (niqqud) | |
| hebrew_diacritics = "".join([chr(i) for i in range(0x0591, 0x05C8)]) | |
| text = "".join(c for c in text if c not in hebrew_diacritics) | |
| # Remove punctuation | |
| text = re.sub(r'[^\w\s]', ' ', text) | |
| # Remove extra whitespace and convert to lowercase | |
| text = ' '.join(text.split()).strip().lower() | |
| return text | |
| def calculate_wer_cer(reference, hypothesis): | |
| """Calculate WER and CER for Hebrew text""" | |
| try: | |
| # Normalize both texts | |
| ref_normalized = normalize_hebrew_text(reference) | |
| hyp_normalized = normalize_hebrew_text(hypothesis) | |
| if not ref_normalized or not hyp_normalized: | |
| return float('inf'), float('inf'), ref_normalized, hyp_normalized | |
| # Calculate WER and CER | |
| word_error_rate = wer(ref_normalized, hyp_normalized) | |
| char_error_rate = cer(ref_normalized, hyp_normalized) | |
| return word_error_rate, char_error_rate, ref_normalized, hyp_normalized | |
| except Exception as e: | |
| print(f"Error calculating WER/CER: {e}") | |
| return float('inf'), float('inf'), "", "" | |
| def initialize_whisper_model(model_id, progress=gr.Progress()): | |
| """Initialize a specific Whisper model with progress indication""" | |
| global WHISPER_MODELS, DEVICE | |
| try: | |
| # Skip if model is already loaded | |
| if model_id in WHISPER_MODELS and WHISPER_MODELS[model_id] is not None: | |
| print(f"✅ Model {model_id} already loaded") | |
| return True | |
| # Determine device | |
| if DEVICE is None: | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| compute_type = "float16" if torch.cuda.is_available() else "int8" | |
| print(f"🔧 Loading Whisper model: {model_id} on {DEVICE}") | |
| progress(0.3, desc=f"Loading {model_id}...") | |
| # Initialize Whisper model (faster-whisper) | |
| WHISPER_MODELS[model_id] = WhisperModel( | |
| model_id, | |
| device=DEVICE, | |
| compute_type=compute_type | |
| ) | |
| progress(1.0, desc=f"Loaded {model_id} successfully!") | |
| print(f"✅ Model {model_id} initialized successfully!") | |
| return True | |
| except Exception as e: | |
| print(f"❌ Error initializing model {model_id}: {str(e)}") | |
| WHISPER_MODELS[model_id] = None | |
| return False | |
| def transcribe_audio_with_model(audio_file, model_id, language="he"): | |
| """Transcribe audio using a specific Whisper model""" | |
| try: | |
| # Initialize model if needed | |
| if model_id not in WHISPER_MODELS or WHISPER_MODELS[model_id] is None: | |
| success = initialize_whisper_model(model_id) | |
| if not success: | |
| return "", f"Failed to load model {model_id}" | |
| model = WHISPER_MODELS[model_id] | |
| print(f"🎤 Transcribing with {model_id}: {Path(audio_file).name}") | |
| # Transcribe with faster-whisper | |
| segments, info = model.transcribe( | |
| audio_file, | |
| language=language, | |
| beam_size=5, | |
| best_of=5, | |
| temperature=0.0 | |
| ) | |
| # Collect all segments | |
| transcript_text = "" | |
| for segment in segments: | |
| transcript_text += segment.text + " " | |
| transcript_text = transcript_text.strip() | |
| print(f"✅ Transcription completed with {model_id}. Length: {len(transcript_text)} characters") | |
| return transcript_text, f"Success - Duration: {info.duration:.1f}s" | |
| except Exception as e: | |
| print(f"❌ Error transcribing with {model_id}: {str(e)}") | |
| return "", f"Error: {str(e)}" | |
| def evaluate_all_models(audio_file, reference_text, selected_models, progress=gr.Progress()): | |
| """Evaluate all selected models and calculate WER/CER""" | |
| if not audio_file or not reference_text.strip(): | |
| return "❌ Please provide both audio file and reference transcription", [] | |
| if not selected_models: | |
| return "❌ Please select at least one model to evaluate", [] | |
| results = [] | |
| detailed_results = [] | |
| print(f"🎯 Starting WER evaluation with {len(selected_models)} models...") | |
| for i, model_id in enumerate(selected_models): | |
| progress((i + 1) / len(selected_models), desc=f"Evaluating {model_id}...") | |
| print(f"\n🔄 Evaluating model: {model_id}") | |
| # Transcribe with current model | |
| start_time = time.time() | |
| transcript, status = transcribe_audio_with_model(audio_file, model_id) | |
| transcription_time = time.time() - start_time | |
| if transcript: | |
| # Calculate WER and CER | |
| word_error_rate, char_error_rate, ref_norm, hyp_norm = calculate_wer_cer(reference_text, transcript) | |
| # Store results | |
| result = { | |
| 'model': model_id, | |
| 'model_name': AVAILABLE_WHISPER_MODELS.get(model_id, model_id), | |
| 'transcript': transcript, | |
| 'wer': word_error_rate, | |
| 'cer': char_error_rate, | |
| 'time': transcription_time, | |
| 'status': status, | |
| 'ref_normalized': ref_norm, | |
| 'hyp_normalized': hyp_norm | |
| } | |
| results.append(result) | |
| print(f"✅ {model_id}: WER={word_error_rate:.3f}, CER={char_error_rate:.3f}") | |
| else: | |
| print(f"❌ {model_id}: Transcription failed") | |
| results.append({ | |
| 'model': model_id, | |
| 'model_name': AVAILABLE_WHISPER_MODELS.get(model_id, model_id), | |
| 'transcript': 'FAILED', | |
| 'wer': float('inf'), | |
| 'cer': float('inf'), | |
| 'time': transcription_time, | |
| 'status': status, | |
| 'ref_normalized': '', | |
| 'hyp_normalized': '' | |
| }) | |
| # Sort results by WER (best first) | |
| results.sort(key=lambda x: x['wer']) | |
| # Create summary report | |
| summary_report = "# 📊 WER Evaluation Results\n\n" | |
| summary_report += f"**Audio File:** {os.path.basename(audio_file)}\n" | |
| summary_report += f"**Reference Text:** {reference_text[:100]}...\n" | |
| summary_report += f"**Models Tested:** {len(selected_models)}\n" | |
| summary_report += f"**Device:** {DEVICE}\n\n" | |
| # Add results summary | |
| summary_report += "## Results Summary (sorted by WER)\n\n" | |
| for i, result in enumerate(results): | |
| if result['wer'] == float('inf'): | |
| wer_display = "FAILED" | |
| cer_display = "FAILED" | |
| else: | |
| wer_display = f"{result['wer']:.3f} ({result['wer']*100:.1f}%)" | |
| cer_display = f"{result['cer']:.3f} ({result['cer']*100:.1f}%)" | |
| summary_report += f"**{i+1}. {result['model_name']}**\n" | |
| summary_report += f"- WER: {wer_display}\n" | |
| summary_report += f"- CER: {cer_display}\n" | |
| summary_report += f"- Processing Time: {result['time']:.2f}s\n\n" | |
| # Create table data for Gradio with WER column | |
| table_data = [] | |
| # Add ground truth row | |
| table_data.append(["Ground Truth", reference_text, "N/A", "N/A"]) | |
| # Add model results | |
| for result in results: | |
| if result['wer'] == float('inf'): | |
| wer_display = "FAILED" | |
| cer_display = "FAILED" | |
| else: | |
| wer_display = f"{result['wer']:.3f}" | |
| cer_display = f"{result['cer']:.3f}" | |
| table_data.append([ | |
| result['model_name'], | |
| result['transcript'], | |
| wer_display, | |
| cer_display | |
| ]) | |
| print("✅ WER evaluation completed!") | |
| return summary_report, table_data | |
| def create_gradio_interface(): | |
| """Create and configure the Gradio interface""" | |
| # Initialize device info | |
| global DEVICE | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| status_msg = f"""✅ Hebrew STT WER Evaluation Tool Ready! | |
| 🔧 Device: {DEVICE} | |
| 📱 Available Models: {len(AVAILABLE_WHISPER_MODELS)} | |
| 🎯 Purpose: Compare WER performance across Hebrew STT models""" | |
| # Create Gradio interface | |
| with gr.Blocks( | |
| title="Hebrew STT WER Evaluation", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { max-width: 1600px !important; } | |
| .evaluation-section { | |
| border: 2px solid #e0e0e0; | |
| border-radius: 10px; | |
| padding: 15px; | |
| margin: 10px 0; | |
| } | |
| """ | |
| ) as demo: | |
| gr.Markdown(""" | |
| # 📊 Hebrew STT WER Evaluation Tool | |
| Upload an audio file and reference transcription to test the performance of different Whisper models on Hebrew speech-to-text tasks. | |
| """) | |
| # Status section | |
| with gr.Row(): | |
| status_display = gr.Textbox( | |
| label="🔧 System Status", | |
| value=status_msg, | |
| interactive=False, | |
| lines=4 | |
| ) | |
| # Input section | |
| with gr.Row(): | |
| # Audio and Reference Input | |
| with gr.Column(scale=1, elem_classes=["evaluation-section"]): | |
| gr.Markdown("### 📁 Evaluation Inputs") | |
| # Predefined audio selection | |
| predefined_audio_dropdown = gr.Dropdown( | |
| label="🎵 Select Predefined Audio File", | |
| choices=[(f"{k} - {v['description']}", k) for k, v in PREDEFINED_AUDIO_FILES.items()], | |
| value="web01.wav", | |
| interactive=True | |
| ) | |
| # OR upload custom audio | |
| gr.Markdown("**OR**") | |
| audio_input = gr.Audio( | |
| label="🎵 Upload Custom Audio File - Upload Hebrew audio file for transcription", | |
| type="filepath", | |
| value=None | |
| ) | |
| reference_text = gr.Textbox( | |
| label="📝 Reference Transcription (Ground Truth) - The correct transcription for WER calculation", | |
| placeholder="Enter the correct transcription of the audio file...", | |
| value=DEFAULT_TRANSCRIPTION, | |
| lines=5 | |
| ) | |
| # Model selection | |
| model_selection = gr.CheckboxGroup( | |
| label="🤖 Select Models to Test - Choose which models to evaluate (2-4 recommended)", | |
| choices=list(AVAILABLE_WHISPER_MODELS.keys()), | |
| value=["ivrit-ai/faster-whisper-v2-d4", "large-v3"] | |
| ) | |
| with gr.Row(): | |
| load_models_btn = gr.Button( | |
| "🔧 Pre-load Selected Models (Optional)", | |
| variant="secondary" | |
| ) | |
| evaluate_btn = gr.Button( | |
| "🎯 Run WER Evaluation", | |
| variant="primary" | |
| ) | |
| # Quick info panel | |
| with gr.Column(scale=1, elem_classes=["evaluation-section"]): | |
| gr.Markdown("### 📊 WER Evaluation Results") | |
| gr.Markdown(""" | |
| **What is WER?** | |
| Word Error Rate - measures transcription accuracy at word level | |
| **How it works:** | |
| 1. Upload Hebrew audio file | |
| 2. Enter correct transcription | |
| 3. Select models to test | |
| 4. Tool transcribes with each model | |
| 5. Calculates WER & CER for each model | |
| 6. Ranks models by performance | |
| **Evaluation Metrics:** | |
| - **WER**: Word-level errors (%) | |
| - **CER**: Character-level errors (%) | |
| - **Processing Time**: Transcription speed | |
| **Tips:** | |
| - Use high-quality audio | |
| - Ensure reference transcription is accurate | |
| - Select 2-4 models for comparison | |
| - Lower WER = better performance | |
| """) | |
| # Results section | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📊 WER Evaluation Results") | |
| results_output = gr.Markdown( | |
| value="Evaluation results will appear here after running the test..." | |
| ) | |
| results_table = gr.Dataframe( | |
| label="Transcription Comparison", | |
| headers=["Model", "Transcription", "WER", "CER"], | |
| datatype=["str", "str", "str", "str"], | |
| col_count=(4, "fixed") | |
| ) | |
| # Event handlers | |
| def load_predefined_audio(selected_file): | |
| """Load predefined audio file and its transcription""" | |
| if selected_file and selected_file in PREDEFINED_AUDIO_FILES: | |
| audio_data = PREDEFINED_AUDIO_FILES[selected_file] | |
| return audio_data["file"], audio_data["transcription"] | |
| return None, DEFAULT_TRANSCRIPTION | |
| def load_selected_models(selected_models, progress=gr.Progress()): | |
| """Pre-load selected models""" | |
| if not selected_models: | |
| return "❌ No models selected" | |
| status_msg = f"🔧 Loading {len(selected_models)} models...\n\n" | |
| for model_id in selected_models: | |
| try: | |
| status_msg += f"⏳ Loading {model_id}...\n" | |
| success = initialize_whisper_model(model_id, progress) | |
| if success: | |
| status_msg += f"✅ {model_id} loaded successfully\n" | |
| else: | |
| status_msg += f"❌ Error loading {model_id}\n" | |
| status_msg += "\n" | |
| except Exception as e: | |
| status_msg += f"❌ Error loading {model_id}: {str(e)}\n\n" | |
| loaded_count = len([m for m in selected_models if m in WHISPER_MODELS and WHISPER_MODELS[m] is not None]) | |
| status_msg += f"✅ Model loading complete! Available: {loaded_count}/{len(selected_models)}" | |
| return status_msg | |
| def run_wer_evaluation(audio_file, reference, selected_models, predefined_file, progress=gr.Progress()): | |
| """Run the complete WER evaluation""" | |
| # Use predefined file if no custom audio is uploaded | |
| if not audio_file and predefined_file: | |
| audio_file = PREDEFINED_AUDIO_FILES[predefined_file]["file"] | |
| if not audio_file: | |
| return "❌ Please select a predefined audio file or upload a custom one", [] | |
| if not reference or not reference.strip(): | |
| return "❌ Please enter reference transcription", [] | |
| if not selected_models: | |
| return "❌ Please select at least one model", [] | |
| # Run evaluation | |
| results, table_data = evaluate_all_models(audio_file, reference, selected_models, progress) | |
| return results, table_data | |
| # Connect events | |
| predefined_audio_dropdown.change( | |
| fn=load_predefined_audio, | |
| inputs=[predefined_audio_dropdown], | |
| outputs=[audio_input, reference_text] | |
| ) | |
| load_models_btn.click( | |
| fn=load_selected_models, | |
| inputs=[model_selection], | |
| outputs=[status_display] | |
| ) | |
| evaluate_btn.click( | |
| fn=run_wer_evaluation, | |
| inputs=[audio_input, reference_text, model_selection, predefined_audio_dropdown], | |
| outputs=[results_output, results_table] | |
| ) | |
| # Footer | |
| gr.Markdown(""" | |
| --- | |
| ### 🔧 Technical Information | |
| - **STT Engine**: Faster-Whisper (optimized for Hebrew) | |
| - **Evaluation Metrics**: WER (Word Error Rate) and CER (Character Error Rate) | |
| - **Text Normalization**: Removes diacritics, punctuation, and extra whitespace | |
| - **Purpose**: Compare performance of different transcription models on Hebrew text | |
| ### 📦 Setup Instructions | |
| ```bash | |
| # Install dependencies | |
| pip install gradio faster-whisper torch torchaudio jiwer | |
| # For GPU support (recommended) | |
| pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 | |
| ``` | |
| ### 📊 Output Format | |
| The tool displays: | |
| - Model ranking by WER | |
| - Detailed results for each model | |
| - Processing times | |
| - Normalized transcription comparison | |
| """) | |
| return demo | |
| # Launch the app | |
| if __name__ == "__main__": | |
| print("🎯 Launching Hebrew STT WER Evaluation Tool...") | |
| demo = create_gradio_interface() | |
| # Launch the demo | |
| demo.launch( | |
| share=False, # Set to True to create a public link | |
| debug=True, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) |