Spaces:
Runtime error
Runtime error
| import logging | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.distributed as dist | |
| from einops import rearrange | |
| from torch.nn.utils import weight_norm | |
| def WNConv1d(*args, **kwargs): | |
| return weight_norm(nn.Conv1d(*args, **kwargs)) | |
| def WNConvTranspose1d(*args, **kwargs): | |
| return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) | |
| def sample_vectors(samples, num): | |
| # samples: (N, D), num_samples: N, feature dim: D | |
| num_samples, device = samples.shape[0], samples.device | |
| if num_samples >= num: | |
| indices = torch.randperm(num_samples, device=device)[:num] | |
| else: | |
| indices = torch.randint(0, num_samples, (num,), device=device) | |
| return samples[indices].float() # (num, D), ensure fp32 | |
| def kmeans(samples, num_clusters, num_iters=10): | |
| # samples: (N, D), N samples with D dimensions | |
| dim, dtype = samples.shape[-1], torch.float32 # Force fp32 | |
| means = sample_vectors(samples, num_clusters).float() # (num_clusters, D), ensure fp32 | |
| for _ in range(num_iters): | |
| dists = -(samples.float().pow(2).sum(1, keepdim=True) - # (N, 1), ensure fp32 | |
| 2 * samples.float() @ means.t() + # (N, num_clusters), ensure fp32 | |
| means.t().float().pow(2).sum(0, keepdim=True)) # (1, num_clusters), ensure fp32 | |
| # dists: (N, num_clusters) | |
| buckets = dists.max(dim=-1).indices # (N) | |
| bins = torch.bincount(buckets, minlength=num_clusters) # (num_clusters) | |
| zero_mask = bins == 0 # (num_clusters) | |
| bins_min_clamped = bins.masked_fill(zero_mask, 1) # (num_clusters) | |
| new_means = buckets.new_zeros(num_clusters, dim, dtype=torch.float32) # (num_clusters, D), ensure fp32 | |
| new_means.scatter_add_(0, buckets.unsqueeze(1).expand(-1, dim), samples.float()) # (num_clusters, D), ensure fp32 | |
| new_means = new_means / bins_min_clamped[..., None] # (num_clusters, D) | |
| means = torch.where(zero_mask[..., None], means, new_means) # (num_clusters, D) | |
| # Final cluster assignments for returning cluster sizes | |
| dists = -(samples.float().pow(2).sum(1, keepdim=True) - | |
| 2 * samples.float() @ means.t() + | |
| means.t().float().pow(2).sum(0, keepdim=True)) # (N, num_clusters), ensure fp32 | |
| buckets = dists.max(dim=-1).indices # (N) | |
| bins = torch.bincount(buckets, minlength=num_clusters).float() # (num_clusters), ensure fp32 | |
| return means, bins # (num_clusters, D), (num_clusters) | |
| class VectorQuantize(nn.Module): | |
| def __init__( | |
| self, | |
| input_dim, | |
| codebook_size, | |
| codebook_dim, | |
| commitment=1.0, | |
| decay=0.99, # EMA decay | |
| epsilon=1e-5, # Laplace smoothing epsilon | |
| threshold_ema_dead=2, # Dead code threshold | |
| kmeans_init=True, # Use kmeans initialization | |
| kmeans_iters=10, # Kmeans iterations | |
| ): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.codebook_size = codebook_size | |
| self.codebook_dim = codebook_dim | |
| self.commitment = commitment | |
| self.decay = decay | |
| self.epsilon = epsilon | |
| self.threshold_ema_dead = threshold_ema_dead | |
| self.kmeans_init = kmeans_init | |
| self.kmeans_iters = kmeans_iters | |
| if self.input_dim != self.codebook_dim: | |
| self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1) # (B, D, T) -> (B, D', T) | |
| self.out_project = WNConv1d(self.codebook_dim, self.input_dim, kernel_size=1) # (B, D', T) -> (B, D, T) | |
| else: | |
| self.in_project = nn.Identity() | |
| self.out_project = nn.Identity() | |
| # Initialize codebook and EMA buffers | |
| init_fn = torch.zeros if kmeans_init else lambda x, y: torch.randn(x, y) | |
| self.register_buffer("codebook", init_fn(codebook_size, codebook_dim).float()) # (codebook_size, D'), ensure fp32 | |
| self.register_buffer("inited", torch.tensor([not kmeans_init], dtype=torch.bool)) # (1) | |
| self.register_buffer("cluster_size", torch.zeros(codebook_size).float()) # (codebook_size), ensure fp32 | |
| self.register_buffer("embed_avg", self.codebook.clone().float()) # (codebook_size, D'), ensure fp32 | |
| def ema_update(self, encodings, embed_onehot): | |
| # encodings: (B*T, D'), embed_onehot: (B*T, codebook_size) | |
| """Update codebook using EMA""" | |
| encodings = encodings.float() # Ensure fp32 | |
| embed_onehot = embed_onehot.float() # Ensure fp32 | |
| cluster_size_new = embed_onehot.sum(0) # (codebook_size) | |
| embed_sum = encodings.t() @ embed_onehot # (D', codebook_size) | |
| # Distributed reduction | |
| if dist.is_initialized(): | |
| dist.all_reduce(cluster_size_new, op=dist.ReduceOp.SUM) | |
| dist.all_reduce(embed_sum, op=dist.ReduceOp.SUM) | |
| ema_inplace(self.cluster_size, cluster_size_new, self.decay) # (codebook_size) | |
| ema_inplace(self.embed_avg, embed_sum.t(), self.decay) # (codebook_size, D') | |
| # Laplace smoothing | |
| cluster_size = (self.cluster_size + self.epsilon) / (self.cluster_size.sum() + self.codebook_size * self.epsilon) # (codebook_size) | |
| cluster_size = cluster_size * self.cluster_size.sum() # (codebook_size) | |
| self.codebook.copy_(self.embed_avg / cluster_size.unsqueeze(1)) # (codebook_size, D') | |
| def replace_dead_codes(self, encodings): | |
| # encodings: (B*T, D') | |
| """Replace dead codes with random samples from current batch""" | |
| if self.threshold_ema_dead == 0: | |
| return | |
| dead_mask = self.cluster_size < self.threshold_ema_dead # (codebook_size) | |
| if dead_mask.any(): | |
| if dist.is_initialized() and dist.get_rank() == 0: | |
| samples = sample_vectors(encodings.float(), self.codebook_size) # (codebook_size, D'), ensure fp32 | |
| else: | |
| samples = torch.zeros_like(self.codebook).float() # Placeholder, ensure fp32 | |
| # Broadcast samples | |
| if dist.is_initialized(): | |
| dist.broadcast(samples, src=0) | |
| self.codebook[dead_mask] = samples[:dead_mask.sum()].to(self.codebook.dtype) # Update dead codes | |
| def init_codebook(self, encodings): | |
| # encodings: (B*T, D') | |
| """Initialize codebook with k-means and update cluster_size""" | |
| if self.inited.item(): | |
| return | |
| if dist.is_initialized() and dist.get_rank() == 0: | |
| embed, cluster_sizes = kmeans(encodings.float(), self.codebook_size, self.kmeans_iters) # (codebook_size, D'), (codebook_size), ensure fp32 | |
| else: | |
| embed = torch.zeros(self.codebook_size, self.codebook_dim, device=encodings.device).float() # ensure fp32 | |
| cluster_sizes = torch.zeros(self.codebook_size, device=encodings.device, dtype=torch.float32) # ensure fp32 | |
| # Broadcast results | |
| if dist.is_initialized(): | |
| dist.broadcast(embed, src=0) | |
| dist.broadcast(cluster_sizes, src=0) | |
| self.codebook.copy_(embed) # (codebook_size, D') | |
| self.embed_avg.copy_(embed.clone()) # (codebook_size, D') | |
| self.cluster_size.copy_(cluster_sizes.float()) # (codebook_size) | |
| self.inited.fill_(True) | |
| def forward(self, z): # z: (B, D, T) | |
| # logging.info(f"{self.cluster_size = }, {self.codebook = }, {self.embed_avg = }, {self.inited = }") | |
| z = z.float() # Ensure fp32 | |
| z_e = self.in_project(z).float() # (B, D', T), ensure fp32 | |
| # Rearrange for quantization | |
| encodings = rearrange(z_e, "b d t -> (b t) d").float() # (B*T, D'), ensure fp32 | |
| # Initialize codebook if needed | |
| if self.kmeans_init and not self.inited.item(): | |
| self.init_codebook(encodings) | |
| # Quantization | |
| dist = (encodings.pow(2).sum(1, keepdim=True) - # (B*T, 1) | |
| 2 * encodings @ self.codebook.float().t() + # (B*T, codebook_size) | |
| self.codebook.float().pow(2).sum(1, keepdim=True).t()) # (1, codebook_size) | |
| # dist: (B*T, codebook_size) | |
| indices = (-dist).max(1)[1] # (B*T) | |
| indices = rearrange(indices, "(b t) -> b t", b=z.size(0)) # (B, T) | |
| # Get quantized vectors | |
| z_q = self.decode_code(indices).float() # (B, D', T), ensure fp32 | |
| # Commitment loss | |
| commit_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) * self.commitment # (B) | |
| # EMA updates and dead code replacement during training | |
| if self.training and torch.is_grad_enabled(): | |
| embed_onehot = F.one_hot(indices.view(-1), self.codebook_size).float() # (B*T, codebook_size), ensure fp32 | |
| self.ema_update(encodings, embed_onehot) | |
| self.replace_dead_codes(encodings) | |
| # Straight-through estimator | |
| z_q = z_e + (z_q - z_e).detach() # (B, D', T) | |
| z_q = self.out_project(z_q).float() # (B, D, T), ensure fp32 | |
| return z_q, commit_loss, torch.tensor(0.0, device=z.device, dtype=torch.float32), indices, z # (B, D, T), (B), scalar, (B, T), (B, D', T) | |
| def decode_code(self, embed_id): # embed_id: (B, T) | |
| return F.embedding(embed_id, self.codebook).transpose(1, 2).float() # (B, D', T), ensure fp32 | |
| class ResidualVQ(nn.Module): | |
| def __init__( | |
| self, | |
| input_dim: int = 1280, # Input dimension, unrelated to RVQ | |
| rvq_dim = None, # RVQ dimension. If different from input_dim/output_dim, will add input_dim->rvq_dim/rvq_dim->output_dim projection | |
| output_dim: int = None, # Output dimension, unrelated to RVQ | |
| num_quantizers: int = 32, | |
| codebook_size: int = 1024, | |
| codebook_dim: int = 8, # Dimension of each codebook. If different from rvq_dim, will add rvq_dim->codebook_dim and codebook_dim->rvq_dim projections | |
| quantizer_dropout: float = 0.5, | |
| decay=0.99, | |
| epsilon=1e-5, | |
| threshold_ema_dead=2, | |
| kmeans_init=True, | |
| kmeans_iters=10, | |
| skip_rvq_ratio: float = 0.0, # New parameter: probability of skipping RVQ | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.num_quantizers = num_quantizers | |
| self.codebook_size = codebook_size | |
| self.codebook_dim = codebook_dim | |
| self.quantizer_dropout = quantizer_dropout | |
| self.skip_rvq_ratio = skip_rvq_ratio # Store skip probability | |
| self.rvq_dim = rvq_dim | |
| self.input_proj = WNConv1d(input_dim, rvq_dim, kernel_size=1) if input_dim != rvq_dim else nn.Identity() | |
| self.output_proj = WNConv1d(rvq_dim, output_dim, kernel_size=1) if rvq_dim != output_dim else nn.Identity() | |
| self.quantizers = nn.ModuleList( | |
| [ | |
| VectorQuantize( | |
| input_dim=rvq_dim, | |
| codebook_size=codebook_size, | |
| codebook_dim=codebook_dim, | |
| decay=decay, | |
| epsilon=epsilon, | |
| threshold_ema_dead=threshold_ema_dead, | |
| kmeans_init=kmeans_init, | |
| kmeans_iters=kmeans_iters, | |
| **kwargs, | |
| ) | |
| for _ in range(num_quantizers) | |
| ] | |
| ) | |
| def forward(self, z, input_length, n_quantizers: int = None): # z: (B, D, T), input_length: (B) | |
| z = self.input_proj(z) | |
| with torch.autocast('cuda', enabled = False): | |
| batch_size, _, max_time = z.shape | |
| mask = torch.arange(max_time, device=z.device).expand(batch_size, max_time) < input_length.unsqueeze(1) # (B, T) | |
| quantized_out = torch.zeros_like(z, dtype=torch.float32) # (B, D, T), ensure fp32 | |
| residual = z.clone().float() # (B, D, T), ensure fp32 | |
| all_commit_losses = [] | |
| all_indices = [] | |
| all_quantized = [] | |
| n_quantizers = n_quantizers or self.num_quantizers | |
| # Randomly decide whether to skip RVQ during training | |
| skip_mask = None | |
| if self.training and torch.is_grad_enabled() and self.skip_rvq_ratio > 0: | |
| # Generate random mask with skip_rvq_ratio probability | |
| skip_mask = torch.rand(batch_size, device=z.device) < self.skip_rvq_ratio # (B,) | |
| # If all samples are skipped, force the first sample to be unskipped | |
| if skip_mask.all(): | |
| skip_mask[0] = False # Ensure at least one sample (index 0) is not skipped | |
| if self.training and torch.is_grad_enabled(): | |
| n_quantizers_tensor = torch.ones((z.shape[0],), dtype=torch.float32, device=z.device) * self.num_quantizers + 1 # (B) | |
| dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],), dtype=torch.float32, device=z.device) # (B) | |
| n_dropout = int(z.shape[0] * self.quantizer_dropout) | |
| n_quantizers_tensor[:n_dropout] = dropout[:n_dropout] # (B) | |
| else: | |
| n_quantizers_tensor = torch.full((z.shape[0],), n_quantizers, dtype=torch.float32, device=z.device) # (B) | |
| for i, quantizer in enumerate(self.quantizers): | |
| if not self.training and i >= n_quantizers: | |
| break | |
| masked_residual = residual * mask.unsqueeze(1) # (B, D, T) | |
| # If skipping RVQ, directly use input value | |
| if self.training and skip_mask is not None and skip_mask.any(): | |
| z_q_i = torch.zeros_like(masked_residual, dtype=torch.float32) # (B, D, T), ensure fp32 | |
| commit_loss_i = torch.zeros(batch_size, device=z.device, dtype=torch.float32) # (B), ensure fp32 | |
| indices_i = torch.zeros(batch_size, max_time, device=z.device, dtype=torch.long) # (B, T) | |
| z_e_i = torch.zeros_like(masked_residual, dtype=torch.float32) # (B, D, T), ensure fp32 | |
| # Quantize non-skipped samples | |
| non_skipped_mask = ~skip_mask # (B) | |
| if non_skipped_mask.any(): | |
| z_q_i_non_skipped, commit_loss_i_non_skipped, _, indices_i_non_skipped, z_e_i_non_skipped = quantizer( | |
| masked_residual[non_skipped_mask].float() # Ensure fp32 | |
| ) | |
| z_q_i[non_skipped_mask] = z_q_i_non_skipped | |
| commit_loss_i[non_skipped_mask] = commit_loss_i_non_skipped | |
| indices_i[non_skipped_mask] = indices_i_non_skipped | |
| z_e_i[non_skipped_mask] = z_e_i_non_skipped | |
| else: | |
| z_q_i, commit_loss_i, _, indices_i, z_e_i = quantizer(masked_residual.float()) # (B, D, T), (B), scalar, (B, T), (B, D', T), ensure fp32 | |
| quantizer_mask = (torch.full((z.shape[0],), i, device=z.device, dtype=torch.float32) < n_quantizers_tensor) # (B) | |
| update_mask = (mask & quantizer_mask.unsqueeze(-1)).unsqueeze(1) # (B, 1, T) | |
| # If skipping, output is directly the input | |
| if skip_mask is not None: | |
| skip_mask_expanded = skip_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1) | |
| z_q_i = torch.where(skip_mask_expanded, masked_residual, z_q_i) # (B, D, T) | |
| commit_loss_i = torch.where(skip_mask, torch.zeros_like(commit_loss_i), commit_loss_i) # (B) | |
| quantized_out = quantized_out + z_q_i * update_mask # (B, D, T) | |
| residual_fp32 = residual.to(dtype=torch.float32) # (B, D, T) | |
| z_q_i_fp32 = z_q_i.to(dtype=torch.float32) # (B, D, T) | |
| residual_fp32 = residual_fp32 - z_q_i_fp32 * update_mask # (B, D, T) | |
| residual = residual_fp32.to(dtype=torch.float32) # (B, D, T), ensure fp32 | |
| valid_mask = mask & quantizer_mask.unsqueeze(-1) # (B, T) | |
| if valid_mask.any(): | |
| commit_loss_i = (commit_loss_i * quantizer_mask).sum() / quantizer_mask.sum() # scalar | |
| else: | |
| commit_loss_i = torch.tensor(0.0, device=z.device, dtype=torch.float32) # scalar, ensure fp32 | |
| all_commit_losses.append(commit_loss_i) # scalar | |
| all_indices.append(indices_i) # (B, T) | |
| all_quantized.append(z_q_i) # (B, D, T) | |
| all_commit_losses = torch.stack(all_commit_losses) # (N) | |
| all_indices = torch.stack(all_indices) # (N, B, T) | |
| all_quantized = torch.stack(all_quantized) # (N, B, D, T) | |
| output_length = input_length # (B) | |
| quantized_out = self.output_proj(quantized_out) | |
| return ( | |
| quantized_out, # (B, D, T) | |
| all_indices, # (N, B, T) | |
| all_commit_losses,# (N) | |
| all_quantized, # (N, B, D, T) | |
| output_length, # (B) | |
| ) | |
| def decode_codes(self, codes): # codes: (nq, B, T) | |
| """Decode codes from multiple quantizers to embeddings. | |
| Args: | |
| codes: Tensor of shape (nq, B, T) containing code indices for each quantizer. | |
| Returns: | |
| emb: Tensor of shape (B, D, T) representing the decoded embeddings. | |
| """ | |
| nq, B, T = codes.shape | |
| device = codes.device | |
| emb = torch.zeros(B, self.rvq_dim, T, device=device, dtype=torch.float32) # (B, D, T) | |
| for i, quantizer in enumerate(self.quantizers[:nq]): | |
| code_i = codes[i] # (B, T) | |
| quantized_i = quantizer.decode_code(code_i) # (B, D', T) | |
| emb += quantized_i # Accumulate quantized embeddings | |
| emb = self.output_proj(emb) # (B, D, T), apply output projection | |
| return emb # (B, D, T) | |
| def ema_inplace(moving_avg, new, decay): | |
| # moving_avg: (codebook_size) or (codebook_size, D'), new: same as moving_avg | |
| """Update exponential moving average in-place""" | |
| moving_avg.data.mul_(decay).add_(new.float(), alpha=(1 - decay)) # ensure fp32 |