"""Minimal Whisper-VQ encoder for MossSpeech codec. This file provides only the components used by `MossSpeechCodec/modeling_moss_speech_codec.py` during inference: - vector quantization helper - causal conv for streaming - SDPA attention for encoder - WhisperVQEncoderLayer and WhisperVQEncoder """ from dataclasses import dataclass from typing import Optional, Tuple import math import torch from torch import nn from transformers.activations import ACT2FN from transformers.cache_utils import EncoderDecoderCache from transformers.modeling_outputs import BaseModelOutput from transformers.modeling_utils import PreTrainedModel from .utils import WhisperVQConfig @dataclass class QuantizedBaseModelOutput(BaseModelOutput): quantized_token_ids: Optional[torch.LongTensor] = None @dataclass class QuantizedBaseModelOutputWithCache(QuantizedBaseModelOutput): past_key_value: Optional[EncoderDecoderCache] = None conv1_cache: Optional[torch.Tensor] = None conv2_cache: Optional[torch.Tensor] = None def vector_quantize(inputs: torch.Tensor, codebook: torch.Tensor): embedding_size = codebook.size(1) inputs_flatten = inputs.reshape(-1, embedding_size) codebook_sqr = torch.sum(codebook ** 2, dim=1) inputs_sqr = torch.sum(inputs_flatten ** 2, dim=1, keepdim=True) distances = torch.addmm(codebook_sqr + inputs_sqr, inputs_flatten, codebook.t(), alpha=-2.0, beta=1.0) _, indices_flatten = torch.min(distances, dim=1) codes_flatten = torch.index_select(codebook, dim=0, index=indices_flatten) return codes_flatten.view_as(inputs), indices_flatten, distances class CausalConv1d(nn.Conv1d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, **kwargs): super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=0, dilation=dilation, groups=groups, bias=bias, **kwargs) def forward(self, x: torch.Tensor) -> torch.Tensor: causal_padding = (self.kernel_size[0] - 1) * self.dilation[0] x = nn.functional.pad(x, (causal_padding, 0)) return super().forward(x) def forward_causal(self, inp: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: k, d = self.kernel_size[0], self.dilation[0] if conv_cache is None: inp_pad = nn.functional.pad(inp, (k - 1, 0)) else: inp_pad = torch.cat((conv_cache, inp), dim=-1) out = super().forward(inp_pad) new_cache = inp_pad[:, :, -(k - 1) * d :] return out, new_cache def _prepare_4d_causal_attention_mask_with_cache_position(attention_mask, sequence_length, target_length, cache_position=None, dtype=torch.float32, device=None, min_dtype=None, batch_size=None): if batch_size is None: batch_size = attention_mask.shape[0] if attention_mask is not None else 1 if device is None: device = attention_mask.device if attention_mask is not None else None if min_dtype is None: min_dtype = torch.finfo(dtype).min if cache_position is None: target_length = sequence_length sequence_length = target_length if attention_mask is not None: mask_length = attention_mask.shape[-1] target_length = mask_length causal_mask = attention_mask if causal_mask is None: causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) else: causal_mask = causal_mask[:, None, None, :].expand(batch_size, 1, sequence_length, target_length).to(dtype) causal_mask = (1.0 - causal_mask) * min_dtype if attention_mask is not None: mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype) return causal_mask class WhisperAttention(nn.Module): def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, is_causal: bool = False, layer_idx: Optional[int] = None, config: Optional[WhisperVQConfig] = None): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads self.config = config self.is_causal = is_causal self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() class WhisperSdpaAttention(WhisperAttention): def forward(self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[EncoderDecoderCache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.Tensor] = None): bsz, tgt_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() is_cross_attention = key_value_states is not None current_states = key_value_states if is_cross_attention else hidden_states key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous() value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous() causal_mask = attention_mask sign = False if self.is_causal and causal_mask is None and tgt_len > 1: if cache_position is not None: dtype, device = query_states.dtype, query_states.device min_dtype = torch.finfo(dtype).min causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(None, query_states.shape[-2], key_states.shape[-2], cache_position=cache_position, dtype=dtype, device=device, min_dtype=min_dtype, batch_size=query_states.shape[0]) else: sign = True attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=sign) attn_output = attn_output.transpose(1, 2).reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) return attn_output, None, None WHISPER_ATTENTION_CLASSES = { "sdpa": WhisperSdpaAttention, } class WhisperVQEncoderLayer(nn.Module): def __init__(self, config: WhisperVQConfig, is_causal=True, layer_idx=None): super().__init__() self.embed_dim = config.d_model self.kv_cache = True impl = getattr(config, "_attn_implementation", "sdpa") if impl not in WHISPER_ATTENTION_CLASSES: impl = "sdpa" self.self_attn = WHISPER_ATTENTION_CLASSES[impl](embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, is_causal=is_causal, layer_idx=layer_idx, config=config) self.is_causal = is_causal if self.is_causal: self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward_causal(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, past_key_value: Optional[EncoderDecoderCache] = None, cache_position: Optional[torch.LongTensor] = None): residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, self_attn_weights, present_key_value = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask if not self.is_causal else None, layer_head_mask=layer_head_mask, output_attentions=output_attentions, past_key_value=past_key_value, use_cache=self.kv_cache, cache_position=cache_position) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if self.kv_cache: outputs += (present_key_value,) return outputs, cache_position class WhisperPreTrainedModel(PreTrainedModel): config_class = WhisperVQConfig base_model_prefix = "model" main_input_name = "input_features" def _init_weights(self, module): std = self.config.init_std if isinstance(module, (nn.Linear, nn.Conv1d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, WhisperVQEncoder): with torch.no_grad(): embed_positions = module.embed_positions.weight embed_positions.copy_(sinusoids(*embed_positions.shape)) def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor: if channels % 2 != 0: raise ValueError("channels must be even for sinusoidal positional embeddings") log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1) inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1) return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1) class WhisperVQEncoder(WhisperPreTrainedModel): def __init__(self, config: WhisperVQConfig): super().__init__(config) self.config = config self.dropout = config.dropout self.layerdrop = config.encoder_layerdrop embed_dim = config.d_model self.num_mel_bins = config.num_mel_bins self.padding_idx = config.pad_token_id self.max_source_positions = config.max_source_positions if config.encoder_causal_convolution: conv_class = CausalConv1d else: conv_class = nn.Conv1d self.conv1 = conv_class(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) self.conv2 = conv_class(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim) self.embed_positions.requires_grad_(False) if config.quantize_encoder_only: self.layers = nn.ModuleList([WhisperVQEncoderLayer(config, is_causal=config.encoder_causal_attention or config.quantize_causal_encoder, layer_idx=i) for i in range(config.quantize_position)]) else: self.layers = nn.ModuleList([WhisperVQEncoderLayer(config, is_causal=config.encoder_causal_attention or (config.quantize_causal_encoder and layer_id < config.quantize_position), layer_idx=layer_id) for layer_id in range(config.encoder_layers)]) self.layer_norm = nn.LayerNorm(config.d_model) self.pooling_layer = None if config.pooling_kernel_size is not None: self.pooling_layer = nn.AvgPool1d(kernel_size=config.pooling_kernel_size) if config.pooling_type == "avg" else nn.MaxPool1d(kernel_size=config.pooling_kernel_size) self.codebook = None self.embed_positions2 = None if config.quantize_vocab_size is not None: self.codebook = nn.Embedding(config.quantize_vocab_size, config.d_model) pos2_len = self.max_source_positions // max(int(config.pooling_kernel_size or 1), 1) self.embed_positions2 = nn.Embedding(pos2_len, config.d_model) self.embed_positions2.requires_grad_(False) self.post_init() def forward(self, input_features: torch.FloatTensor, attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, past_key_values: Optional[EncoderDecoderCache] = None, cache_position: Optional[torch.LongTensor] = None, quantized_token_ids: Optional[torch.LongTensor] = None, conv1_cache: Optional[torch.Tensor] = None, conv2_cache: Optional[torch.Tensor] = None): return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states device = input_features.device if input_features.dim() != 3: raise ValueError("`input_features` should be (batch, feature_size, seq_len)") if input_features.shape[-1] % 2 == 1: input_features = nn.functional.pad(input_features, (0, 1)) if input_features.shape[1] != self.num_mel_bins: raise ValueError(f"Expected {self.num_mel_bins} mel bins, got {input_features.shape[1]}") if isinstance(self.conv1, CausalConv1d): conv1_output, new_conv1_cache = self.conv1.forward_causal(input_features, conv1_cache) else: conv1_output = self.conv1(input_features) new_conv1_cache = None x = nn.functional.gelu(conv1_output) if isinstance(self.conv2, CausalConv1d): conv2_output, new_conv2_cache = self.conv2.forward_causal(x, conv2_cache) else: conv2_output = self.conv2(x) new_conv2_cache = None x = nn.functional.gelu(conv2_output) x = x.permute(0, 2, 1) batch_size, seq_len, _ = x.shape if attention_mask is not None: attention_mask = attention_mask[:, :: self.conv1.stride[0] * self.conv2.stride[0]] if cache_position is None: cache_position = torch.arange(0, seq_len, device=device) embed_pos = self.embed_positions.weight hidden_states = x + embed_pos[cache_position] hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None if past_key_values is None: past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) for idx, layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) layer_outputs, _ = layer.forward_causal(hidden_states, attention_mask=attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), output_attentions=output_attentions, past_key_value=past_key_values if past_key_values is not None else None, cache_position=cache_position) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) if idx + 1 == self.config.pooling_position and self.pooling_layer is not None: hs = hidden_states.permute(0, 2, 1) if hs.shape[-1] % self.config.pooling_kernel_size != 0: hs = nn.functional.pad(hs, (0, self.config.pooling_kernel_size - hs.shape[-1] % self.config.pooling_kernel_size)) hidden_states = self.pooling_layer(hs).permute(0, 2, 1) if idx + 1 == self.config.quantize_position and self.codebook is not None: if quantized_token_ids is not None: hidden_states = self.codebook(quantized_token_ids) else: hidden_quantized, indices_flat, _ = vector_quantize(hidden_states, self.codebook.weight) quantized_token_ids = indices_flat.reshape(batch_size, hidden_quantized.shape[1]) hidden_states = hidden_quantized hidden_states = hidden_states + self.embed_positions2.weight[: hidden_states.shape[1]] if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return QuantizedBaseModelOutputWithCache(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions, quantized_token_ids=quantized_token_ids, past_key_value=past_key_values, conv1_cache=new_conv1_cache, conv2_cache=new_conv2_cache)