Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| from torch.nn import RMSNorm | |
| from .config import DiaConfig | |
| from .state import DecoderInferenceState, EncoderInferenceState, KVCache | |
| def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]: | |
| return tuple(ax if ax >= 0 else ndim + ax for ax in axes) | |
| class DenseGeneral(nn.Module): | |
| """ | |
| PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init. | |
| Stores weights (`kernel`) in the same layout as Jax and uses torch.tensordot | |
| for the generalized matrix multiplication. Weight/bias shapes are calculated | |
| and parameters created during initialization based on config. | |
| `load_weights` validates shapes and copies data. | |
| Attributes: | |
| axis (Tuple[int, ...]): Input axis or axes to contract. | |
| in_shapes (Tuple[int, ...]): Sizes of the input dimensions specified by `axis`. | |
| out_features (Tuple[int, ...]): Shape of the output features (non-contracted dims). | |
| use_bias (bool): Whether to add a bias term. | |
| weight (nn.Parameter): The kernel parameter. | |
| bias (Optional[nn.Parameter]): The bias parameter (if use_bias=True). | |
| """ | |
| def __init__( | |
| self, | |
| in_shapes: tuple[int, ...], | |
| out_features: tuple[int, ...], | |
| axis: tuple[int, ...] = (-1,), | |
| weight_dtype: torch.dtype | None = None, | |
| device: torch.device | None = None, | |
| ): | |
| super().__init__() | |
| self.in_shapes = in_shapes | |
| self.out_features = out_features | |
| self.axis = axis | |
| self.kernel_shape = self.in_shapes + self.out_features | |
| factory_kwargs = {"device": device, "dtype": weight_dtype} | |
| self.weight = nn.Parameter(torch.empty(self.kernel_shape, **factory_kwargs)) | |
| self.register_parameter("bias", None) | |
| def forward(self, inputs: Tensor) -> Tensor: | |
| norm_axis = _normalize_axes(self.axis, inputs.ndim) | |
| kernel_contract_axes = tuple(range(len(norm_axis))) | |
| output = torch.tensordot( | |
| inputs.to(self.weight.dtype), | |
| self.weight, | |
| dims=(norm_axis, kernel_contract_axes), | |
| ).to(inputs.dtype) | |
| return output | |
| class MlpBlock(nn.Module): | |
| """MLP block using DenseGeneral.""" | |
| def __init__(self, embed_dim: int, intermediate_dim: int, compute_dtype: torch.dtype): | |
| super().__init__() | |
| self.dtype = compute_dtype | |
| self.wi_fused = DenseGeneral( | |
| in_shapes=(embed_dim,), | |
| out_features=(2, intermediate_dim), | |
| axis=(-1,), | |
| weight_dtype=compute_dtype, | |
| ) | |
| self.wo = DenseGeneral( | |
| in_shapes=(intermediate_dim,), | |
| out_features=(embed_dim,), | |
| axis=(-1,), | |
| weight_dtype=compute_dtype, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Forward pass.""" | |
| fused_x = self.wi_fused(x) | |
| gate = fused_x[..., 0, :] | |
| up = fused_x[..., 1, :] | |
| hidden = torch.mul(F.silu(gate), up).to(self.dtype) | |
| output = self.wo(hidden) | |
| return output | |
| class RotaryEmbedding(nn.Module): | |
| """Rotary Position Embedding (RoPE) implementation in PyTorch.""" | |
| def __init__( | |
| self, | |
| embedding_dims: int, | |
| min_timescale: int = 1, | |
| max_timescale: int = 10000, | |
| dtype: torch.dtype = torch.float32, | |
| ): | |
| super().__init__() | |
| if embedding_dims % 2 != 0: | |
| raise ValueError("Embedding dim must be even for RoPE.") | |
| self.embedding_dims = embedding_dims | |
| self.min_timescale = min_timescale | |
| self.max_timescale = max_timescale | |
| self.dtype = dtype | |
| half_embedding_dim = embedding_dims // 2 | |
| fraction = (2.0 * torch.arange(0, half_embedding_dim)) / embedding_dims | |
| self.register_buffer( | |
| "timescale", | |
| self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction, | |
| persistent=False, | |
| ) | |
| def extra_repr(self) -> str: | |
| s = f"{self.timescale.shape}" | |
| return s | |
| def forward(self, inputs: torch.Tensor, position: torch.Tensor): | |
| """Applies RoPE.""" | |
| position = position.unsqueeze(-1).unsqueeze(-1) | |
| timescale = self.timescale.to(inputs.device) | |
| sinusoid_inp = position / timescale | |
| sin = torch.sin(sinusoid_inp).to(inputs.dtype) | |
| cos = torch.cos(sinusoid_inp).to(inputs.dtype) | |
| first_half, second_half = torch.chunk(inputs, 2, dim=-1) | |
| first_part = first_half * cos - second_half * sin | |
| second_part = second_half * cos + first_half * sin | |
| return torch.cat((first_part, second_part), dim=-1) | |
| class Attention(nn.Module): | |
| """Attention using DenseGeneral.""" | |
| def __init__( | |
| self, | |
| config: DiaConfig, | |
| q_embed_dim: int, | |
| kv_embed_dim: int, | |
| num_query_heads: int, | |
| num_kv_heads: int, | |
| head_dim: int, | |
| compute_dtype: torch.dtype, | |
| is_cross_attn: bool = False, | |
| out_embed_dim: int | None = None, | |
| ): | |
| super().__init__() | |
| self.num_query_heads = num_query_heads | |
| self.num_kv_heads = num_kv_heads | |
| self.head_dim = head_dim | |
| self.is_cross_attn = is_cross_attn | |
| self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim | |
| self.projected_query_dim = num_query_heads * head_dim | |
| if num_query_heads % num_kv_heads != 0: | |
| raise ValueError(f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})") | |
| self.num_gqa_groups = num_query_heads // num_kv_heads | |
| # --- Projection Layers using DenseGeneral --- | |
| self.q_proj = DenseGeneral( | |
| in_shapes=(q_embed_dim,), | |
| out_features=(num_query_heads, head_dim), | |
| axis=(-1,), | |
| weight_dtype=compute_dtype, | |
| ) | |
| self.k_proj = DenseGeneral( | |
| in_shapes=(kv_embed_dim,), | |
| out_features=(num_kv_heads, head_dim), | |
| axis=(-1,), | |
| weight_dtype=compute_dtype, | |
| ) | |
| self.v_proj = DenseGeneral( | |
| in_shapes=(kv_embed_dim,), | |
| out_features=(num_kv_heads, head_dim), | |
| axis=(-1,), | |
| weight_dtype=compute_dtype, | |
| ) | |
| self.o_proj = DenseGeneral( | |
| in_shapes=(num_query_heads, head_dim), | |
| out_features=(self.output_dim,), | |
| axis=(-2, -1), | |
| weight_dtype=compute_dtype, | |
| ) | |
| # --- Rotary Embedding --- | |
| self.rotary_emb = RotaryEmbedding( | |
| embedding_dims=self.head_dim, | |
| min_timescale=config.model.rope_min_timescale, | |
| max_timescale=config.model.rope_max_timescale, | |
| dtype=compute_dtype, | |
| ) | |
| def forward( | |
| self, | |
| Xq: torch.Tensor, # (B, T, D) T = 1 in AR generation | |
| Xkv: torch.Tensor, # (B, S, E) S = 1 in AR generation | |
| q_positions: torch.Tensor, # (B, T) | |
| kv_positions: torch.Tensor | None = None, # (B, S) | |
| attn_mask: torch.Tensor | None = None, # None in Decoder Self Attention, Valid mask in Others | |
| cache: KVCache | None = None, # None in Encoder, KVCache in Decoder | |
| prefill: bool = False, | |
| is_causal: bool = False, | |
| ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]: | |
| """ | |
| Performs attention calculation with optional KV caching. | |
| Args: | |
| Xq: Query tensor (B, T, D). T=1 during single-step decoding. | |
| Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn. | |
| q_positions: Positions for queries (B, T). | |
| kv_positions: Positions for keys/values (B, S). If None, uses q_positions. | |
| attn_mask: Attention mask. | |
| cache: KVCache. | |
| prefill: If True, use prefill mode. | |
| Returns: | |
| A tuple containing: | |
| - output: The attention output tensor (B, T, output_dim). | |
| - present_kv: The K/V state to be cached for the next step ((B, N, S_new, H), (B, N, S_new, H)). For self-attn, S_new = S_past + S. For cross-attn, S_new = S_kv. | |
| """ | |
| if kv_positions is None: | |
| kv_positions = q_positions | |
| original_dtype = Xq.dtype | |
| Xq_BxTxNxH = self.q_proj(Xq) | |
| Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions) | |
| Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2) | |
| attn_k: torch.Tensor | None = None | |
| attn_v: torch.Tensor | None = None | |
| if self.is_cross_attn: | |
| attn_k, attn_v = cache.k, cache.v | |
| else: | |
| Xk_BxSxKxH = self.k_proj(Xkv) # (B, S, K, H) | |
| Xv_BxSxKxH = self.v_proj(Xkv) # (B, S, K, H) | |
| Xk_BxSxKxH = self.rotary_emb(Xk_BxSxKxH, position=kv_positions) # (B, S, K, H) | |
| Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H) | |
| Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H) | |
| if cache is None: | |
| attn_k = Xk_BxKxSxH | |
| attn_v = Xv_BxKxSxH | |
| else: | |
| if prefill: | |
| attn_k, attn_v = Xk_BxKxSxH, Xv_BxKxSxH | |
| cache.prefill(attn_k, attn_v) | |
| else: | |
| attn_k, attn_v = cache.update(Xk_BxKxSxH, Xv_BxKxSxH) | |
| attn_output = F.scaled_dot_product_attention( | |
| Xq_BxNxTxH, | |
| attn_k, | |
| attn_v, | |
| attn_mask=attn_mask, | |
| scale=1.0, | |
| enable_gqa=self.num_gqa_groups > 1, | |
| is_causal=is_causal, | |
| ) | |
| attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H) | |
| output = self.o_proj(attn_output) | |
| return output.to(original_dtype) | |
| class EncoderLayer(nn.Module): | |
| """Transformer Encoder Layer using DenseGeneral.""" | |
| def __init__(self, config: DiaConfig, compute_dtype: torch.dtype): | |
| super().__init__() | |
| self.config = config | |
| model_config = config.model | |
| enc_config = config.model.encoder | |
| embed_dim = enc_config.n_embd | |
| self.pre_sa_norm = RMSNorm( | |
| embed_dim, | |
| eps=model_config.normalization_layer_epsilon, | |
| dtype=torch.float32, | |
| ) | |
| self.self_attention = Attention( | |
| config, | |
| q_embed_dim=embed_dim, | |
| kv_embed_dim=embed_dim, | |
| num_query_heads=enc_config.n_head, | |
| num_kv_heads=enc_config.n_head, | |
| head_dim=enc_config.head_dim, | |
| compute_dtype=compute_dtype, | |
| is_cross_attn=False, | |
| out_embed_dim=embed_dim, | |
| ) | |
| self.post_sa_norm = RMSNorm( | |
| embed_dim, | |
| eps=model_config.normalization_layer_epsilon, | |
| dtype=torch.float32, | |
| ) | |
| self.mlp = MlpBlock(embed_dim=embed_dim, intermediate_dim=enc_config.n_hidden, compute_dtype=compute_dtype) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| state: EncoderInferenceState, | |
| ) -> torch.Tensor: | |
| residual = x | |
| x_norm = self.pre_sa_norm(x) | |
| sa_out = self.self_attention( | |
| Xq=x_norm, | |
| Xkv=x_norm, | |
| q_positions=state.positions, | |
| kv_positions=state.positions, | |
| attn_mask=state.attn_mask, | |
| ) | |
| x = residual + sa_out | |
| residual = x | |
| x_norm = self.post_sa_norm(x) | |
| mlp_out = self.mlp(x_norm) | |
| x = residual + mlp_out | |
| return x | |
| class Encoder(nn.Module): | |
| """Transformer Encoder Stack using DenseGeneral.""" | |
| def __init__(self, config: DiaConfig, compute_dtype: torch.dtype): | |
| super().__init__() | |
| self.config = config | |
| model_config = config.model | |
| enc_config = config.model.encoder | |
| self.embedding = nn.Embedding( | |
| model_config.src_vocab_size, | |
| enc_config.n_embd, | |
| dtype=compute_dtype, | |
| ) | |
| self.layers = nn.ModuleList([EncoderLayer(config, compute_dtype) for _ in range(enc_config.n_layer)]) | |
| self.norm = RMSNorm( | |
| enc_config.n_embd, | |
| eps=model_config.normalization_layer_epsilon, | |
| dtype=torch.float32, | |
| ) | |
| def forward( | |
| self, | |
| x_ids: torch.Tensor, | |
| state: EncoderInferenceState, | |
| ) -> torch.Tensor: | |
| x = self.embedding(x_ids) | |
| for layer in self.layers: | |
| x = layer(x, state) | |
| x = self.norm(x) | |
| return x | |
| class DecoderLayer(nn.Module): | |
| """Transformer Decoder Layer using DenseGeneral.""" | |
| def __init__(self, config: DiaConfig, compute_dtype: torch.dtype): | |
| super().__init__() | |
| self.config = config | |
| model_config = config.model | |
| dec_config = config.model.decoder | |
| enc_config = config.model.encoder | |
| dec_embed_dim = dec_config.n_embd | |
| enc_embed_dim = enc_config.n_embd | |
| # Norms | |
| self.pre_sa_norm = RMSNorm( | |
| dec_embed_dim, | |
| eps=model_config.normalization_layer_epsilon, | |
| dtype=torch.float32, | |
| ) | |
| self.pre_ca_norm = RMSNorm( | |
| dec_embed_dim, | |
| eps=model_config.normalization_layer_epsilon, | |
| dtype=torch.float32, | |
| ) | |
| self.pre_mlp_norm = RMSNorm( | |
| dec_embed_dim, | |
| eps=model_config.normalization_layer_epsilon, | |
| dtype=torch.float32, | |
| ) | |
| # Self-Attention (GQA) with Causal Masking | |
| self.self_attention = Attention( | |
| config, | |
| q_embed_dim=dec_embed_dim, | |
| kv_embed_dim=dec_embed_dim, | |
| num_query_heads=dec_config.gqa_query_heads, | |
| num_kv_heads=dec_config.kv_heads, | |
| head_dim=dec_config.gqa_head_dim, | |
| compute_dtype=compute_dtype, | |
| is_cross_attn=False, | |
| out_embed_dim=dec_embed_dim, | |
| ) | |
| # Cross-Attention (MHA) | |
| self.cross_attention = Attention( | |
| config=config, | |
| q_embed_dim=dec_embed_dim, | |
| kv_embed_dim=enc_embed_dim, # Note kv_embed_dim | |
| num_query_heads=dec_config.cross_query_heads, | |
| num_kv_heads=dec_config.cross_query_heads, | |
| head_dim=dec_config.cross_head_dim, | |
| compute_dtype=compute_dtype, | |
| is_cross_attn=True, | |
| out_embed_dim=dec_embed_dim, | |
| ) | |
| # MLP | |
| self.mlp = MlpBlock( | |
| embed_dim=dec_embed_dim, | |
| intermediate_dim=dec_config.n_hidden, | |
| compute_dtype=compute_dtype, | |
| ) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| state: DecoderInferenceState, | |
| self_attn_cache: KVCache | None = None, | |
| cross_attn_cache: KVCache | None = None, | |
| prefill: bool = False, | |
| ) -> torch.Tensor: | |
| residual = x | |
| x_norm = self.pre_sa_norm(x) | |
| sa_out = self.self_attention( | |
| Xq=x_norm, # (2, 1, D) | |
| Xkv=x_norm, # (2, 1, D) | |
| q_positions=state.dec_positions, # (2, 1) | |
| kv_positions=state.dec_positions, # (2, 1) | |
| attn_mask=None, | |
| cache=self_attn_cache, | |
| prefill=prefill, | |
| is_causal=prefill, | |
| ) | |
| x = residual + sa_out | |
| residual = x | |
| x_norm = self.pre_ca_norm(x) | |
| ca_out = self.cross_attention( | |
| Xq=x_norm, | |
| Xkv=state.enc_out, | |
| q_positions=state.dec_positions, | |
| kv_positions=state.enc_positions, | |
| attn_mask=state.dec_cross_attn_mask, | |
| cache=cross_attn_cache, | |
| ) | |
| x = residual + ca_out | |
| residual = x | |
| x_norm = self.pre_mlp_norm(x) | |
| mlp_out = self.mlp(x_norm) | |
| x = residual + mlp_out | |
| return x | |
| class Decoder(nn.Module): | |
| """Transformer Decoder Stack using DenseGeneral.""" | |
| def __init__(self, config: DiaConfig, compute_dtype: torch.dtype): | |
| super().__init__() | |
| self.config = config | |
| model_config = config.model | |
| dec_config = config.model.decoder | |
| data_config = config.data | |
| self.num_channels = data_config.channels | |
| self.num_layers = dec_config.n_layer | |
| self.embeddings = nn.ModuleList( | |
| [ | |
| nn.Embedding(model_config.tgt_vocab_size, dec_config.n_embd, dtype=compute_dtype) | |
| for _ in range(self.num_channels) | |
| ] | |
| ) | |
| self.layers = nn.ModuleList( | |
| [DecoderLayer(config=config, compute_dtype=compute_dtype) for _ in range(self.num_layers)] | |
| ) | |
| self.norm = RMSNorm( | |
| dec_config.n_embd, | |
| eps=model_config.normalization_layer_epsilon, | |
| dtype=torch.float32, | |
| ) | |
| self.logits_dense = DenseGeneral( | |
| in_shapes=(dec_config.n_embd,), | |
| out_features=(self.num_channels, model_config.tgt_vocab_size), | |
| axis=(-1,), | |
| weight_dtype=compute_dtype, | |
| ) | |
| def precompute_cross_attn_cache( | |
| self, | |
| enc_out: torch.Tensor, # (B, S, E) | |
| enc_positions: torch.Tensor, # (B, S) | |
| ) -> list[KVCache]: | |
| """ | |
| Computes the Key and Value tensors for cross-attention for each layer from the encoder output. | |
| """ | |
| per_layer_kv_cache: list[KVCache] = [] | |
| for layer in self.layers: | |
| cross_attn_module = layer.cross_attention | |
| k_proj = cross_attn_module.k_proj(enc_out) | |
| v_proj = cross_attn_module.v_proj(enc_out) | |
| k_proj = cross_attn_module.rotary_emb(k_proj, position=enc_positions) | |
| k = k_proj.transpose(1, 2) | |
| v = v_proj.transpose(1, 2) | |
| per_layer_kv_cache.append(KVCache.from_kv(k, v)) | |
| return per_layer_kv_cache | |
| def decode_step( | |
| self, | |
| tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C] | |
| state: DecoderInferenceState, | |
| ) -> torch.Tensor: | |
| """ | |
| Performs a single decoding step, managing KV caches layer by layer. | |
| Returns: | |
| A tuple containing: | |
| - logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32. | |
| """ | |
| x = None | |
| for i in range(self.num_channels): | |
| channel_tokens = tgt_ids_Bx1xC[..., i] | |
| channel_embed = self.embeddings[i](channel_tokens) | |
| x = channel_embed if x is None else x + channel_embed | |
| for i, layer in enumerate(self.layers): | |
| self_cache = state.self_attn_cache[i] | |
| cross_cache = state.cross_attn_cache[i] | |
| x = layer( | |
| x, # (2, 1, D) | |
| state, | |
| self_attn_cache=self_cache, | |
| cross_attn_cache=cross_cache, | |
| ) | |
| x = self.norm(x) | |
| logits_Bx1xCxV = self.logits_dense(x) | |
| return logits_Bx1xCxV.to(torch.float32) | |
| def forward(self, tgt_ids_BxTxC: torch.Tensor, state: DecoderInferenceState) -> torch.Tensor: | |
| """ | |
| Forward pass for the Decoder stack, managing KV caches. | |
| Args: | |
| tgt_ids_BxTxC: Target token IDs (B, T, C). | |
| encoder_out: Output from the encoder (B, S, E). | |
| tgt_positions: Positions for target sequence (B, T). | |
| src_positions: Positions for source sequence (B, S). | |
| self_attn_mask: Mask for self-attention. | |
| cross_attn_mask: Mask for cross-attention. | |
| past_key_values: List containing the self-attention KV cache for each layer | |
| from the previous decoding step. `len(past_key_values)` should | |
| equal `num_layers`. | |
| precomputed_cross_attn_kv: A single tuple containing the pre-computed K/V cache | |
| derived from `encoder_out`. This is passed identically | |
| to all layers. | |
| Returns: | |
| A tuple containing: | |
| - logits: The final output logits (B, T, C * V), cast to float32. | |
| - present_key_values: A list containing the updated self-attention KV cache | |
| for each layer for the *current* decoding step. | |
| """ | |
| _, _, num_channels_in = tgt_ids_BxTxC.shape | |
| assert num_channels_in == self.num_channels, "Input channels mismatch" | |
| # Embeddings | |
| x = None | |
| for i in range(self.num_channels): | |
| channel_tokens = tgt_ids_BxTxC[..., i] | |
| channel_embed = self.embeddings[i](channel_tokens) | |
| x = channel_embed if x is None else x + channel_embed | |
| for i, layer in enumerate(self.layers): | |
| self_cache = state.self_attn_cache[i] | |
| cross_cache = state.cross_attn_cache[i] | |
| x = layer(x, state, self_attn_cache=self_cache, cross_attn_cache=cross_cache, prefill=True) | |
| # Final Norm | |
| x = self.norm(x) | |
| logits_BxTxCxV = self.logits_dense(x) | |
| return logits_BxTxCxV.to(torch.float32) | |
| class DiaModel(nn.Module): | |
| """PyTorch Dia Model using DenseGeneral.""" | |
| def __init__(self, config: DiaConfig, compute_dtype: torch.dtype): | |
| super().__init__() | |
| self.config = config | |
| self.encoder = Encoder(config, compute_dtype) | |
| self.decoder = Decoder(config, compute_dtype) | |