tiny-audio-moe-shared / asr_processing.py
mazesmazes's picture
Training in progress - step 500
93f2561 verified
from typing import Optional, Union
import torch
import transformers
from transformers import ProcessorMixin
try:
from .asr_config import ASRConfig
except ImportError:
from asr_config import ASRConfig # type: ignore[no-redef]
class ASRProcessor(ProcessorMixin):
"""Processor for Whisper-based ASR models."""
attributes = ["feature_extractor", "tokenizer"]
feature_extractor_class = "AutoFeatureExtractor"
tokenizer_class = "AutoTokenizer"
AUDIO_TOKEN = "<audio>"
TRANSCRIBE_PROMPT = "Transcribe: "
def __init__(self, feature_extractor, tokenizer):
self.feature_extractor = feature_extractor
self.tokenizer = tokenizer
self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
def __call__(
self,
audio: Optional[Union[list, "torch.Tensor"]] = None,
text: Optional[str] = None,
system_prompt: Optional[str] = None,
return_tensors: str = "pt",
**kwargs,
) -> dict:
"""Process audio and text inputs for inference.
Args:
audio: Raw audio waveform(s)
text: Target transcription (optional, for training - but use DataCollator instead)
system_prompt: Optional system prompt
return_tensors: Return format ("pt" for PyTorch)
Returns:
Dict with input_features, input_ids, attention_mask
"""
result = {}
# Process audio
if audio is not None:
audio_inputs = self.feature_extractor(
audio,
sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000),
return_tensors=return_tensors,
**kwargs,
)
result["input_features"] = audio_inputs["input_features"]
# Whisper encoder output length = mel_len // 2 (stride-2 conv)
num_audio_tokens = audio_inputs["input_features"].shape[-1] // 2
else:
num_audio_tokens = 0
# Build prompt with audio token placeholders
user_content = self.TRANSCRIBE_PROMPT
if num_audio_tokens > 0:
user_content += self.AUDIO_TOKEN * num_audio_tokens
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": user_content})
if text is not None:
messages.append({"role": "assistant", "content": text})
# Tokenize
input_ids = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=(text is None),
return_tensors=return_tensors,
)
if isinstance(input_ids, torch.Tensor) and input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
result["input_ids"] = input_ids
result["attention_mask"] = torch.ones_like(input_ids)
return result
ASRProcessor.register_for_auto_class()
transformers.AutoProcessor.register(ASRConfig, ASRProcessor)