Spaces:
Running
on
L4
Running
on
L4
| from dataclasses import dataclass, field | |
| import torch | |
| import torch.nn as nn | |
| from seva.modules.layers import ( | |
| Downsample, | |
| GroupNorm32, | |
| ResBlock, | |
| TimestepEmbedSequential, | |
| Upsample, | |
| timestep_embedding, | |
| ) | |
| from seva.modules.transformer import MultiviewTransformer | |
| class SevaParams(object): | |
| in_channels: int = 11 | |
| model_channels: int = 320 | |
| out_channels: int = 4 | |
| num_frames: int = 21 | |
| num_res_blocks: int = 2 | |
| attention_resolutions: list[int] = field(default_factory=lambda: [4, 2, 1]) | |
| channel_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4]) | |
| num_head_channels: int = 64 | |
| transformer_depth: list[int] = field(default_factory=lambda: [1, 1, 1, 1]) | |
| context_dim: int = 1024 | |
| dense_in_channels: int = 6 | |
| dropout: float = 0.0 | |
| unflatten_names: list[str] = field( | |
| default_factory=lambda: ["middle_ds8", "output_ds4", "output_ds2"] | |
| ) | |
| def __post_init__(self): | |
| assert len(self.channel_mult) == len(self.transformer_depth) | |
| class Seva(nn.Module): | |
| def __init__(self, params: SevaParams) -> None: | |
| super().__init__() | |
| self.params = params | |
| self.model_channels = params.model_channels | |
| self.out_channels = params.out_channels | |
| self.num_head_channels = params.num_head_channels | |
| time_embed_dim = params.model_channels * 4 | |
| self.time_embed = nn.Sequential( | |
| nn.Linear(params.model_channels, time_embed_dim), | |
| nn.SiLU(), | |
| nn.Linear(time_embed_dim, time_embed_dim), | |
| ) | |
| self.input_blocks = nn.ModuleList( | |
| [ | |
| TimestepEmbedSequential( | |
| nn.Conv2d(params.in_channels, params.model_channels, 3, padding=1) | |
| ) | |
| ] | |
| ) | |
| self._feature_size = params.model_channels | |
| input_block_chans = [params.model_channels] | |
| ch = params.model_channels | |
| ds = 1 | |
| for level, mult in enumerate(params.channel_mult): | |
| for _ in range(params.num_res_blocks): | |
| input_layers: list[ResBlock | MultiviewTransformer | Downsample] = [ | |
| ResBlock( | |
| channels=ch, | |
| emb_channels=time_embed_dim, | |
| out_channels=mult * params.model_channels, | |
| dense_in_channels=params.dense_in_channels, | |
| dropout=params.dropout, | |
| ) | |
| ] | |
| ch = mult * params.model_channels | |
| if ds in params.attention_resolutions: | |
| num_heads = ch // params.num_head_channels | |
| dim_head = params.num_head_channels | |
| input_layers.append( | |
| MultiviewTransformer( | |
| ch, | |
| num_heads, | |
| dim_head, | |
| name=f"input_ds{ds}", | |
| depth=params.transformer_depth[level], | |
| context_dim=params.context_dim, | |
| unflatten_names=params.unflatten_names, | |
| ) | |
| ) | |
| self.input_blocks.append(TimestepEmbedSequential(*input_layers)) | |
| self._feature_size += ch | |
| input_block_chans.append(ch) | |
| if level != len(params.channel_mult) - 1: | |
| ds *= 2 | |
| out_ch = ch | |
| self.input_blocks.append( | |
| TimestepEmbedSequential(Downsample(ch, out_channels=out_ch)) | |
| ) | |
| ch = out_ch | |
| input_block_chans.append(ch) | |
| self._feature_size += ch | |
| num_heads = ch // params.num_head_channels | |
| dim_head = params.num_head_channels | |
| self.middle_block = TimestepEmbedSequential( | |
| ResBlock( | |
| channels=ch, | |
| emb_channels=time_embed_dim, | |
| out_channels=None, | |
| dense_in_channels=params.dense_in_channels, | |
| dropout=params.dropout, | |
| ), | |
| MultiviewTransformer( | |
| ch, | |
| num_heads, | |
| dim_head, | |
| name=f"middle_ds{ds}", | |
| depth=params.transformer_depth[-1], | |
| context_dim=params.context_dim, | |
| unflatten_names=params.unflatten_names, | |
| ), | |
| ResBlock( | |
| channels=ch, | |
| emb_channels=time_embed_dim, | |
| out_channels=None, | |
| dense_in_channels=params.dense_in_channels, | |
| dropout=params.dropout, | |
| ), | |
| ) | |
| self._feature_size += ch | |
| self.output_blocks = nn.ModuleList([]) | |
| for level, mult in list(enumerate(params.channel_mult))[::-1]: | |
| for i in range(params.num_res_blocks + 1): | |
| ich = input_block_chans.pop() | |
| output_layers: list[ResBlock | MultiviewTransformer | Upsample] = [ | |
| ResBlock( | |
| channels=ch + ich, | |
| emb_channels=time_embed_dim, | |
| out_channels=params.model_channels * mult, | |
| dense_in_channels=params.dense_in_channels, | |
| dropout=params.dropout, | |
| ) | |
| ] | |
| ch = params.model_channels * mult | |
| if ds in params.attention_resolutions: | |
| num_heads = ch // params.num_head_channels | |
| dim_head = params.num_head_channels | |
| output_layers.append( | |
| MultiviewTransformer( | |
| ch, | |
| num_heads, | |
| dim_head, | |
| name=f"output_ds{ds}", | |
| depth=params.transformer_depth[level], | |
| context_dim=params.context_dim, | |
| unflatten_names=params.unflatten_names, | |
| ) | |
| ) | |
| if level and i == params.num_res_blocks: | |
| out_ch = ch | |
| ds //= 2 | |
| output_layers.append(Upsample(ch, out_ch)) | |
| self.output_blocks.append(TimestepEmbedSequential(*output_layers)) | |
| self._feature_size += ch | |
| self.out = nn.Sequential( | |
| GroupNorm32(32, ch), | |
| nn.SiLU(), | |
| nn.Conv2d(self.model_channels, params.out_channels, 3, padding=1), | |
| ) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| t: torch.Tensor, | |
| y: torch.Tensor, | |
| dense_y: torch.Tensor, | |
| num_frames: int | None = None, | |
| ) -> torch.Tensor: | |
| num_frames = num_frames or self.params.num_frames | |
| t_emb = timestep_embedding(t, self.model_channels) | |
| t_emb = self.time_embed(t_emb) | |
| hs = [] | |
| h = x | |
| for module in self.input_blocks: | |
| h = module( | |
| h, | |
| emb=t_emb, | |
| context=y, | |
| dense_emb=dense_y, | |
| num_frames=num_frames, | |
| ) | |
| hs.append(h) | |
| h = self.middle_block( | |
| h, | |
| emb=t_emb, | |
| context=y, | |
| dense_emb=dense_y, | |
| num_frames=num_frames, | |
| ) | |
| for module in self.output_blocks: | |
| h = torch.cat([h, hs.pop()], dim=1) | |
| h = module( | |
| h, | |
| emb=t_emb, | |
| context=y, | |
| dense_emb=dense_y, | |
| num_frames=num_frames, | |
| ) | |
| h = h.type(x.dtype) | |
| return self.out(h) | |
| class SGMWrapper(nn.Module): | |
| def __init__(self, module: Seva): | |
| super().__init__() | |
| self.module = module | |
| def forward( | |
| self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs | |
| ) -> torch.Tensor: | |
| x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) | |
| return self.module( | |
| x, | |
| t=t, | |
| y=c["crossattn"], | |
| dense_y=c["dense_vector"], | |
| **kwargs, | |
| ) | |