File size: 4,335 Bytes
3f3ad46
 
 
36087fa
3f3ad46
36087fa
 
 
 
 
 
 
 
3f3ad46
36087fa
3f3ad46
36087fa
 
3f3ad46
 
47f9dbe
 
36087fa
47f9dbe
 
 
 
 
 
 
36087fa
 
3f3ad46
47f9dbe
 
 
 
 
 
 
 
 
3f3ad46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47f9dbe
3f3ad46
 
 
 
47f9dbe
 
 
 
 
 
3f3ad46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487e129
3f3ad46
 
 
 
36087fa
 
487e129
 
 
 
 
 
 
 
3f3ad46
36087fa
3f3ad46
 
36087fa
3f3ad46
36087fa
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from typing import Optional, Union

import torch
import transformers
from transformers import ProcessorMixin

try:
    from .asr_config import ASRConfig
except ImportError:
    from asr_config import ASRConfig  # type: ignore[no-redef]


class ASRProcessor(ProcessorMixin):
    """Processor for Whisper-based ASR models."""

    attributes = ["feature_extractor", "tokenizer"]
    feature_extractor_class = "AutoFeatureExtractor"
    tokenizer_class = "AutoTokenizer"
    AUDIO_TOKEN = "<audio>"
    TRANSCRIBE_PROMPT = "Transcribe: "
    # Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
    DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]

    def __init__(
        self,
        feature_extractor,
        tokenizer,
        projector=None,
        encoder_conv_layers: Optional[list] = None,
    ):
        self.feature_extractor = feature_extractor
        self.tokenizer = tokenizer
        self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
        self.projector = projector
        self.encoder_conv_layers = encoder_conv_layers or self.DEFAULT_ENCODER_CONV_LAYERS

    def _compute_encoder_output_length(self, mel_length: int) -> int:
        """Compute encoder output length using conv layer formulas."""
        length = mel_length
        for padding, kernel_size, stride in self.encoder_conv_layers:
            length = (length + 2 * padding - (kernel_size - 1) - 1) // stride + 1
        return length

    def __call__(
        self,
        audio: Optional[Union[list, "torch.Tensor"]] = None,
        text: Optional[str] = None,
        system_prompt: Optional[str] = None,
        return_tensors: str = "pt",
        **kwargs,
    ) -> dict:
        """Process audio and text inputs for inference.

        Args:
            audio: Raw audio waveform(s)
            text: Target transcription (optional, for training - but use DataCollator instead)
            system_prompt: Optional system prompt
            return_tensors: Return format ("pt" for PyTorch)

        Returns:
            Dict with input_features, input_ids, attention_mask
        """
        result = {}

        # Process audio
        if audio is not None:
            audio_inputs = self.feature_extractor(
                audio,
                sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000),
                return_attention_mask=True,
                return_tensors=return_tensors,
                **kwargs,
            )
            result["input_features"] = audio_inputs["input_features"]
            result["audio_attention_mask"] = audio_inputs["attention_mask"]

            # Use actual audio length (from attention mask) for token count
            real_mel_len = int(audio_inputs["attention_mask"].sum(dim=-1).max().item())
            encoder_output_len = self._compute_encoder_output_length(real_mel_len)
            num_audio_tokens = self.projector.get_output_length(encoder_output_len)
        else:
            num_audio_tokens = 0

        # Build prompt with audio token placeholders
        user_content = self.TRANSCRIBE_PROMPT
        if num_audio_tokens > 0:
            user_content += self.AUDIO_TOKEN * num_audio_tokens

        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": user_content})
        if text is not None:
            messages.append({"role": "assistant", "content": text})

        # Tokenize
        tokenized = self.tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=(text is None),
            return_tensors=return_tensors,
        )

        # Handle both tensor and BatchEncoding returns
        if isinstance(tokenized, torch.Tensor):
            input_ids = tokenized
        else:
            # BatchEncoding or dict-like object
            input_ids = tokenized["input_ids"] if "input_ids" in tokenized else tokenized.input_ids

        if input_ids.dim() == 1:
            input_ids = input_ids.unsqueeze(0)

        result["input_ids"] = input_ids
        result["attention_mask"] = torch.ones_like(input_ids)

        return result


ASRProcessor.register_for_auto_class()
transformers.AutoProcessor.register(ASRConfig, ASRProcessor)