Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| import math | |
| from modules.audio_detokenizer.flow_matching.dit_block import DiTBlock, FinalLayer | |
| def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, | |
| interpolation_factor: int = 1, max_seq_length: int = 4096): | |
| print(f'using rope base theta = {theta}, interpolation factor = {interpolation_factor}') | |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) | |
| # ROPE type-A extention | |
| # we choose to use interpolation rather than extrapolation for better position encoding | |
| # for scale purposes, t should be a float tensor | |
| t = torch.arange(end, device=freqs.device).float() | |
| scale = 1.0 / float(interpolation_factor) | |
| t *= scale | |
| freqs = torch.outer(t, freqs).float() # type: ignore | |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 | |
| # Sometimes, we don't need so many rope emb as seq_len is smaller than max_pos_emb | |
| # e.g. rope 1M but seqlen 32k, this will cause gpu memory waste | |
| if max_seq_length < end: | |
| freqs_cis = freqs_cis[:max_seq_length,].clone() | |
| return freqs_cis | |
| class TimestepEmbedder(nn.Module): | |
| """ | |
| Embeds scalar timesteps into vector representations. | |
| """ | |
| def __init__(self, hidden_size, frequency_embedding_size=256): | |
| super().__init__() | |
| self.mlp = nn.Sequential( | |
| nn.Linear(frequency_embedding_size, hidden_size, bias=True), | |
| nn.SiLU(), | |
| nn.Linear(hidden_size, hidden_size, bias=True), | |
| ) | |
| self.frequency_embedding_size = frequency_embedding_size | |
| def timestep_embedding(t, dim, max_period=10000): | |
| """ | |
| Create sinusoidal timestep embeddings. | |
| :param t: a 1-D Tensor of N indices, one per batch element. | |
| These may be fractional. | |
| :param dim: the dimension of the output. | |
| :param max_period: controls the minimum frequency of the embeddings. | |
| :return: an (N, D) Tensor of positional embeddings. | |
| """ | |
| # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py | |
| half = dim // 2 | |
| freqs = torch.exp( | |
| -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half | |
| ).float().to(device=t.device) | |
| args = t[:, None].float() * freqs[None] | |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| if dim % 2: | |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) | |
| return embedding | |
| def forward(self, t): | |
| t_freq = self.timestep_embedding(t, self.frequency_embedding_size) | |
| t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype)) | |
| return t_emb | |
| class SinusoidalPositionalEmbedding(nn.Module): | |
| """This module produces sinusoidal positional embeddings of any length. | |
| Padding symbols are ignored. | |
| """ | |
| def __init__(self, embedding_dim, padding_idx, init_size=1024): | |
| super().__init__() | |
| self.embedding_dim = embedding_dim | |
| self.padding_idx = padding_idx | |
| self.weights = SinusoidalPositionalEmbedding.get_embedding( | |
| init_size, | |
| embedding_dim, | |
| padding_idx, | |
| ) | |
| self.register_buffer('_float_tensor', torch.FloatTensor(1)) | |
| def get_embedding(num_embeddings, embedding_dim, padding_idx=None): | |
| """Build sinusoidal embeddings. | |
| This matches the implementation in tensor2tensor, but differs slightly | |
| from the description in Section 3.5 of "Attention Is All You Need". | |
| """ | |
| half_dim = embedding_dim // 2 # d/2 | |
| emb = math.log(10000) / (half_dim - 1) # 2*log(10000)/(d-2) | |
| emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) # -2i/(d-2)*log(10000); i from 0 to (d-2)/2; shape: (d/2, ) | |
| emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) # pos/[1000 ** (2i/(d-2))]; shape: (num_embeddings, d/2) | |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) # shape: (num_embeddings, d) | |
| if embedding_dim % 2 == 1: | |
| # zero pad | |
| emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) | |
| if padding_idx is not None: | |
| emb[padding_idx, :] = 0 | |
| return emb | |
| def forward(self, input, incremental_state=None, timestep=None, **kwargs): | |
| """Input is expected to be of size [bsz x seqlen].""" | |
| bsz, seq_len = input.shape[:2] | |
| max_pos = self.padding_idx + 1 + seq_len | |
| if self.weights is None or max_pos > self.weights.size(0): | |
| # recompute/expand embeddings if needed | |
| self.weights = SinusoidalPositionalEmbedding.get_embedding( | |
| max_pos, | |
| self.embedding_dim, | |
| self.padding_idx, | |
| ) | |
| self.weights = self.weights.to(self._float_tensor) | |
| if incremental_state is not None: | |
| # positions is the same for every token when decoding a single step | |
| pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len | |
| return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) | |
| positions = self.make_positions(input, self.padding_idx) | |
| return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() # (B, T, dim) | |
| def max_positions(self): | |
| """Maximum number of supported positions.""" | |
| return int(1e5) # an arbitrary large number | |
| def make_positions(self, tensor, padding_idx): | |
| """Replace non-padding symbols with their position numbers. | |
| Position numbers begin at padding_idx+1. Padding symbols are ignored. | |
| """ | |
| # The series of casts and type-conversions here are carefully | |
| # balanced to both work with ONNX export and XLA. In particular XLA | |
| # prefers ints, cumsum defaults to output longs, and ONNX doesn't know | |
| # how to handle the dtype kwarg in cumsum. | |
| mask = tensor.ne(padding_idx).int() | |
| return ( | |
| torch.cumsum(mask, dim=1).type_as(mask) * mask | |
| ).long() + padding_idx | |
| class DiTPrefix(nn.Module): | |
| """ | |
| Diffusion model with a Transformer backbone. | |
| """ | |
| def __init__( | |
| self, | |
| input_size, | |
| output_size, | |
| semantic_vocab_size, | |
| hidden_size=1024, | |
| depth=12, | |
| num_heads=4, | |
| # mlp related | |
| mlp_ratio=4.0, | |
| ffn_type="conv1d_conv1d", | |
| ffn_gated_glu=True, | |
| ffn_act_layer="gelu", | |
| ffn_conv_kernel_size=5, | |
| # rope | |
| use_rope=False, | |
| rope_params={ | |
| "max_position_embeddings": 4096, | |
| "rope_base": 10000.0, | |
| "rope_interpolation_factor": 1.0, | |
| }, | |
| position_embedding_type="sincos", | |
| max_seq_len=4096, | |
| prompt_cfg_dropout=0.0 | |
| ): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.prompt_cfg_dropout = prompt_cfg_dropout | |
| self.t_embedder = TimestepEmbedder(hidden_size) | |
| self.semantic_token_embedding = nn.Embedding(semantic_vocab_size, hidden_size) | |
| self.input_linear = nn.Linear(input_size, hidden_size) | |
| # position embedding | |
| if position_embedding_type == "learnable": | |
| self.position_embedding = nn.Embedding(max_seq_len+1, hidden_size) | |
| elif position_embedding_type == "sincos": | |
| self.position_embedding = SinusoidalPositionalEmbedding(hidden_size, 0, max_seq_len+1) | |
| elif position_embedding_type == "skip": | |
| self.position_embedding = None | |
| else: | |
| raise NotImplementedError("Position embedding type: {} not implemented.".format(position_embedding_type)) | |
| self.use_rope = use_rope | |
| if self.use_rope: | |
| assert hidden_size % num_heads == 0, "Hidden size must be divisible by num_heads for rope position embedding." | |
| rope_dim = hidden_size // num_heads | |
| self.rotary_pos_emb = precompute_freqs_cis( | |
| rope_dim, rope_params["max_position_embeddings"], | |
| theta=rope_params["rope_base"], | |
| interpolation_factor=rope_params["rope_interpolation_factor"], | |
| max_seq_length=max_seq_len | |
| ) | |
| self.blocks = nn.ModuleList([ | |
| DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, | |
| ffn_type=ffn_type, ffn_conv_kernel_size=ffn_conv_kernel_size, ffn_gated_glu=ffn_gated_glu, ffn_act_layer=ffn_act_layer) for _ in range(depth) | |
| ]) | |
| self.final_layer = FinalLayer(hidden_size, output_size) | |
| self.initialize_weights() | |
| def initialize_weights(self): | |
| # Initialize transformer layers: | |
| def _basic_init(module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias, 0) | |
| self.apply(_basic_init) | |
| # Initialize timestep embedding MLP: | |
| nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) | |
| nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) | |
| # Zero-out adaLN modulation layers in DiT blocks: | |
| for block in self.blocks: | |
| nn.init.constant_(block.adaLN_modulation[-1].weight, 0) | |
| nn.init.constant_(block.adaLN_modulation[-1].bias, 0) | |
| # Zero-out output layers: | |
| nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) | |
| nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) | |
| nn.init.constant_(self.final_layer.linear.weight, 0) | |
| nn.init.constant_(self.final_layer.linear.bias, 0) | |
| def forward(self, x, position_ids, t, condition, seq_len, cu_seqlens, cu_maxlen, cu_seqlens_k, cu_maxlen_k, mask, incremental_state=None, nopadding=True): | |
| """ | |
| Forward pass of DiT. | |
| x: (N, T, C) tensor of inputs (latent representations of speech) | |
| position_ids: (N, T) tensor of positional indices | |
| t: (N,) tensor of diffusion timesteps | |
| condition: (N, T) tensor of semantic tokens | |
| seq_len: (N,) tensor of sequence lengths | |
| """ | |
| condition = self.semantic_token_embedding(condition) # (N, T, D) | |
| x = self.input_linear(x) | |
| if self.position_embedding is not None: | |
| position_emb = self.position_embedding(position_ids) | |
| x = x + position_emb | |
| # ROPE | |
| if self.use_rope: | |
| bsz, seqlen = position_ids.shape | |
| if self.rotary_pos_emb.device != position_ids.device: | |
| self.rotary_pos_emb = self.rotary_pos_emb.to(position_ids.device) | |
| rotary_pos_emb = torch.zeros((bsz, seqlen, self.rotary_pos_emb.shape[1]), | |
| dtype=self.rotary_pos_emb.dtype, | |
| device=self.rotary_pos_emb.device) | |
| for b in range(bsz): | |
| cur_rope = rotary_pos_emb[b] | |
| cur_position_ids = position_ids[b] | |
| cur_rope[:] = self.rotary_pos_emb[cur_position_ids] | |
| else: | |
| rotary_pos_emb = None | |
| t = self.t_embedder(t) # (N, D) | |
| c = t.unsqueeze(1) + condition # (N, T, D) | |
| for block_idx, block in enumerate(self.blocks): | |
| # x = block(x, c, attn_mask) # (N, T, D) | |
| # XXX mask could be None because we always use full mask | |
| if incremental_state is not None: | |
| if block_idx not in incremental_state: | |
| incremental_state[block_idx] = {} | |
| incr = incremental_state[block_idx] | |
| else: | |
| incr = None | |
| x = block(x=x, c=c, seq_len=seq_len, cu_seqlens=cu_seqlens, cu_maxlen=cu_maxlen, cu_seqlens_k=cu_seqlens_k, cu_maxlen_k=cu_maxlen_k, mask=mask, rotary_pos_emb=rotary_pos_emb, incremental_state=incr, nopadding=nopadding) | |
| x = self.final_layer(x, c) # (N, T, C) | |
| return x | |