Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # Copyright (c) Alibaba, Inc. and its affiliates. | |
| import re | |
| from collections import OrderedDict | |
| from functools import partial | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| from torch.nn.utils.rnn import pad_sequence | |
| from torch.utils.checkpoint import checkpoint_sequential | |
| from scepter.modules.model.base_model import BaseModel | |
| from scepter.modules.model.registry import BACKBONES | |
| from scepter.modules.utils.config import dict_to_yaml | |
| from scepter.modules.utils.file_system import FS | |
| from .layers import ( | |
| Mlp, | |
| TimestepEmbedder, | |
| PatchEmbed, | |
| DiTACEBlock, | |
| T2IFinalLayer | |
| ) | |
| from .pos_embed import rope_params | |
| class DiTACE(BaseModel): | |
| para_dict = { | |
| 'PATCH_SIZE': { | |
| 'value': 2, | |
| 'description': '' | |
| }, | |
| 'IN_CHANNELS': { | |
| 'value': 4, | |
| 'description': '' | |
| }, | |
| 'HIDDEN_SIZE': { | |
| 'value': 1152, | |
| 'description': '' | |
| }, | |
| 'DEPTH': { | |
| 'value': 28, | |
| 'description': '' | |
| }, | |
| 'NUM_HEADS': { | |
| 'value': 16, | |
| 'description': '' | |
| }, | |
| 'MLP_RATIO': { | |
| 'value': 4.0, | |
| 'description': '' | |
| }, | |
| 'PRED_SIGMA': { | |
| 'value': True, | |
| 'description': '' | |
| }, | |
| 'DROP_PATH': { | |
| 'value': 0., | |
| 'description': '' | |
| }, | |
| 'WINDOW_SIZE': { | |
| 'value': 0, | |
| 'description': '' | |
| }, | |
| 'WINDOW_BLOCK_INDEXES': { | |
| 'value': None, | |
| 'description': '' | |
| }, | |
| 'Y_CHANNELS': { | |
| 'value': 4096, | |
| 'description': '' | |
| }, | |
| 'ATTENTION_BACKEND': { | |
| 'value': None, | |
| 'description': '' | |
| }, | |
| 'QK_NORM': { | |
| 'value': True, | |
| 'description': 'Whether to use RMSNorm for query and key.', | |
| }, | |
| } | |
| para_dict.update(BaseModel.para_dict) | |
| def __init__(self, cfg, logger): | |
| super().__init__(cfg, logger=logger) | |
| self.window_block_indexes = cfg.get('WINDOW_BLOCK_INDEXES', None) | |
| if self.window_block_indexes is None: | |
| self.window_block_indexes = [] | |
| self.pred_sigma = cfg.get('PRED_SIGMA', True) | |
| self.in_channels = cfg.get('IN_CHANNELS', 4) | |
| self.out_channels = self.in_channels * 2 if self.pred_sigma else self.in_channels | |
| self.patch_size = cfg.get('PATCH_SIZE', 2) | |
| self.num_heads = cfg.get('NUM_HEADS', 16) | |
| self.hidden_size = cfg.get('HIDDEN_SIZE', 1152) | |
| self.y_channels = cfg.get('Y_CHANNELS', 4096) | |
| self.drop_path = cfg.get('DROP_PATH', 0.) | |
| self.depth = cfg.get('DEPTH', 28) | |
| self.mlp_ratio = cfg.get('MLP_RATIO', 4.0) | |
| self.use_grad_checkpoint = cfg.get('USE_GRAD_CHECKPOINT', False) | |
| self.attention_backend = cfg.get('ATTENTION_BACKEND', None) | |
| self.max_seq_len = cfg.get('MAX_SEQ_LEN', 1024) | |
| self.qk_norm = cfg.get('QK_NORM', False) | |
| self.ignore_keys = cfg.get('IGNORE_KEYS', []) | |
| assert (self.hidden_size % self.num_heads | |
| ) == 0 and (self.hidden_size // self.num_heads) % 2 == 0 | |
| d = self.hidden_size // self.num_heads | |
| self.freqs = torch.cat( | |
| [ | |
| rope_params(self.max_seq_len, d - 4 * (d // 6)), # T (~1/3) | |
| rope_params(self.max_seq_len, 2 * (d // 6)), # H (~1/3) | |
| rope_params(self.max_seq_len, 2 * (d // 6)) # W (~1/3) | |
| ], | |
| dim=1) | |
| # init embedder | |
| self.x_embedder = PatchEmbed(self.patch_size, | |
| self.in_channels + 1, | |
| self.hidden_size, | |
| bias=True, | |
| flatten=False) | |
| self.t_embedder = TimestepEmbedder(self.hidden_size) | |
| self.y_embedder = Mlp(in_features=self.y_channels, | |
| hidden_features=self.hidden_size, | |
| out_features=self.hidden_size, | |
| act_layer=lambda: nn.GELU(approximate='tanh'), | |
| drop=0) | |
| self.t_block = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=True)) | |
| # init blocks | |
| drop_path = [ | |
| x.item() for x in torch.linspace(0, self.drop_path, self.depth) | |
| ] | |
| self.blocks = nn.ModuleList([ | |
| DiTACEBlock(self.hidden_size, | |
| self.num_heads, | |
| mlp_ratio=self.mlp_ratio, | |
| drop_path=drop_path[i], | |
| window_size=self.window_size | |
| if i in self.window_block_indexes else 0, | |
| backend=self.attention_backend, | |
| use_condition=True, | |
| qk_norm=self.qk_norm) for i in range(self.depth) | |
| ]) | |
| self.final_layer = T2IFinalLayer(self.hidden_size, self.patch_size, | |
| self.out_channels) | |
| self.initialize_weights() | |
| def load_pretrained_model(self, pretrained_model): | |
| if pretrained_model: | |
| with FS.get_from(pretrained_model, wait_finish=True) as local_path: | |
| model = torch.load(local_path, map_location='cpu') | |
| if 'state_dict' in model: | |
| model = model['state_dict'] | |
| new_ckpt = OrderedDict() | |
| for k, v in model.items(): | |
| if self.ignore_keys is not None: | |
| if (isinstance(self.ignore_keys, str) and re.match(self.ignore_keys, k)) or \ | |
| (isinstance(self.ignore_keys, list) and k in self.ignore_keys): | |
| continue | |
| k = k.replace('.cross_attn.q_linear.', '.cross_attn.q.') | |
| k = k.replace('.cross_attn.proj.', | |
| '.cross_attn.o.').replace( | |
| '.attn.proj.', '.attn.o.') | |
| if '.cross_attn.kv_linear.' in k: | |
| k_p, v_p = torch.split(v, v.shape[0] // 2) | |
| new_ckpt[k.replace('.cross_attn.kv_linear.', | |
| '.cross_attn.k.')] = k_p | |
| new_ckpt[k.replace('.cross_attn.kv_linear.', | |
| '.cross_attn.v.')] = v_p | |
| elif '.attn.qkv.' in k: | |
| q_p, k_p, v_p = torch.split(v, v.shape[0] // 3) | |
| new_ckpt[k.replace('.attn.qkv.', '.attn.q.')] = q_p | |
| new_ckpt[k.replace('.attn.qkv.', '.attn.k.')] = k_p | |
| new_ckpt[k.replace('.attn.qkv.', '.attn.v.')] = v_p | |
| elif 'y_embedder.y_proj.' in k: | |
| new_ckpt[k.replace('y_embedder.y_proj.', | |
| 'y_embedder.')] = v | |
| elif k in ('x_embedder.proj.weight'): | |
| model_p = self.state_dict()[k] | |
| if v.shape != model_p.shape: | |
| model_p.zero_() | |
| model_p[:, :4, :, :].copy_(v) | |
| new_ckpt[k] = torch.nn.parameter.Parameter(model_p) | |
| else: | |
| new_ckpt[k] = v | |
| elif k in ('x_embedder.proj.bias'): | |
| new_ckpt[k] = v | |
| else: | |
| new_ckpt[k] = v | |
| missing, unexpected = self.load_state_dict(new_ckpt, | |
| strict=False) | |
| print( | |
| f'Restored from {pretrained_model} with {len(missing)} missing and {len(unexpected)} unexpected keys' | |
| ) | |
| if len(missing) > 0: | |
| print(f'Missing Keys:\n {missing}') | |
| if len(unexpected) > 0: | |
| print(f'\nUnexpected Keys:\n {unexpected}') | |
| def forward(self, | |
| x, | |
| t=None, | |
| cond=dict(), | |
| mask=None, | |
| text_position_embeddings=None, | |
| gc_seg=-1, | |
| **kwargs): | |
| if self.freqs.device != x.device: | |
| self.freqs = self.freqs.to(x.device) | |
| if isinstance(cond, dict): | |
| context = cond.get('crossattn', None) | |
| else: | |
| context = cond | |
| if text_position_embeddings is not None: | |
| # default use the text_position_embeddings in state_dict | |
| # if state_dict doesn't including this key, use the arg: text_position_embeddings | |
| proj_position_embeddings = self.y_embedder( | |
| text_position_embeddings) | |
| else: | |
| proj_position_embeddings = None | |
| ctx_batch, txt_lens = [], [] | |
| if mask is not None and isinstance(mask, list): | |
| for ctx, ctx_mask in zip(context, mask): | |
| for frame_id, one_ctx in enumerate(zip(ctx, ctx_mask)): | |
| u, m = one_ctx | |
| t_len = m.flatten().sum() # l | |
| u = u[:t_len] | |
| u = self.y_embedder(u) | |
| if frame_id == 0: | |
| u = u + proj_position_embeddings[ | |
| len(ctx) - | |
| 1] if proj_position_embeddings is not None else u | |
| else: | |
| u = u + proj_position_embeddings[ | |
| frame_id - | |
| 1] if proj_position_embeddings is not None else u | |
| ctx_batch.append(u) | |
| txt_lens.append(t_len) | |
| else: | |
| raise TypeError | |
| y = torch.cat(ctx_batch, dim=0) | |
| txt_lens = torch.LongTensor(txt_lens).to(x.device, non_blocking=True) | |
| batch_frames = [] | |
| for u, shape, m in zip(x, cond['x_shapes'], cond['x_mask']): | |
| u = u[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1]) | |
| m = torch.ones_like(u[[0], :, :]) if m is None else m.squeeze(0) | |
| batch_frames.append([torch.cat([u, m], dim=0).unsqueeze(0)]) | |
| if 'edit' in cond: | |
| for i, (edit, edit_mask) in enumerate( | |
| zip(cond['edit'], cond['edit_mask'])): | |
| if edit is None: | |
| continue | |
| for u, m in zip(edit, edit_mask): | |
| u = u.squeeze(0) | |
| m = torch.ones_like( | |
| u[[0], :, :]) if m is None else m.squeeze(0) | |
| batch_frames[i].append( | |
| torch.cat([u, m], dim=0).unsqueeze(0)) | |
| patch_batch, shape_batch, self_x_len, cross_x_len = [], [], [], [] | |
| for frames in batch_frames: | |
| patches, patch_shapes = [], [] | |
| self_x_len.append(0) | |
| for frame_id, u in enumerate(frames): | |
| u = self.x_embedder(u) | |
| h, w = u.size(2), u.size(3) | |
| u = rearrange(u, '1 c h w -> (h w) c') | |
| if frame_id == 0: | |
| u = u + proj_position_embeddings[ | |
| len(frames) - | |
| 1] if proj_position_embeddings is not None else u | |
| else: | |
| u = u + proj_position_embeddings[ | |
| frame_id - | |
| 1] if proj_position_embeddings is not None else u | |
| patches.append(u) | |
| patch_shapes.append([h, w]) | |
| cross_x_len.append(h * w) # b*s, 1 | |
| self_x_len[-1] += h * w # b, 1 | |
| # u = torch.cat(patches, dim=0) | |
| patch_batch.extend(patches) | |
| shape_batch.append( | |
| torch.LongTensor(patch_shapes).to(x.device, non_blocking=True)) | |
| # repeat t to align with x | |
| t = torch.cat([t[i].repeat(l) for i, l in enumerate(self_x_len)]) | |
| self_x_len, cross_x_len = (torch.LongTensor(self_x_len).to( | |
| x.device, non_blocking=True), torch.LongTensor(cross_x_len).to( | |
| x.device, non_blocking=True)) | |
| # x = pad_sequence(tuple(patch_batch), batch_first=True) # b, s*max(cl), c | |
| x = torch.cat(patch_batch, dim=0) | |
| x_shapes = pad_sequence(tuple(shape_batch), | |
| batch_first=True) # b, max(len(frames)), 2 | |
| t = self.t_embedder(t) # (N, D) | |
| t0 = self.t_block(t) | |
| # y = self.y_embedder(context) | |
| kwargs = dict(y=y, | |
| t=t0, | |
| x_shapes=x_shapes, | |
| self_x_len=self_x_len, | |
| cross_x_len=cross_x_len, | |
| freqs=self.freqs, | |
| txt_lens=txt_lens) | |
| if self.use_grad_checkpoint and gc_seg >= 0: | |
| x = checkpoint_sequential( | |
| functions=[partial(block, **kwargs) for block in self.blocks], | |
| segments=gc_seg if gc_seg > 0 else len(self.blocks), | |
| input=x, | |
| use_reentrant=False) | |
| else: | |
| for block in self.blocks: | |
| x = block(x, **kwargs) | |
| x = self.final_layer(x, t) # b*s*n, d | |
| outs, cur_length = [], 0 | |
| p = self.patch_size | |
| for seq_length, shape in zip(self_x_len, shape_batch): | |
| x_i = x[cur_length:cur_length + seq_length] | |
| h, w = shape[0].tolist() | |
| u = x_i[:h * w].view(h, w, p, p, -1) | |
| u = rearrange(u, 'h w p q c -> (h p w q) c' | |
| ) # dump into sequence for following tensor ops | |
| cur_length = cur_length + seq_length | |
| outs.append(u) | |
| x = pad_sequence(tuple(outs), batch_first=True).permute(0, 2, 1) | |
| if self.pred_sigma: | |
| return x.chunk(2, dim=1)[0] | |
| else: | |
| return x | |
| def initialize_weights(self): | |
| # Initialize transformer layers: | |
| def _basic_init(module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias, 0) | |
| self.apply(_basic_init) | |
| # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): | |
| w = self.x_embedder.proj.weight.data | |
| nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
| # Initialize timestep embedding MLP: | |
| nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) | |
| nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) | |
| nn.init.normal_(self.t_block[1].weight, std=0.02) | |
| # Initialize caption embedding MLP: | |
| if hasattr(self, 'y_embedder'): | |
| nn.init.normal_(self.y_embedder.fc1.weight, std=0.02) | |
| nn.init.normal_(self.y_embedder.fc2.weight, std=0.02) | |
| # Zero-out adaLN modulation layers | |
| for block in self.blocks: | |
| nn.init.constant_(block.cross_attn.o.weight, 0) | |
| nn.init.constant_(block.cross_attn.o.bias, 0) | |
| # Zero-out output layers: | |
| nn.init.constant_(self.final_layer.linear.weight, 0) | |
| nn.init.constant_(self.final_layer.linear.bias, 0) | |
| def dtype(self): | |
| return next(self.parameters()).dtype | |
| def get_config_template(): | |
| return dict_to_yaml('BACKBONE', | |
| __class__.__name__, | |
| DiTACE.para_dict, | |
| set_name=True) | |