File size: 18,071 Bytes
1a05ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
"""Minimal Whisper-VQ encoder for MossSpeech codec.

This file provides only the components used by
`MossSpeechCodec/modeling_moss_speech_codec.py` during inference:
 - vector quantization helper
 - causal conv for streaming
 - SDPA attention for encoder
 - WhisperVQEncoderLayer and WhisperVQEncoder
"""

from dataclasses import dataclass
from typing import Optional, Tuple

import math
import torch
from torch import nn

from transformers.activations import ACT2FN
from transformers.cache_utils import EncoderDecoderCache
from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_utils import PreTrainedModel

from .utils import WhisperVQConfig


@dataclass
class QuantizedBaseModelOutput(BaseModelOutput):
    quantized_token_ids: Optional[torch.LongTensor] = None


@dataclass
class QuantizedBaseModelOutputWithCache(QuantizedBaseModelOutput):
    past_key_value: Optional[EncoderDecoderCache] = None
    conv1_cache: Optional[torch.Tensor] = None
    conv2_cache: Optional[torch.Tensor] = None


def vector_quantize(inputs: torch.Tensor, codebook: torch.Tensor):
    embedding_size = codebook.size(1)
    inputs_flatten = inputs.reshape(-1, embedding_size)
    codebook_sqr = torch.sum(codebook ** 2, dim=1)
    inputs_sqr = torch.sum(inputs_flatten ** 2, dim=1, keepdim=True)
    distances = torch.addmm(codebook_sqr + inputs_sqr, inputs_flatten, codebook.t(), alpha=-2.0, beta=1.0)
    _, indices_flatten = torch.min(distances, dim=1)
    codes_flatten = torch.index_select(codebook, dim=0, index=indices_flatten)
    return codes_flatten.view_as(inputs), indices_flatten, distances


class CausalConv1d(nn.Conv1d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, **kwargs):
        super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=0, dilation=dilation, groups=groups, bias=bias, **kwargs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        causal_padding = (self.kernel_size[0] - 1) * self.dilation[0]
        x = nn.functional.pad(x, (causal_padding, 0))
        return super().forward(x)

    def forward_causal(self, inp: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        k, d = self.kernel_size[0], self.dilation[0]
        if conv_cache is None:
            inp_pad = nn.functional.pad(inp, (k - 1, 0))
        else:
            inp_pad = torch.cat((conv_cache, inp), dim=-1)
        out = super().forward(inp_pad)
        new_cache = inp_pad[:, :, -(k - 1) * d :]
        return out, new_cache


def _prepare_4d_causal_attention_mask_with_cache_position(attention_mask, sequence_length, target_length, cache_position=None, dtype=torch.float32, device=None, min_dtype=None, batch_size=None):
    if batch_size is None:
        batch_size = attention_mask.shape[0] if attention_mask is not None else 1
    if device is None:
        device = attention_mask.device if attention_mask is not None else None
    if min_dtype is None:
        min_dtype = torch.finfo(dtype).min
    if cache_position is None:
        target_length = sequence_length
        sequence_length = target_length
        if attention_mask is not None:
            mask_length = attention_mask.shape[-1]
            target_length = mask_length
    causal_mask = attention_mask
    if causal_mask is None:
        causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
        causal_mask = torch.triu(causal_mask, diagonal=1)
        causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
        causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
    else:
        causal_mask = causal_mask[:, None, None, :].expand(batch_size, 1, sequence_length, target_length).to(dtype)
        causal_mask = (1.0 - causal_mask) * min_dtype
    if attention_mask is not None:
        mask_length = attention_mask.shape[-1]
        padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
        padding_mask = padding_mask == 0
        causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype)
    return causal_mask


class WhisperAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, is_causal: bool = False, layer_idx: Optional[int] = None, config: Optional[WhisperVQConfig] = None):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        self.config = config
        self.is_causal = is_causal
        self.layer_idx = layer_idx
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()


