Spaces:
Running
on
L4
Running
on
L4
| import torch | |
| import torch.nn.functional as F | |
| from einops import rearrange, repeat | |
| from torch import nn | |
| from torch.nn.attention import SDPBackend, sdpa_kernel | |
| class GEGLU(nn.Module): | |
| def __init__(self, dim_in: int, dim_out: int): | |
| super().__init__() | |
| self.proj = nn.Linear(dim_in, dim_out * 2) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x, gate = self.proj(x).chunk(2, dim=-1) | |
| return x * F.gelu(gate) | |
| class FeedForward(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| dim_out: int | None = None, | |
| mult: int = 4, | |
| dropout: float = 0.0, | |
| ): | |
| super().__init__() | |
| inner_dim = int(dim * mult) | |
| dim_out = dim_out or dim | |
| self.net = nn.Sequential( | |
| GEGLU(dim, inner_dim), nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.net(x) | |
| class Attention(nn.Module): | |
| def __init__( | |
| self, | |
| query_dim: int, | |
| context_dim: int | None = None, | |
| heads: int = 8, | |
| dim_head: int = 64, | |
| dropout: float = 0.0, | |
| ): | |
| super().__init__() | |
| self.heads = heads | |
| self.dim_head = dim_head | |
| inner_dim = dim_head * heads | |
| context_dim = context_dim or query_dim | |
| self.to_q = nn.Linear(query_dim, inner_dim, bias=False) | |
| self.to_k = nn.Linear(context_dim, inner_dim, bias=False) | |
| self.to_v = nn.Linear(context_dim, inner_dim, bias=False) | |
| self.to_out = nn.Sequential( | |
| nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) | |
| ) | |
| def forward( | |
| self, x: torch.Tensor, context: torch.Tensor | None = None | |
| ) -> torch.Tensor: | |
| q = self.to_q(x) | |
| context = context if context is not None else x | |
| k = self.to_k(context) | |
| v = self.to_v(context) | |
| q, k, v = map( | |
| lambda t: rearrange(t, "b l (h d) -> b h l d", h=self.heads), | |
| (q, k, v), | |
| ) | |
| with sdpa_kernel(SDPBackend.FLASH_ATTENTION): | |
| out = F.scaled_dot_product_attention(q, k, v) | |
| out = rearrange(out, "b h l d -> b l (h d)") | |
| out = self.to_out(out) | |
| return out | |
| class TransformerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| n_heads: int, | |
| d_head: int, | |
| context_dim: int, | |
| dropout: float = 0.0, | |
| ): | |
| super().__init__() | |
| self.attn1 = Attention( | |
| query_dim=dim, | |
| context_dim=None, | |
| heads=n_heads, | |
| dim_head=d_head, | |
| dropout=dropout, | |
| ) | |
| self.ff = FeedForward(dim, dropout=dropout) | |
| self.attn2 = Attention( | |
| query_dim=dim, | |
| context_dim=context_dim, | |
| heads=n_heads, | |
| dim_head=d_head, | |
| dropout=dropout, | |
| ) | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.norm2 = nn.LayerNorm(dim) | |
| self.norm3 = nn.LayerNorm(dim) | |
| def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor: | |
| x = self.attn1(self.norm1(x)) + x | |
| x = self.attn2(self.norm2(x), context=context) + x | |
| x = self.ff(self.norm3(x)) + x | |
| return x | |
| class TransformerBlockTimeMix(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| n_heads: int, | |
| d_head: int, | |
| context_dim: int, | |
| dropout: float = 0.0, | |
| ): | |
| super().__init__() | |
| inner_dim = n_heads * d_head | |
| self.norm_in = nn.LayerNorm(dim) | |
| self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout) | |
| self.attn1 = Attention( | |
| query_dim=inner_dim, | |
| context_dim=None, | |
| heads=n_heads, | |
| dim_head=d_head, | |
| dropout=dropout, | |
| ) | |
| self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout) | |
| self.attn2 = Attention( | |
| query_dim=inner_dim, | |
| context_dim=context_dim, | |
| heads=n_heads, | |
| dim_head=d_head, | |
| dropout=dropout, | |
| ) | |
| self.norm1 = nn.LayerNorm(inner_dim) | |
| self.norm2 = nn.LayerNorm(inner_dim) | |
| self.norm3 = nn.LayerNorm(inner_dim) | |
| def forward( | |
| self, x: torch.Tensor, context: torch.Tensor, num_frames: int | |
| ) -> torch.Tensor: | |
| _, s, _ = x.shape | |
| x = rearrange(x, "(b t) s c -> (b s) t c", t=num_frames) | |
| x = self.ff_in(self.norm_in(x)) + x | |
| x = self.attn1(self.norm1(x), context=None) + x | |
| x = self.attn2(self.norm2(x), context=context) + x | |
| x = self.ff(self.norm3(x)) | |
| x = rearrange(x, "(b s) t c -> (b t) s c", s=s) | |
| return x | |
| class SkipConnect(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward( | |
| self, x_spatial: torch.Tensor, x_temporal: torch.Tensor | |
| ) -> torch.Tensor: | |
| return x_spatial + x_temporal | |
| class MultiviewTransformer(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| n_heads: int, | |
| d_head: int, | |
| name: str, | |
| unflatten_names: list[str] = [], | |
| depth: int = 1, | |
| context_dim: int = 1024, | |
| dropout: float = 0.0, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.name = name | |
| self.unflatten_names = unflatten_names | |
| inner_dim = n_heads * d_head | |
| self.norm = nn.GroupNorm(32, in_channels, eps=1e-6) | |
| self.proj_in = nn.Linear(in_channels, inner_dim) | |
| self.transformer_blocks = nn.ModuleList( | |
| [ | |
| TransformerBlock( | |
| inner_dim, | |
| n_heads, | |
| d_head, | |
| context_dim=context_dim, | |
| dropout=dropout, | |
| ) | |
| for _ in range(depth) | |
| ] | |
| ) | |
| self.proj_out = nn.Linear(inner_dim, in_channels) | |
| self.time_mixer = SkipConnect() | |
| self.time_mix_blocks = nn.ModuleList( | |
| [ | |
| TransformerBlockTimeMix( | |
| inner_dim, | |
| n_heads, | |
| d_head, | |
| context_dim=context_dim, | |
| dropout=dropout, | |
| ) | |
| for _ in range(depth) | |
| ] | |
| ) | |
| def forward( | |
| self, x: torch.Tensor, context: torch.Tensor, num_frames: int | |
| ) -> torch.Tensor: | |
| assert context.ndim == 3 | |
| _, _, h, w = x.shape | |
| x_in = x | |
| time_context = context | |
| time_context_first_timestep = time_context[::num_frames] | |
| time_context = repeat( | |
| time_context_first_timestep, "b ... -> (b n) ...", n=h * w | |
| ) | |
| if self.name in self.unflatten_names: | |
| context = context[::num_frames] | |
| x = self.norm(x) | |
| x = rearrange(x, "b c h w -> b (h w) c") | |
| x = self.proj_in(x) | |
| for block, mix_block in zip(self.transformer_blocks, self.time_mix_blocks): | |
| if self.name in self.unflatten_names: | |
| x = rearrange(x, "(b t) (h w) c -> b (t h w) c", t=num_frames, h=h, w=w) | |
| x = block(x, context=context) | |
| if self.name in self.unflatten_names: | |
| x = rearrange(x, "b (t h w) c -> (b t) (h w) c", t=num_frames, h=h, w=w) | |
| x_mix = mix_block(x, context=time_context, num_frames=num_frames) | |
| x = self.time_mixer(x_spatial=x, x_temporal=x_mix) | |
| x = self.proj_out(x) | |
| x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) | |
| out = x + x_in | |
| return out | |