Spaces:
Running
on
Zero
Running
on
Zero
| import clip | |
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| import numpy as np | |
| from einops.layers.torch import Rearrange | |
| from einops import rearrange | |
| import matplotlib.pyplot as plt | |
| import os | |
| import torch.nn as nn | |
| # Custom LayerNorm class to handle fp16 | |
| class CustomLayerNorm(nn.LayerNorm): | |
| def forward(self, x: torch.Tensor): | |
| if self.weight.dtype == torch.float32: | |
| orig_type = x.dtype | |
| ret = super().forward(x.type(torch.float32)) | |
| return ret.type(orig_type) | |
| else: | |
| return super().forward(x) | |
| # Function to replace LayerNorm in CLIP model with CustomLayerNorm | |
| def replace_layer_norm(model): | |
| for name, module in model.named_children(): | |
| if isinstance(module, nn.LayerNorm): | |
| setattr(model, name, CustomLayerNorm(module.normalized_shape, elementwise_affine=module.elementwise_affine).cuda()) | |
| else: | |
| replace_layer_norm(module) # Recursively apply to all submodules | |
| MONITOR_ATTN = [] | |
| SELF_ATTN = [] | |
| def vis_attn(att, out_path, step, layer, shape, type_="self", lines=True): | |
| if lines: | |
| plt.figure(figsize=(10, 3)) | |
| for token_index in range(att.shape[1]): | |
| plt.plot(att[:, token_index], label=f"Token {token_index}") | |
| plt.title("Attention Values for Each Token") | |
| plt.xlabel("time") | |
| plt.ylabel("Attention Value") | |
| plt.legend(loc="upper right", bbox_to_anchor=(1.15, 1)) | |
| # save image | |
| savepath = os.path.join(out_path, f"vis-{type_}/step{str(step)}/layer{str(layer)}_lines_{shape}.png") | |
| os.makedirs(os.path.dirname(savepath), exist_ok=True) | |
| plt.savefig(savepath, bbox_inches="tight") | |
| np.save(savepath.replace(".png", ".npy"), att) | |
| else: | |
| plt.figure(figsize=(10, 10)) | |
| plt.imshow(att.transpose(), cmap="viridis", aspect="auto") | |
| plt.colorbar() | |
| plt.title("Attention Matrix Heatmap") | |
| plt.ylabel("time") | |
| plt.xlabel("time") | |
| # save image | |
| savepath = os.path.join(out_path, f"vis-{type_}/step{str(step)}/layer{str(layer)}_heatmap_{shape}.png") | |
| os.makedirs(os.path.dirname(savepath), exist_ok=True) | |
| plt.savefig(savepath, bbox_inches="tight") | |
| np.save(savepath.replace(".png", ".npy"), att) | |
| def zero_module(module): | |
| """ | |
| Zero out the parameters of a module and return it. | |
| """ | |
| for p in module.parameters(): | |
| p.detach().zero_() | |
| return module | |
| class FFN(nn.Module): | |
| def __init__(self, latent_dim, ffn_dim, dropout): | |
| super().__init__() | |
| self.linear1 = nn.Linear(latent_dim, ffn_dim) | |
| self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim)) | |
| self.activation = nn.GELU() | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| y = self.linear2(self.dropout(self.activation(self.linear1(x)))) | |
| y = x + y | |
| return y | |
| class Conv1dAdaGNBlock(nn.Module): | |
| """ | |
| Conv1d --> GroupNorm --> scale,shift --> Mish | |
| """ | |
| def __init__(self, inp_channels, out_channels, kernel_size, n_groups=4): | |
| super().__init__() | |
| self.out_channels = out_channels | |
| self.block = nn.Conv1d( | |
| inp_channels, out_channels, kernel_size, padding=kernel_size // 2 | |
| ) | |
| self.group_norm = nn.GroupNorm(n_groups, out_channels) | |
| self.avtication = nn.Mish() | |
| def forward(self, x, scale, shift): | |
| """ | |
| Args: | |
| x: [bs, nfeat, nframes] | |
| scale: [bs, out_feat, 1] | |
| shift: [bs, out_feat, 1] | |
| """ | |
| x = self.block(x) | |
| batch_size, channels, horizon = x.size() | |
| x = rearrange( | |
| x, "batch channels horizon -> (batch horizon) channels" | |
| ) # [bs*seq, nfeats] | |
| x = self.group_norm(x) | |
| x = rearrange( | |
| x.reshape(batch_size, horizon, channels), | |
| "batch horizon channels -> batch channels horizon", | |
| ) | |
| x = ada_shift_scale(x, shift, scale) | |
| return self.avtication(x) | |
| class SelfAttention(nn.Module): | |
| def __init__( | |
| self, | |
| latent_dim, | |
| text_latent_dim, | |
| num_heads: int = 8, | |
| dropout: float = 0.0, | |
| log_attn=False, | |
| edit_config=None, | |
| ): | |
| super().__init__() | |
| self.num_head = num_heads | |
| self.norm = nn.LayerNorm(latent_dim) | |
| self.query = nn.Linear(latent_dim, latent_dim) | |
| self.key = nn.Linear(latent_dim, latent_dim) | |
| self.value = nn.Linear(latent_dim, latent_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| self.edit_config = edit_config | |
| self.log_attn = log_attn | |
| def forward(self, x): | |
| """ | |
| x: B, T, D | |
| xf: B, N, L | |
| """ | |
| B, T, D = x.shape | |
| N = x.shape[1] | |
| assert N == T | |
| H = self.num_head | |
| # B, T, 1, D | |
| query = self.query(self.norm(x)).unsqueeze(2) | |
| # B, 1, N, D | |
| key = self.key(self.norm(x)).unsqueeze(1) | |
| query = query.view(B, T, H, -1) | |
| key = key.view(B, N, H, -1) | |
| # style transfer motion editing | |
| style_tranfer = self.edit_config.style_tranfer.use | |
| if style_tranfer: | |
| if ( | |
| len(SELF_ATTN) | |
| <= self.edit_config.style_tranfer.style_transfer_steps_end | |
| ): | |
| query[1] = query[0] | |
| # example based motion generation | |
| example_based = self.edit_config.example_based.use | |
| if example_based: | |
| if len(SELF_ATTN) == self.edit_config.example_based.example_based_steps_end: | |
| temp_seed = self.edit_config.example_based.temp_seed | |
| for id_ in range(query.shape[0] - 1): | |
| with torch.random.fork_rng(): | |
| torch.manual_seed(temp_seed) | |
| tensor = query[0] | |
| chunks = torch.split( | |
| tensor, self.edit_config.example_based.chunk_size, dim=0 | |
| ) | |
| shuffled_indices = torch.randperm(len(chunks)) | |
| shuffled_chunks = [chunks[i] for i in shuffled_indices] | |
| shuffled_tensor = torch.cat(shuffled_chunks, dim=0) | |
| query[1 + id_] = shuffled_tensor | |
| temp_seed += self.edit_config.example_based.temp_seed_bar | |
| # time shift motion editing (q, k) | |
| time_shift = self.edit_config.time_shift.use | |
| if time_shift: | |
| if len(MONITOR_ATTN) <= self.edit_config.time_shift.time_shift_steps_end: | |
| part1 = int( | |
| key.shape[1] * self.edit_config.time_shift.time_shift_ratio // 1 | |
| ) | |
| part2 = int( | |
| key.shape[1] | |
| * (1 - self.edit_config.time_shift.time_shift_ratio) | |
| // 1 | |
| ) | |
| q_front_part = query[0, :part1, :, :] | |
| q_back_part = query[0, -part2:, :, :] | |
| new_q = torch.cat((q_back_part, q_front_part), dim=0) | |
| query[1] = new_q | |
| k_front_part = key[0, :part1, :, :] | |
| k_back_part = key[0, -part2:, :, :] | |
| new_k = torch.cat((k_back_part, k_front_part), dim=0) | |
| key[1] = new_k | |
| # B, T, N, H | |
| attention = torch.einsum("bnhd,bmhd->bnmh", query, key) / math.sqrt(D // H) | |
| weight = self.dropout(F.softmax(attention, dim=2)) | |
| # for counting the step and logging attention maps | |
| try: | |
| attention_matrix = ( | |
| weight[0, :, :].mean(dim=-1).detach().cpu().numpy().astype(float) | |
| ) | |
| SELF_ATTN[-1].append(attention_matrix) | |
| except: | |
| pass | |
| # attention manipulation for replacement | |
| attention_manipulation = self.edit_config.manipulation.use | |
| if attention_manipulation: | |
| if len(SELF_ATTN) <= self.edit_config.manipulation.manipulation_steps_end: | |
| weight[1, :, :, :] = weight[0, :, :, :] | |
| value = self.value(self.norm(x)).view(B, N, H, -1) | |
| # time shift motion editing (v) | |
| if time_shift: | |
| if len(MONITOR_ATTN) <= self.edit_config.time_shift.time_shift_steps_end: | |
| v_front_part = value[0, :part1, :, :] | |
| v_back_part = value[0, -part2:, :, :] | |
| new_v = torch.cat((v_back_part, v_front_part), dim=0) | |
| value[1] = new_v | |
| y = torch.einsum("bnmh,bmhd->bnhd", weight, value).reshape(B, T, D) | |
| return y | |
| class TimestepEmbedder(nn.Module): | |
| def __init__(self, d_model, max_len=5000): | |
| super(TimestepEmbedder, self).__init__() | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp( | |
| torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model) | |
| ) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| self.register_buffer("pe", pe) | |
| def forward(self, x): | |
| self.pe = self.pe.cuda() | |
| return self.pe[x] | |
| class Downsample1d(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.conv = nn.Conv1d(dim, dim, 3, 2, 1) | |
| def forward(self, x): | |
| self.conv = self.conv.cuda() | |
| return self.conv(x) | |
| class Upsample1d(nn.Module): | |
| def __init__(self, dim_in, dim_out=None): | |
| super().__init__() | |
| dim_out = dim_out or dim_in | |
| self.conv = nn.ConvTranspose1d(dim_in, dim_out, 4, 2, 1) | |
| def forward(self, x): | |
| self.conv = self.conv.cuda() | |
| return self.conv(x) | |
| class Conv1dBlock(nn.Module): | |
| """ | |
| Conv1d --> GroupNorm --> Mish | |
| """ | |
| def __init__(self, inp_channels, out_channels, kernel_size, n_groups=4, zero=False): | |
| super().__init__() | |
| self.out_channels = out_channels | |
| self.block = nn.Conv1d( | |
| inp_channels, out_channels, kernel_size, padding=kernel_size // 2 | |
| ) | |
| self.norm = nn.GroupNorm(n_groups, out_channels) | |
| self.activation = nn.Mish() | |
| if zero: | |
| # zero init the convolution | |
| nn.init.zeros_(self.block.weight) | |
| nn.init.zeros_(self.block.bias) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: [bs, nfeat, nframes] | |
| """ | |
| x = self.block(x) | |
| batch_size, channels, horizon = x.size() | |
| x = rearrange( | |
| x, "batch channels horizon -> (batch horizon) channels" | |
| ) # [bs*seq, nfeats] | |
| x = self.norm(x) | |
| x = rearrange( | |
| x.reshape(batch_size, horizon, channels), | |
| "batch horizon channels -> batch channels horizon", | |
| ) | |
| return self.activation(x) | |
| def ada_shift_scale(x, shift, scale): | |
| return x * (1 + scale) + shift | |
| class ResidualTemporalBlock(nn.Module): | |
| def __init__( | |
| self, | |
| inp_channels, | |
| out_channels, | |
| embed_dim, | |
| kernel_size=5, | |
| zero=True, | |
| n_groups=8, | |
| dropout: float = 0.1, | |
| adagn=True, | |
| ): | |
| super().__init__() | |
| self.adagn = adagn | |
| self.blocks = nn.ModuleList( | |
| [ | |
| # adagn only the first conv (following guided-diffusion) | |
| ( | |
| Conv1dAdaGNBlock(inp_channels, out_channels, kernel_size, n_groups) | |
| if adagn | |
| else Conv1dBlock(inp_channels, out_channels, kernel_size) | |
| ), | |
| Conv1dBlock( | |
| out_channels, out_channels, kernel_size, n_groups, zero=zero | |
| ), | |
| ] | |
| ) | |
| self.time_mlp = nn.Sequential( | |
| nn.Mish(), | |
| # adagn = scale and shift | |
| nn.Linear(embed_dim, out_channels * 2 if adagn else out_channels), | |
| Rearrange("batch t -> batch t 1"), | |
| ) | |
| self.dropout = nn.Dropout(dropout) | |
| if zero: | |
| nn.init.zeros_(self.time_mlp[1].weight) | |
| nn.init.zeros_(self.time_mlp[1].bias) | |
| self.residual_conv = ( | |
| nn.Conv1d(inp_channels, out_channels, 1) | |
| if inp_channels != out_channels | |
| else nn.Identity() | |
| ) | |
| def forward(self, x, time_embeds=None): | |
| """ | |
| x : [ batch_size x inp_channels x nframes ] | |
| t : [ batch_size x embed_dim ] | |
| returns: [ batch_size x out_channels x nframes ] | |
| """ | |
| if self.adagn: | |
| scale, shift = self.time_mlp(time_embeds).chunk(2, dim=1) | |
| out = self.blocks[0](x, scale, shift) | |
| else: | |
| out = self.blocks[0](x) + self.time_mlp(time_embeds) | |
| out = self.blocks[1](out) | |
| out = self.dropout(out) | |
| return out + self.residual_conv(x) | |
| class CrossAttention(nn.Module): | |
| def __init__( | |
| self, | |
| latent_dim, | |
| text_latent_dim, | |
| num_heads: int = 8, | |
| dropout: float = 0.0, | |
| log_attn=False, | |
| edit_config=None, | |
| ): | |
| super().__init__() | |
| self.num_head = num_heads | |
| self.norm = nn.LayerNorm(latent_dim) | |
| self.text_norm = nn.LayerNorm(text_latent_dim) | |
| self.query = nn.Linear(latent_dim, latent_dim) | |
| self.key = nn.Linear(text_latent_dim, latent_dim) | |
| self.value = nn.Linear(text_latent_dim, latent_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| self.edit_config = edit_config | |
| self.log_attn = log_attn | |
| def forward(self, x, xf): | |
| """ | |
| x: B, T, D | |
| xf: B, N, L | |
| """ | |
| B, T, D = x.shape | |
| N = xf.shape[1] | |
| H = self.num_head | |
| # B, T, 1, D | |
| query = self.query(self.norm(x)).unsqueeze(2) | |
| # B, 1, N, D | |
| key = self.key(self.text_norm(xf)).unsqueeze(1) | |
| query = query.view(B, T, H, -1) | |
| key = key.view(B, N, H, -1) | |
| # B, T, N, H | |
| attention = torch.einsum("bnhd,bmhd->bnmh", query, key) / math.sqrt(D // H) | |
| weight = self.dropout(F.softmax(attention, dim=2)) | |
| # attention reweighting for (de)-emphasizing motion | |
| if self.edit_config.reweighting_attn.use: | |
| reweighting_attn = self.edit_config.reweighting_attn.reweighting_attn_weight | |
| if self.edit_config.reweighting_attn.idx == -1: | |
| # read idxs from txt file | |
| with open("./assets/reweighting_idx.txt", "r") as f: | |
| idxs = f.readlines() | |
| else: | |
| # gradio demo mode | |
| idxs = [0, self.edit_config.reweighting_attn.idx] | |
| idxs = [int(idx) for idx in idxs] | |
| for i in range(len(idxs)): | |
| weight[i, :, 1 + idxs[i]] = weight[i, :, 1 + idxs[i]] + reweighting_attn | |
| weight[i, :, 1 + idxs[i] + 1] = ( | |
| weight[i, :, 1 + idxs[i] + 1] + reweighting_attn | |
| ) | |
| # for counting the step and logging attention maps | |
| try: | |
| attention_matrix = ( | |
| weight[0, :, 1 : 1 + 3] | |
| .mean(dim=-1) | |
| .detach() | |
| .cpu() | |
| .numpy() | |
| .astype(float) | |
| ) | |
| MONITOR_ATTN[-1].append(attention_matrix) | |
| except: | |
| pass | |
| # erasing motion (autually is the deemphasizing motion) | |
| erasing_motion = self.edit_config.erasing_motion.use | |
| if erasing_motion: | |
| reweighting_attn = self.edit_config.erasing_motion.erasing_motion_weight | |
| begin = self.edit_config.erasing_motion.time_start | |
| end = self.edit_config.erasing_motion.time_end | |
| idx = self.edit_config.erasing_motion.idx | |
| if reweighting_attn > 0.01 or reweighting_attn < -0.01: | |
| weight[1, int(T * begin) : int(T * end), idx] = ( | |
| weight[1, int(T * begin) : int(T * end) :, idx] * reweighting_attn | |
| ) | |
| weight[1, int(T * begin) : int(T * end), idx + 1] = ( | |
| weight[1, int(T * begin) : int(T * end), idx + 1] * reweighting_attn | |
| ) | |
| # attention manipulation for motion replacement | |
| manipulation = self.edit_config.manipulation.use | |
| if manipulation: | |
| if ( | |
| len(MONITOR_ATTN) | |
| <= self.edit_config.manipulation.manipulation_steps_end_crossattn | |
| ): | |
| word_idx = self.edit_config.manipulation.word_idx | |
| weight[1, :, : 1 + word_idx, :] = weight[0, :, : 1 + word_idx, :] | |
| weight[1, :, 1 + word_idx + 1 :, :] = weight[ | |
| 0, :, 1 + word_idx + 1 :, : | |
| ] | |
| value = self.value(self.text_norm(xf)).view(B, N, H, -1) | |
| y = torch.einsum("bnmh,bmhd->bnhd", weight, value).reshape(B, T, D) | |
| return y | |
| class ResidualCLRAttentionLayer(nn.Module): | |
| def __init__( | |
| self, | |
| dim1, | |
| dim2, | |
| num_heads: int = 8, | |
| dropout: float = 0.1, | |
| no_eff: bool = False, | |
| self_attention: bool = False, | |
| log_attn=False, | |
| edit_config=None, | |
| ): | |
| super(ResidualCLRAttentionLayer, self).__init__() | |
| self.dim1 = dim1 | |
| self.dim2 = dim2 | |
| self.num_heads = num_heads | |
| # Multi-Head Attention Layer | |
| if no_eff: | |
| self.cross_attention = CrossAttention( | |
| latent_dim=dim1, | |
| text_latent_dim=dim2, | |
| num_heads=num_heads, | |
| dropout=dropout, | |
| log_attn=log_attn, | |
| edit_config=edit_config, | |
| ) | |
| else: | |
| self.cross_attention = LinearCrossAttention( | |
| latent_dim=dim1, | |
| text_latent_dim=dim2, | |
| num_heads=num_heads, | |
| dropout=dropout, | |
| log_attn=log_attn, | |
| ) | |
| if self_attention: | |
| self.self_attn_use = True | |
| self.self_attention = SelfAttention( | |
| latent_dim=dim1, | |
| text_latent_dim=dim2, | |
| num_heads=num_heads, | |
| dropout=dropout, | |
| log_attn=log_attn, | |
| edit_config=edit_config, | |
| ) | |
| else: | |
| self.self_attn_use = False | |
| def forward(self, input_tensor, condition_tensor, cond_indices): | |
| """ | |
| input_tensor :B, D, L | |
| condition_tensor: B, L, D | |
| """ | |
| if cond_indices.numel() == 0: | |
| return input_tensor | |
| # self attention | |
| if self.self_attn_use: | |
| x = input_tensor | |
| x = x.permute(0, 2, 1) # (batch_size, seq_length, feat_dim) | |
| x = self.self_attention(x) | |
| x = x.permute(0, 2, 1) # (batch_size, feat_dim, seq_length) | |
| input_tensor = input_tensor + x | |
| x = input_tensor | |
| # cross attention | |
| x = x[cond_indices].permute(0, 2, 1) # (batch_size, seq_length, feat_dim) | |
| x = self.cross_attention(x, condition_tensor[cond_indices]) | |
| x = x.permute(0, 2, 1) # (batch_size, feat_dim, seq_length) | |
| input_tensor[cond_indices] = input_tensor[cond_indices] + x | |
| return input_tensor | |
| class CLRBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim_in, | |
| dim_out, | |
| cond_dim, | |
| time_dim, | |
| adagn=True, | |
| zero=True, | |
| no_eff=False, | |
| self_attention=False, | |
| dropout: float = 0.1, | |
| log_attn=False, | |
| edit_config=None, | |
| ) -> None: | |
| super().__init__() | |
| self.conv1d = ResidualTemporalBlock( | |
| dim_in, dim_out, embed_dim=time_dim, adagn=adagn, zero=zero, dropout=dropout | |
| ) | |
| self.clr_attn = ResidualCLRAttentionLayer( | |
| dim1=dim_out, | |
| dim2=cond_dim, | |
| no_eff=no_eff, | |
| dropout=dropout, | |
| self_attention=self_attention, | |
| log_attn=log_attn, | |
| edit_config=edit_config, | |
| ) | |
| # import pdb; pdb.set_trace() | |
| self.ffn = FFN(dim_out, dim_out * 4, dropout=dropout) | |
| def forward(self, x, t, cond, cond_indices=None): | |
| x = self.conv1d(x, t) | |
| x = self.clr_attn(x, cond, cond_indices) | |
| x = self.ffn(x.permute(0, 2, 1)).permute(0, 2, 1) | |
| return x | |
| class CondUnet1D(nn.Module): | |
| """ | |
| Diffusion's style UNET with 1D convolution and adaptive group normalization for motion suquence denoising, | |
| cross-attention to introduce conditional prompts (like text). | |
| """ | |
| def __init__( | |
| self, | |
| input_dim, | |
| cond_dim, | |
| dim=128, | |
| dim_mults=(1, 2, 4, 8), | |
| dims=None, | |
| time_dim=512, | |
| adagn=True, | |
| zero=True, | |
| dropout=0.1, | |
| no_eff=False, | |
| self_attention=False, | |
| log_attn=False, | |
| edit_config=None, | |
| ): | |
| super().__init__() | |
| if not dims: | |
| dims = [input_dim, *map(lambda m: int(dim * m), dim_mults)] ##[d, d,2d,4d] | |
| print("dims: ", dims, "mults: ", dim_mults) | |
| in_out = list(zip(dims[:-1], dims[1:])) | |
| self.time_mlp = nn.Sequential( | |
| TimestepEmbedder(time_dim), | |
| nn.Linear(time_dim, time_dim * 4), | |
| nn.Mish(), | |
| nn.Linear(time_dim * 4, time_dim), | |
| ) | |
| self.downs = nn.ModuleList([]) | |
| self.ups = nn.ModuleList([]) | |
| for ind, (dim_in, dim_out) in enumerate(in_out): | |
| self.downs.append( | |
| nn.ModuleList( | |
| [ | |
| CLRBlock( | |
| dim_in, | |
| dim_out, | |
| cond_dim, | |
| time_dim, | |
| adagn=adagn, | |
| zero=zero, | |
| no_eff=no_eff, | |
| dropout=dropout, | |
| self_attention=self_attention, | |
| log_attn=log_attn, | |
| edit_config=edit_config, | |
| ), | |
| CLRBlock( | |
| dim_out, | |
| dim_out, | |
| cond_dim, | |
| time_dim, | |
| adagn=adagn, | |
| zero=zero, | |
| no_eff=no_eff, | |
| dropout=dropout, | |
| self_attention=self_attention, | |
| log_attn=log_attn, | |
| edit_config=edit_config, | |
| ), | |
| Downsample1d(dim_out), | |
| ] | |
| ) | |
| ) | |
| mid_dim = dims[-1] | |
| self.mid_block1 = CLRBlock( | |
| dim_in=mid_dim, | |
| dim_out=mid_dim, | |
| cond_dim=cond_dim, | |
| time_dim=time_dim, | |
| adagn=adagn, | |
| zero=zero, | |
| no_eff=no_eff, | |
| dropout=dropout, | |
| self_attention=self_attention, | |
| log_attn=log_attn, | |
| edit_config=edit_config, | |
| ) | |
| self.mid_block2 = CLRBlock( | |
| dim_in=mid_dim, | |
| dim_out=mid_dim, | |
| cond_dim=cond_dim, | |
| time_dim=time_dim, | |
| adagn=adagn, | |
| zero=zero, | |
| no_eff=no_eff, | |
| dropout=dropout, | |
| self_attention=self_attention, | |
| log_attn=log_attn, | |
| edit_config=edit_config, | |
| ) | |
| last_dim = mid_dim | |
| for ind, dim_out in enumerate(reversed(dims[1:])): | |
| self.ups.append( | |
| nn.ModuleList( | |
| [ | |
| Upsample1d(last_dim, dim_out), | |
| CLRBlock( | |
| dim_out * 2, | |
| dim_out, | |
| cond_dim, | |
| time_dim, | |
| adagn=adagn, | |
| zero=zero, | |
| no_eff=no_eff, | |
| dropout=dropout, | |
| self_attention=self_attention, | |
| log_attn=log_attn, | |
| edit_config=edit_config, | |
| ), | |
| CLRBlock( | |
| dim_out, | |
| dim_out, | |
| cond_dim, | |
| time_dim, | |
| adagn=adagn, | |
| zero=zero, | |
| no_eff=no_eff, | |
| dropout=dropout, | |
| self_attention=self_attention, | |
| log_attn=log_attn, | |
| edit_config=edit_config, | |
| ), | |
| ] | |
| ) | |
| ) | |
| last_dim = dim_out | |
| self.final_conv = nn.Conv1d(dim_out, input_dim, 1) | |
| if zero: | |
| nn.init.zeros_(self.final_conv.weight) | |
| nn.init.zeros_(self.final_conv.bias) | |
| def forward( | |
| self, | |
| x, | |
| t, | |
| cond, | |
| cond_indices, | |
| ): | |
| self.time_mlp = self.time_mlp.cuda() | |
| temb = self.time_mlp(t) | |
| h = [] | |
| for block1, block2, downsample in self.downs: | |
| block1 = block1.cuda() | |
| block2 = block2.cuda() | |
| x = block1(x, temb, cond, cond_indices) | |
| x = block2(x, temb, cond, cond_indices) | |
| h.append(x) | |
| x = downsample(x) | |
| self.mid_block1 = self.mid_block1.cuda() | |
| self.mid_block2 = self.mid_block2.cuda() | |
| x = self.mid_block1(x, temb, cond, cond_indices) | |
| x = self.mid_block2(x, temb, cond, cond_indices) | |
| for upsample, block1, block2 in self.ups: | |
| x = upsample(x) | |
| x = torch.cat((x, h.pop()), dim=1) | |
| block1 = block1.cuda() | |
| block2 = block2.cuda() | |
| x = block1(x, temb, cond, cond_indices) | |
| x = block2(x, temb, cond, cond_indices) | |
| self.final_conv = self.final_conv.cuda() | |
| x = self.final_conv(x) | |
| return x | |
| class MotionCLR(nn.Module): | |
| """ | |
| Diffuser's style UNET for text-to-motion task. | |
| """ | |
| def __init__( | |
| self, | |
| input_feats, | |
| base_dim=128, | |
| dim_mults=(1, 2, 2, 2), | |
| dims=None, | |
| adagn=True, | |
| zero=True, | |
| dropout=0.1, | |
| no_eff=False, | |
| time_dim=512, | |
| latent_dim=256, | |
| cond_mask_prob=0.1, | |
| clip_dim=512, | |
| clip_version="ViT-B/32", | |
| text_latent_dim=256, | |
| text_ff_size=2048, | |
| text_num_heads=4, | |
| activation="gelu", | |
| num_text_layers=4, | |
| self_attention=False, | |
| vis_attn=False, | |
| edit_config=None, | |
| out_path=None, | |
| ): | |
| super().__init__() | |
| self.input_feats = input_feats | |
| self.dim_mults = dim_mults | |
| self.base_dim = base_dim | |
| self.latent_dim = latent_dim | |
| self.cond_mask_prob = cond_mask_prob | |
| self.vis_attn = vis_attn | |
| self.counting_map = [] | |
| self.out_path = out_path | |
| print( | |
| f"The T2M Unet mask the text prompt by {self.cond_mask_prob} prob. in training" | |
| ) | |
| # text encoder | |
| self.embed_text = nn.Linear(clip_dim, text_latent_dim) | |
| self.clip_version = clip_version | |
| self.clip_model = self.load_and_freeze_clip(clip_version) | |
| replace_layer_norm(self.clip_model) | |
| textTransEncoderLayer = nn.TransformerEncoderLayer( | |
| d_model=text_latent_dim, | |
| nhead=text_num_heads, | |
| dim_feedforward=text_ff_size, | |
| dropout=dropout, | |
| activation=activation, | |
| ) | |
| self.textTransEncoder = nn.TransformerEncoder( | |
| textTransEncoderLayer, num_layers=num_text_layers | |
| ) | |
| self.text_ln = nn.LayerNorm(text_latent_dim) | |
| self.unet = CondUnet1D( | |
| input_dim=self.input_feats, | |
| cond_dim=text_latent_dim, | |
| dim=self.base_dim, | |
| dim_mults=self.dim_mults, | |
| adagn=adagn, | |
| zero=zero, | |
| dropout=dropout, | |
| no_eff=no_eff, | |
| dims=dims, | |
| time_dim=time_dim, | |
| self_attention=self_attention, | |
| log_attn=self.vis_attn, | |
| edit_config=edit_config, | |
| ) | |
| self.clip_model = self.clip_model.cuda() | |
| self.embed_text = self.embed_text.cuda() | |
| self.textTransEncoder = self.textTransEncoder.cuda() | |
| self.text_ln = self.text_ln.cuda() | |
| self.unet = self.unet.cuda() | |
| def encode_text(self, raw_text, device): | |
| self.clip_model.token_embedding = self.clip_model.token_embedding.to(device) | |
| self.clip_model.transformer = self.clip_model.transformer.to(device) | |
| self.clip_model.ln_final = self.clip_model.ln_final.to(device) | |
| with torch.no_grad(): | |
| texts = clip.tokenize(raw_text, truncate=True).to( | |
| device | |
| ) # [bs, context_length] # if n_tokens > 77 -> will truncate | |
| x = self.clip_model.token_embedding(texts).type(self.clip_model.dtype).to(device) # [batch_size, n_ctx, d_model] | |
| x = x + self.clip_model.positional_embedding.type(self.clip_model.dtype).to(device) | |
| x = x.permute(1, 0, 2) # NLD -> LND | |
| x = self.clip_model.transformer(x) | |
| x = self.clip_model.ln_final(x).type( | |
| self.clip_model.dtype | |
| ) # [len, batch_size, 512] | |
| self.embed_text = self.embed_text.to(device) | |
| x = self.embed_text(x) # [len, batch_size, 256] | |
| self.textTransEncoder = self.textTransEncoder.to(device) | |
| x = self.textTransEncoder(x) | |
| self.text_ln = self.text_ln.to(device) | |
| x = self.text_ln(x) | |
| # T, B, D -> B, T, D | |
| xf_out = x.permute(1, 0, 2) | |
| ablation_text = False | |
| if ablation_text: | |
| xf_out[:, 1:, :] = xf_out[:, 0, :].unsqueeze(1) | |
| return xf_out | |
| def load_and_freeze_clip(self, clip_version): | |
| clip_model, _ = clip.load( # clip_model.dtype=float32 | |
| clip_version, device="cpu", jit=False | |
| ) # Must set jit=False for training | |
| # Freeze CLIP weights | |
| clip_model.eval() | |
| for p in clip_model.parameters(): | |
| p.requires_grad = False | |
| return clip_model | |
| def mask_cond(self, bs, force_mask=False): | |
| """ | |
| mask motion condition , return contitional motion index in the batch | |
| """ | |
| if force_mask: | |
| cond_indices = torch.empty(0) | |
| elif self.training and self.cond_mask_prob > 0.0: | |
| mask = torch.bernoulli( | |
| torch.ones( | |
| bs, | |
| ) | |
| * self.cond_mask_prob | |
| ) # 1-> use null_cond, 0-> use real cond | |
| mask = 1.0 - mask | |
| cond_indices = torch.nonzero(mask).squeeze(-1) | |
| else: | |
| cond_indices = torch.arange(bs) | |
| return cond_indices | |
| def forward( | |
| self, | |
| x, | |
| timesteps, | |
| text=None, | |
| uncond=False, | |
| enc_text=None, | |
| ): | |
| """ | |
| Args: | |
| x: [batch_size, nframes, nfeats], | |
| timesteps: [batch_size] (int) | |
| text: list (batch_size length) of strings with input text prompts | |
| uncond: whethere using text condition | |
| Returns: [batch_size, seq_length, nfeats] | |
| """ | |
| B, T, _ = x.shape | |
| x = x.transpose(1, 2) # [bs, nfeats, nframes] | |
| if enc_text is None: | |
| enc_text = self.encode_text(text, x.device) # [bs, seqlen, text_dim] | |
| cond_indices = self.mask_cond(x.shape[0], force_mask=uncond) | |
| # NOTE: need to pad to be the multiplier of 8 for the unet | |
| PADDING_NEEEDED = (16 - (T % 16)) % 16 | |
| padding = (0, PADDING_NEEEDED) | |
| x = F.pad(x, padding, value=0) | |
| x = self.unet( | |
| x, | |
| t=timesteps, | |
| cond=enc_text, | |
| cond_indices=cond_indices, | |
| ) # [bs, nfeats,, nframes] | |
| x = x[:, :, :T].transpose(1, 2) # [bs, nframes, nfeats,] | |
| return x | |
| def forward_with_cfg(self, x, timesteps, text=None, enc_text=None, cfg_scale=2.5): | |
| """ | |
| Args: | |
| x: [batch_size, nframes, nfeats], | |
| timesteps: [batch_size] (int) | |
| text: list (batch_size length) of strings with input text prompts | |
| Returns: [batch_size, max_frames, nfeats] | |
| """ | |
| global SELF_ATTN | |
| global MONITOR_ATTN | |
| MONITOR_ATTN.append([]) | |
| SELF_ATTN.append([]) | |
| B, T, _ = x.shape | |
| x = x.transpose(1, 2) # [bs, nfeats, nframes] | |
| if enc_text is None: | |
| enc_text = self.encode_text(text, x.device) # [bs, seqlen, text_dim] | |
| cond_indices = self.mask_cond(B) | |
| # NOTE: need to pad to be the multiplier of 8 for the unet | |
| PADDING_NEEEDED = (16 - (T % 16)) % 16 | |
| padding = (0, PADDING_NEEEDED) | |
| x = F.pad(x, padding, value=0) | |
| combined_x = torch.cat([x, x], dim=0) | |
| combined_t = torch.cat([timesteps, timesteps], dim=0) | |
| out = self.unet( | |
| x=combined_x, | |
| t=combined_t, | |
| cond=enc_text, | |
| cond_indices=cond_indices, | |
| ) # [bs, nfeats, nframes] | |
| out = out[:, :, :T].transpose(1, 2) # [bs, nframes, nfeats,] | |
| out_cond, out_uncond = torch.split(out, len(out) // 2, dim=0) | |
| if self.vis_attn == True: | |
| i = len(MONITOR_ATTN) | |
| attnlist = MONITOR_ATTN[-1] | |
| print(i, "cross", len(attnlist)) | |
| for j, att in enumerate(attnlist): | |
| vis_attn( | |
| att, | |
| out_path=self.out_path, | |
| step=i, | |
| layer=j, | |
| shape="_".join(map(str, att.shape)), | |
| type_="cross", | |
| ) | |
| attnlist = SELF_ATTN[-1] | |
| print(i, "self", len(attnlist)) | |
| for j, att in enumerate(attnlist): | |
| vis_attn( | |
| att, | |
| out_path=self.out_path, | |
| step=i, | |
| layer=j, | |
| shape="_".join(map(str, att.shape)), | |
| type_="self", | |
| lines=False, | |
| ) | |
| if len(SELF_ATTN) % 10 == 0: | |
| SELF_ATTN = [] | |
| MONITOR_ATTN = [] | |
| return out_uncond + (cfg_scale * (out_cond - out_uncond)) | |
| if __name__ == "__main__": | |
| device = "cuda:0" | |
| n_feats = 263 | |
| num_frames = 196 | |
| text_latent_dim = 256 | |
| dim_mults = [2, 2, 2, 2] | |
| base_dim = 512 | |
| model = MotionCLR( | |
| input_feats=n_feats, | |
| text_latent_dim=text_latent_dim, | |
| base_dim=base_dim, | |
| dim_mults=dim_mults, | |
| adagn=True, | |
| zero=True, | |
| dropout=0.1, | |
| no_eff=True, | |
| cond_mask_prob=0.1, | |
| self_attention=True, | |
| ) | |
| model = model.to(device) | |
| from utils.model_load import load_model_weights | |
| checkpoint_path = "/comp_robot/chenlinghao/StableMoFusion/checkpoints/t2m/self_attn—fulllayer-ffn-drop0_1-lr1e4/model/latest.tar" | |
| new_state_dict = {} | |
| checkpoint = torch.load(checkpoint_path) | |
| ckpt2 = checkpoint.copy() | |
| ckpt2["model_ema"] = {} | |
| ckpt2["encoder"] = {} | |
| for key, value in list(checkpoint["model_ema"].items()): | |
| new_key = key.replace( | |
| "cross_attn", "clr_attn" | |
| ) # Replace 'cross_attn' with 'clr_attn' | |
| ckpt2["model_ema"][new_key] = value | |
| for key, value in list(checkpoint["encoder"].items()): | |
| new_key = key.replace( | |
| "cross_attn", "clr_attn" | |
| ) # Replace 'cross_attn' with 'clr_attn' | |
| ckpt2["encoder"][new_key] = value | |
| torch.save( | |
| ckpt2, | |
| "/comp_robot/chenlinghao/CLRpreview/checkpoints/t2m/release/model/latest.tar", | |
| ) | |
| dtype = torch.float32 | |
| bs = 1 | |
| x = torch.rand((bs, 196, 263), dtype=dtype).to(device) | |
| timesteps = torch.randint(low=0, high=1000, size=(bs,)).to(device) | |
| y = ["A man jumps to his left." for i in range(bs)] | |
| length = torch.randint(low=20, high=196, size=(bs,)).to(device) | |
| out = model(x, timesteps, text=y) | |
| print(out.shape) | |
| model.eval() | |
| out = model.forward_with_cfg(x, timesteps, text=y) | |
| print(out.shape) | |