Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| from inspect import isfunction | |
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn, einsum | |
| from einops import rearrange, repeat | |
| from pdb import set_trace as st | |
| from ldm.modules.attention import MemoryEfficientCrossAttention | |
| from .dit_models_xformers import * | |
| # from apex.normalization import FusedLayerNorm as LayerNorm | |
| from torch.nn import LayerNorm | |
| # from apex.normalization import FusedRMSNorm as RMSNorm | |
| try: | |
| from apex.normalization import FusedRMSNorm as RMSNorm | |
| except: | |
| from dit.norm import RMSNorm | |
| from timm.models.vision_transformer import Mlp | |
| from vit.vit_triplane import XYZPosEmbed | |
| from .dit_trilatent import DiT_PCD_PixelArt | |
| class DiT_I23D(DiT): | |
| # DiT with 3D_aware operations | |
| def __init__( | |
| self, | |
| input_size=32, | |
| patch_size=2, | |
| in_channels=4, | |
| hidden_size=1152, | |
| depth=28, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| class_dropout_prob=0.1, | |
| num_classes=1000, | |
| learn_sigma=True, | |
| mixing_logit_init=-3, | |
| mixed_prediction=True, | |
| context_dim=False, | |
| pooling_ctx_dim=768, | |
| roll_out=False, | |
| vit_blk=ImageCondDiTBlock, | |
| final_layer_blk=T2IFinalLayer, | |
| enable_rope=False, | |
| ): | |
| # st() | |
| super().__init__(input_size, patch_size, in_channels, hidden_size, | |
| depth, num_heads, mlp_ratio, class_dropout_prob, | |
| num_classes, learn_sigma, mixing_logit_init, | |
| mixed_prediction, context_dim, roll_out, vit_blk, | |
| T2IFinalLayer, enable_rope=enable_rope) | |
| assert self.roll_out | |
| # if context_dim is not None: | |
| # self.dino_proj = CaptionEmbedder(context_dim, | |
| self.clip_ctx_dim = 1024 # vit-l | |
| # self.dino_proj = CaptionEmbedder(self.clip_ctx_dim, # ! dino-vitl/14 here, for img-cond | |
| self.dino_proj = CaptionEmbedder(context_dim, # ! dino-vitb/14 here, for MV-cond. hard coded for now... | |
| # self.dino_proj = CaptionEmbedder(1024, # ! dino-vitb/14 here, for MV-cond. hard coded for now... | |
| hidden_size, | |
| act_layer=approx_gelu) | |
| self.clip_spatial_proj = CaptionEmbedder(1024, # clip_I-L | |
| hidden_size, | |
| act_layer=approx_gelu) | |
| def init_PE_3D_aware(self): | |
| self.pos_embed = nn.Parameter(torch.zeros( | |
| 1, self.plane_n * self.x_embedder.num_patches, self.embed_dim), | |
| requires_grad=False) | |
| # Initialize (and freeze) pos_embed by sin-cos embedding: | |
| p = int(self.x_embedder.num_patches**0.5) | |
| D = self.pos_embed.shape[-1] | |
| grid_size = (self.plane_n, p * p) # B n HW C | |
| pos_embed = get_2d_sincos_pos_embed(D, grid_size).reshape( | |
| self.plane_n * p * p, D) # H*W, D | |
| self.pos_embed.data.copy_( | |
| torch.from_numpy(pos_embed).float().unsqueeze(0)) | |
| def initialize_weights(self): | |
| super().initialize_weights() | |
| # ! add 3d-aware PE | |
| self.init_PE_3D_aware() | |
| def forward(self, | |
| x, | |
| timesteps=None, | |
| context=None, | |
| y=None, | |
| get_attr='', | |
| **kwargs): | |
| """ | |
| Forward pass of DiT. | |
| x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) | |
| t: (N,) tensor of diffusion timesteps | |
| y: (N,) tensor of class labels | |
| """ | |
| # t = timesteps | |
| assert isinstance(context, dict) | |
| # context = self.clip_text_proj(context) | |
| clip_cls_token = self.clip_text_proj(context['vector']) | |
| clip_spatial_token, dino_spatial_token = context['crossattn'][..., :self.clip_ctx_dim], self.dino_proj(context['crossattn'][..., self.clip_ctx_dim:]) | |
| t = self.t_embedder(timesteps) + clip_cls_token # (N, D) | |
| # ! todo, return spatial clip features. | |
| # if self.roll_out: # ! | |
| x = rearrange(x, 'b (c n) h w->(b n) c h w', | |
| n=3) # downsample with same conv | |
| x = self.x_embedder(x) # (b n) c h/f w/f | |
| x = rearrange(x, '(b n) l c -> b (n l) c', n=3) | |
| x = x + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 | |
| # if self.roll_out: # ! roll-out in the L dim, not B dim. add condition to all tokens. | |
| # x = rearrange(x, '(b n) l c ->b (n l) c', n=3) | |
| # assert context.ndim == 2 | |
| # if isinstance(context, dict): | |
| # context = context['crossattn'] # sgm conditioner compat | |
| # c = t + context | |
| # else: | |
| # c = t # BS 1024 | |
| for blk_idx, block in enumerate(self.blocks): | |
| x = block(x, t, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token) # (N, T, D) | |
| # todo later | |
| x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) | |
| if self.roll_out: # move n from L to B axis | |
| x = rearrange(x, 'b (n l) c ->(b n) l c', n=3) | |
| x = self.unpatchify(x) # (N, out_channels, H, W) | |
| if self.roll_out: # move n from L to B axis | |
| x = rearrange(x, '(b n) c h w -> b (c n) h w', n=3) | |
| # x = rearrange(x, 'b n) c h w -> b (n c) h w', n=3) | |
| # cast to float32 for better accuracy | |
| x = x.to(torch.float32).contiguous() | |
| return x | |
| # ! compat issue | |
| def forward_with_cfg(self, x, t, context, cfg_scale): | |
| """ | |
| Forward pass of SiT, but also batches the unconSiTional forward pass for classifier-free guidance. | |
| """ | |
| # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb | |
| # half = x[: len(x) // 2] | |
| # combined = torch.cat([half, half], dim=0) | |
| eps = self.forward(x, t, context) | |
| # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] | |
| # eps, rest = model_out[:, :3], model_out[:, 3:] | |
| cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) | |
| half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) | |
| eps = torch.cat([half_eps, half_eps], dim=0) | |
| return eps | |
| class DiT_I23D_PixelArt(DiT_I23D): | |
| def __init__( | |
| self, | |
| input_size=32, | |
| patch_size=2, | |
| in_channels=4, | |
| hidden_size=1152, | |
| depth=28, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| class_dropout_prob=0.1, | |
| num_classes=1000, | |
| learn_sigma=True, | |
| mixing_logit_init=-3, | |
| mixed_prediction=True, | |
| context_dim=False, | |
| pooling_ctx_dim=768, | |
| roll_out=False, | |
| vit_blk=ImageCondDiTBlockPixelArtRMSNorm, | |
| final_layer_blk=FinalLayer, | |
| create_cap_embedder=True, | |
| enable_rope=False, | |
| ): | |
| super().__init__(input_size, patch_size, in_channels, hidden_size, | |
| depth, num_heads, mlp_ratio, class_dropout_prob, | |
| num_classes, learn_sigma, mixing_logit_init, | |
| # mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt, | |
| mixed_prediction, context_dim, pooling_ctx_dim, roll_out, vit_blk, | |
| final_layer_blk, | |
| enable_rope=enable_rope) | |
| # ! a shared one | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) | |
| # ! single | |
| nn.init.constant_(self.adaLN_modulation[-1].weight, 0) | |
| nn.init.constant_(self.adaLN_modulation[-1].bias, 0) | |
| del self.clip_text_proj | |
| if create_cap_embedder: | |
| self.cap_embedder = nn.Sequential( # TODO, init with zero here. | |
| LayerNorm(pooling_ctx_dim), | |
| nn.Linear( | |
| pooling_ctx_dim, | |
| hidden_size, | |
| ), | |
| ) | |
| nn.init.constant_(self.cap_embedder[-1].weight, 0) | |
| nn.init.constant_(self.cap_embedder[-1].bias, 0) | |
| else: | |
| self.cap_embedder = nn.Identity() # placeholder | |
| print(self) # check model arch | |
| self.attention_y_norm = RMSNorm( | |
| 1024, eps=1e-5 | |
| ) # https://github.com/Alpha-VLLM/Lumina-T2X/blob/0c8dd6a07a3b7c18da3d91f37b1e00e7ae661293/lumina_t2i/models/model.py#L570C9-L570C61 | |
| def forward(self, | |
| x, | |
| timesteps=None, | |
| context=None, | |
| y=None, | |
| get_attr='', | |
| **kwargs): | |
| """ | |
| Forward pass of DiT. | |
| x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) | |
| t: (N,) tensor of diffusion timesteps | |
| y: (N,) tensor of class labels | |
| """ | |
| # t = timesteps | |
| assert isinstance(context, dict) | |
| # context = self.clip_text_proj(context) | |
| clip_cls_token = self.cap_embedder(context['vector']) | |
| clip_spatial_token, dino_spatial_token = context['crossattn'][..., :self.clip_ctx_dim], self.dino_proj(context['crossattn'][..., self.clip_ctx_dim:]) | |
| clip_spatial_token = self.attention_y_norm(clip_spatial_token) # avoid re-normalization in each blk | |
| t = self.t_embedder(timesteps) + clip_cls_token # (N, D) | |
| t0 = self.adaLN_modulation(t) # single-adaLN, B 6144 | |
| # if self.roll_out: # ! | |
| x = rearrange(x, 'b (c n) h w->(b n) c h w', | |
| n=3) # downsample with same conv | |
| x = self.x_embedder(x) # (b n) c h/f w/f | |
| x = rearrange(x, '(b n) l c -> b (n l) c', n=3) | |
| x = x + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 | |
| # if self.roll_out: # ! roll-out in the L dim, not B dim. add condition to all tokens. | |
| # x = rearrange(x, '(b n) l c ->b (n l) c', n=3) | |
| # assert context.ndim == 2 | |
| # if isinstance(context, dict): | |
| # context = context['crossattn'] # sgm conditioner compat | |
| # c = t + context | |
| # else: | |
| # c = t # BS 1024 | |
| for blk_idx, block in enumerate(self.blocks): | |
| x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token) # (N, T, D) | |
| # todo later | |
| x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) | |
| if self.roll_out: # move n from L to B axis | |
| x = rearrange(x, 'b (n l) c ->(b n) l c', n=3) | |
| x = self.unpatchify(x) # (N, out_channels, H, W) | |
| if self.roll_out: # move n from L to B axis | |
| x = rearrange(x, '(b n) c h w -> b (c n) h w', n=3) | |
| # x = rearrange(x, 'b n) c h w -> b (n c) h w', n=3) | |
| # cast to float32 for better accuracy | |
| x = x.to(torch.float32).contiguous() | |
| return x | |
| class DiT_I23D_PCD_PixelArt(DiT_I23D_PixelArt): | |
| def __init__( | |
| self, | |
| input_size=32, | |
| patch_size=2, | |
| in_channels=4, | |
| hidden_size=1152, | |
| depth=28, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| class_dropout_prob=0.1, | |
| num_classes=1000, | |
| learn_sigma=True, | |
| mixing_logit_init=-3, | |
| mixed_prediction=True, | |
| context_dim=False, | |
| pooling_ctx_dim=768, | |
| roll_out=False, | |
| vit_blk=ImageCondDiTBlockPixelArtRMSNorm, | |
| final_layer_blk=FinalLayer, | |
| create_cap_embedder=True, | |
| use_clay_ca=False, | |
| ): | |
| super().__init__(input_size, patch_size, in_channels, hidden_size, | |
| depth, num_heads, mlp_ratio, class_dropout_prob, | |
| num_classes, learn_sigma, mixing_logit_init, | |
| # mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt, | |
| mixed_prediction, context_dim, pooling_ctx_dim, roll_out, vit_blk, | |
| final_layer_blk) | |
| self.x_embedder = Mlp(in_features=in_channels, | |
| hidden_features=hidden_size, | |
| out_features=hidden_size, | |
| act_layer=approx_gelu, | |
| drop=0) | |
| del self.pos_embed | |
| self.use_clay_ca = use_clay_ca | |
| if use_clay_ca: | |
| del self.dino_proj # no prepending required. | |
| # add ln_pred and ln_post, as in point-e. (does not help, worse performance) | |
| # self.ln_pre = LayerNorm(hidden_size) | |
| # self.ln_post = LayerNorm(hidden_size) | |
| def precompute_freqs_cis( | |
| dim: int, | |
| end: int, | |
| theta: float = 10000.0, | |
| rope_scaling_factor: float = 1.0, | |
| ntk_factor: float = 1.0, | |
| ): | |
| """ | |
| Precompute the frequency tensor for complex exponentials (cis) with | |
| given dimensions. | |
| This function calculates a frequency tensor with complex exponentials | |
| using the given dimension 'dim' and the end index 'end'. The 'theta' | |
| parameter scales the frequencies. The returned tensor contains complex | |
| values in complex64 data type. | |
| Args: | |
| dim (int): Dimension of the frequency tensor. | |
| end (int): End index for precomputing frequencies. | |
| theta (float, optional): Scaling factor for frequency computation. | |
| Defaults to 10000.0. | |
| Returns: | |
| torch.Tensor: Precomputed frequency tensor with complex | |
| exponentials. | |
| """ | |
| theta = theta * ntk_factor | |
| print(f"theta {theta} rope scaling {rope_scaling_factor} ntk {ntk_factor}") | |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float().cuda() / dim)) | |
| t = torch.arange(end, device=freqs.device, dtype=torch.float) # type: ignore | |
| t = t / rope_scaling_factor | |
| freqs = torch.outer(t, freqs).float() # type: ignore | |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 | |
| return freqs_cis | |
| def forward(self, | |
| x, | |
| timesteps=None, | |
| context=None, | |
| y=None, | |
| get_attr='', | |
| **kwargs): | |
| """ | |
| Forward pass of DiT. | |
| x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) | |
| t: (N,) tensor of diffusion timesteps | |
| y: (N,) tensor of class labels | |
| """ | |
| # t = timesteps | |
| assert isinstance(context, dict) | |
| # dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features. | |
| # t = self.t_embedder(timesteps) | |
| # global condition | |
| if 'caption_vector' in context: | |
| clip_cls_token = self.cap_embedder(context['caption_vector']) | |
| elif 'img_vector' in context: | |
| clip_cls_token = self.cap_embedder(context['img_vector']) | |
| else: | |
| clip_cls_token = 0 | |
| # spatial condition | |
| clip_spatial_token, dino_spatial_token = context['img_crossattn'][..., :self.clip_ctx_dim], context['img_crossattn'][..., self.clip_ctx_dim:] | |
| if not self.use_clay_ca: | |
| dino_spatial_token=self.dino_proj(dino_spatial_token) | |
| t = self.t_embedder(timesteps) + clip_cls_token # (N, D) | |
| t0 = self.adaLN_modulation(t) # single-adaLN, B 6144 | |
| x = self.x_embedder(x) | |
| # add a norm layer here, as in point-e | |
| # x = self.ln_pre(x) | |
| for blk_idx, block in enumerate(self.blocks): | |
| x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token, clip_caption_token=context.get('caption_crossattn')) | |
| # add a norm layer here, as in point-e | |
| # x = self.ln_post(x) | |
| # todo later | |
| x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) | |
| x = x.to(torch.float32).contiguous() | |
| return x | |
| # dino only version | |
| class DiT_I23D_PCD_PixelArt_noclip(DiT_I23D_PixelArt): | |
| def __init__( | |
| self, | |
| input_size=32, | |
| patch_size=2, | |
| in_channels=4, | |
| hidden_size=1152, | |
| depth=28, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| class_dropout_prob=0.1, | |
| num_classes=1000, | |
| learn_sigma=True, | |
| mixing_logit_init=-3, | |
| mixed_prediction=True, | |
| context_dim=False, | |
| pooling_ctx_dim=768, | |
| roll_out=False, | |
| vit_blk=ImageCondDiTBlockPixelArtRMSNormNoClip, | |
| final_layer_blk=FinalLayer, | |
| create_cap_embedder=True, | |
| use_clay_ca=False, | |
| has_caption=False, | |
| # has_rope=False, | |
| rope_scaling_factor: float = 1.0, | |
| ntk_factor: float = 1.0, | |
| enable_rope=False, | |
| ): | |
| super().__init__(input_size, patch_size, in_channels, hidden_size, | |
| depth, num_heads, mlp_ratio, class_dropout_prob, | |
| num_classes, learn_sigma, mixing_logit_init, | |
| # mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt, | |
| mixed_prediction, context_dim, pooling_ctx_dim, roll_out, vit_blk, | |
| final_layer_blk, enable_rope=enable_rope) | |
| self.x_embedder = Mlp(in_features=in_channels, | |
| hidden_features=hidden_size, | |
| out_features=hidden_size, | |
| act_layer=approx_gelu, | |
| drop=0) | |
| del self.pos_embed | |
| del self.dino_proj | |
| self.enable_rope = enable_rope | |
| if self.enable_rope: # implementation copied from Lumina-T2X code base | |
| self.freqs_cis = DiT_I23D_PCD_PixelArt.precompute_freqs_cis( | |
| hidden_size // num_heads, | |
| 40000, | |
| rope_scaling_factor=rope_scaling_factor, | |
| ntk_factor=ntk_factor, | |
| ) | |
| else: | |
| self.freqs_cis = None | |
| self.rope_scaling_factor = rope_scaling_factor | |
| self.ntk_factor = ntk_factor | |
| self.use_clay_ca = use_clay_ca | |
| self.has_caption = has_caption | |
| pooled_vector_dim = context_dim | |
| if has_caption: | |
| pooled_vector_dim += 768 | |
| self.pooled_vec_embedder = nn.Sequential( # TODO, init with zero here. | |
| LayerNorm(pooled_vector_dim), | |
| nn.Linear( | |
| pooled_vector_dim, | |
| hidden_size, | |
| ), | |
| ) | |
| nn.init.constant_(self.pooled_vec_embedder[-1].weight, 0) | |
| nn.init.constant_(self.pooled_vec_embedder[-1].bias, 0) | |
| def forward(self, | |
| x, | |
| timesteps=None, | |
| context=None, | |
| y=None, | |
| get_attr='', | |
| **kwargs): | |
| """ | |
| Forward pass of DiT. | |
| x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) | |
| t: (N,) tensor of diffusion timesteps | |
| y: (N,) tensor of class labels | |
| """ | |
| # t = timesteps | |
| assert isinstance(context, dict) | |
| # dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features. | |
| # t = self.t_embedder(timesteps) | |
| # clip_cls_token = self.cap_embedder(context['vector']) | |
| # clip_spatial_token, dino_spatial_token = context['crossattn'][..., :self.clip_ctx_dim], self.dino_proj(context['crossattn'][..., self.clip_ctx_dim:]) | |
| # dino_spatial_token = context['crossattn'] | |
| dino_spatial_token = context['img_crossattn'] | |
| dino_pooled_vector = context['img_vector'] | |
| if self.has_caption: | |
| clip_caption_token = context.get('caption_crossattn') | |
| pooled_vector = torch.cat([dino_pooled_vector, context.get('caption_vector')], -1) # concat dino_vector | |
| else: | |
| clip_caption_token = None | |
| pooled_vector = dino_pooled_vector | |
| t = self.t_embedder(timesteps) + self.pooled_vec_embedder(pooled_vector) | |
| t0 = self.adaLN_modulation(t) # single-adaLN, B 6144 | |
| x = self.x_embedder(x) | |
| freqs_cis = None | |
| if self.enable_rope: | |
| freqs_cis=self.freqs_cis[: x.size(1)] | |
| # add a norm layer here, as in point-e | |
| # x = self.ln_pre(x) | |
| for blk_idx, block in enumerate(self.blocks): | |
| x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_caption_token=clip_caption_token, freqs_cis=freqs_cis) | |
| # add a norm layer here, as in point-e | |
| # x = self.ln_post(x) | |
| # todo later | |
| x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) | |
| x = x.to(torch.float32).contiguous() | |
| return x | |
| # xyz-diff | |
| # xyz-cond tex diff | |
| class DiT_I23D_PCD_PixelArt_xyz_cond_kl_diff(DiT_I23D_PCD_PixelArt): | |
| def __init__( | |
| self, | |
| input_size=32, | |
| patch_size=2, | |
| in_channels=4, | |
| hidden_size=1152, | |
| depth=28, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| class_dropout_prob=0.1, | |
| num_classes=1000, | |
| learn_sigma=True, | |
| mixing_logit_init=-3, | |
| mixed_prediction=True, | |
| context_dim=False, | |
| pooling_ctx_dim=768, | |
| roll_out=False, | |
| vit_blk=ImageCondDiTBlockPixelArtRMSNorm, | |
| final_layer_blk=FinalLayer, | |
| create_cap_embedder=True, | |
| use_pe_cond=False, | |
| ): | |
| super().__init__(input_size, patch_size, in_channels, hidden_size, | |
| depth, num_heads, mlp_ratio, class_dropout_prob, | |
| num_classes, learn_sigma, mixing_logit_init, | |
| # mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt, | |
| mixed_prediction, context_dim, pooling_ctx_dim, roll_out, vit_blk, | |
| final_layer_blk) | |
| self.use_pe_cond = use_pe_cond | |
| self.x_embedder = Mlp(in_features=in_channels+3*(1-use_pe_cond), | |
| hidden_features=hidden_size, | |
| out_features=hidden_size, | |
| act_layer=approx_gelu, | |
| drop=0) | |
| if use_pe_cond: | |
| self.xyz_pos_embed = XYZPosEmbed(hidden_size) | |
| def forward(self, | |
| x, | |
| timesteps=None, | |
| context=None, | |
| y=None, | |
| get_attr='', | |
| **kwargs): | |
| """ | |
| Forward pass of DiT. | |
| x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) | |
| t: (N,) tensor of diffusion timesteps | |
| y: (N,) tensor of class labels | |
| """ | |
| # t = timesteps | |
| assert isinstance(context, dict) | |
| # dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features. | |
| # t = self.t_embedder(timesteps) | |
| clip_cls_token = self.cap_embedder(context['vector']) | |
| clip_spatial_token, dino_spatial_token = context['crossattn'][..., :self.clip_ctx_dim], self.dino_proj(context['crossattn'][..., self.clip_ctx_dim:]) | |
| fps_xyz = context['fps-xyz'] | |
| t = self.t_embedder(timesteps) + clip_cls_token # (N, D) | |
| t0 = self.adaLN_modulation(t) # single-adaLN, B 6144 | |
| if self.use_pe_cond: | |
| x = self.x_embedder(x) + self.xyz_pos_embed(fps_xyz) # point-wise addition | |
| else: # use concat to add info | |
| x = torch.cat([fps_xyz, x], dim=-1) | |
| x = self.x_embedder(x) | |
| # add a norm layer here, as in point-e | |
| # x = self.ln_pre(x) | |
| for blk_idx, block in enumerate(self.blocks): | |
| x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token) | |
| # add a norm layer here, as in point-e | |
| # x = self.ln_post(x) | |
| x = self.final_layer(x, t) # no loss on the xyz side | |
| x = x.to(torch.float32).contiguous() | |
| return x | |
| # xyz-cond tex diff, but clay | |
| class DiT_I23D_PCD_PixelArt_noclip_clay_stage2(DiT_I23D_PCD_PixelArt_noclip): | |
| def __init__( | |
| self, | |
| input_size=32, | |
| patch_size=2, | |
| in_channels=4, | |
| hidden_size=1152, | |
| depth=28, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| class_dropout_prob=0.1, | |
| num_classes=1000, | |
| learn_sigma=True, | |
| mixing_logit_init=-3, | |
| mixed_prediction=True, | |
| context_dim=False, | |
| pooling_ctx_dim=768, | |
| roll_out=False, | |
| vit_blk=ImageCondDiTBlockPixelArtRMSNorm, | |
| final_layer_blk=FinalLayer, | |
| create_cap_embedder=True, | |
| use_pe_cond=False, | |
| has_caption=False, | |
| use_clay_ca=False, | |
| ): | |
| super().__init__(input_size, patch_size, in_channels, hidden_size, | |
| depth, num_heads, mlp_ratio, class_dropout_prob, | |
| num_classes, learn_sigma, mixing_logit_init, | |
| # mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt, | |
| mixed_prediction, context_dim, pooling_ctx_dim, roll_out, vit_blk, | |
| final_layer_blk, use_clay_ca=use_clay_ca, has_caption=has_caption) | |
| self.has_caption = False | |
| self.use_pe_cond = use_pe_cond | |
| self.x_embedder = Mlp(in_features=in_channels+3*(1-use_pe_cond), | |
| hidden_features=hidden_size, | |
| out_features=hidden_size, | |
| act_layer=approx_gelu, | |
| drop=0) | |
| if use_pe_cond: | |
| self.xyz_pos_embed = XYZPosEmbed(hidden_size) | |
| def forward(self, | |
| x, | |
| timesteps=None, | |
| context=None, | |
| y=None, | |
| get_attr='', | |
| **kwargs): | |
| """ | |
| Forward pass of DiT. | |
| x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) | |
| t: (N,) tensor of diffusion timesteps | |
| y: (N,) tensor of class labels | |
| """ | |
| # t = timesteps | |
| assert isinstance(context, dict) | |
| dino_spatial_token = context['img_crossattn'] | |
| dino_pooled_vector = context['img_vector'] | |
| if self.has_caption: | |
| clip_caption_token = context.get('caption_crossattn') | |
| pooled_vector = torch.cat([dino_pooled_vector, context.get('caption_vector')], -1) # concat dino_vector | |
| else: | |
| clip_caption_token = None | |
| pooled_vector = dino_pooled_vector | |
| t = self.t_embedder(timesteps) + self.pooled_vec_embedder(pooled_vector) | |
| t0 = self.adaLN_modulation(t) # single-adaLN, B 6144 | |
| fps_xyz = context['fps-xyz'] | |
| if self.use_pe_cond: | |
| x = self.x_embedder(x) + self.xyz_pos_embed(fps_xyz) # point-wise addition | |
| else: # use concat to add info | |
| x = torch.cat([fps_xyz, x], dim=-1) | |
| x = self.x_embedder(x) | |
| for blk_idx, block in enumerate(self.blocks): | |
| x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_caption_token=clip_caption_token) | |
| # todo later | |
| x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) | |
| x = x.to(torch.float32).contiguous() | |
| return x | |
| class DiT_I23D_PixelArt_MVCond(DiT_I23D_PixelArt): | |
| def __init__( | |
| self, | |
| input_size=32, | |
| patch_size=2, | |
| in_channels=4, | |
| hidden_size=1152, | |
| depth=28, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| class_dropout_prob=0.1, | |
| num_classes=1000, | |
| learn_sigma=True, | |
| mixing_logit_init=-3, | |
| mixed_prediction=True, | |
| context_dim=False, | |
| pooling_ctx_dim=768, | |
| roll_out=False, | |
| vit_blk=ImageCondDiTBlockPixelArt, | |
| final_layer_blk=FinalLayer, | |
| create_cap_embedder=False, | |
| ): | |
| super().__init__(input_size, patch_size, in_channels, hidden_size, | |
| depth, num_heads, mlp_ratio, class_dropout_prob, | |
| num_classes, learn_sigma, mixing_logit_init, | |
| # mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt, | |
| mixed_prediction, context_dim, | |
| pooling_ctx_dim, roll_out, ImageCondDiTBlockPixelArtRMSNorm, | |
| final_layer_blk, create_cap_embedder=create_cap_embedder) | |
| # support multi-view img condition | |
| # DINO handles global pooling here; clip takes care of camera-cond with ModLN | |
| # Input DINO concat also + global pool. InstantMesh adopts DINO (but CA). | |
| # expected: support dynamic numbers of frames? since CA, shall be capable of. Any number of context window size. | |
| del self.dino_proj | |
| def forward(self, | |
| x, | |
| timesteps=None, | |
| context=None, | |
| y=None, | |
| get_attr='', | |
| **kwargs): | |
| """ | |
| Forward pass of DiT. | |
| x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) | |
| t: (N,) tensor of diffusion timesteps | |
| y: (N,) tensor of class labels | |
| """ | |
| # t = timesteps | |
| assert isinstance(context, dict) | |
| # st() | |
| # (Pdb) p context.keys() | |
| # dict_keys(['crossattn', 'vector', 'concat']) | |
| # (Pdb) p context['vector'].shape | |
| # torch.Size([2, 768]) | |
| # (Pdb) p context['crossattn'].shape | |
| # torch.Size([2, 256, 1024]) | |
| # (Pdb) p context['concat'].shape | |
| # torch.Size([2, 4, 256, 768]) # mv dino spatial features | |
| # ! clip spatial tokens for append self-attn, thus add a projection layer (self.dino_proj) | |
| # DINO features sent via crossattn, thus no proj required (already KV linear layers in crossattn blk) | |
| clip_cls_token, clip_spatial_token = self.cap_embedder(context['vector']), self.clip_spatial_proj(context['crossattn']) # no norm here required? QK norm is enough, since self.ln_post(x) in vit | |
| dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features. | |
| t = self.t_embedder(timesteps) + clip_cls_token # (N, D) | |
| t0 = self.adaLN_modulation(t) # single-adaLN, B 6144 | |
| # if self.roll_out: # ! | |
| x = rearrange(x, 'b (c n) h w->(b n) c h w', | |
| n=3) # downsample with same conv | |
| x = self.x_embedder(x) # (b n) c h/f w/f | |
| x = rearrange(x, '(b n) l c -> b (n l) c', n=3) | |
| x = x + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 | |
| for blk_idx, block in enumerate(self.blocks): | |
| # x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token) # (N, T, D) | |
| # ! DINO tokens for CA, CLIP tokens for append here. | |
| x = block(x, t0, dino_spatial_token=clip_spatial_token, clip_spatial_token=dino_spatial_token) # (N, T, D) | |
| # todo later | |
| x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) | |
| if self.roll_out: # move n from L to B axis | |
| x = rearrange(x, 'b (n l) c ->(b n) l c', n=3) | |
| x = self.unpatchify(x) # (N, out_channels, H, W) | |
| if self.roll_out: # move n from L to B axis | |
| x = rearrange(x, '(b n) c h w -> b (c n) h w', n=3) | |
| x = x.to(torch.float32).contiguous() | |
| return x | |
| class DiT_I23D_PixelArt_MVCond_noClip(DiT_I23D_PixelArt): | |
| def __init__( | |
| self, | |
| input_size=32, | |
| patch_size=2, | |
| in_channels=4, | |
| hidden_size=1152, | |
| depth=28, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| class_dropout_prob=0.1, | |
| num_classes=1000, | |
| learn_sigma=True, | |
| mixing_logit_init=-3, | |
| mixed_prediction=True, | |
| context_dim=False, | |
| pooling_ctx_dim=768, | |
| roll_out=False, | |
| vit_blk=ImageCondDiTBlockPixelArt, | |
| final_layer_blk=FinalLayer, | |
| create_cap_embedder=False, | |
| ): | |
| super().__init__(input_size, patch_size, in_channels, hidden_size, | |
| depth, num_heads, mlp_ratio, class_dropout_prob, | |
| num_classes, learn_sigma, mixing_logit_init, | |
| # mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt, | |
| mixed_prediction, context_dim, | |
| pooling_ctx_dim, roll_out, | |
| ImageCondDiTBlockPixelArtRMSNormNoClip, | |
| final_layer_blk, | |
| create_cap_embedder=create_cap_embedder) | |
| # support multi-view img condition | |
| # DINO handles global pooling here; clip takes care of camera-cond with ModLN | |
| # Input DINO concat also + global pool. InstantMesh adopts DINO (but CA). | |
| # expected: support dynamic numbers of frames? since CA, shall be capable of. Any number of context window size. | |
| del self.dino_proj | |
| del self.clip_spatial_proj, self.cap_embedder # no clip required | |
| def forward(self, | |
| x, | |
| timesteps=None, | |
| context=None, | |
| y=None, | |
| get_attr='', | |
| **kwargs): | |
| """ | |
| Forward pass of DiT. | |
| x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) | |
| t: (N,) tensor of diffusion timesteps | |
| y: (N,) tensor of class labels | |
| """ | |
| # t = timesteps | |
| assert isinstance(context, dict) | |
| # st() | |
| # (Pdb) p context.keys() | |
| # dict_keys(['crossattn', 'vector', 'concat']) | |
| # (Pdb) p context['vector'].shape | |
| # torch.Size([2, 768]) | |
| # (Pdb) p context['crossattn'].shape | |
| # torch.Size([2, 256, 1024]) | |
| # (Pdb) p context['concat'].shape | |
| # torch.Size([2, 4, 256, 768]) # mv dino spatial features | |
| # ! clip spatial tokens for append self-attn, thus add a projection layer (self.dino_proj) | |
| # DINO features sent via crossattn, thus no proj required (already KV linear layers in crossattn blk) | |
| # clip_cls_token, clip_spatial_token = self.cap_embedder(context['vector']), self.clip_spatial_proj(context['crossattn']) # no norm here required? QK norm is enough, since self.ln_post(x) in vit | |
| dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features. | |
| # t = self.t_embedder(timesteps) + clip_cls_token # (N, D) | |
| t = self.t_embedder(timesteps) | |
| t0 = self.adaLN_modulation(t) # single-adaLN, B 6144 | |
| # if self.roll_out: # ! | |
| x = rearrange(x, 'b (c n) h w->(b n) c h w', | |
| n=3) # downsample with same conv | |
| x = self.x_embedder(x) # (b n) c h/f w/f | |
| x = rearrange(x, '(b n) l c -> b (n l) c', n=3) | |
| x = x + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 | |
| for blk_idx, block in enumerate(self.blocks): | |
| # x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token) # (N, T, D) | |
| # ! DINO tokens for CA, CLIP tokens for append here. | |
| x = block(x, t0, dino_spatial_token=dino_spatial_token) # (N, T, D) | |
| # todo later | |
| x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) | |
| if self.roll_out: # move n from L to B axis | |
| x = rearrange(x, 'b (n l) c ->(b n) l c', n=3) | |
| x = self.unpatchify(x) # (N, out_channels, H, W) | |
| if self.roll_out: # move n from L to B axis | |
| x = rearrange(x, '(b n) c h w -> b (c n) h w', n=3) | |
| x = x.to(torch.float32).contiguous() | |
| return x | |
| # pcd-structured latent ddpm | |
| class DiT_pcd_I23D_PixelArt_MVCond(DiT_I23D_PixelArt_MVCond_noClip): | |
| def __init__( | |
| self, | |
| input_size=32, | |
| patch_size=2, | |
| in_channels=4, | |
| hidden_size=1152, | |
| depth=28, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| class_dropout_prob=0.1, | |
| num_classes=1000, | |
| learn_sigma=True, | |
| mixing_logit_init=-3, | |
| mixed_prediction=True, | |
| context_dim=False, | |
| pooling_ctx_dim=768, | |
| roll_out=False, | |
| vit_blk=ImageCondDiTBlockPixelArt, | |
| final_layer_blk=FinalLayer, | |
| create_cap_embedder=False, | |
| ): | |
| super().__init__(input_size, patch_size, in_channels, hidden_size, | |
| depth, num_heads, mlp_ratio, class_dropout_prob, | |
| num_classes, learn_sigma, mixing_logit_init, | |
| # mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt, | |
| mixed_prediction, context_dim, | |
| pooling_ctx_dim, | |
| roll_out, ImageCondDiTBlockPixelArtRMSNorm, | |
| final_layer_blk, | |
| create_cap_embedder=create_cap_embedder) | |
| # ! first, normalize xyz from [-0.45,0.45] to [-1,1] | |
| # Then, encode xyz with point fourier feat + MLP projection, serves as PE here. | |
| # a separate MLP for the KL feature | |
| # add them together in the feature space | |
| # use a single MLP (final_layer) to map them back to 16 + 3 dims. | |
| self.x_embedder = Mlp(in_features=in_channels, | |
| hidden_features=hidden_size, | |
| out_features=hidden_size, | |
| act_layer=approx_gelu, | |
| drop=0) | |
| del self.pos_embed | |
| def forward(self, | |
| x, | |
| timesteps=None, | |
| context=None, | |
| y=None, | |
| get_attr='', | |
| **kwargs): | |
| """ | |
| Forward pass of DiT. | |
| x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) | |
| t: (N,) tensor of diffusion timesteps | |
| y: (N,) tensor of class labels | |
| """ | |
| # t = timesteps | |
| assert isinstance(context, dict) | |
| dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features. | |
| t = self.t_embedder(timesteps) | |
| t0 = self.adaLN_modulation(t) # single-adaLN, B 6144 | |
| x = self.x_embedder(x) | |
| for blk_idx, block in enumerate(self.blocks): | |
| # x = block(x, t0, dino_spatial_token=dino_spatial_token, clip_spatial_token=clip_spatial_token) # (N, T, D) | |
| # ! DINO tokens for CA, CLIP tokens for append here. | |
| x = block(x, t0, dino_spatial_token=dino_spatial_token) | |
| # todo later | |
| x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) | |
| x = x.to(torch.float32).contiguous() | |
| return x | |
| class DiT_pcd_I23D_PixelArt_MVCond_clay(DiT_PCD_PixelArt): | |
| # fine-tune the mv model from text conditioned model | |
| def __init__( | |
| self, | |
| input_size=32, | |
| patch_size=2, | |
| in_channels=4, | |
| hidden_size=1152, | |
| depth=28, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| class_dropout_prob=0.1, | |
| num_classes=1000, | |
| learn_sigma=True, | |
| mixing_logit_init=-3, | |
| mixed_prediction=True, | |
| context_dim=False, | |
| pooling_ctx_dim=768, | |
| roll_out=False, | |
| vit_blk=ImageCondDiTBlockPixelArt, | |
| final_layer_blk=FinalLayer, | |
| create_cap_embedder=False, **kwargs | |
| ): | |
| super().__init__(input_size, patch_size, in_channels, hidden_size, | |
| depth, num_heads, mlp_ratio, class_dropout_prob, | |
| num_classes, learn_sigma, mixing_logit_init, | |
| # mixed_prediction, context_dim, roll_out, ImageCondDiTBlockPixelArt, | |
| mixed_prediction, context_dim, | |
| # pooling_ctx_dim, | |
| roll_out, vit_blk, | |
| final_layer_blk,) | |
| # create_cap_embedder=create_cap_embedder) | |
| def forward(self, | |
| x, | |
| timesteps=None, | |
| context=None, | |
| y=None, | |
| get_attr='', | |
| **kwargs): | |
| """ | |
| Forward pass of DiT. | |
| x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) | |
| t: (N,) tensor of diffusion timesteps | |
| y: (N,) tensor of class labels | |
| """ | |
| # t = timesteps | |
| assert context is not None | |
| clip_cls_token = self.cap_embedder(context['caption_vector']) # pooled | |
| t = self.t_embedder(timesteps) + clip_cls_token # (N, D) | |
| t0 = self.adaLN_modulation(t) # single-adaLN, B 6144 | |
| x = self.x_embedder(x) | |
| # ! spatial tokens | |
| dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features. | |
| # assert context.ndim == 2 | |
| # if isinstance(context, dict): | |
| # context = context['caption_crossattn'] # sgm conditioner compat | |
| # loop dit block | |
| for blk_idx, block in enumerate(self.blocks): | |
| x = block(x, t0, clip_caption_token=context['caption_crossattn'], | |
| dino_spatial_token=dino_spatial_token) # (N, T, D) | |
| x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) | |
| # cast to float32 for better accuracy | |
| x = x.to(torch.float32).contiguous() | |
| return x | |
| # single-img pretrained clay | |
| class DiT_pcd_I23D_PixelArt_MVCond_clay_i23dpt(DiT_I23D_PCD_PixelArt_noclip): | |
| # fine-tune the mv model from text conditioned model | |
| def __init__( | |
| self, | |
| input_size=32, | |
| patch_size=2, | |
| in_channels=4, | |
| hidden_size=1152, | |
| depth=28, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| class_dropout_prob=0.1, | |
| num_classes=1000, | |
| learn_sigma=True, | |
| mixing_logit_init=-3, | |
| mixed_prediction=True, | |
| context_dim=False, | |
| pooling_ctx_dim=768, | |
| roll_out=False, | |
| vit_blk=ImageCondDiTBlockPixelArt, | |
| final_layer_blk=FinalLayer, | |
| create_cap_embedder=False, **kwargs | |
| ): | |
| super().__init__(input_size, patch_size, in_channels, hidden_size, | |
| depth, num_heads, mlp_ratio, class_dropout_prob, | |
| num_classes, learn_sigma, mixing_logit_init, | |
| mixed_prediction, context_dim, | |
| pooling_ctx_dim, | |
| roll_out, vit_blk, | |
| final_layer_blk,) | |
| self.has_caption = False | |
| def forward(self, | |
| x, | |
| timesteps=None, | |
| context=None, | |
| y=None, | |
| get_attr='', | |
| **kwargs): | |
| """ | |
| Forward pass of DiT. | |
| x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) | |
| t: (N,) tensor of diffusion timesteps | |
| y: (N,) tensor of class labels | |
| """ | |
| # t = timesteps | |
| assert isinstance(context, dict) | |
| # dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features. | |
| # t = self.t_embedder(timesteps) | |
| # clip_cls_token = self.cap_embedder(context['vector']) | |
| # clip_spatial_token, dino_spatial_token = context['crossattn'][..., :self.clip_ctx_dim], self.dino_proj(context['crossattn'][..., self.clip_ctx_dim:]) | |
| # dino_spatial_token = context['crossattn'] | |
| # st() | |
| dino_spatial_token = context['img_crossattn'] | |
| dino_pooled_vector = context['img_vector'] | |
| dino_mv_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features. | |
| if self.has_caption: | |
| clip_caption_token = context.get('caption_crossattn') | |
| pooled_vector = torch.cat([dino_pooled_vector, context.get('caption_vector')], -1) # concat dino_vector | |
| else: | |
| clip_caption_token = None | |
| pooled_vector = dino_pooled_vector | |
| t = self.t_embedder(timesteps) + self.pooled_vec_embedder(pooled_vector) | |
| t0 = self.adaLN_modulation(t) # single-adaLN, B 6144 | |
| x = self.x_embedder(x) | |
| # add a norm layer here, as in point-e | |
| # x = self.ln_pre(x) | |
| for blk_idx, block in enumerate(self.blocks): | |
| x = block(x, t0, dino_spatial_token=dino_spatial_token, dino_mv_spatial_token=dino_mv_spatial_token) | |
| # add a norm layer here, as in point-e | |
| # x = self.ln_post(x) | |
| # todo later | |
| x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) | |
| x = x.to(torch.float32).contiguous() | |
| return x | |
| # stage 2 | |
| class DiT_pcd_I23D_PixelArt_MVCond_clay_i23dpt_stage2(DiT_I23D_PCD_PixelArt_noclip): | |
| # fine-tune the mv model from text conditioned model | |
| def __init__( | |
| self, | |
| input_size=32, | |
| patch_size=2, | |
| in_channels=4, | |
| hidden_size=1152, | |
| depth=28, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| class_dropout_prob=0.1, | |
| num_classes=1000, | |
| learn_sigma=True, | |
| mixing_logit_init=-3, | |
| mixed_prediction=True, | |
| context_dim=False, | |
| pooling_ctx_dim=768, | |
| roll_out=False, | |
| vit_blk=ImageCondDiTBlockPixelArt, | |
| final_layer_blk=FinalLayer, | |
| create_cap_embedder=False, **kwargs | |
| ): | |
| super().__init__(input_size, patch_size, in_channels, hidden_size, | |
| depth, num_heads, mlp_ratio, class_dropout_prob, | |
| num_classes, learn_sigma, mixing_logit_init, | |
| mixed_prediction, context_dim, | |
| pooling_ctx_dim, | |
| roll_out, vit_blk, | |
| final_layer_blk,) | |
| self.has_caption = False | |
| self.use_pe_cond = True | |
| self.x_embedder = Mlp(in_features=in_channels+3*(1-self.use_pe_cond), | |
| hidden_features=hidden_size, | |
| out_features=hidden_size, | |
| act_layer=approx_gelu, | |
| drop=0) | |
| if self.use_pe_cond: | |
| self.xyz_pos_embed = XYZPosEmbed(hidden_size) | |
| def forward(self, | |
| x, | |
| timesteps=None, | |
| context=None, | |
| y=None, | |
| get_attr='', | |
| **kwargs): | |
| """ | |
| Forward pass of DiT. | |
| x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) | |
| t: (N,) tensor of diffusion timesteps | |
| y: (N,) tensor of class labels | |
| """ | |
| # t = timesteps | |
| assert isinstance(context, dict) | |
| # dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features. | |
| # t = self.t_embedder(timesteps) | |
| # clip_cls_token = self.cap_embedder(context['vector']) | |
| # clip_spatial_token, dino_spatial_token = context['crossattn'][..., :self.clip_ctx_dim], self.dino_proj(context['crossattn'][..., self.clip_ctx_dim:]) | |
| # dino_spatial_token = context['crossattn'] | |
| # st() | |
| dino_spatial_token = context['img_crossattn'] | |
| dino_pooled_vector = context['img_vector'] | |
| dino_mv_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features. | |
| if self.has_caption: | |
| clip_caption_token = context.get('caption_crossattn') | |
| pooled_vector = torch.cat([dino_pooled_vector, context.get('caption_vector')], -1) # concat dino_vector | |
| else: | |
| clip_caption_token = None | |
| pooled_vector = dino_pooled_vector | |
| t = self.t_embedder(timesteps) + self.pooled_vec_embedder(pooled_vector) | |
| t0 = self.adaLN_modulation(t) # single-adaLN, B 6144 | |
| # x = self.x_embedder(x) | |
| fps_xyz = context['fps-xyz'] | |
| if self.use_pe_cond: | |
| x = self.x_embedder(x) + self.xyz_pos_embed(fps_xyz) # point-wise addition | |
| else: # use concat to add info | |
| x = torch.cat([fps_xyz, x], dim=-1) | |
| x = self.x_embedder(x) | |
| # add a norm layer here, as in point-e | |
| # x = self.ln_pre(x) | |
| for blk_idx, block in enumerate(self.blocks): | |
| x = block(x, t0, dino_spatial_token=dino_spatial_token, dino_mv_spatial_token=dino_mv_spatial_token) | |
| # add a norm layer here, as in point-e | |
| # x = self.ln_post(x) | |
| # todo later | |
| x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) | |
| x = x.to(torch.float32).contiguous() | |
| return x | |
| class DiT_pcd_I23D_PixelArt_MVCond_clay_i23dpt_noi23d(DiT_I23D_PCD_PixelArt_noclip): | |
| # fine-tune the mv model from text conditioned model | |
| def __init__( | |
| self, | |
| input_size=32, | |
| patch_size=2, | |
| in_channels=4, | |
| hidden_size=1152, | |
| depth=28, | |
| num_heads=16, | |
| mlp_ratio=4, | |
| class_dropout_prob=0.1, | |
| num_classes=1000, | |
| learn_sigma=True, | |
| mixing_logit_init=-3, | |
| mixed_prediction=True, | |
| context_dim=False, | |
| pooling_ctx_dim=768, | |
| roll_out=False, | |
| vit_blk=ImageCondDiTBlockPixelArt, | |
| final_layer_blk=FinalLayer, | |
| create_cap_embedder=False, **kwargs | |
| ): | |
| super().__init__(input_size, patch_size, in_channels, hidden_size, | |
| depth, num_heads, mlp_ratio, class_dropout_prob, | |
| num_classes, learn_sigma, mixing_logit_init, | |
| mixed_prediction, context_dim, | |
| pooling_ctx_dim, | |
| roll_out, vit_blk, | |
| final_layer_blk,) | |
| self.has_caption = False | |
| del self.pooled_vec_embedder | |
| def forward(self, | |
| x, | |
| timesteps=None, | |
| context=None, | |
| y=None, | |
| get_attr='', | |
| **kwargs): | |
| """ | |
| Forward pass of DiT. | |
| x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) | |
| t: (N,) tensor of diffusion timesteps | |
| y: (N,) tensor of class labels | |
| """ | |
| # t = timesteps | |
| assert isinstance(context, dict) | |
| # dino_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features. | |
| # t = self.t_embedder(timesteps) | |
| # clip_cls_token = self.cap_embedder(context['vector']) | |
| # clip_spatial_token, dino_spatial_token = context['crossattn'][..., :self.clip_ctx_dim], self.dino_proj(context['crossattn'][..., self.clip_ctx_dim:]) | |
| # dino_spatial_token = context['crossattn'] | |
| # st() | |
| # dino_spatial_token = context['img_crossattn'] | |
| # dino_pooled_vector = context['img_vector'] | |
| dino_mv_spatial_token = rearrange(context['concat'], 'b v l c -> b (v l) c') # flatten MV dino features. | |
| if self.has_caption: | |
| clip_caption_token = context.get('caption_crossattn') | |
| pooled_vector = torch.cat([dino_pooled_vector, context.get('caption_vector')], -1) # concat dino_vector | |
| else: | |
| clip_caption_token = None | |
| # pooled_vector = dino_pooled_vector | |
| pooled_vector = None | |
| # t = self.t_embedder(timesteps) + self.pooled_vec_embedder(pooled_vector) | |
| t = self.t_embedder(timesteps) | |
| t0 = self.adaLN_modulation(t) # single-adaLN, B 6144 | |
| x = self.x_embedder(x) | |
| # add a norm layer here, as in point-e | |
| # x = self.ln_pre(x) | |
| for blk_idx, block in enumerate(self.blocks): | |
| # x = block(x, t0, dino_spatial_token=dino_spatial_token, dino_mv_spatial_token=dino_mv_spatial_token) | |
| x = block(x, t0, dino_mv_spatial_token=dino_mv_spatial_token) | |
| # add a norm layer here, as in point-e | |
| # x = self.ln_post(x) | |
| # todo later | |
| x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) | |
| x = x.to(torch.float32).contiguous() | |
| return x | |
| ################################################################################# | |
| # DiT_I23D Configs # | |
| ################################################################################# | |
| def DiT_XL_2(**kwargs): | |
| return DiT_I23D(depth=28, | |
| hidden_size=1152, | |
| patch_size=2, | |
| num_heads=16, | |
| **kwargs) | |
| def DiT_L_2(**kwargs): | |
| return DiT_I23D(depth=24, | |
| hidden_size=1024, | |
| patch_size=2, | |
| num_heads=16, | |
| **kwargs) | |
| def DiT_B_2(**kwargs): | |
| return DiT_I23D(depth=12, | |
| hidden_size=768, | |
| patch_size=2, | |
| num_heads=12, | |
| **kwargs) | |
| def DiT_B_1(**kwargs): | |
| return DiT_I23D(depth=12, | |
| hidden_size=768, | |
| patch_size=1, | |
| num_heads=12, | |
| **kwargs) | |
| def DiT_L_Pixelart_2(**kwargs): | |
| return DiT_I23D_PixelArt(depth=24, | |
| hidden_size=1024, | |
| patch_size=2, | |
| num_heads=16, | |
| **kwargs) | |
| def DiT_B_Pixelart_2(**kwargs): | |
| return DiT_I23D_PixelArt(depth=12, | |
| hidden_size=768, | |
| patch_size=2, | |
| num_heads=12, | |
| **kwargs) | |
| def DiT_L_Pixelart_MV_2(**kwargs): | |
| return DiT_I23D_PixelArt_MVCond(depth=24, | |
| hidden_size=1024, | |
| patch_size=2, | |
| num_heads=16, | |
| **kwargs) | |
| def DiT_L_Pixelart_MV_2_noclip(**kwargs): | |
| return DiT_I23D_PixelArt_MVCond_noClip(depth=24, | |
| hidden_size=1024, | |
| patch_size=2, | |
| num_heads=16, | |
| **kwargs) | |
| def DiT_XL_Pixelart_MV_2(**kwargs): | |
| return DiT_I23D_PixelArt_MVCond(depth=28, | |
| hidden_size=1152, | |
| patch_size=2, | |
| num_heads=16, | |
| **kwargs) | |
| def DiT_B_Pixelart_MV_2(**kwargs): | |
| return DiT_I23D_PixelArt_MVCond(depth=12, | |
| hidden_size=768, | |
| patch_size=2, | |
| num_heads=12, | |
| **kwargs) | |
| # pcd latent | |
| def DiT_L_Pixelart_MV_pcd(**kwargs): | |
| return DiT_pcd_I23D_PixelArt_MVCond(depth=24, | |
| hidden_size=1024, | |
| patch_size=1, # no spatial compression here | |
| num_heads=16, | |
| **kwargs) | |
| # raw gs i23d | |
| def DiT_L_Pixelart_pcd(**kwargs): | |
| return DiT_I23D_PCD_PixelArt(depth=24, | |
| # return DiT_I23D_PCD_PixelArt_noclip(depth=24, | |
| hidden_size=1024, | |
| patch_size=1, # no spatial compression here | |
| num_heads=16, | |
| **kwargs) | |
| def DiT_L_Pixelart_clay_pcd(**kwargs): | |
| return DiT_I23D_PCD_PixelArt_noclip(depth=24, | |
| vit_blk=ImageCondDiTBlockPixelArtRMSNormClayLRM, | |
| use_clay_ca=True, | |
| hidden_size=1024, | |
| patch_size=1, # no spatial compression here | |
| num_heads=16, | |
| enable_rope=False, | |
| **kwargs) | |
| def DiT_XL_Pixelart_clay_pcd(**kwargs): | |
| return DiT_I23D_PCD_PixelArt_noclip(depth=28, | |
| vit_blk=ImageCondDiTBlockPixelArtRMSNormClayLRM, | |
| use_clay_ca=True, | |
| hidden_size=1152, | |
| patch_size=1, # no spatial compression here | |
| num_heads=16, | |
| enable_rope=False, | |
| **kwargs) | |
| def DiT_B_Pixelart_clay_pcd(**kwargs): | |
| return DiT_I23D_PCD_PixelArt_noclip(depth=12, | |
| vit_blk=ImageCondDiTBlockPixelArtRMSNormClayLRM, | |
| use_clay_ca=True, | |
| hidden_size=768, | |
| patch_size=1, # no spatial compression here | |
| num_heads=12, | |
| **kwargs) | |
| def DiT_L_Pixelart_clay_pcd_stage2(**kwargs): | |
| return DiT_I23D_PCD_PixelArt_noclip_clay_stage2(depth=24, | |
| vit_blk=ImageCondDiTBlockPixelArtRMSNormClayLRM, | |
| use_clay_ca=True, | |
| hidden_size=1024, | |
| patch_size=1, # no spatial compression here | |
| num_heads=16, | |
| use_pe_cond=True, | |
| **kwargs) | |
| def DiT_B_Pixelart_clay_pcd_stage2(**kwargs): | |
| return DiT_I23D_PCD_PixelArt_noclip_clay_stage2(depth=12, | |
| vit_blk=ImageCondDiTBlockPixelArtRMSNormClayLRM, | |
| use_clay_ca=True, | |
| hidden_size=768, | |
| patch_size=1, # no spatial compression here | |
| num_heads=12, | |
| use_pe_cond=True, | |
| **kwargs) | |
| def DiT_L_Pixelart_clay_tandi_pcd(**kwargs): | |
| return DiT_I23D_PCD_PixelArt_noclip(depth=24, | |
| vit_blk=ImageCondDiTBlockPixelArtRMSNormClayText, | |
| use_clay_ca=True, | |
| hidden_size=1024, | |
| patch_size=1, # no spatial compression here | |
| num_heads=16, | |
| has_caption=True, | |
| **kwargs) | |
| def DiT_B_Pixelart_clay_tandi_pcd(**kwargs): | |
| return DiT_I23D_PCD_PixelArt_noclip(depth=12, | |
| vit_blk=ImageCondDiTBlockPixelArtRMSNormClayText, | |
| use_clay_ca=True, | |
| hidden_size=768, | |
| patch_size=1, # no spatial compression here | |
| num_heads=12, | |
| has_caption=True, | |
| **kwargs) | |
| def DiT_B_Pixelart_pcd(**kwargs): | |
| return DiT_I23D_PCD_PixelArt(depth=12, | |
| hidden_size=768, | |
| patch_size=1, # no spatial compression here | |
| num_heads=12, | |
| **kwargs) | |
| def DiT_B_Pixelart_pcd_cond_diff(**kwargs): | |
| return DiT_I23D_PCD_PixelArt_xyz_cond_kl_diff(depth=12, | |
| hidden_size=768, | |
| patch_size=1, # no spatial compression here | |
| num_heads=12, | |
| **kwargs) | |
| def DiT_B_Pixelart_pcd_cond_diff_pe(**kwargs): | |
| return DiT_I23D_PCD_PixelArt_xyz_cond_kl_diff(depth=12, | |
| hidden_size=768, | |
| patch_size=1, # no spatial compression here | |
| vit_blk=ImageCondDiTBlockPixelArtRMSNormClayLRM, | |
| num_heads=12, | |
| use_pe_cond=True, | |
| **kwargs) | |
| def DiT_L_Pixelart_pcd_cond_diff_pe(**kwargs): | |
| return DiT_I23D_PCD_PixelArt_xyz_cond_kl_diff(depth=24, | |
| hidden_size=1024, | |
| patch_size=1, # no spatial compression here | |
| vit_blk=ImageCondDiTBlockPixelArtRMSNormClayLRM, | |
| num_heads=16, | |
| use_pe_cond=True, | |
| **kwargs) | |
| # mv version | |
| def DiT_L_Pixelart_clay_mv_pcd(**kwargs): | |
| return DiT_pcd_I23D_PixelArt_MVCond_clay(depth=24, | |
| vit_blk=ImageCondDiTBlockPixelArtRMSNormClayText, | |
| use_clay_ca=True, | |
| hidden_size=1024, | |
| patch_size=1, # no spatial compression here | |
| num_heads=16, | |
| **kwargs) | |
| def DiT_L_Pixelart_clay_mv_i23dpt_pcd(**kwargs): | |
| return DiT_pcd_I23D_PixelArt_MVCond_clay_i23dpt(depth=24, | |
| vit_blk=ImageCondDiTBlockPixelArtRMSNormClayMV, | |
| use_clay_ca=True, | |
| hidden_size=1024, | |
| patch_size=1, # no spatial compression here | |
| num_heads=16, | |
| **kwargs) | |
| def DiT_L_Pixelart_clay_mv_i23dpt_pcd_noi23d(**kwargs): | |
| return DiT_pcd_I23D_PixelArt_MVCond_clay_i23dpt_noi23d(depth=24, | |
| vit_blk=ImageCondDiTBlockPixelArtRMSNormClayMV_noi23d, | |
| use_clay_ca=True, | |
| hidden_size=1024, | |
| patch_size=1, # no spatial compression here | |
| num_heads=16, | |
| **kwargs) | |
| def DiT_L_Pixelart_clay_mv_i23dpt_pcd_stage2(**kwargs): | |
| return DiT_pcd_I23D_PixelArt_MVCond_clay_i23dpt_stage2(depth=24, | |
| vit_blk=ImageCondDiTBlockPixelArtRMSNormClayMV, | |
| use_clay_ca=True, | |
| hidden_size=1024, | |
| patch_size=1, # no spatial compression here | |
| num_heads=16, | |
| **kwargs) | |
| DiT_models = { | |
| 'DiT-XL/2': DiT_XL_2, | |
| 'DiT-L/2': DiT_L_2, | |
| 'DiT-B/2': DiT_B_2, | |
| 'DiT-B/1': DiT_B_1, | |
| 'DiT-PixArt-L/2': DiT_L_Pixelart_2, | |
| 'DiT-PixArt-MV-XL/2': DiT_XL_Pixelart_MV_2, | |
| # 'DiT-PixArt-MV-L/2': DiT_L_Pixelart_MV_2, | |
| 'DiT-PixArt-MV-L/2': DiT_L_Pixelart_MV_2_noclip, | |
| 'DiT-PixArt-MV-PCD-L': DiT_L_Pixelart_MV_pcd, | |
| # raw xyz cond | |
| 'DiT-PixArt-PCD-L': DiT_L_Pixelart_pcd, | |
| 'DiT-PixArt-PCD-CLAY-XL': DiT_XL_Pixelart_clay_pcd, | |
| 'DiT-PixArt-PCD-CLAY-L': DiT_L_Pixelart_clay_pcd, | |
| 'DiT-PixArt-PCD-CLAY-B': DiT_B_Pixelart_clay_pcd, | |
| 'DiT-PixArt-PCD-CLAY-stage2-B': DiT_B_Pixelart_clay_pcd_stage2, | |
| 'DiT-PixArt-PCD-CLAY-stage2-L': DiT_L_Pixelart_clay_pcd_stage2, | |
| 'DiT-PixArt-PCD-CLAY-TandI-L': DiT_L_Pixelart_clay_tandi_pcd, | |
| 'DiT-PixArt-PCD-CLAY-TandI-B': DiT_B_Pixelart_clay_tandi_pcd, | |
| 'DiT-PixArt-PCD-B': DiT_B_Pixelart_pcd, | |
| # xyz-conditioned KL feature diffusion | |
| 'DiT-PixArt-PCD-cond-diff-B': DiT_B_Pixelart_pcd_cond_diff, | |
| 'DiT-PixArt-PCD-cond-diff-pe-B': DiT_B_Pixelart_pcd_cond_diff_pe, | |
| 'DiT-PixArt-PCD-cond-diff-pe-L': DiT_L_Pixelart_pcd_cond_diff_pe, | |
| 'DiT-PixArt-MV-B/2': DiT_B_Pixelart_MV_2, | |
| 'DiT-PixArt-B/2': DiT_B_Pixelart_2, | |
| # ! mv version following clay | |
| 'DiT-PixArt-PCD-MV-L': DiT_L_Pixelart_clay_mv_pcd, | |
| 'DiT-PixArt-PCD-MV-I23Dpt-L': DiT_L_Pixelart_clay_mv_i23dpt_pcd, | |
| 'DiT-PixArt-PCD-MV-I23Dpt-L-noI23D': DiT_L_Pixelart_clay_mv_i23dpt_pcd_noi23d, | |
| 'DiT-PixArt-PCD-MV-I23Dpt-L-stage2': DiT_L_Pixelart_clay_mv_i23dpt_pcd_stage2, | |
| } | |