Spaces:
Running
Running
| import sys | |
| from typing import Literal, Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from cube3d.model.transformers.norm import RMSNorm | |
| class SphericalVectorQuantizer(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| num_codes: int, | |
| width: Optional[int] = None, | |
| codebook_regularization: Literal["batch_norm", "kl"] = "batch_norm", | |
| ): | |
| """ | |
| Initializes the SphericalVQ module. | |
| Args: | |
| embed_dim (int): The dimensionality of the embeddings. | |
| num_codes (int): The number of codes in the codebook. | |
| width (Optional[int], optional): The width of the input. Defaults to None. | |
| Raises: | |
| ValueError: If beta is not in the range [0, 1]. | |
| """ | |
| super().__init__() | |
| self.num_codes = num_codes | |
| self.codebook = nn.Embedding(num_codes, embed_dim) | |
| self.codebook.weight.data.uniform_(-1.0 / num_codes, 1.0 / num_codes) | |
| width = width or embed_dim | |
| if width != embed_dim: | |
| self.c_in = nn.Linear(width, embed_dim) | |
| self.c_x = nn.Linear(width, embed_dim) # shortcut | |
| self.c_out = nn.Linear(embed_dim, width) | |
| else: | |
| self.c_in = self.c_out = self.c_x = nn.Identity() | |
| self.norm = RMSNorm(embed_dim, elementwise_affine=False) | |
| self.cb_reg = codebook_regularization | |
| if self.cb_reg == "batch_norm": | |
| self.cb_norm = nn.BatchNorm1d(embed_dim, track_running_stats=False) | |
| else: | |
| self.cb_weight = nn.Parameter(torch.ones([embed_dim])) | |
| self.cb_bias = nn.Parameter(torch.zeros([embed_dim])) | |
| self.cb_norm = lambda x: x.mul(self.cb_weight).add_(self.cb_bias) | |
| def get_codebook(self): | |
| """ | |
| Retrieves the normalized codebook weights. | |
| This method applies a series of normalization operations to the | |
| codebook weights, ensuring they are properly scaled and normalized | |
| before being returned. | |
| Returns: | |
| torch.Tensor: The normalized weights of the codebook. | |
| """ | |
| return self.norm(self.cb_norm(self.codebook.weight)) | |
| def lookup_codebook(self, q: torch.Tensor): | |
| """ | |
| Perform a lookup in the codebook and process the result. | |
| This method takes an input tensor of indices, retrieves the corresponding | |
| embeddings from the codebook, and applies a transformation to the retrieved | |
| embeddings. | |
| Args: | |
| q (torch.Tensor): A tensor containing indices to look up in the codebook. | |
| Returns: | |
| torch.Tensor: The transformed embeddings retrieved from the codebook. | |
| """ | |
| # normalize codebook | |
| z_q = F.embedding(q, self.get_codebook()) | |
| z_q = self.c_out(z_q) | |
| return z_q | |
| def lookup_codebook_latents(self, q: torch.Tensor): | |
| """ | |
| Retrieves the latent representations from the codebook corresponding to the given indices. | |
| Args: | |
| q (torch.Tensor): A tensor containing the indices of the codebook entries to retrieve. | |
| The indices should be integers and correspond to the rows in the codebook. | |
| Returns: | |
| torch.Tensor: A tensor containing the latent representations retrieved from the codebook. | |
| The shape of the returned tensor depends on the shape of the input indices | |
| and the dimensionality of the codebook entries. | |
| """ | |
| # normalize codebook | |
| z_q = F.embedding(q, self.get_codebook()) | |
| return z_q | |
| def quantize(self, z: torch.Tensor): | |
| """ | |
| Quantizes the latent codes z with the codebook | |
| Args: | |
| z (Tensor): B x ... x F | |
| """ | |
| # normalize codebook | |
| codebook = self.get_codebook() | |
| # the process of finding quantized codes is non differentiable | |
| with torch.no_grad(): | |
| # flatten z | |
| z_flat = z.view(-1, z.shape[-1]) | |
| # calculate distance and find the closest code | |
| d = torch.cdist(z_flat, codebook) | |
| q = torch.argmin(d, dim=1) # num_ele | |
| z_q = codebook[q, :].reshape(*z.shape[:-1], -1) | |
| q = q.view(*z.shape[:-1]) | |
| return z_q, {"z": z.detach(), "q": q} | |
| def straight_through_approximation(self, z, z_q): | |
| """passed gradient from z_q to z""" | |
| z_q = z + (z_q - z).detach() | |
| return z_q | |
| def forward(self, z: torch.Tensor): | |
| """ | |
| Forward pass of the spherical vector quantization autoencoder. | |
| Args: | |
| z (torch.Tensor): Input tensor of shape (batch_size, ..., feature_dim). | |
| Returns: | |
| Tuple[torch.Tensor, Dict[str, Any]]: | |
| - z_q (torch.Tensor): The quantized output tensor after applying the | |
| straight-through approximation and output projection. | |
| - ret_dict (Dict[str, Any]): A dictionary containing additional | |
| information: | |
| - "z_q" (torch.Tensor): Detached quantized tensor. | |
| - "q" (torch.Tensor): Indices of the quantized vectors. | |
| - "perplexity" (torch.Tensor): The perplexity of the quantization, | |
| calculated as the exponential of the negative sum of the | |
| probabilities' log values. | |
| """ | |
| with torch.autocast(device_type=z.device.type, enabled=False): | |
| # work in full precision | |
| z = z.float() | |
| # project and normalize | |
| z_e = self.norm(self.c_in(z)) | |
| z_q, ret_dict = self.quantize(z_e) | |
| ret_dict["z_q"] = z_q.detach() | |
| z_q = self.straight_through_approximation(z_e, z_q) | |
| z_q = self.c_out(z_q) | |
| return z_q, ret_dict | |