Run_code_api / test.py
ABAO77's picture
update: new model xlsr
54a64d4
raw
history blame
15.1 kB
# import torch
# import librosa
# from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
# # Cấu hình
# # MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-english"
# MODEL_ID = "facebook/wav2vec2-large-xlsr-53"
# AUDIO_FILE_PATH = "./hello_how_are_you_today.wav" # Thay đổi đường dẫn này
# # Load model và processor
# processor = Wav2Vec2Processor.from_pretrained(MODEL_ID)
# model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
# def transcribe_audio_file(audio_path):
# """
# Chuyển đổi file audio thành text sử dụng Wav2Vec2
# """
# # Đọc file audio
# try:
# speech_array, sampling_rate = librosa.load(audio_path, sr=16_000)
# print(f"Đã load audio file: {audio_path}")
# print(f"Độ dài audio: {len(speech_array)/16_000:.2f} giây")
# except Exception as e:
# print(f"Lỗi khi đọc file audio: {e}")
# return None
# # Tiền xử lý
# inputs = processor(
# speech_array,
# sampling_rate=16_000,
# return_tensors="pt",
# padding=True
# )
# # Dự đoán
# with torch.no_grad():
# logits = model(
# inputs.input_values,
# attention_mask=inputs.attention_mask
# ).logits
# # Decode kết quả
# predicted_ids = torch.argmax(logits, dim=-1)
# predicted_sentence = processor.batch_decode(predicted_ids)[0]
# return predicted_sentence
# # Test với file audio của bạn
# if __name__ == "__main__":
# # Thay đổi đường dẫn đến file audio của bạn
# audio_files = [
# "./hello_world.wav", # Thay đổi tên file này
# # "another_file.mp3", # Có thể thêm nhiều file
# ]
# for audio_file in audio_files:
# print("=" * 80)
# print(f"Đang xử lý: {audio_file}")
# print("=" * 80)
# prediction = transcribe_audio_file(audio_file)
# if prediction:
# print(f"Kết quả nhận dạng: {prediction}")
# else:
# print("Không thể xử lý file này")
# print()
# # Phiên bản đơn giản hơn - chỉ cần thay đổi đường dẫn file
# def quick_transcribe(audio_path):
# """Phiên bản nhanh để transcribe một file"""
# speech_array, _ = librosa.load(audio_path, sr=16_000)
# inputs = processor(speech_array, sampling_rate=16_000, return_tensors="pt", padding=True)
# with torch.no_grad():
# logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
# predicted_ids = torch.argmax(logits, dim=-1)
# return processor.batch_decode(predicted_ids)[0]
# # Sử dụng nhanh:
# result = quick_transcribe("./hello_how_are_you_today.wav")
# print(result)
import torch
from transformers import (
AutoModelForCTC,
AutoProcessor,
Wav2Vec2Processor,
Wav2Vec2ForCTC,
)
import onnxruntime as rt
import numpy as np
import librosa
import warnings
import os
warnings.filterwarnings("ignore")
# Available Wave2Vec2 models
WAVE2VEC2_MODELS = {
"english_large": "jonatasgrosman/wav2vec2-large-xlsr-53-english",
"multilingual": "facebook/wav2vec2-large-xlsr-53",
"english_960h": "facebook/wav2vec2-large-960h-lv60-self",
"base_english": "facebook/wav2vec2-base-960h",
"large_english": "facebook/wav2vec2-large-960h",
"xlsr_english": "jonatasgrosman/wav2vec2-large-xlsr-53-english",
"xlsr_multilingual": "facebook/wav2vec2-large-xlsr-53"
}
# Default model
DEFAULT_MODEL = "jonatasgrosman/wav2vec2-large-xlsr-53-english"
def get_available_models():
"""Return dictionary of available Wave2Vec2 models"""
return WAVE2VEC2_MODELS.copy()
def get_model_name(model_key=None):
"""
Get model name from key or return default
Args:
model_key: Key from WAVE2VEC2_MODELS or full model name
Returns:
str: Full model name
"""
if model_key is None:
return DEFAULT_MODEL
if model_key in WAVE2VEC2_MODELS:
return WAVE2VEC2_MODELS[model_key]
# If it's already a full model name, return as is
return model_key
class Wave2Vec2Inference:
def __init__(self, model_name=None, use_gpu=True):
# Get the actual model name using helper function
self.model_name = get_model_name(model_name)
# Auto-detect device
if use_gpu:
if torch.backends.mps.is_available():
self.device = "mps"
elif torch.cuda.is_available():
self.device = "cuda"
else:
self.device = "cpu"
else:
self.device = "cpu"
print(f"Using device: {self.device}")
print(f"Loading model: {self.model_name}")
# Check if model is XLSR and use appropriate processor/model
is_xlsr = "xlsr" in self.model_name.lower()
if is_xlsr:
print("Using Wav2Vec2Processor and Wav2Vec2ForCTC for XLSR model")
self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name)
else:
print("Using AutoProcessor and AutoModelForCTC")
self.processor = AutoProcessor.from_pretrained(self.model_name)
self.model = AutoModelForCTC.from_pretrained(self.model_name)
self.model.to(self.device)
self.model.eval()
# Disable gradients for inference
torch.set_grad_enabled(False)
def buffer_to_text(self, audio_buffer):
if len(audio_buffer) == 0:
return ""
# Convert to tensor
if isinstance(audio_buffer, np.ndarray):
audio_tensor = torch.from_numpy(audio_buffer).float()
else:
audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32)
# Process audio
inputs = self.processor(
audio_tensor,
sampling_rate=16_000,
return_tensors="pt",
padding=True,
)
# Move to device
input_values = inputs.input_values.to(self.device)
attention_mask = (
inputs.attention_mask.to(self.device)
if "attention_mask" in inputs
else None
)
# Inference
with torch.no_grad():
if attention_mask is not None:
logits = self.model(input_values, attention_mask=attention_mask).logits
else:
logits = self.model(input_values).logits
# Decode
predicted_ids = torch.argmax(logits, dim=-1)
if self.device != "cpu":
predicted_ids = predicted_ids.cpu()
transcription = self.processor.batch_decode(predicted_ids)[0]
return transcription.lower().strip()
def file_to_text(self, filename):
try:
audio_input, _ = librosa.load(filename, sr=16000, dtype=np.float32)
return self.buffer_to_text(audio_input)
except Exception as e:
print(f"Error loading audio file {filename}: {e}")
return ""
class Wave2Vec2ONNXInference:
def __init__(self, model_name=None, onnx_path=None, use_gpu=True):
# Get the actual model name using helper function
self.model_name = get_model_name(model_name)
print(f"Loading ONNX model: {self.model_name}")
# Always use Wav2Vec2Processor for ONNX (works for all models)
self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
# Setup ONNX Runtime
options = rt.SessionOptions()
options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
# Choose providers based on GPU availability
providers = []
if use_gpu and rt.get_available_providers():
if "CUDAExecutionProvider" in rt.get_available_providers():
providers.append("CUDAExecutionProvider")
providers.append("CPUExecutionProvider")
self.model = rt.InferenceSession(onnx_path, options, providers=providers)
self.input_name = self.model.get_inputs()[0].name
print(f"ONNX model loaded with providers: {self.model.get_providers()}")
def buffer_to_text(self, audio_buffer):
if len(audio_buffer) == 0:
return ""
# Convert to tensor
if isinstance(audio_buffer, np.ndarray):
audio_tensor = torch.from_numpy(audio_buffer).float()
else:
audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32)
# Process audio
inputs = self.processor(
audio_tensor,
sampling_rate=16_000,
return_tensors="np",
padding=True,
)
# ONNX inference
input_values = inputs.input_values.astype(np.float32)
onnx_outputs = self.model.run(None, {self.input_name: input_values})[0]
# Decode
prediction = np.argmax(onnx_outputs, axis=-1)
transcription = self.processor.decode(prediction.squeeze().tolist())
return transcription.lower().strip()
def file_to_text(self, filename):
try:
audio_input, _ = librosa.load(filename, sr=16000, dtype=np.float32)
return self.buffer_to_text(audio_input)
except Exception as e:
print(f"Error loading audio file {filename}: {e}")
return ""
def convert_to_onnx(model_id_or_path, onnx_model_name):
"""Convert PyTorch model to ONNX format"""
print(f"Converting {model_id_or_path} to ONNX...")
model = Wav2Vec2ForCTC.from_pretrained(model_id_or_path)
model.eval()
# Create dummy input
audio_len = 250000
dummy_input = torch.randn(1, audio_len, requires_grad=True)
torch.onnx.export(
model,
dummy_input,
onnx_model_name,
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {1: "audio_len"},
"output": {1: "audio_len"},
},
)
print(f"ONNX model saved to: {onnx_model_name}")
def quantize_onnx_model(onnx_model_path, quantized_model_path):
"""Quantize ONNX model for faster inference"""
print("Starting quantization...")
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic(
onnx_model_path, quantized_model_path, weight_type=QuantType.QUInt8
)
print(f"Quantized model saved to: {quantized_model_path}")
def export_to_onnx(model_name, quantize=False):
"""
Export model to ONNX format with optional quantization
Args:
model_name: HuggingFace model name
quantize: Whether to also create quantized version
Returns:
tuple: (onnx_path, quantized_path or None)
"""
onnx_filename = f"{model_name.split('/')[-1]}.onnx"
convert_to_onnx(model_name, onnx_filename)
quantized_path = None
if quantize:
quantized_path = onnx_filename.replace(".onnx", ".quantized.onnx")
quantize_onnx_model(onnx_filename, quantized_path)
return onnx_filename, quantized_path
def create_inference(
model_name=None, use_onnx=False, onnx_path=None, use_gpu=True, use_onnx_quantize=False
):
"""
Create optimized inference instance
Args:
model_name: Model key from WAVE2VEC2_MODELS or full HuggingFace model name (default: uses DEFAULT_MODEL)
use_onnx: Whether to use ONNX runtime
onnx_path: Path to ONNX model file
use_gpu: Whether to use GPU if available
use_onnx_quantize: Whether to use quantized ONNX model
Returns:
Inference instance
"""
# Get the actual model name
actual_model_name = get_model_name(model_name)
if use_onnx:
if not onnx_path or not os.path.exists(onnx_path):
# Convert to ONNX if path not provided or doesn't exist
onnx_filename = f"{actual_model_name.split('/')[-1]}.onnx"
convert_to_onnx(actual_model_name, onnx_filename)
onnx_path = onnx_filename
if use_onnx_quantize:
quantized_path = onnx_path.replace(".onnx", ".quantized.onnx")
if not os.path.exists(quantized_path):
quantize_onnx_model(onnx_path, quantized_path)
onnx_path = quantized_path
print(f"Using ONNX model: {onnx_path}")
return Wave2Vec2ONNXInference(model_name, onnx_path, use_gpu)
else:
print("Using PyTorch model")
return Wave2Vec2Inference(model_name, use_gpu)
if __name__ == "__main__":
import time
# Display available models
print("Available Wave2Vec2 models:")
for key, model_name in get_available_models().items():
print(f" {key}: {model_name}")
print(f"\nDefault model: {DEFAULT_MODEL}")
print()
# Test with different models
test_models = ["english_large", "multilingual", "english_960h"]
test_file = "./hello_how_are_you_today.wav"
if not os.path.exists(test_file):
print(f"Test file {test_file} not found. Please provide a valid audio file.")
print("Creating example usage without actual file...")
# Example usage without file
print("\n=== Example Usage ===")
# Using default model
print("1. Using default model:")
asr_default = create_inference()
print(f" Model loaded: {asr_default.model_name}")
# Using model key
print("\n2. Using model key 'english_large':")
asr_key = create_inference("english_large")
print(f" Model loaded: {asr_key.model_name}")
# Using full model name
print("\n3. Using full model name:")
asr_full = create_inference("facebook/wav2vec2-base-960h")
print(f" Model loaded: {asr_full.model_name}")
exit(0)
# Test different model configurations
for model_key in test_models:
print(f"\n=== Testing model: {model_key} ===")
# Test different configurations
configs = [
{"use_onnx": False, "use_gpu": True},
{"use_onnx": True, "use_gpu": True, "use_onnx_quantize": False},
]
for config in configs:
print(f"\nConfig: {config}")
# Create inference instance with model selection
asr = create_inference(model_key, **config)
# Warm up
asr.file_to_text(test_file)
# Test performance
times = []
for i in range(3):
start_time = time.time()
text = asr.file_to_text(test_file)
end_time = time.time()
execution_time = end_time - start_time
times.append(execution_time)
print(f"Run {i+1}: {execution_time:.3f}s - {text[:50]}...")
avg_time = sum(times) / len(times)
print(f"Average time: {avg_time:.3f}s")