Spaces:
Running
on
L40S
Running
on
L40S
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from vector_quantize_pytorch import GroupedResidualFSQ | |
| from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet | |
| class FSQResult: | |
| z: torch.Tensor | |
| codes: torch.Tensor | |
| latents: torch.Tensor | |
| class DownsampleFiniteScalarQuantize(nn.Module): | |
| def __init__( | |
| self, | |
| input_dim: int = 512, | |
| n_codebooks: int = 9, | |
| n_groups: int = 1, | |
| levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10 | |
| downsample_factor: tuple[int] = (2, 2), | |
| downsample_dims: tuple[int] | None = None, | |
| ): | |
| super().__init__() | |
| if downsample_dims is None: | |
| downsample_dims = [input_dim for _ in range(len(downsample_factor))] | |
| all_dims = (input_dim,) + tuple(downsample_dims) | |
| self.residual_fsq = GroupedResidualFSQ( | |
| dim=all_dims[-1], | |
| levels=levels, | |
| num_quantizers=n_codebooks, | |
| groups=n_groups, | |
| ) | |
| self.downsample_factor = downsample_factor | |
| self.downsample_dims = downsample_dims | |
| self.downsample = nn.Sequential( | |
| *[ | |
| nn.Sequential( | |
| FishConvNet( | |
| all_dims[idx], | |
| all_dims[idx + 1], | |
| kernel_size=factor, | |
| stride=factor, | |
| ), | |
| ConvNeXtBlock(dim=all_dims[idx + 1]), | |
| ) | |
| for idx, factor in enumerate(downsample_factor) | |
| ] | |
| ) | |
| self.upsample = nn.Sequential( | |
| *[ | |
| nn.Sequential( | |
| FishTransConvNet( | |
| all_dims[idx + 1], | |
| all_dims[idx], | |
| kernel_size=factor, | |
| stride=factor, | |
| ), | |
| ConvNeXtBlock(dim=all_dims[idx]), | |
| ) | |
| for idx, factor in reversed(list(enumerate(downsample_factor))) | |
| ] | |
| ) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, (nn.Conv1d, nn.Linear)): | |
| nn.init.trunc_normal_(m.weight, std=0.02) | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, z) -> FSQResult: | |
| original_shape = z.shape | |
| z = self.downsample(z) | |
| quantized, indices = self.residual_fsq(z.mT) | |
| result = FSQResult( | |
| z=quantized.mT, | |
| codes=indices.mT, | |
| latents=z, | |
| ) | |
| result.z = self.upsample(result.z) | |
| # Pad or crop z to match original shape | |
| diff = original_shape[-1] - result.z.shape[-1] | |
| left = diff // 2 | |
| right = diff - left | |
| if diff > 0: | |
| result.z = F.pad(result.z, (left, right)) | |
| elif diff < 0: | |
| result.z = result.z[..., left:-right] | |
| return result | |
| def encode(self, z): | |
| z = self.downsample(z) | |
| _, indices = self.residual_fsq(z.mT) | |
| indices = rearrange(indices, "g b l r -> b (g r) l") | |
| return indices | |
| def decode(self, indices: torch.Tensor): | |
| indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups) | |
| z_q = self.residual_fsq.get_output_from_indices(indices) | |
| z_q = self.upsample(z_q.mT) | |
| return z_q | |