class WhisperSdpaAttention(WhisperAttention):
    def forward(self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[EncoderDecoderCache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.Tensor] = None):
        bsz, tgt_len, _ = hidden_states.size()
        query_states = self.q_proj(hidden_states)
        query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

        is_cross_attention = key_value_states is not None
        current_states = key_value_states if is_cross_attention else hidden_states
        key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
        value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

        causal_mask = attention_mask
        sign = False
        if self.is_causal and causal_mask is None and tgt_len > 1:
            if cache_position is not None:
                dtype, device = query_states.dtype, query_states.device
                min_dtype = torch.finfo(dtype).min
                causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(None, query_states.shape[-2], key_states.shape[-2], cache_position=cache_position, dtype=dtype, device=device, min_dtype=min_dtype, batch_size=query_states.shape[0])
            else:
                sign = True

        attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=sign)
        attn_output = attn_output.transpose(1, 2).reshape(bsz, tgt_len, -1).contiguous()
        attn_output = self.out_proj(attn_output)
        return attn_output, None, None


WHISPER_ATTENTION_CLASSES = {
    "sdpa": WhisperSdpaAttention,
}


class WhisperVQEncoderLayer(nn.Module):
    def __init__(self, config: WhisperVQConfig, is_causal=True, layer_idx=None):
        super().__init__()
        self.embed_dim = config.d_model
        self.kv_cache = True
        impl = getattr(config, "_attn_implementation", "sdpa")
        if impl not in WHISPER_ATTENTION_CLASSES:
            impl = "sdpa"
        self.self_attn = WHISPER_ATTENTION_CLASSES[impl](embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, is_causal=is_causal, layer_idx=layer_idx, config=config)
        self.is_causal = is_causal
        if self.is_causal:
            self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.dropout = config.dropout
        self.activation_fn = ACT2FN[config.activation_function]
        self.activation_dropout = config.activation_dropout
        self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
        self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward_causal(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, past_key_value: Optional[EncoderDecoderCache] = None, cache_position: Optional[torch.LongTensor] = None):
        residual = hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)
        hidden_states, self_attn_weights, present_key_value = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask if not self.is_causal else None, layer_head_mask=layer_head_mask, output_attentions=output_attentions, past_key_value=past_key_value, use_cache=self.kv_cache, cache_position=cache_position)
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
        hidden_states = self.fc2(hidden_states)
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)
        if output_attentions:
            outputs += (self_attn_weights,)
        if self.kv_cache:
            outputs += (present_key_value,)
        return outputs, cache_position


class WhisperPreTrainedModel(PreTrainedModel):
    config_class = WhisperVQConfig
    base_model_prefix = "model"
    main_input_name = "input_features"

    def _init_weights(self, module):
        std = self.config.init_std
        if isinstance(module, (nn.Linear, nn.Conv1d)):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, WhisperVQEncoder):
            with torch.no_grad():
                embed_positions = module.embed_positions.weight
                embed_positions.copy_(sinusoids(*embed_positions.shape))


def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor:
    if channels % 2 != 0:
        raise ValueError("channels must be even for sinusoidal positional embeddings")
    log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
    return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)


