Run_code_api / inference.py
ABAO77's picture
add deepspeed
5d88ac1
import torch
from transformers import (
Wav2Vec2ForCTC,
Wav2Vec2Processor,
AutoProcessor,
AutoModelForCTC,
)
# import deepspeed
import librosa
import numpy as np
from typing import Optional, List, Union
def get_model_name(model_name: Optional[str] = None) -> str:
"""Helper function to get model name with default fallback"""
if model_name is None:
return "facebook/wav2vec2-large-robust-ft-libri-960h"
return model_name
class Wave2Vec2Inference:
def __init__(
self,
model_name: Optional[str] = None,
use_gpu: bool = True,
use_deepspeed: bool = True,
):
"""
Initialize Wav2Vec2 model for inference with optional DeepSpeed optimization.
Args:
model_name: HuggingFace model name or None for default
use_gpu: Whether to use GPU acceleration
use_deepspeed: Whether to use DeepSpeed optimization
"""
# Get the actual model name using helper function
self.model_name = get_model_name(model_name)
self.use_deepspeed = use_deepspeed
# Auto-detect device
if use_gpu:
if torch.backends.mps.is_available():
self.device = "mps"
elif torch.cuda.is_available():
self.device = "cuda"
else:
self.device = "cpu"
else:
self.device = "cpu"
print(f"Using device: {self.device}")
print(f"Loading model: {self.model_name}")
print(f"DeepSpeed enabled: {self.use_deepspeed}")
# Check if model is XLSR and use appropriate processor/model
is_xlsr = "xlsr" in self.model_name.lower()
if is_xlsr:
print("Using Wav2Vec2Processor and Wav2Vec2ForCTC for XLSR model")
self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name)
else:
print("Using AutoProcessor and AutoModelForCTC")
self.processor = AutoProcessor.from_pretrained(self.model_name)
self.model = AutoModelForCTC.from_pretrained(self.model_name)
# Initialize DeepSpeed if enabled
if self.use_deepspeed:
self._init_deepspeed()
else:
self.model.to(self.device)
self.model.eval()
self.ds_engine = None
# Disable gradients for inference
torch.set_grad_enabled(False)
def _init_deepspeed(self):
"""Initialize DeepSpeed inference engine"""
try:
# DeepSpeed configuration based on device
if self.device == "cuda":
ds_config = {
"tensor_parallel": {"tp_size": 1},
"dtype": torch.float32,
"replace_with_kernel_inject": True,
"enable_cuda_graph": False,
}
else:
ds_config = {
"tensor_parallel": {"tp_size": 1},
"dtype": torch.float32,
"replace_with_kernel_inject": False,
"enable_cuda_graph": False,
}
print("Initializing DeepSpeed inference engine...")
self.ds_engine = deepspeed.init_inference(self.model, **ds_config)
self.ds_engine.module.to(self.device)
except Exception as e:
print(f"DeepSpeed initialization failed: {e}")
print("Falling back to standard PyTorch inference...")
self.use_deepspeed = False
self.ds_engine = None
self.model.to(self.device)
self.model.eval()
def _get_model(self):
"""Get the appropriate model for inference"""
if self.use_deepspeed and self.ds_engine is not None:
return self.ds_engine.module
return self.model
def buffer_to_text(
self, audio_buffer: Union[np.ndarray, torch.Tensor, List]
) -> str:
"""
Convert audio buffer to text transcription.
Args:
audio_buffer: Audio data as numpy array, tensor, or list
Returns:
str: Transcribed text
"""
if len(audio_buffer) == 0:
return ""
# Convert to tensor
if isinstance(audio_buffer, np.ndarray):
audio_tensor = torch.from_numpy(audio_buffer).float()
elif isinstance(audio_buffer, list):
audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32)
else:
audio_tensor = audio_buffer.float()
# Process audio
inputs = self.processor(
audio_tensor,
sampling_rate=16_000,
return_tensors="pt",
padding=True,
)
# Move to device
input_values = inputs.input_values.to(self.device)
attention_mask = (
inputs.attention_mask.to(self.device)
if "attention_mask" in inputs
else None
)
# Get the appropriate model
model = self._get_model()
# Inference
with torch.no_grad():
if attention_mask is not None:
outputs = model(input_values, attention_mask=attention_mask)
else:
outputs = model(input_values)
# Handle different output formats
if hasattr(outputs, "logits"):
logits = outputs.logits
else:
logits = outputs
# Decode
predicted_ids = torch.argmax(logits, dim=-1)
if self.device != "cpu":
predicted_ids = predicted_ids.cpu()
transcription = self.processor.batch_decode(predicted_ids)[0]
return transcription.lower().strip()
def file_to_text(self, filename: str) -> str:
"""
Transcribe audio file to text.
Args:
filename: Path to audio file
Returns:
str: Transcribed text
"""
try:
audio_input, _ = librosa.load(filename, sr=16000, dtype=np.float32)
return self.buffer_to_text(audio_input)
except Exception as e:
print(f"Error loading audio file {filename}: {e}")
return ""
def batch_file_to_text(self, filenames: List[str]) -> List[str]:
"""
Transcribe multiple audio files to text.
Args:
filenames: List of audio file paths
Returns:
List[str]: List of transcribed texts
"""
results = []
for i, filename in enumerate(filenames):
print(f"Processing file {i+1}/{len(filenames)}: {filename}")
transcription = self.file_to_text(filename)
results.append(transcription)
if transcription:
print(f"Transcription: {transcription}")
else:
print("Failed to transcribe")
return results
def transcribe_with_confidence(
self, audio_buffer: Union[np.ndarray, torch.Tensor]
) -> tuple:
"""
Transcribe audio and return confidence scores.
Args:
audio_buffer: Audio data
Returns:
tuple: (transcription, confidence_scores)
"""
if len(audio_buffer) == 0:
return "", []
# Convert to tensor
if isinstance(audio_buffer, np.ndarray):
audio_tensor = torch.from_numpy(audio_buffer).float()
else:
audio_tensor = audio_buffer.float()
# Process audio
inputs = self.processor(
audio_tensor,
sampling_rate=16_000,
return_tensors="pt",
padding=True,
)
input_values = inputs.input_values.to(self.device)
attention_mask = (
inputs.attention_mask.to(self.device)
if "attention_mask" in inputs
else None
)
model = self._get_model()
# Inference
with torch.no_grad():
if attention_mask is not None:
outputs = model(input_values, attention_mask=attention_mask)
else:
outputs = model(input_values)
if hasattr(outputs, "logits"):
logits = outputs.logits
else:
logits = outputs
# Get probabilities and confidence scores
probs = torch.nn.functional.softmax(logits, dim=-1)
predicted_ids = torch.argmax(logits, dim=-1)
# Calculate confidence as max probability for each prediction
max_probs = torch.max(probs, dim=-1)[0]
confidence_scores = max_probs.cpu().numpy().tolist()
if self.device != "cpu":
predicted_ids = predicted_ids.cpu()
transcription = self.processor.batch_decode(predicted_ids)[0]
return transcription.lower().strip(), confidence_scores
def cleanup(self):
"""Clean up resources"""
if hasattr(self, "ds_engine") and self.ds_engine is not None:
del self.ds_engine
if hasattr(self, "model"):
del self.model
if hasattr(self, "processor"):
del self.processor
torch.cuda.empty_cache() if torch.cuda.is_available() else None
def __del__(self):
"""Destructor to clean up resources"""
self.cleanup()
# Example usage
if __name__ == "__main__":
# Initialize with DeepSpeed
asr = Wave2Vec2Inference(
model_name="facebook/wav2vec2-large-robust-ft-libri-960h",
use_gpu=False,
use_deepspeed=False,
)
# Single file transcription
result = asr.file_to_text("./test_audio/hello_how_are_you_today.wav")
print(f"Transcription: {result}")
# # Batch processing
# files = ["audio1.wav", "audio2.wav", "audio3.wav"]
# batch_results = asr.batch_file_to_text(files)
# # Transcription with confidence scores
# audio_data, _ = librosa.load("path/to/audio.wav", sr=16000)
# transcription, confidence = asr.transcribe_with_confidence(audio_data)
# print(f"Transcription: {transcription}")
# print(f"Average confidence: {np.mean(confidence):.3f}")
# Cleanup