import torch import onnx import onnxruntime import numpy as np from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC from typing import Dict, Tuple import librosa import os class Wav2Vec2ONNXConverter: """Convert Wav2Vec2 model to ONNX format""" def __init__(self, model_name: str = "facebook/wav2vec2-base-960h"): """Initialize the converter with the specified model""" print(f"Loading Wav2Vec2 model: {model_name}") self.model_name = model_name self.processor = Wav2Vec2Processor.from_pretrained(model_name) self.model = Wav2Vec2ForCTC.from_pretrained(model_name) # Disable flash attention and scaled_dot_product_attention for ONNX compatibility if hasattr(self.model.config, 'use_flash_attention_2'): self.model.config.use_flash_attention_2 = False # Force model to use standard attention if hasattr(self.model, 'wav2vec2') and hasattr(self.model.wav2vec2, 'encoder'): for layer in self.model.wav2vec2.encoder.layers: if hasattr(layer.attention, 'attention_dropout'): # Ensure standard attention is used layer.attention.attention_dropout = torch.nn.Dropout(layer.attention.attention_dropout.p) self.model.eval() self.sample_rate = 16000 print("Model loaded successfully") def convert_to_onnx(self, onnx_path: str = "wav2vec2_model.onnx", input_length: int = 160000, # 10 seconds at 16kHz opset_version: int = 14) -> str: """ Convert the Wav2Vec2 model to ONNX format Args: onnx_path: Path to save the ONNX model input_length: Length of input audio (samples) opset_version: ONNX opset version Returns: Path to the saved ONNX model """ print(f"Converting model to ONNX format...") # Create dummy input dummy_input = torch.randn(1, input_length, dtype=torch.float32) # Input names and dynamic axes input_names = ["input_values"] output_names = ["logits"] # Dynamic axes for variable length input dynamic_axes = { "input_values": {0: "batch_size", 1: "sequence_length"}, "logits": {0: "batch_size", 1: "sequence_length"} } try: # Disable torch optimizations that may cause ONNX issues with torch.no_grad(): # Set model to evaluation mode and disable dropout self.model.eval() for module in self.model.modules(): if isinstance(module, torch.nn.Dropout): module.p = 0.0 # Export to ONNX torch.onnx.export( self.model, dummy_input, onnx_path, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, opset_version=opset_version, do_constant_folding=True, verbose=False, export_params=True, training=torch.onnx.TrainingMode.EVAL, operator_export_type=torch.onnx.OperatorExportTypes.ONNX ) print(f"Model successfully exported to: {onnx_path}") # Verify the exported model self._verify_onnx_model(onnx_path, dummy_input) return onnx_path except Exception as e: print(f"Error during ONNX conversion: {e}") raise def _verify_onnx_model(self, onnx_path: str, test_input: torch.Tensor): """Verify the exported ONNX model""" print("Verifying ONNX model...") try: # Load and check ONNX model onnx_model = onnx.load(onnx_path) onnx.checker.check_model(onnx_model) print("✓ ONNX model structure is valid") # Test inference with ONNX Runtime ort_session = onnxruntime.InferenceSession(onnx_path) # Get model input/output info input_name = ort_session.get_inputs()[0].name output_name = ort_session.get_outputs()[0].name print(f"✓ Input name: {input_name}") print(f"✓ Output name: {output_name}") # Run inference ort_inputs = {input_name: test_input.numpy()} ort_outputs = ort_session.run([output_name], ort_inputs) # Compare with original PyTorch model with torch.no_grad(): torch_output = self.model(test_input) torch_logits = torch_output.logits # Check output similarity onnx_logits = ort_outputs[0] max_diff = np.max(np.abs(torch_logits.numpy() - onnx_logits)) print(f"✓ Maximum difference between PyTorch and ONNX: {max_diff:.6f}") if max_diff < 1e-4: print("✓ ONNX model verification successful!") else: print("⚠ Warning: Large difference detected between models") except Exception as e: print(f"Error during verification: {e}") raise class Wav2Vec2ONNXInference: """ONNX inference class for Wav2Vec2""" def __init__(self, onnx_path: str, processor_name: str = "facebook/wav2vec2-base-960h"): """Initialize ONNX inference""" print(f"Loading ONNX model from: {onnx_path}") # Load processor for tokenization self.processor = Wav2Vec2Processor.from_pretrained(processor_name) # Create ONNX Runtime session self.session = onnxruntime.InferenceSession(onnx_path) self.input_name = self.session.get_inputs()[0].name self.output_name = self.session.get_outputs()[0].name self.sample_rate = 16000 print("ONNX model loaded successfully") def transcribe(self, audio_path: str) -> Dict: """Transcribe audio using ONNX model""" try: # Load audio speech, sr = librosa.load(audio_path, sr=self.sample_rate) # Prepare input input_values = self.processor( speech, sampling_rate=self.sample_rate, return_tensors="np" ).input_values # Run ONNX inference ort_inputs = {self.input_name: input_values} ort_outputs = self.session.run([self.output_name], ort_inputs) logits = ort_outputs[0] # Decode predictions predicted_ids = np.argmax(logits, axis=-1) transcription = self.processor.batch_decode(predicted_ids)[0] # Calculate confidence scores confidence_scores = np.max(self._softmax(logits), axis=-1)[0] return { "transcription": transcription, "confidence_scores": confidence_scores[:100].tolist(), # Limit for JSON "predicted_ids": predicted_ids[0].tolist() } except Exception as e: print(f"Transcription error: {e}") return { "transcription": "", "confidence_scores": [], "predicted_ids": [] } def _softmax(self, x): """Apply softmax to logits""" exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True)) return exp_x / np.sum(exp_x, axis=-1, keepdims=True) # Example usage and testing def main(): """Example usage of the converter""" # Method 1: Try standard conversion try: print("Method 1: Standard conversion...") converter = Wav2Vec2ONNXConverter("facebook/wav2vec2-base-960h") onnx_path = converter.convert_to_onnx( onnx_path="wav2vec2_asr.onnx", input_length=160000, # 10 seconds opset_version=14 # Updated to version 14 for compatibility ) print("✓ Standard conversion successful!") except Exception as e: print(f"✗ Standard conversion failed: {e}") print("\nMethod 2: Trying fallback approach...") try: # Method 2: Use compatible model creation model, processor = create_compatible_model("facebook/wav2vec2-base-960h") onnx_path = export_with_fallback( model, processor, "wav2vec2_asr_fallback.onnx", input_length=160000 ) print("✓ Fallback conversion successful!") except Exception as e2: print(f"✗ All conversion methods failed: {e2}") return # Test ONNX inference print("\nTesting ONNX inference...") try: onnx_inference = Wav2Vec2ONNXInference(onnx_path) print("✓ ONNX model loaded successfully for inference") # Create a test audio file (or use your own) # result = onnx_inference.transcribe("test_audio.wav") # print("Transcription:", result["transcription"]) except Exception as e: print(f"✗ ONNX inference test failed: {e}") print("Conversion process completed!") # Additional utility functions def create_compatible_model(model_name: str = "facebook/wav2vec2-base-960h"): """Create a Wav2Vec2 model compatible with ONNX export""" from transformers import Wav2Vec2Config # Load config and modify for ONNX compatibility config = Wav2Vec2Config.from_pretrained(model_name) # Disable features that may cause ONNX issues if hasattr(config, 'use_flash_attention_2'): config.use_flash_attention_2 = False if hasattr(config, 'torch_dtype'): config.torch_dtype = torch.float32 # Load model with modified config model = Wav2Vec2ForCTC.from_pretrained(model_name, config=config, torch_dtype=torch.float32) processor = Wav2Vec2Processor.from_pretrained(model_name) return model, processor def export_with_fallback(model, processor, onnx_path: str, input_length: int = 160000): """Export model with fallback options for different opset versions""" dummy_input = torch.randn(1, input_length, dtype=torch.float32) input_names = ["input_values"] output_names = ["logits"] dynamic_axes = { "input_values": {0: "batch_size", 1: "sequence_length"}, "logits": {0: "batch_size", 1: "sequence_length"} } # Try different opset versions opset_versions = [14, 13, 12, 11] for opset_version in opset_versions: try: print(f"Trying ONNX export with opset version {opset_version}...") with torch.no_grad(): model.eval() # Disable all dropouts for module in model.modules(): if isinstance(module, torch.nn.Dropout): module.p = 0.0 torch.onnx.export( model, dummy_input, onnx_path, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, opset_version=opset_version, do_constant_folding=True, verbose=False, export_params=True, training=torch.onnx.TrainingMode.EVAL ) print(f"✓ Successfully exported with opset version {opset_version}") return onnx_path except Exception as e: print(f"✗ Failed with opset {opset_version}: {str(e)[:100]}...") continue raise Exception("Failed to export with all attempted opset versions") def optimize_onnx_model(onnx_path: str, optimized_path: str = None): """Optimize ONNX model for inference""" try: from onnxruntime.tools import optimizer if optimized_path is None: optimized_path = onnx_path.replace(".onnx", "_optimized.onnx") # Optimize model opt_model = optimizer.optimize_model( onnx_path, model_type="bert", # Similar architecture num_heads=12, hidden_size=768 ) opt_model.save_model_to_file(optimized_path) print(f"Optimized model saved to: {optimized_path}") return optimized_path except ImportError: print("ONNX Runtime tools not available for optimization") return onnx_path except Exception as e: print(f"Optimization error: {e}") return onnx_path def compare_models(original_converter, onnx_inference, test_audio_path: str): """Compare PyTorch and ONNX model outputs""" print("Comparing PyTorch vs ONNX outputs...") # PyTorch inference torch_result = original_converter.transcribe_to_characters(test_audio_path) # ONNX inference onnx_result = onnx_inference.transcribe(test_audio_path) print(f"PyTorch transcription: {torch_result['character_transcript']}") print(f"ONNX transcription: {onnx_result['transcription']}") # Compare similarity if torch_result['character_transcript'] == onnx_result['transcription']: print("✓ Transcriptions match exactly!") else: print("⚠ Transcriptions differ") if __name__ == "__main__": main()