|
|
"""Minimal Whisper-VQ encoder for MossSpeech codec. |
|
|
|
|
|
This file provides only the components used by |
|
|
`MossSpeechCodec/modeling_moss_speech_codec.py` during inference: |
|
|
- vector quantization helper |
|
|
- causal conv for streaming |
|
|
- SDPA attention for encoder |
|
|
- WhisperVQEncoderLayer and WhisperVQEncoder |
|
|
""" |
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
import math |
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
from transformers.activations import ACT2FN |
|
|
from transformers.cache_utils import EncoderDecoderCache |
|
|
from transformers.modeling_outputs import BaseModelOutput |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
|
|
from .utils import WhisperVQConfig |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class QuantizedBaseModelOutput(BaseModelOutput): |
|
|
quantized_token_ids: Optional[torch.LongTensor] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class QuantizedBaseModelOutputWithCache(QuantizedBaseModelOutput): |
|
|
past_key_value: Optional[EncoderDecoderCache] = None |
|
|
conv1_cache: Optional[torch.Tensor] = None |
|
|
conv2_cache: Optional[torch.Tensor] = None |
|
|
|
|
|
|
|
|
def vector_quantize(inputs: torch.Tensor, codebook: torch.Tensor): |
|
|
embedding_size = codebook.size(1) |
|
|
inputs_flatten = inputs.reshape(-1, embedding_size) |
|
|
codebook_sqr = torch.sum(codebook ** 2, dim=1) |
|
|
inputs_sqr = torch.sum(inputs_flatten ** 2, dim=1, keepdim=True) |
|
|
distances = torch.addmm(codebook_sqr + inputs_sqr, inputs_flatten, codebook.t(), alpha=-2.0, beta=1.0) |
|
|
_, indices_flatten = torch.min(distances, dim=1) |
|
|
codes_flatten = torch.index_select(codebook, dim=0, index=indices_flatten) |
|
|
return codes_flatten.view_as(inputs), indices_flatten, distances |
|
|
|
|
|
|
|
|
class CausalConv1d(nn.Conv1d): |
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, **kwargs): |
|
|
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=0, dilation=dilation, groups=groups, bias=bias, **kwargs) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
causal_padding = (self.kernel_size[0] - 1) * self.dilation[0] |
|
|
x = nn.functional.pad(x, (causal_padding, 0)) |
|
|
return super().forward(x) |
|
|
|
|
|
def forward_causal(self, inp: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
k, d = self.kernel_size[0], self.dilation[0] |
|
|
if conv_cache is None: |
|
|
inp_pad = nn.functional.pad(inp, (k - 1, 0)) |
|
|
else: |
|
|
inp_pad = torch.cat((conv_cache, inp), dim=-1) |
|
|
out = super().forward(inp_pad) |
|
|
new_cache = inp_pad[:, :, -(k - 1) * d :] |
|
|
return out, new_cache |
|
|
|
|
|
|
|
|
def _prepare_4d_causal_attention_mask_with_cache_position(attention_mask, sequence_length, target_length, cache_position=None, dtype=torch.float32, device=None, min_dtype=None, batch_size=None): |
|
|
if batch_size is None: |
|
|
batch_size = attention_mask.shape[0] if attention_mask is not None else 1 |
|
|
if device is None: |
|
|
device = attention_mask.device if attention_mask is not None else None |
|
|
if min_dtype is None: |
|
|
min_dtype = torch.finfo(dtype).min |
|
|
if cache_position is None: |
|
|
target_length = sequence_length |
|
|
sequence_length = target_length |
|
|
if attention_mask is not None: |
|
|
mask_length = attention_mask.shape[-1] |
|
|
target_length = mask_length |
|
|
causal_mask = attention_mask |
|
|
if causal_mask is None: |
|
|
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) |
|
|
causal_mask = torch.triu(causal_mask, diagonal=1) |
|
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
|
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
|
|
else: |
|
|
causal_mask = causal_mask[:, None, None, :].expand(batch_size, 1, sequence_length, target_length).to(dtype) |
|
|
causal_mask = (1.0 - causal_mask) * min_dtype |
|
|
if attention_mask is not None: |
|
|
mask_length = attention_mask.shape[-1] |
|
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
|
|
padding_mask = padding_mask == 0 |
|
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype) |
|
|
return causal_mask |
|
|
|
|
|
|
|
|
class WhisperAttention(nn.Module): |
|
|
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, is_causal: bool = False, layer_idx: Optional[int] = None, config: Optional[WhisperVQConfig] = None): |
|
|
super().__init__() |
|
|
self.embed_dim = embed_dim |
|
|
self.num_heads = num_heads |
|
|
self.dropout = dropout |
|
|
self.head_dim = embed_dim // num_heads |
|
|
self.config = config |
|
|
self.is_causal = is_causal |
|
|
self.layer_idx = layer_idx |
|
|
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) |
|
|
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
|
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
|
|
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
|
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
|
|
|
|
|
|
class WhisperSdpaAttention(WhisperAttention): |
|
|
def forward(self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[EncoderDecoderCache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.Tensor] = None): |
|
|
bsz, tgt_len, _ = hidden_states.size() |
|
|
query_states = self.q_proj(hidden_states) |
|
|
query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
|
|
|
is_cross_attention = key_value_states is not None |
|
|
current_states = key_value_states if is_cross_attention else hidden_states |
|
|
key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
|
|
|
causal_mask = attention_mask |
|
|
sign = False |
|
|
if self.is_causal and causal_mask is None and tgt_len > 1: |
|
|
if cache_position is not None: |
|
|
dtype, device = query_states.dtype, query_states.device |
|
|
min_dtype = torch.finfo(dtype).min |
|
|
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(None, query_states.shape[-2], key_states.shape[-2], cache_position=cache_position, dtype=dtype, device=device, min_dtype=min_dtype, batch_size=query_states.shape[0]) |
|
|
else: |
|
|
sign = True |
|
|
|
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=sign) |
|
|
attn_output = attn_output.transpose(1, 2).reshape(bsz, tgt_len, -1).contiguous() |
|
|
attn_output = self.out_proj(attn_output) |
|
|
return attn_output, None, None |
|
|
|
|
|
|
|
|
WHISPER_ATTENTION_CLASSES = { |
|
|
"sdpa": WhisperSdpaAttention, |
|
|
} |
|
|
|
|
|
|
|
|
class WhisperVQEncoderLayer(nn.Module): |
|
|
def __init__(self, config: WhisperVQConfig, is_causal=True, layer_idx=None): |
|
|
super().__init__() |
|
|
self.embed_dim = config.d_model |
|
|
self.kv_cache = True |
|
|
impl = getattr(config, "_attn_implementation", "sdpa") |
|
|
if impl not in WHISPER_ATTENTION_CLASSES: |
|
|
impl = "sdpa" |
|
|
self.self_attn = WHISPER_ATTENTION_CLASSES[impl](embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, is_causal=is_causal, layer_idx=layer_idx, config=config) |
|
|
self.is_causal = is_causal |
|
|
if self.is_causal: |
|
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) |
|
|
self.dropout = config.dropout |
|
|
self.activation_fn = ACT2FN[config.activation_function] |
|
|
self.activation_dropout = config.activation_dropout |
|
|
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) |
|
|
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) |
|
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim) |
|
|
|
|
|
def forward_causal(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, past_key_value: Optional[EncoderDecoderCache] = None, cache_position: Optional[torch.LongTensor] = None): |
|
|
residual = hidden_states |
|
|
hidden_states = self.self_attn_layer_norm(hidden_states) |
|
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask if not self.is_causal else None, layer_head_mask=layer_head_mask, output_attentions=output_attentions, past_key_value=past_key_value, use_cache=self.kv_cache, cache_position=cache_position) |
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.final_layer_norm(hidden_states) |
|
|
hidden_states = self.activation_fn(self.fc1(hidden_states)) |
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) |
|
|
hidden_states = self.fc2(hidden_states) |
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
outputs = (hidden_states,) |
|
|
if output_attentions: |
|
|
outputs += (self_attn_weights,) |
|
|
if self.kv_cache: |
|
|
outputs += (present_key_value,) |
|
|
return outputs, cache_position |
|
|
|
|
|
|
|
|
class WhisperPreTrainedModel(PreTrainedModel): |
|
|
config_class = WhisperVQConfig |
|
|
base_model_prefix = "model" |
|
|
main_input_name = "input_features" |
|
|
|
|
|
def _init_weights(self, module): |
|
|
std = self.config.init_std |
|
|
if isinstance(module, (nn.Linear, nn.Conv1d)): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
elif isinstance(module, WhisperVQEncoder): |
|
|
with torch.no_grad(): |
|
|
embed_positions = module.embed_positions.weight |
|
|
embed_positions.copy_(sinusoids(*embed_positions.shape)) |
|
|
|
|
|
|
|
|
def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor: |
|
|
if channels % 2 != 0: |
|
|
raise ValueError("channels must be even for sinusoidal positional embeddings") |
|
|
log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1) |
|
|
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) |
|
|
scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1) |
|
|
return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1) |
|
|
|
|
|
|
|
|
class WhisperVQEncoder(WhisperPreTrainedModel): |
|
|
def __init__(self, config: WhisperVQConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.dropout = config.dropout |
|
|
self.layerdrop = config.encoder_layerdrop |
|
|
embed_dim = config.d_model |
|
|
self.num_mel_bins = config.num_mel_bins |
|
|
self.padding_idx = config.pad_token_id |
|
|
self.max_source_positions = config.max_source_positions |
|
|
if config.encoder_causal_convolution: |
|
|
conv_class = CausalConv1d |
|
|
else: |
|
|
conv_class = nn.Conv1d |
|
|
self.conv1 = conv_class(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) |
|
|
self.conv2 = conv_class(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) |
|
|
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim) |
|
|
self.embed_positions.requires_grad_(False) |
|
|
if config.quantize_encoder_only: |
|
|
self.layers = nn.ModuleList([WhisperVQEncoderLayer(config, is_causal=config.encoder_causal_attention or config.quantize_causal_encoder, layer_idx=i) for i in range(config.quantize_position)]) |
|
|
else: |
|
|
self.layers = nn.ModuleList([WhisperVQEncoderLayer(config, is_causal=config.encoder_causal_attention or (config.quantize_causal_encoder and layer_id < config.quantize_position), layer_idx=layer_id) for layer_id in range(config.encoder_layers)]) |
|
|
self.layer_norm = nn.LayerNorm(config.d_model) |
|
|
|
|
|
self.pooling_layer = None |
|
|
if config.pooling_kernel_size is not None: |
|
|
self.pooling_layer = nn.AvgPool1d(kernel_size=config.pooling_kernel_size) if config.pooling_type == "avg" else nn.MaxPool1d(kernel_size=config.pooling_kernel_size) |
|
|
|
|
|
self.codebook = None |
|
|
self.embed_positions2 = None |
|
|
if config.quantize_vocab_size is not None: |
|
|
self.codebook = nn.Embedding(config.quantize_vocab_size, config.d_model) |
|
|
pos2_len = self.max_source_positions // max(int(config.pooling_kernel_size or 1), 1) |
|
|
self.embed_positions2 = nn.Embedding(pos2_len, config.d_model) |
|
|
self.embed_positions2.requires_grad_(False) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward(self, input_features: torch.FloatTensor, attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, past_key_values: Optional[EncoderDecoderCache] = None, cache_position: Optional[torch.LongTensor] = None, quantized_token_ids: Optional[torch.LongTensor] = None, conv1_cache: Optional[torch.Tensor] = None, conv2_cache: Optional[torch.Tensor] = None): |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
device = input_features.device |
|
|
if input_features.dim() != 3: |
|
|
raise ValueError("`input_features` should be (batch, feature_size, seq_len)") |
|
|
|
|
|
if input_features.shape[-1] % 2 == 1: |
|
|
input_features = nn.functional.pad(input_features, (0, 1)) |
|
|
if input_features.shape[1] != self.num_mel_bins: |
|
|
raise ValueError(f"Expected {self.num_mel_bins} mel bins, got {input_features.shape[1]}") |
|
|
|
|
|
if isinstance(self.conv1, CausalConv1d): |
|
|
conv1_output, new_conv1_cache = self.conv1.forward_causal(input_features, conv1_cache) |
|
|
else: |
|
|
conv1_output = self.conv1(input_features) |
|
|
new_conv1_cache = None |
|
|
x = nn.functional.gelu(conv1_output) |
|
|
if isinstance(self.conv2, CausalConv1d): |
|
|
conv2_output, new_conv2_cache = self.conv2.forward_causal(x, conv2_cache) |
|
|
else: |
|
|
conv2_output = self.conv2(x) |
|
|
new_conv2_cache = None |
|
|
x = nn.functional.gelu(conv2_output) |
|
|
x = x.permute(0, 2, 1) |
|
|
batch_size, seq_len, _ = x.shape |
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask[:, :: self.conv1.stride[0] * self.conv2.stride[0]] |
|
|
if cache_position is None: |
|
|
cache_position = torch.arange(0, seq_len, device=device) |
|
|
embed_pos = self.embed_positions.weight |
|
|
hidden_states = x + embed_pos[cache_position] |
|
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
|
|
|
|
|
encoder_states = () if output_hidden_states else None |
|
|
all_attentions = () if output_attentions else None |
|
|
if past_key_values is None: |
|
|
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) |
|
|
for idx, layer in enumerate(self.layers): |
|
|
if output_hidden_states: |
|
|
encoder_states = encoder_states + (hidden_states,) |
|
|
layer_outputs, _ = layer.forward_causal(hidden_states, attention_mask=attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), output_attentions=output_attentions, past_key_value=past_key_values if past_key_values is not None else None, cache_position=cache_position) |
|
|
hidden_states = layer_outputs[0] |
|
|
if output_attentions: |
|
|
all_attentions = all_attentions + (layer_outputs[1],) |
|
|
if idx + 1 == self.config.pooling_position and self.pooling_layer is not None: |
|
|
hs = hidden_states.permute(0, 2, 1) |
|
|
if hs.shape[-1] % self.config.pooling_kernel_size != 0: |
|
|
hs = nn.functional.pad(hs, (0, self.config.pooling_kernel_size - hs.shape[-1] % self.config.pooling_kernel_size)) |
|
|
hidden_states = self.pooling_layer(hs).permute(0, 2, 1) |
|
|
if idx + 1 == self.config.quantize_position and self.codebook is not None: |
|
|
if quantized_token_ids is not None: |
|
|
hidden_states = self.codebook(quantized_token_ids) |
|
|
else: |
|
|
hidden_quantized, indices_flat, _ = vector_quantize(hidden_states, self.codebook.weight) |
|
|
quantized_token_ids = indices_flat.reshape(batch_size, hidden_quantized.shape[1]) |
|
|
hidden_states = hidden_quantized |
|
|
hidden_states = hidden_states + self.embed_positions2.weight[: hidden_states.shape[1]] |
|
|
|
|
|
if not return_dict: |
|
|
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) |
|
|
return QuantizedBaseModelOutputWithCache(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions, quantized_token_ids=quantized_token_ids, past_key_value=past_key_values, conv1_cache=new_conv1_cache, conv2_cache=new_conv2_cache) |
|
|
|