Spaces:
Running
Running
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.nn.init as init | |
| import math | |
| from einops import rearrange | |
| from torch import nn | |
| def get_2d_sincos_pos_embed( | |
| embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 | |
| ): | |
| """ | |
| grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or | |
| [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) | |
| """ | |
| if isinstance(grid_size, int): | |
| grid_size = (grid_size, grid_size) | |
| grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale | |
| grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale | |
| grid = np.meshgrid(grid_w, grid_h) # here w goes first | |
| grid = np.stack(grid, axis=0) | |
| grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) | |
| pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | |
| if cls_token and extra_tokens > 0: | |
| pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) | |
| return pos_embed | |
| def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | |
| if embed_dim % 2 != 0: | |
| raise ValueError("embed_dim must be divisible by 2") | |
| # use half of dimensions to encode grid_h | |
| emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) | |
| emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) | |
| emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) | |
| return emb | |
| def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | |
| """ | |
| embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) | |
| """ | |
| if embed_dim % 2 != 0: | |
| raise ValueError("embed_dim must be divisible by 2") | |
| omega = np.arange(embed_dim // 2, dtype=np.float64) | |
| omega /= embed_dim / 2.0 | |
| omega = 1.0 / 10000**omega # (D/2,) | |
| pos = pos.reshape(-1) # (M,) | |
| out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product | |
| emb_sin = np.sin(out) # (M, D/2) | |
| emb_cos = np.cos(out) # (M, D/2) | |
| emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) | |
| return emb | |
| class Patch1D(nn.Module): | |
| def __init__( | |
| self, | |
| channels: int, | |
| use_conv: bool = False, | |
| out_channels: Optional[int] = None, | |
| stride: int = 2, | |
| padding: int = 0, | |
| name: str = "conv", | |
| ): | |
| super().__init__() | |
| self.channels = channels | |
| self.out_channels = out_channels or channels | |
| self.use_conv = use_conv | |
| self.padding = padding | |
| self.name = name | |
| if use_conv: | |
| self.conv = nn.Conv1d(self.channels, self.out_channels, stride, stride=stride, padding=padding) | |
| init.constant_(self.conv.weight, 0.0) | |
| with torch.no_grad(): | |
| for i in range(len(self.conv.weight)): self.conv.weight[i, i] = 1 / stride | |
| init.constant_(self.conv.bias, 0.0) | |
| else: | |
| assert self.channels == self.out_channels | |
| self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) | |
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
| assert inputs.shape[1] == self.channels | |
| return self.conv(inputs) | |
| class UnPatch1D(nn.Module): | |
| def __init__( | |
| self, | |
| channels: int, | |
| use_conv: bool = False, | |
| use_conv_transpose: bool = False, | |
| out_channels: Optional[int] = None, | |
| name: str = "conv", | |
| ): | |
| super().__init__() | |
| self.channels = channels | |
| self.out_channels = out_channels or channels | |
| self.use_conv = use_conv | |
| self.use_conv_transpose = use_conv_transpose | |
| self.name = name | |
| self.conv = None | |
| if use_conv_transpose: | |
| self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) | |
| elif use_conv: | |
| self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) | |
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
| assert inputs.shape[1] == self.channels | |
| if self.use_conv_transpose: | |
| return self.conv(inputs) | |
| outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") | |
| if self.use_conv: | |
| outputs = self.conv(outputs) | |
| return outputs | |
| class Upsampler(nn.Module): | |
| def __init__( | |
| self, | |
| spatial_upsample_factor: int = 1, | |
| temporal_upsample_factor: int = 1, | |
| ): | |
| super().__init__() | |
| self.spatial_upsample_factor = spatial_upsample_factor | |
| self.temporal_upsample_factor = temporal_upsample_factor | |
| class TemporalUpsampler3D(Upsampler): | |
| def __init__(self): | |
| super().__init__( | |
| spatial_upsample_factor=1, | |
| temporal_upsample_factor=2, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if x.shape[2] > 1: | |
| first_frame, x = x[:, :, :1], x[:, :, 1:] | |
| x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear") | |
| x = torch.cat([first_frame, x], dim=2) | |
| return x | |
| def cast_tuple(t, length = 1): | |
| return t if isinstance(t, tuple) else ((t,) * length) | |
| def divisible_by(num, den): | |
| return (num % den) == 0 | |
| def is_odd(n): | |
| return not divisible_by(n, 2) | |
| class CausalConv3d(nn.Conv3d): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size=3, # : int | tuple[int, int, int], | |
| stride=1, # : int | tuple[int, int, int] = 1, | |
| padding=1, # : int | tuple[int, int, int], # TODO: change it to 0. | |
| dilation=1, # : int | tuple[int, int, int] = 1, | |
| **kwargs, | |
| ): | |
| kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 3 | |
| assert len(kernel_size) == 3, f"Kernel size must be a 3-tuple, got {kernel_size} instead." | |
| stride = stride if isinstance(stride, tuple) else (stride,) * 3 | |
| assert len(stride) == 3, f"Stride must be a 3-tuple, got {stride} instead." | |
| dilation = dilation if isinstance(dilation, tuple) else (dilation,) * 3 | |
| assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead." | |
| t_ks, h_ks, w_ks = kernel_size | |
| _, h_stride, w_stride = stride | |
| t_dilation, h_dilation, w_dilation = dilation | |
| t_pad = (t_ks - 1) * t_dilation | |
| # TODO: align with SD | |
| if padding is None: | |
| h_pad = math.ceil(((h_ks - 1) * h_dilation + (1 - h_stride)) / 2) | |
| w_pad = math.ceil(((w_ks - 1) * w_dilation + (1 - w_stride)) / 2) | |
| elif isinstance(padding, int): | |
| h_pad = w_pad = padding | |
| else: | |
| assert NotImplementedError | |
| self.temporal_padding = t_pad | |
| super().__init__( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| dilation=dilation, | |
| padding=(0, h_pad, w_pad), | |
| **kwargs, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # x: (B, C, T, H, W) | |
| x = F.pad( | |
| x, | |
| pad=(0, 0, 0, 0, self.temporal_padding, 0), | |
| mode="replicate", # TODO: check if this is necessary | |
| ) | |
| return super().forward(x) | |
| class PatchEmbed3D(nn.Module): | |
| """3D Image to Patch Embedding""" | |
| def __init__( | |
| self, | |
| height=224, | |
| width=224, | |
| patch_size=16, | |
| time_patch_size=4, | |
| in_channels=3, | |
| embed_dim=768, | |
| layer_norm=False, | |
| flatten=True, | |
| bias=True, | |
| interpolation_scale=1, | |
| ): | |
| super().__init__() | |
| num_patches = (height // patch_size) * (width // patch_size) | |
| self.flatten = flatten | |
| self.layer_norm = layer_norm | |
| self.proj = nn.Conv3d( | |
| in_channels, embed_dim, kernel_size=(time_patch_size, patch_size, patch_size), stride=(time_patch_size, patch_size, patch_size), bias=bias | |
| ) | |
| if layer_norm: | |
| self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) | |
| else: | |
| self.norm = None | |
| self.patch_size = patch_size | |
| # See: | |
| # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161 | |
| self.height, self.width = height // patch_size, width // patch_size | |
| self.base_size = height // patch_size | |
| self.interpolation_scale = interpolation_scale | |
| pos_embed = get_2d_sincos_pos_embed( | |
| embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale | |
| ) | |
| self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) | |
| def forward(self, latent): | |
| height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size | |
| latent = self.proj(latent) | |
| latent = rearrange(latent, "b c f h w -> (b f) c h w") | |
| if self.flatten: | |
| latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC | |
| if self.layer_norm: | |
| latent = self.norm(latent) | |
| # Interpolate positional embeddings if needed. | |
| # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) | |
| if self.height != height or self.width != width: | |
| pos_embed = get_2d_sincos_pos_embed( | |
| embed_dim=self.pos_embed.shape[-1], | |
| grid_size=(height, width), | |
| base_size=self.base_size, | |
| interpolation_scale=self.interpolation_scale, | |
| ) | |
| pos_embed = torch.from_numpy(pos_embed) | |
| pos_embed = pos_embed.float().unsqueeze(0).to(latent.device) | |
| else: | |
| pos_embed = self.pos_embed | |
| return (latent + pos_embed).to(latent.dtype) | |
| class PatchEmbedF3D(nn.Module): | |
| """Fake 3D Image to Patch Embedding""" | |
| def __init__( | |
| self, | |
| height=224, | |
| width=224, | |
| patch_size=16, | |
| in_channels=3, | |
| embed_dim=768, | |
| layer_norm=False, | |
| flatten=True, | |
| bias=True, | |
| interpolation_scale=1, | |
| ): | |
| super().__init__() | |
| num_patches = (height // patch_size) * (width // patch_size) | |
| self.flatten = flatten | |
| self.layer_norm = layer_norm | |
| self.proj = nn.Conv2d( | |
| in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias | |
| ) | |
| self.proj_t = Patch1D( | |
| embed_dim, True, stride=patch_size | |
| ) | |
| if layer_norm: | |
| self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) | |
| else: | |
| self.norm = None | |
| self.patch_size = patch_size | |
| # See: | |
| # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161 | |
| self.height, self.width = height // patch_size, width // patch_size | |
| self.base_size = height // patch_size | |
| self.interpolation_scale = interpolation_scale | |
| pos_embed = get_2d_sincos_pos_embed( | |
| embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale | |
| ) | |
| self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) | |
| def forward(self, latent): | |
| height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size | |
| b, c, f, h, w = latent.size() | |
| latent = rearrange(latent, "b c f h w -> (b f) c h w") | |
| latent = self.proj(latent) | |
| latent = rearrange(latent, "(b f) c h w -> b c f h w", f=f) | |
| latent = rearrange(latent, "b c f h w -> (b h w) c f") | |
| latent = self.proj_t(latent) | |
| latent = rearrange(latent, "(b h w) c f -> b c f h w", h=h//2, w=w//2) | |
| latent = rearrange(latent, "b c f h w -> (b f) c h w") | |
| if self.flatten: | |
| latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC | |
| if self.layer_norm: | |
| latent = self.norm(latent) | |
| # Interpolate positional embeddings if needed. | |
| # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) | |
| if self.height != height or self.width != width: | |
| pos_embed = get_2d_sincos_pos_embed( | |
| embed_dim=self.pos_embed.shape[-1], | |
| grid_size=(height, width), | |
| base_size=self.base_size, | |
| interpolation_scale=self.interpolation_scale, | |
| ) | |
| pos_embed = torch.from_numpy(pos_embed) | |
| pos_embed = pos_embed.float().unsqueeze(0).to(latent.device) | |
| else: | |
| pos_embed = self.pos_embed | |
| return (latent + pos_embed).to(latent.dtype) | |
| class CasualPatchEmbed3D(nn.Module): | |
| """3D Image to Patch Embedding""" | |
| def __init__( | |
| self, | |
| height=224, | |
| width=224, | |
| patch_size=16, | |
| time_patch_size=4, | |
| in_channels=3, | |
| embed_dim=768, | |
| layer_norm=False, | |
| flatten=True, | |
| bias=True, | |
| interpolation_scale=1, | |
| ): | |
| super().__init__() | |
| num_patches = (height // patch_size) * (width // patch_size) | |
| self.flatten = flatten | |
| self.layer_norm = layer_norm | |
| self.proj = CausalConv3d( | |
| in_channels, embed_dim, kernel_size=(time_patch_size, patch_size, patch_size), stride=(time_patch_size, patch_size, patch_size), bias=bias, padding=None | |
| ) | |
| if layer_norm: | |
| self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) | |
| else: | |
| self.norm = None | |
| self.patch_size = patch_size | |
| # See: | |
| # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161 | |
| self.height, self.width = height // patch_size, width // patch_size | |
| self.base_size = height // patch_size | |
| self.interpolation_scale = interpolation_scale | |
| pos_embed = get_2d_sincos_pos_embed( | |
| embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale | |
| ) | |
| self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) | |
| def forward(self, latent): | |
| height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size | |
| latent = self.proj(latent) | |
| latent = rearrange(latent, "b c f h w -> (b f) c h w") | |
| if self.flatten: | |
| latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC | |
| if self.layer_norm: | |
| latent = self.norm(latent) | |
| # Interpolate positional embeddings if needed. | |
| # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) | |
| if self.height != height or self.width != width: | |
| pos_embed = get_2d_sincos_pos_embed( | |
| embed_dim=self.pos_embed.shape[-1], | |
| grid_size=(height, width), | |
| base_size=self.base_size, | |
| interpolation_scale=self.interpolation_scale, | |
| ) | |
| pos_embed = torch.from_numpy(pos_embed) | |
| pos_embed = pos_embed.float().unsqueeze(0).to(latent.device) | |
| else: | |
| pos_embed = self.pos_embed | |
| return (latent + pos_embed).to(latent.dtype) | |