Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from flash_attn import flash_attn_varlen_func, flash_attn_varlen_qkvpacked_func | |
| def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): | |
| # x shape: bsz, seqlen, self.n_local_heads, self.head_hidden_dim / 2 | |
| # the last shape is "self.hidden_dim / 2" because we convert to complex | |
| assert x.ndim == 4 | |
| assert freqs_cis.shape == (x.shape[0], x.shape[1], x.shape[-1]), \ | |
| f'x shape: {x.shape}, freqs_cis shape: {freqs_cis.shape}' | |
| # reshape freq cis to match and apply pointwise multiply | |
| # new shape: bsz, seq_len, 1, self.head_hidden_dim / 2 | |
| shape = [x.shape[0], x.shape[1], 1, x.shape[-1]] | |
| return freqs_cis.view(*shape) | |
| def apply_rotary_emb( | |
| xq: torch.Tensor, | |
| xk: torch.Tensor, | |
| freqs_cis: torch.Tensor, | |
| ): | |
| xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
| xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
| freqs_cis = reshape_for_broadcast(freqs_cis, xq_) | |
| xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
| xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| return xq_out.type_as(xq), xk_out.type_as(xk) | |
| class Attention(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int = 8, | |
| qkv_bias: bool = False, | |
| qk_norm: bool = False, | |
| attn_drop: float = 0., | |
| proj_drop: float = 0., | |
| norm_layer: nn.Module = nn.LayerNorm, | |
| flash_attention: bool = True | |
| ) -> None: | |
| super().__init__() | |
| assert dim % num_heads == 0, 'dim should be divisible by num_heads' | |
| self.num_heads = num_heads | |
| self.head_dim = dim // num_heads | |
| self.scale = self.head_dim ** -0.5 | |
| self.fused_attn = flash_attention | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| self.qk_norm = qk_norm | |
| self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() | |
| self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| def forward(self, x: torch.Tensor, seq_len, cu_seqlens, max_seqlen, cu_seqlens_k, max_seqlen_k, rotary_pos_emb=None, incremental_state=None, nopadding=True) -> torch.Tensor: | |
| B, N, C = x.shape | |
| if self.fused_attn: | |
| if nopadding: | |
| qkv = self.qkv(x) | |
| qkv = qkv.view(B * N, self.num_heads * 3, self.head_dim) | |
| q, k, v = qkv.split([self.num_heads] * 3, dim=1) | |
| q, k = self.q_norm(q), self.k_norm(k) | |
| q = q.view(B, N, self.num_heads, self.head_dim) | |
| k = k.view(B, N, self.num_heads, self.head_dim) | |
| v = v.view(B, N, self.num_heads, self.head_dim) | |
| if rotary_pos_emb is not None: | |
| q, k = apply_rotary_emb(q, k, rotary_pos_emb) | |
| if incremental_state is not None: | |
| if "prev_k" in incremental_state: | |
| prev_k = incremental_state["prev_k"] | |
| k = torch.cat([prev_k, k], dim=1) | |
| if "cur_k" not in incremental_state: | |
| incremental_state["cur_k"] = {} | |
| incremental_state["cur_k"] = k | |
| if "prev_v" in incremental_state: | |
| prev_v = incremental_state["prev_v"] | |
| v = torch.cat([prev_v, v], dim=1) | |
| if "cur_v" not in incremental_state: | |
| incremental_state["cur_v"] = {} | |
| incremental_state["cur_v"] = v | |
| q = q.view(B * N, self.num_heads, self.head_dim) | |
| k = k.view(-1, self.num_heads, self.head_dim) | |
| v = v.view(-1, self.num_heads, self.head_dim) | |
| x = flash_attn_varlen_func( | |
| q=q, | |
| k=k, | |
| v=v, | |
| cu_seqlens_q=cu_seqlens, | |
| cu_seqlens_k=cu_seqlens_k, | |
| max_seqlen_q=max_seqlen, | |
| max_seqlen_k=max_seqlen_k, | |
| dropout_p=self.attn_drop.p if self.training else 0., | |
| ) | |
| else: | |
| if incremental_state is not None: | |
| raise NotImplementedError("It is designed for batching inference. AR-chunk is not supported currently.") | |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) | |
| if self.qk_norm: | |
| q, k, v = qkv.unbind(2) | |
| q, k = self.q_norm(q), self.k_norm(k) | |
| # re-bind | |
| qkv = torch.stack((q, k, v), dim=2) | |
| # pack qkv with seq_len | |
| qkv_collect = [] | |
| for i in range(qkv.shape[0]): | |
| qkv_collect.append( | |
| qkv[i, :seq_len[i], :, :, :] | |
| ) | |
| qkv = torch.cat(qkv_collect, dim=0) | |
| x = flash_attn_varlen_qkvpacked_func(qkv=qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=self.attn_drop.p if self.training else 0.) | |
| # unpack and pad 0 | |
| x_collect = [] | |
| for i in range(B): | |
| x_collect.append( | |
| x[cu_seqlens[i]:cu_seqlens[i+1], :, :] | |
| ) | |
| x = torch.nn.utils.rnn.pad_sequence(x_collect, batch_first=True, padding_value=0) | |
| else: | |
| q = q * self.scale | |
| attn = q @ k.transpose(-2, -1) | |
| attn = attn.softmax(dim=-1) | |
| attn = self.attn_drop(attn) | |
| x = attn @ v | |
| x = x.transpose(1, 2) | |
| x = x.reshape(B, N, C) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| def modulate(x, shift, scale): | |
| return x * (1 + scale) + shift | |
| class FinalLayer(nn.Module): | |
| """ | |
| The final layer of DiT. | |
| """ | |
| def __init__(self, hidden_size, out_channels): | |
| super().__init__() | |
| self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| self.linear = nn.Linear(hidden_size, out_channels, bias=True) | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(hidden_size, 2 * hidden_size, bias=True) | |
| ) | |
| def forward(self, x, c): | |
| shift, scale = self.adaLN_modulation(c).chunk(2, dim=2) | |
| x = modulate(self.norm_final(x), shift, scale) | |
| x = self.linear(x) | |
| return x | |
| class DiTBlock(nn.Module): | |
| """ | |
| A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. | |
| """ | |
| def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, ffn_type="conv1d_conv1d", ffn_gated_glu=True, ffn_act_layer="gelu", ffn_conv_kernel_size=5, **block_kwargs): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) | |
| self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| if ffn_type == "vanilla_mlp": | |
| from timm.models.vision_transformer import Mlp | |
| mlp_hidden_dim = int(hidden_size * mlp_ratio) | |
| approx_gelu = lambda: nn.GELU(approximate="tanh") | |
| self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) | |
| else: | |
| raise NotImplementedError(f"FFN type {ffn_type} is not implemented") | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(hidden_size, 6 * hidden_size, bias=True) | |
| ) | |
| def forward(self, x, c, seq_len, cu_seqlens, cu_maxlen, cu_seqlens_k, cu_maxlen_k, mask, rotary_pos_emb=None, incremental_state=None, nopadding=True): | |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=2) | |
| x_ = modulate(self.norm1(x), shift_msa, scale_msa) | |
| if incremental_state is not None: | |
| if "attn_kvcache" not in incremental_state: | |
| incremental_state["attn_kvcache"] = {} | |
| inc_attn = incremental_state["attn_kvcache"] | |
| else: | |
| inc_attn = None | |
| x_ = self.attn(x_, seq_len=seq_len, cu_seqlens=cu_seqlens, max_seqlen=cu_maxlen, cu_seqlens_k=cu_seqlens_k, max_seqlen_k=cu_maxlen_k, rotary_pos_emb=rotary_pos_emb, incremental_state=inc_attn, nopadding=nopadding) | |
| if not nopadding: | |
| x_ = x_ * mask[:, :, None] | |
| x = x + gate_msa * x_ | |
| x_ = modulate(self.norm2(x), shift_mlp, scale_mlp) | |
| x_ = self.mlp(x_) | |
| if not nopadding: | |
| x_ = x_ * mask[:, :, None] | |
| x = x + gate_mlp * x_ | |
| return x | |