Spaces:
Running
on
Zero
Running
on
Zero
| # Code from: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/models/model.py | |
| from __future__ import annotations | |
| from typing import List, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import comfy.ldm.common_dit | |
| from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder | |
| from comfy.ldm.modules.attention import optimized_attention_masked | |
| from comfy.ldm.flux.layers import EmbedND | |
| def modulate(x, scale): | |
| return x * (1 + scale.unsqueeze(1)) | |
| ############################################################################# | |
| # Core NextDiT Model # | |
| ############################################################################# | |
| class JointAttention(nn.Module): | |
| """Multi-head attention module.""" | |
| def __init__( | |
| self, | |
| dim: int, | |
| n_heads: int, | |
| n_kv_heads: Optional[int], | |
| qk_norm: bool, | |
| operation_settings={}, | |
| ): | |
| """ | |
| Initialize the Attention module. | |
| Args: | |
| dim (int): Number of input dimensions. | |
| n_heads (int): Number of heads. | |
| n_kv_heads (Optional[int]): Number of kv heads, if using GQA. | |
| """ | |
| super().__init__() | |
| self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads | |
| self.n_local_heads = n_heads | |
| self.n_local_kv_heads = self.n_kv_heads | |
| self.n_rep = self.n_local_heads // self.n_local_kv_heads | |
| self.head_dim = dim // n_heads | |
| self.qkv = operation_settings.get("operations").Linear( | |
| dim, | |
| (n_heads + self.n_kv_heads + self.n_kv_heads) * self.head_dim, | |
| bias=False, | |
| device=operation_settings.get("device"), | |
| dtype=operation_settings.get("dtype"), | |
| ) | |
| self.out = operation_settings.get("operations").Linear( | |
| n_heads * self.head_dim, | |
| dim, | |
| bias=False, | |
| device=operation_settings.get("device"), | |
| dtype=operation_settings.get("dtype"), | |
| ) | |
| if qk_norm: | |
| self.q_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) | |
| self.k_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) | |
| else: | |
| self.q_norm = self.k_norm = nn.Identity() | |
| def apply_rotary_emb( | |
| x_in: torch.Tensor, | |
| freqs_cis: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Apply rotary embeddings to input tensors using the given frequency | |
| tensor. | |
| This function applies rotary embeddings to the given query 'xq' and | |
| key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The | |
| input tensors are reshaped as complex numbers, and the frequency tensor | |
| is reshaped for broadcasting compatibility. The resulting tensors | |
| contain rotary embeddings and are returned as real tensors. | |
| Args: | |
| x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings. | |
| freqs_cis (torch.Tensor): Precomputed frequency tensor for complex | |
| exponentials. | |
| Returns: | |
| Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor | |
| and key tensor with rotary embeddings. | |
| """ | |
| t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2) | |
| t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] | |
| return t_out.reshape(*x_in.shape) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| x_mask: torch.Tensor, | |
| freqs_cis: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: | |
| x_mask: | |
| freqs_cis: | |
| Returns: | |
| """ | |
| bsz, seqlen, _ = x.shape | |
| xq, xk, xv = torch.split( | |
| self.qkv(x), | |
| [ | |
| self.n_local_heads * self.head_dim, | |
| self.n_local_kv_heads * self.head_dim, | |
| self.n_local_kv_heads * self.head_dim, | |
| ], | |
| dim=-1, | |
| ) | |
| xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) | |
| xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) | |
| xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) | |
| xq = self.q_norm(xq) | |
| xk = self.k_norm(xk) | |
| xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis) | |
| xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis) | |
| n_rep = self.n_local_heads // self.n_local_kv_heads | |
| if n_rep >= 1: | |
| xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) | |
| xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) | |
| output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True) | |
| return self.out(output) | |
| class FeedForward(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| hidden_dim: int, | |
| multiple_of: int, | |
| ffn_dim_multiplier: Optional[float], | |
| operation_settings={}, | |
| ): | |
| """ | |
| Initialize the FeedForward module. | |
| Args: | |
| dim (int): Input dimension. | |
| hidden_dim (int): Hidden dimension of the feedforward layer. | |
| multiple_of (int): Value to ensure hidden dimension is a multiple | |
| of this value. | |
| ffn_dim_multiplier (float, optional): Custom multiplier for hidden | |
| dimension. Defaults to None. | |
| """ | |
| super().__init__() | |
| # custom dim factor multiplier | |
| if ffn_dim_multiplier is not None: | |
| hidden_dim = int(ffn_dim_multiplier * hidden_dim) | |
| hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) | |
| self.w1 = operation_settings.get("operations").Linear( | |
| dim, | |
| hidden_dim, | |
| bias=False, | |
| device=operation_settings.get("device"), | |
| dtype=operation_settings.get("dtype"), | |
| ) | |
| self.w2 = operation_settings.get("operations").Linear( | |
| hidden_dim, | |
| dim, | |
| bias=False, | |
| device=operation_settings.get("device"), | |
| dtype=operation_settings.get("dtype"), | |
| ) | |
| self.w3 = operation_settings.get("operations").Linear( | |
| dim, | |
| hidden_dim, | |
| bias=False, | |
| device=operation_settings.get("device"), | |
| dtype=operation_settings.get("dtype"), | |
| ) | |
| # @torch.compile | |
| def _forward_silu_gating(self, x1, x3): | |
| return F.silu(x1) * x3 | |
| def forward(self, x): | |
| return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) | |
| class JointTransformerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| layer_id: int, | |
| dim: int, | |
| n_heads: int, | |
| n_kv_heads: int, | |
| multiple_of: int, | |
| ffn_dim_multiplier: float, | |
| norm_eps: float, | |
| qk_norm: bool, | |
| modulation=True, | |
| operation_settings={}, | |
| ) -> None: | |
| """ | |
| Initialize a TransformerBlock. | |
| Args: | |
| layer_id (int): Identifier for the layer. | |
| dim (int): Embedding dimension of the input features. | |
| n_heads (int): Number of attention heads. | |
| n_kv_heads (Optional[int]): Number of attention heads in key and | |
| value features (if using GQA), or set to None for the same as | |
| query. | |
| multiple_of (int): | |
| ffn_dim_multiplier (float): | |
| norm_eps (float): | |
| """ | |
| super().__init__() | |
| self.dim = dim | |
| self.head_dim = dim // n_heads | |
| self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings) | |
| self.feed_forward = FeedForward( | |
| dim=dim, | |
| hidden_dim=4 * dim, | |
| multiple_of=multiple_of, | |
| ffn_dim_multiplier=ffn_dim_multiplier, | |
| operation_settings=operation_settings, | |
| ) | |
| self.layer_id = layer_id | |
| self.attention_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) | |
| self.ffn_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) | |
| self.attention_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) | |
| self.ffn_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) | |
| self.modulation = modulation | |
| if modulation: | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.SiLU(), | |
| operation_settings.get("operations").Linear( | |
| min(dim, 1024), | |
| 4 * dim, | |
| bias=True, | |
| device=operation_settings.get("device"), | |
| dtype=operation_settings.get("dtype"), | |
| ), | |
| ) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| x_mask: torch.Tensor, | |
| freqs_cis: torch.Tensor, | |
| adaln_input: Optional[torch.Tensor]=None, | |
| ): | |
| """ | |
| Perform a forward pass through the TransformerBlock. | |
| Args: | |
| x (torch.Tensor): Input tensor. | |
| freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. | |
| Returns: | |
| torch.Tensor: Output tensor after applying attention and | |
| feedforward layers. | |
| """ | |
| if self.modulation: | |
| assert adaln_input is not None | |
| scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) | |
| x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2( | |
| self.attention( | |
| modulate(self.attention_norm1(x), scale_msa), | |
| x_mask, | |
| freqs_cis, | |
| ) | |
| ) | |
| x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( | |
| self.feed_forward( | |
| modulate(self.ffn_norm1(x), scale_mlp), | |
| ) | |
| ) | |
| else: | |
| assert adaln_input is None | |
| x = x + self.attention_norm2( | |
| self.attention( | |
| self.attention_norm1(x), | |
| x_mask, | |
| freqs_cis, | |
| ) | |
| ) | |
| x = x + self.ffn_norm2( | |
| self.feed_forward( | |
| self.ffn_norm1(x), | |
| ) | |
| ) | |
| return x | |
| class FinalLayer(nn.Module): | |
| """ | |
| The final layer of NextDiT. | |
| """ | |
| def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}): | |
| super().__init__() | |
| self.norm_final = operation_settings.get("operations").LayerNorm( | |
| hidden_size, | |
| elementwise_affine=False, | |
| eps=1e-6, | |
| device=operation_settings.get("device"), | |
| dtype=operation_settings.get("dtype"), | |
| ) | |
| self.linear = operation_settings.get("operations").Linear( | |
| hidden_size, | |
| patch_size * patch_size * out_channels, | |
| bias=True, | |
| device=operation_settings.get("device"), | |
| dtype=operation_settings.get("dtype"), | |
| ) | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.SiLU(), | |
| operation_settings.get("operations").Linear( | |
| min(hidden_size, 1024), | |
| hidden_size, | |
| bias=True, | |
| device=operation_settings.get("device"), | |
| dtype=operation_settings.get("dtype"), | |
| ), | |
| ) | |
| def forward(self, x, c): | |
| scale = self.adaLN_modulation(c) | |
| x = modulate(self.norm_final(x), scale) | |
| x = self.linear(x) | |
| return x | |
| class NextDiT(nn.Module): | |
| """ | |
| Diffusion model with a Transformer backbone. | |
| """ | |
| def __init__( | |
| self, | |
| patch_size: int = 2, | |
| in_channels: int = 4, | |
| dim: int = 4096, | |
| n_layers: int = 32, | |
| n_refiner_layers: int = 2, | |
| n_heads: int = 32, | |
| n_kv_heads: Optional[int] = None, | |
| multiple_of: int = 256, | |
| ffn_dim_multiplier: Optional[float] = None, | |
| norm_eps: float = 1e-5, | |
| qk_norm: bool = False, | |
| cap_feat_dim: int = 5120, | |
| axes_dims: List[int] = (16, 56, 56), | |
| axes_lens: List[int] = (1, 512, 512), | |
| image_model=None, | |
| device=None, | |
| dtype=None, | |
| operations=None, | |
| ) -> None: | |
| super().__init__() | |
| self.dtype = dtype | |
| operation_settings = {"operations": operations, "device": device, "dtype": dtype} | |
| self.in_channels = in_channels | |
| self.out_channels = in_channels | |
| self.patch_size = patch_size | |
| self.x_embedder = operation_settings.get("operations").Linear( | |
| in_features=patch_size * patch_size * in_channels, | |
| out_features=dim, | |
| bias=True, | |
| device=operation_settings.get("device"), | |
| dtype=operation_settings.get("dtype"), | |
| ) | |
| self.noise_refiner = nn.ModuleList( | |
| [ | |
| JointTransformerBlock( | |
| layer_id, | |
| dim, | |
| n_heads, | |
| n_kv_heads, | |
| multiple_of, | |
| ffn_dim_multiplier, | |
| norm_eps, | |
| qk_norm, | |
| modulation=True, | |
| operation_settings=operation_settings, | |
| ) | |
| for layer_id in range(n_refiner_layers) | |
| ] | |
| ) | |
| self.context_refiner = nn.ModuleList( | |
| [ | |
| JointTransformerBlock( | |
| layer_id, | |
| dim, | |
| n_heads, | |
| n_kv_heads, | |
| multiple_of, | |
| ffn_dim_multiplier, | |
| norm_eps, | |
| qk_norm, | |
| modulation=False, | |
| operation_settings=operation_settings, | |
| ) | |
| for layer_id in range(n_refiner_layers) | |
| ] | |
| ) | |
| self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings) | |
| self.cap_embedder = nn.Sequential( | |
| operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), | |
| operation_settings.get("operations").Linear( | |
| cap_feat_dim, | |
| dim, | |
| bias=True, | |
| device=operation_settings.get("device"), | |
| dtype=operation_settings.get("dtype"), | |
| ), | |
| ) | |
| self.layers = nn.ModuleList( | |
| [ | |
| JointTransformerBlock( | |
| layer_id, | |
| dim, | |
| n_heads, | |
| n_kv_heads, | |
| multiple_of, | |
| ffn_dim_multiplier, | |
| norm_eps, | |
| qk_norm, | |
| operation_settings=operation_settings, | |
| ) | |
| for layer_id in range(n_layers) | |
| ] | |
| ) | |
| self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) | |
| self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings) | |
| assert (dim // n_heads) == sum(axes_dims) | |
| self.axes_dims = axes_dims | |
| self.axes_lens = axes_lens | |
| self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims) | |
| self.dim = dim | |
| self.n_heads = n_heads | |
| def unpatchify( | |
| self, x: torch.Tensor, img_size: List[Tuple[int, int]], cap_size: List[int], return_tensor=False | |
| ) -> List[torch.Tensor]: | |
| """ | |
| x: (N, T, patch_size**2 * C) | |
| imgs: (N, H, W, C) | |
| """ | |
| pH = pW = self.patch_size | |
| imgs = [] | |
| for i in range(x.size(0)): | |
| H, W = img_size[i] | |
| begin = cap_size[i] | |
| end = begin + (H // pH) * (W // pW) | |
| imgs.append( | |
| x[i][begin:end] | |
| .view(H // pH, W // pW, pH, pW, self.out_channels) | |
| .permute(4, 0, 2, 1, 3) | |
| .flatten(3, 4) | |
| .flatten(1, 2) | |
| ) | |
| if return_tensor: | |
| imgs = torch.stack(imgs, dim=0) | |
| return imgs | |
| def patchify_and_embed( | |
| self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens | |
| ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]: | |
| bsz = len(x) | |
| pH = pW = self.patch_size | |
| device = x[0].device | |
| dtype = x[0].dtype | |
| if cap_mask is not None: | |
| l_effective_cap_len = cap_mask.sum(dim=1).tolist() | |
| else: | |
| l_effective_cap_len = [num_tokens] * bsz | |
| if cap_mask is not None and not torch.is_floating_point(cap_mask): | |
| cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max | |
| img_sizes = [(img.size(1), img.size(2)) for img in x] | |
| l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] | |
| max_seq_len = max( | |
| (cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)) | |
| ) | |
| max_cap_len = max(l_effective_cap_len) | |
| max_img_len = max(l_effective_img_len) | |
| position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device) | |
| for i in range(bsz): | |
| cap_len = l_effective_cap_len[i] | |
| img_len = l_effective_img_len[i] | |
| H, W = img_sizes[i] | |
| H_tokens, W_tokens = H // pH, W // pW | |
| assert H_tokens * W_tokens == img_len | |
| position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) | |
| position_ids[i, cap_len:cap_len+img_len, 0] = cap_len | |
| row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() | |
| col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() | |
| position_ids[i, cap_len:cap_len+img_len, 1] = row_ids | |
| position_ids[i, cap_len:cap_len+img_len, 2] = col_ids | |
| freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype) | |
| # build freqs_cis for cap and image individually | |
| cap_freqs_cis_shape = list(freqs_cis.shape) | |
| # cap_freqs_cis_shape[1] = max_cap_len | |
| cap_freqs_cis_shape[1] = cap_feats.shape[1] | |
| cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) | |
| img_freqs_cis_shape = list(freqs_cis.shape) | |
| img_freqs_cis_shape[1] = max_img_len | |
| img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) | |
| for i in range(bsz): | |
| cap_len = l_effective_cap_len[i] | |
| img_len = l_effective_img_len[i] | |
| cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] | |
| img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len] | |
| # refine context | |
| for layer in self.context_refiner: | |
| cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) | |
| # refine image | |
| flat_x = [] | |
| for i in range(bsz): | |
| img = x[i] | |
| C, H, W = img.size() | |
| img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1) | |
| flat_x.append(img) | |
| x = flat_x | |
| padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype) | |
| padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device) | |
| for i in range(bsz): | |
| padded_img_embed[i, :l_effective_img_len[i]] = x[i] | |
| padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max | |
| padded_img_embed = self.x_embedder(padded_img_embed) | |
| padded_img_mask = padded_img_mask.unsqueeze(1) | |
| for layer in self.noise_refiner: | |
| padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t) | |
| if cap_mask is not None: | |
| mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device) | |
| mask[:, :max_cap_len] = cap_mask[:, :max_cap_len] | |
| else: | |
| mask = None | |
| padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype) | |
| for i in range(bsz): | |
| cap_len = l_effective_cap_len[i] | |
| img_len = l_effective_img_len[i] | |
| padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len] | |
| padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len] | |
| return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis | |
| # def forward(self, x, t, cap_feats, cap_mask): | |
| def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): | |
| t = 1.0 - timesteps | |
| cap_feats = context | |
| cap_mask = attention_mask | |
| bs, c, h, w = x.shape | |
| x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) | |
| """ | |
| Forward pass of NextDiT. | |
| t: (N,) tensor of diffusion timesteps | |
| y: (N,) tensor of text tokens/features | |
| """ | |
| t = self.t_embedder(t, dtype=x.dtype) # (N, D) | |
| adaln_input = t | |
| cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute | |
| x_is_tensor = isinstance(x, torch.Tensor) | |
| x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens) | |
| freqs_cis = freqs_cis.to(x.device) | |
| for layer in self.layers: | |
| x = layer(x, mask, freqs_cis, adaln_input) | |
| x = self.final_layer(x, adaln_input) | |
| x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w] | |
| return -x | |