# Copyright 2025 ASLP Lab and Xiaomi Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import torch import math from torch import nn from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, LlamaConfig from .llama_nar import LlamaNARDecoderLayer class TextEmbedding(nn.Module): def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2): super().__init__() self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token def forward(self, text: int["b nt"]): # noqa: F722 text = self.text_embed(text) # b n -> b n d return text class InputEmbedding(nn.Module): def __init__(self, cond_dim, out_dim): super().__init__() self.proj = nn.Linear(cond_dim, cond_dim) self.proj_2 = nn.Linear(cond_dim, out_dim) def forward(self, x, style_emb, time_emb): # noqa: F722 style_emb = style_emb.unsqueeze(1).repeat(1, x.shape[1], 1) x_orig = x x = x + style_emb + time_emb x = self.proj(x) + x_orig x = self.proj_2(x) return x class AdaLayerNormZero_Final(nn.Module): def __init__(self, dim, cond_dim): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear(cond_dim, dim * 2) self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) def forward(self, x, emb): emb = self.linear(self.silu(emb)) scale, shift = torch.chunk(emb, 2, dim=-1) x = self.norm(x) * (1 + scale) + shift return x class SinusPositionEmbedding(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x, scale=1000): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) emb = scale * x.unsqueeze(-1) * emb.unsqueeze(0).unsqueeze(0) emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb def numel(self): return 0 class TimestepEmbedding(nn.Module): def __init__(self, dim, freq_embed_dim=256): super().__init__() self.time_embed = SinusPositionEmbedding(freq_embed_dim) self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) def forward(self, timestep: float["b"]): # noqa: F821 time_hidden = self.time_embed(timestep) time_hidden = time_hidden.to(timestep.dtype) time = self.time_mlp(time_hidden) # b d return time class DiT(nn.Module): def __init__( self, *, dim, depth=8, heads=8, ff_mult=4, mel_dim=100, text_num_embeds=256, conv_layers=0, long_skip_connection=False, use_flex_attn=False, repa_depth=-1, repa_dims=[1024], **kwargs ): super().__init__() cond_dim = 512 self.time_embed = TimestepEmbedding(cond_dim) self.text_embed = TextEmbedding(text_num_embeds, cond_dim, conv_layers=conv_layers) self.input_embed = InputEmbedding(cond_dim, dim) self.latent_embed = torch.nn.Sequential( nn.Linear(mel_dim, cond_dim), nn.Linear(cond_dim, cond_dim) ) self.dim = dim self.depth = depth self.use_flex_attn = use_flex_attn llama_config = LlamaConfig( hidden_size=dim, num_attention_heads=heads, intermediate_size=dim * ff_mult, hidden_act='silu', max_position_embeddings=4096 ) self.rotary_embed = LlamaRotaryEmbedding(config=llama_config) llama_config._attn_implementation = 'sdpa' self.transformer_blocks = nn.ModuleList( [LlamaNARDecoderLayer(llama_config, layer_idx=i, use_flex_attn=self.use_flex_attn) for i in range(depth)] ) self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None self.norm_out = AdaLayerNormZero_Final(dim, cond_dim) # final modulation self.proj_out = nn.Linear(dim, mel_dim) self.repa_depth = repa_depth self.repa_dims = repa_dims self.projectors = None if self.repa_depth > 0: self.projectors = nn.ModuleList([ nn.Sequential( nn.Linear(self.dim, self.dim * 2), nn.SiLU(), nn.Linear(self.dim * 2, self.dim * 2), nn.SiLU(), nn.Linear(self.dim * 2, repa_dim), ) for repa_dim in self.repa_dims ]) def forward( self, x: torch.Tensor, time: torch.Tensor, position_ids: torch.Tensor, style_prompt: torch.Tensor, attn_mask: torch.Tensor, output_attentions: bool = False, use_cache: bool = False, past_key_value = None, ): """ Args: x: [b, n, d] time: [b, n, 1] position_ids: [b, n] style_prompt: [b, 512] attn_mask: [b, 1, n, n] """ batch, seq_len = x.shape[0], x.shape[1] t = self.time_embed(time) c = t # [B, T, dim] x = self.input_embed(x, style_prompt, c) if self.long_skip_connection is not None: residual = x position_embeddings = self.rotary_embed(x, position_ids) attn_weights = [] if not use_cache: past_key_value = None repa_res = None for i, block in enumerate(self.transformer_blocks): res = block( x, attention_mask=attn_mask, position_embeddings=position_embeddings, output_attentions=output_attentions, past_key_value=past_key_value, use_cache=use_cache ) x = res.pop(0) if output_attentions: attn_weights.append(res.pop(0)) if use_cache: past_key_value = res.pop(0) if i == self.repa_depth - 1: repa_res = x if self.long_skip_connection is not None: x = self.long_skip_connection(torch.cat((x, residual), dim=-1)) x = self.norm_out(x, c) output = self.proj_out(x) return output, attn_weights, past_key_value