Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import ( | |
| Wav2Vec2ForCTC, | |
| Wav2Vec2Processor, | |
| AutoProcessor, | |
| AutoModelForCTC, | |
| ) | |
| # import deepspeed | |
| import librosa | |
| import numpy as np | |
| from typing import Optional, List, Union | |
| def get_model_name(model_name: Optional[str] = None) -> str: | |
| """Helper function to get model name with default fallback""" | |
| if model_name is None: | |
| return "facebook/wav2vec2-large-robust-ft-libri-960h" | |
| return model_name | |
| class Wave2Vec2Inference: | |
| def __init__( | |
| self, | |
| model_name: Optional[str] = None, | |
| use_gpu: bool = True, | |
| use_deepspeed: bool = True, | |
| ): | |
| """ | |
| Initialize Wav2Vec2 model for inference with optional DeepSpeed optimization. | |
| Args: | |
| model_name: HuggingFace model name or None for default | |
| use_gpu: Whether to use GPU acceleration | |
| use_deepspeed: Whether to use DeepSpeed optimization | |
| """ | |
| # Get the actual model name using helper function | |
| self.model_name = get_model_name(model_name) | |
| self.use_deepspeed = use_deepspeed | |
| # Auto-detect device | |
| if use_gpu: | |
| if torch.backends.mps.is_available(): | |
| self.device = "mps" | |
| elif torch.cuda.is_available(): | |
| self.device = "cuda" | |
| else: | |
| self.device = "cpu" | |
| else: | |
| self.device = "cpu" | |
| print(f"Using device: {self.device}") | |
| print(f"Loading model: {self.model_name}") | |
| print(f"DeepSpeed enabled: {self.use_deepspeed}") | |
| # Check if model is XLSR and use appropriate processor/model | |
| is_xlsr = "xlsr" in self.model_name.lower() | |
| if is_xlsr: | |
| print("Using Wav2Vec2Processor and Wav2Vec2ForCTC for XLSR model") | |
| self.processor = Wav2Vec2Processor.from_pretrained(self.model_name) | |
| self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name) | |
| else: | |
| print("Using AutoProcessor and AutoModelForCTC") | |
| self.processor = AutoProcessor.from_pretrained(self.model_name) | |
| self.model = AutoModelForCTC.from_pretrained(self.model_name) | |
| # Initialize DeepSpeed if enabled | |
| if self.use_deepspeed: | |
| self._init_deepspeed() | |
| else: | |
| self.model.to(self.device) | |
| self.model.eval() | |
| self.ds_engine = None | |
| # Disable gradients for inference | |
| torch.set_grad_enabled(False) | |
| def _init_deepspeed(self): | |
| """Initialize DeepSpeed inference engine""" | |
| try: | |
| # DeepSpeed configuration based on device | |
| if self.device == "cuda": | |
| ds_config = { | |
| "tensor_parallel": {"tp_size": 1}, | |
| "dtype": torch.float32, | |
| "replace_with_kernel_inject": True, | |
| "enable_cuda_graph": False, | |
| } | |
| else: | |
| ds_config = { | |
| "tensor_parallel": {"tp_size": 1}, | |
| "dtype": torch.float32, | |
| "replace_with_kernel_inject": False, | |
| "enable_cuda_graph": False, | |
| } | |
| print("Initializing DeepSpeed inference engine...") | |
| self.ds_engine = deepspeed.init_inference(self.model, **ds_config) | |
| self.ds_engine.module.to(self.device) | |
| except Exception as e: | |
| print(f"DeepSpeed initialization failed: {e}") | |
| print("Falling back to standard PyTorch inference...") | |
| self.use_deepspeed = False | |
| self.ds_engine = None | |
| self.model.to(self.device) | |
| self.model.eval() | |
| def _get_model(self): | |
| """Get the appropriate model for inference""" | |
| if self.use_deepspeed and self.ds_engine is not None: | |
| return self.ds_engine.module | |
| return self.model | |
| def buffer_to_text( | |
| self, audio_buffer: Union[np.ndarray, torch.Tensor, List] | |
| ) -> str: | |
| """ | |
| Convert audio buffer to text transcription. | |
| Args: | |
| audio_buffer: Audio data as numpy array, tensor, or list | |
| Returns: | |
| str: Transcribed text | |
| """ | |
| if len(audio_buffer) == 0: | |
| return "" | |
| # Convert to tensor | |
| if isinstance(audio_buffer, np.ndarray): | |
| audio_tensor = torch.from_numpy(audio_buffer).float() | |
| elif isinstance(audio_buffer, list): | |
| audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32) | |
| else: | |
| audio_tensor = audio_buffer.float() | |
| # Process audio | |
| inputs = self.processor( | |
| audio_tensor, | |
| sampling_rate=16_000, | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| # Move to device | |
| input_values = inputs.input_values.to(self.device) | |
| attention_mask = ( | |
| inputs.attention_mask.to(self.device) | |
| if "attention_mask" in inputs | |
| else None | |
| ) | |
| # Get the appropriate model | |
| model = self._get_model() | |
| # Inference | |
| with torch.no_grad(): | |
| if attention_mask is not None: | |
| outputs = model(input_values, attention_mask=attention_mask) | |
| else: | |
| outputs = model(input_values) | |
| # Handle different output formats | |
| if hasattr(outputs, "logits"): | |
| logits = outputs.logits | |
| else: | |
| logits = outputs | |
| # Decode | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| if self.device != "cpu": | |
| predicted_ids = predicted_ids.cpu() | |
| transcription = self.processor.batch_decode(predicted_ids)[0] | |
| return transcription.lower().strip() | |
| def file_to_text(self, filename: str) -> str: | |
| """ | |
| Transcribe audio file to text. | |
| Args: | |
| filename: Path to audio file | |
| Returns: | |
| str: Transcribed text | |
| """ | |
| try: | |
| audio_input, _ = librosa.load(filename, sr=16000, dtype=np.float32) | |
| return self.buffer_to_text(audio_input) | |
| except Exception as e: | |
| print(f"Error loading audio file {filename}: {e}") | |
| return "" | |
| def batch_file_to_text(self, filenames: List[str]) -> List[str]: | |
| """ | |
| Transcribe multiple audio files to text. | |
| Args: | |
| filenames: List of audio file paths | |
| Returns: | |
| List[str]: List of transcribed texts | |
| """ | |
| results = [] | |
| for i, filename in enumerate(filenames): | |
| print(f"Processing file {i+1}/{len(filenames)}: {filename}") | |
| transcription = self.file_to_text(filename) | |
| results.append(transcription) | |
| if transcription: | |
| print(f"Transcription: {transcription}") | |
| else: | |
| print("Failed to transcribe") | |
| return results | |
| def transcribe_with_confidence( | |
| self, audio_buffer: Union[np.ndarray, torch.Tensor] | |
| ) -> tuple: | |
| """ | |
| Transcribe audio and return confidence scores. | |
| Args: | |
| audio_buffer: Audio data | |
| Returns: | |
| tuple: (transcription, confidence_scores) | |
| """ | |
| if len(audio_buffer) == 0: | |
| return "", [] | |
| # Convert to tensor | |
| if isinstance(audio_buffer, np.ndarray): | |
| audio_tensor = torch.from_numpy(audio_buffer).float() | |
| else: | |
| audio_tensor = audio_buffer.float() | |
| # Process audio | |
| inputs = self.processor( | |
| audio_tensor, | |
| sampling_rate=16_000, | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| input_values = inputs.input_values.to(self.device) | |
| attention_mask = ( | |
| inputs.attention_mask.to(self.device) | |
| if "attention_mask" in inputs | |
| else None | |
| ) | |
| model = self._get_model() | |
| # Inference | |
| with torch.no_grad(): | |
| if attention_mask is not None: | |
| outputs = model(input_values, attention_mask=attention_mask) | |
| else: | |
| outputs = model(input_values) | |
| if hasattr(outputs, "logits"): | |
| logits = outputs.logits | |
| else: | |
| logits = outputs | |
| # Get probabilities and confidence scores | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| # Calculate confidence as max probability for each prediction | |
| max_probs = torch.max(probs, dim=-1)[0] | |
| confidence_scores = max_probs.cpu().numpy().tolist() | |
| if self.device != "cpu": | |
| predicted_ids = predicted_ids.cpu() | |
| transcription = self.processor.batch_decode(predicted_ids)[0] | |
| return transcription.lower().strip(), confidence_scores | |
| def cleanup(self): | |
| """Clean up resources""" | |
| if hasattr(self, "ds_engine") and self.ds_engine is not None: | |
| del self.ds_engine | |
| if hasattr(self, "model"): | |
| del self.model | |
| if hasattr(self, "processor"): | |
| del self.processor | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| def __del__(self): | |
| """Destructor to clean up resources""" | |
| self.cleanup() | |
| # Example usage | |
| if __name__ == "__main__": | |
| # Initialize with DeepSpeed | |
| asr = Wave2Vec2Inference( | |
| model_name="facebook/wav2vec2-large-robust-ft-libri-960h", | |
| use_gpu=False, | |
| use_deepspeed=False, | |
| ) | |
| # Single file transcription | |
| result = asr.file_to_text("./test_audio/hello_how_are_you_today.wav") | |
| print(f"Transcription: {result}") | |
| # # Batch processing | |
| # files = ["audio1.wav", "audio2.wav", "audio3.wav"] | |
| # batch_results = asr.batch_file_to_text(files) | |
| # # Transcription with confidence scores | |
| # audio_data, _ = librosa.load("path/to/audio.wav", sr=16000) | |
| # transcription, confidence = asr.transcribe_with_confidence(audio_data) | |
| # print(f"Transcription: {transcription}") | |
| # print(f"Average confidence: {np.mean(confidence):.3f}") | |
| # Cleanup | |