class WhisperVQEncoder(WhisperPreTrainedModel):
    def __init__(self, config: WhisperVQConfig):
        super().__init__(config)
        self.config = config
        self.dropout = config.dropout
        self.layerdrop = config.encoder_layerdrop
        embed_dim = config.d_model
        self.num_mel_bins = config.num_mel_bins
        self.padding_idx = config.pad_token_id
        self.max_source_positions = config.max_source_positions
        if config.encoder_causal_convolution:
            conv_class = CausalConv1d
        else:
            conv_class = nn.Conv1d
        self.conv1 = conv_class(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
        self.conv2 = conv_class(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
        self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
        self.embed_positions.requires_grad_(False)
        if config.quantize_encoder_only:
            self.layers = nn.ModuleList([WhisperVQEncoderLayer(config, is_causal=config.encoder_causal_attention or config.quantize_causal_encoder, layer_idx=i) for i in range(config.quantize_position)])
        else:
            self.layers = nn.ModuleList([WhisperVQEncoderLayer(config, is_causal=config.encoder_causal_attention or (config.quantize_causal_encoder and layer_id < config.quantize_position), layer_idx=layer_id) for layer_id in range(config.encoder_layers)])
            self.layer_norm = nn.LayerNorm(config.d_model)

        self.pooling_layer = None
        if config.pooling_kernel_size is not None:
            self.pooling_layer = nn.AvgPool1d(kernel_size=config.pooling_kernel_size) if config.pooling_type == "avg" else nn.MaxPool1d(kernel_size=config.pooling_kernel_size)

        self.codebook = None
        self.embed_positions2 = None
        if config.quantize_vocab_size is not None:
            self.codebook = nn.Embedding(config.quantize_vocab_size, config.d_model)
            pos2_len = self.max_source_positions // max(int(config.pooling_kernel_size or 1), 1)
            self.embed_positions2 = nn.Embedding(pos2_len, config.d_model)
            self.embed_positions2.requires_grad_(False)

        self.post_init()

    def forward(self, input_features: torch.FloatTensor, attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, past_key_values: Optional[EncoderDecoderCache] = None, cache_position: Optional[torch.LongTensor] = None, quantized_token_ids: Optional[torch.LongTensor] = None, conv1_cache: Optional[torch.Tensor] = None, conv2_cache: Optional[torch.Tensor] = None):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        device = input_features.device
        if input_features.dim() != 3:
            raise ValueError("`input_features` should be (batch, feature_size, seq_len)")

        if input_features.shape[-1] % 2 == 1:
            input_features = nn.functional.pad(input_features, (0, 1))
        if input_features.shape[1] != self.num_mel_bins:
            raise ValueError(f"Expected {self.num_mel_bins} mel bins, got {input_features.shape[1]}")

        if isinstance(self.conv1, CausalConv1d):
            conv1_output, new_conv1_cache = self.conv1.forward_causal(input_features, conv1_cache)
        else:
            conv1_output = self.conv1(input_features)
            new_conv1_cache = None
        x = nn.functional.gelu(conv1_output)
        if isinstance(self.conv2, CausalConv1d):
            conv2_output, new_conv2_cache = self.conv2.forward_causal(x, conv2_cache)
        else:
            conv2_output = self.conv2(x)
            new_conv2_cache = None
        x = nn.functional.gelu(conv2_output)
        x = x.permute(0, 2, 1)
        batch_size, seq_len, _ = x.shape
        if attention_mask is not None:
            attention_mask = attention_mask[:, :: self.conv1.stride[0] * self.conv2.stride[0]]
        if cache_position is None:
            cache_position = torch.arange(0, seq_len, device=device)
        embed_pos = self.embed_positions.weight
        hidden_states = x + embed_pos[cache_position]
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

        encoder_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
        if past_key_values is None:
            past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
        for idx, layer in enumerate(self.layers):
            if output_hidden_states:
                encoder_states = encoder_states + (hidden_states,)
            layer_outputs, _ = layer.forward_causal(hidden_states, attention_mask=attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), output_attentions=output_attentions, past_key_value=past_key_values if past_key_values is not None else None, cache_position=cache_position)
            hidden_states = layer_outputs[0]
            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)
            if idx + 1 == self.config.pooling_position and self.pooling_layer is not None:
                hs = hidden_states.permute(0, 2, 1)
                if hs.shape[-1] % self.config.pooling_kernel_size != 0:
                    hs = nn.functional.pad(hs, (0, self.config.pooling_kernel_size - hs.shape[-1] % self.config.pooling_kernel_size))
                hidden_states = self.pooling_layer(hs).permute(0, 2, 1)
            if idx + 1 == self.config.quantize_position and self.codebook is not None:
                if quantized_token_ids is not None:
                    hidden_states = self.codebook(quantized_token_ids)
                else:
                    hidden_quantized, indices_flat, _ = vector_quantize(hidden_states, self.codebook.weight)
                    quantized_token_ids = indices_flat.reshape(batch_size, hidden_quantized.shape[1])
                    hidden_states = hidden_quantized
                hidden_states = hidden_states + self.embed_positions2.weight[: hidden_states.shape[1]]

        if not return_dict:
            return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
        return QuantizedBaseModelOutputWithCache(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions, quantized_token_ids=quantized_token_ids, past_key_value=past_key_values, conv1_cache=new_conv1_cache, conv2_cache=new_conv2_cache)