File size: 4,335 Bytes
3f3ad46 36087fa 3f3ad46 36087fa 3f3ad46 36087fa 3f3ad46 36087fa 3f3ad46 47f9dbe 36087fa 47f9dbe 36087fa 3f3ad46 47f9dbe 3f3ad46 47f9dbe 3f3ad46 47f9dbe 3f3ad46 487e129 3f3ad46 36087fa 487e129 3f3ad46 36087fa 3f3ad46 36087fa 3f3ad46 36087fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from typing import Optional, Union
import torch
import transformers
from transformers import ProcessorMixin
try:
from .asr_config import ASRConfig
except ImportError:
from asr_config import ASRConfig # type: ignore[no-redef]
class ASRProcessor(ProcessorMixin):
"""Processor for Whisper-based ASR models."""
attributes = ["feature_extractor", "tokenizer"]
feature_extractor_class = "AutoFeatureExtractor"
tokenizer_class = "AutoTokenizer"
AUDIO_TOKEN = "<audio>"
TRANSCRIBE_PROMPT = "Transcribe: "
# Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
def __init__(
self,
feature_extractor,
tokenizer,
projector=None,
encoder_conv_layers: Optional[list] = None,
):
self.feature_extractor = feature_extractor
self.tokenizer = tokenizer
self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
self.projector = projector
self.encoder_conv_layers = encoder_conv_layers or self.DEFAULT_ENCODER_CONV_LAYERS
def _compute_encoder_output_length(self, mel_length: int) -> int:
"""Compute encoder output length using conv layer formulas."""
length = mel_length
for padding, kernel_size, stride in self.encoder_conv_layers:
length = (length + 2 * padding - (kernel_size - 1) - 1) // stride + 1
return length
def __call__(
self,
audio: Optional[Union[list, "torch.Tensor"]] = None,
text: Optional[str] = None,
system_prompt: Optional[str] = None,
return_tensors: str = "pt",
**kwargs,
) -> dict:
"""Process audio and text inputs for inference.
Args:
audio: Raw audio waveform(s)
text: Target transcription (optional, for training - but use DataCollator instead)
system_prompt: Optional system prompt
return_tensors: Return format ("pt" for PyTorch)
Returns:
Dict with input_features, input_ids, attention_mask
"""
result = {}
# Process audio
if audio is not None:
audio_inputs = self.feature_extractor(
audio,
sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000),
return_attention_mask=True,
return_tensors=return_tensors,
**kwargs,
)
result["input_features"] = audio_inputs["input_features"]
result["audio_attention_mask"] = audio_inputs["attention_mask"]
# Use actual audio length (from attention mask) for token count
real_mel_len = int(audio_inputs["attention_mask"].sum(dim=-1).max().item())
encoder_output_len = self._compute_encoder_output_length(real_mel_len)
num_audio_tokens = self.projector.get_output_length(encoder_output_len)
else:
num_audio_tokens = 0
# Build prompt with audio token placeholders
user_content = self.TRANSCRIBE_PROMPT
if num_audio_tokens > 0:
user_content += self.AUDIO_TOKEN * num_audio_tokens
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": user_content})
if text is not None:
messages.append({"role": "assistant", "content": text})
# Tokenize
tokenized = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=(text is None),
return_tensors=return_tensors,
)
# Handle both tensor and BatchEncoding returns
if isinstance(tokenized, torch.Tensor):
input_ids = tokenized
else:
# BatchEncoding or dict-like object
input_ids = tokenized["input_ids"] if "input_ids" in tokenized else tokenized.input_ids
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
result["input_ids"] = input_ids
result["attention_mask"] = torch.ones_like(input_ids)
return result
ASRProcessor.register_for_auto_class()
transformers.AutoProcessor.register(ASRConfig, ASRProcessor)
|