Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # | |
| # Copyright (c) ByteDance, Inc. and its affiliates. | |
| # Copyright (c) Chutong Meng | |
| # | |
| # This source code is licensed under the CC BY-NC license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # Based on AudioDec (https://github.com/facebookresearch/AudioDec) | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class VectorQuantize(nn.Module): | |
| """Vector quantization w/ exponential moving averages (EMA)""" | |
| def __init__( | |
| self, | |
| dim: int, | |
| codebook_size: int, | |
| decay=0.8, | |
| commitment=1.0, | |
| eps=1e-5, | |
| n_embed=None, | |
| ): | |
| super().__init__() | |
| n_embed = self.default(n_embed, codebook_size) | |
| self.dim = dim | |
| self.n_embed = n_embed | |
| self.decay = decay | |
| self.eps = eps | |
| self.commitment = commitment | |
| embed = torch.randn(dim, n_embed) | |
| self.register_buffer("embed", embed) | |
| self.register_buffer("cluster_size", torch.zeros(n_embed)) | |
| self.register_buffer("embed_avg", embed.clone()) | |
| def codebook(self): | |
| return self.embed.transpose(0, 1) | |
| def exists(self, val): | |
| return val is not None | |
| def default(self, val, d): | |
| return val if self.exists(val) else d | |
| def ema_inplace(self, moving_avg, new, decay): | |
| moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) | |
| def laplace_smoothing(self, x, n_categories, eps=1e-5): | |
| return (x + eps) / (x.sum() + n_categories * eps) | |
| def forward(self, input): | |
| dtype = input.dtype | |
| flatten = input.reshape(-1, self.dim) | |
| dist = ( | |
| flatten.pow(2).sum(1, keepdim=True) | |
| - 2 * flatten @ self.embed | |
| + self.embed.pow(2).sum(0, keepdim=True) | |
| ) | |
| _, embed_ind = (-dist).max(1) | |
| embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype) | |
| embed_ind = embed_ind.view(*input.shape[:-1]) | |
| quantize = F.embedding(embed_ind, self.embed.transpose(0, 1)) | |
| if self.training: | |
| self.ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) | |
| embed_sum = flatten.transpose(0, 1) @ embed_onehot | |
| self.ema_inplace(self.embed_avg, embed_sum, self.decay) | |
| cluster_size = ( | |
| self.laplace_smoothing(self.cluster_size, self.n_embed, self.eps) | |
| * self.cluster_size.sum() | |
| ) | |
| embed_normalized = self.embed_avg / cluster_size.unsqueeze(0) | |
| self.embed.data.copy_(embed_normalized) | |
| loss = F.mse_loss(quantize.detach(), input) * self.commitment | |
| quantize = input + (quantize - input).detach() | |
| avg_probs = torch.mean(embed_onehot, dim=0) | |
| perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) | |
| return quantize, loss, perplexity | |
| def forward_index(self, input): | |
| dtype = input.dtype | |
| flatten = input.reshape(-1, self.dim) | |
| dist = ( | |
| flatten.pow(2).sum(1, keepdim=True) | |
| - 2 * flatten @ self.embed | |
| + self.embed.pow(2).sum(0, keepdim=True) | |
| ) | |
| _, embed_ind = (-dist).max(1) | |
| embed_onehot = F.one_hot(embed_ind, self.n_embed).type(dtype) | |
| embed_ind = embed_ind.view(*input.shape[:-1]) | |
| quantize = F.embedding(embed_ind, self.embed.transpose(0, 1)) | |
| quantize = input + (quantize - input).detach() | |
| return quantize, embed_ind | |
| class ResidualVQ(nn.Module): | |
| """Residual VQ following algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf""" | |
| def __init__(self, *, num_quantizers, **kwargs): | |
| super().__init__() | |
| self.layers = nn.ModuleList( | |
| [VectorQuantize(**kwargs) for _ in range(num_quantizers)] | |
| ) | |
| def forward(self, x): | |
| quantized_out = 0.0 | |
| residual = x | |
| all_losses = [] | |
| all_perplexities = [] | |
| for layer in self.layers: | |
| quantized, loss, perplexity = layer(residual) | |
| # Issue: https://github.com/lucidrains/vector-quantize-pytorch/issues/33 | |
| # We found considering only the 1st layer VQ's graident results in better performance | |
| # residual = residual - quantized.detach() # considering all layers' graidents | |
| residual = ( | |
| residual - quantized | |
| ) # considering only the first layer's graident | |
| quantized_out = quantized_out + quantized | |
| all_losses.append(loss) | |
| all_perplexities.append(perplexity) | |
| all_losses, all_perplexities = map(torch.stack, (all_losses, all_perplexities)) | |
| return quantized_out, all_losses, all_perplexities | |
| def forward_index(self, x, flatten_idx=False): | |
| """ | |
| all_indices: [num_of_quantizers, B, T] | |
| """ | |
| quantized_out = 0.0 | |
| residual = x | |
| all_indices = [] | |
| for i, layer in enumerate(self.layers): | |
| quantized, indices = layer.forward_index(residual) | |
| # residual = residual - quantized.detach() | |
| residual = residual - quantized | |
| quantized_out = quantized_out + quantized | |
| if flatten_idx: | |
| indices += self.codebook_size * i | |
| all_indices.append(indices) | |
| all_indices = torch.stack(all_indices) | |
| return quantized_out, all_indices | |
| def initial(self): | |
| self.codebook = [] | |
| for layer in self.layers: | |
| self.codebook.append(layer.codebook) | |
| self.codebook_size = self.codebook[0].size(0) | |
| self.codebook = torch.stack(self.codebook) | |
| self.codebook = self.codebook.reshape(-1, self.codebook.size(-1)) | |
| def lookup(self, indices): | |
| quantized_out = F.embedding(indices, self.codebook) # Num x T x C | |
| return torch.sum(quantized_out, dim=0, keepdim=True) | |
| class Quantizer(nn.Module): | |
| def __init__( | |
| self, | |
| code_dim: int, | |
| codebook_num: int, | |
| codebook_size: int, | |
| ): | |
| super().__init__() | |
| self.codebook = ResidualVQ( | |
| dim=code_dim, num_quantizers=codebook_num, codebook_size=codebook_size | |
| ) | |
| def initial(self): | |
| self.codebook.initial() | |
| def forward(self, z): | |
| zq, vqloss, perplexity = self.codebook(z.transpose(2, 1)) | |
| zq = zq.transpose(2, 1) | |
| return zq, vqloss, perplexity | |
| def inference(self, z): | |
| zq, indices = self.codebook.forward_index(z.transpose(2, 1)) | |
| zq = zq.transpose(2, 1) | |
| return zq, indices | |
| def encode(self, z): | |
| zq, indices = self.codebook.forward_index(z.transpose(2, 1), flatten_idx=True) | |
| return zq, indices | |
| def decode(self, indices): | |
| z = self.codebook.lookup(indices) | |
| return z | |
| class Conv1d1x1(nn.Conv1d): | |
| """1x1 Conv1d.""" | |
| def __init__(self, in_channels, out_channels, bias=True): | |
| super(Conv1d1x1, self).__init__( | |
| in_channels, out_channels, kernel_size=1, bias=bias | |
| ) | |
| class Conv1d(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| stride: int = 1, | |
| padding: int = -1, | |
| dilation: int = 1, | |
| groups: int = 1, | |
| bias: bool = True, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.kernel_size = kernel_size | |
| if padding < 0: | |
| padding = (kernel_size - 1) // 2 * dilation | |
| self.dilation = dilation | |
| self.conv = nn.Conv1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| groups=groups, | |
| bias=bias, | |
| ) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x (Tensor): Float tensor variable with the shape (B, C, T). | |
| Returns: | |
| Tensor: Float tensor variable with the shape (B, C, T). | |
| """ | |
| x = self.conv(x) | |
| return x | |
| class ConvTranspose1d(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| stride: int, | |
| padding=-1, | |
| output_padding=-1, | |
| groups=1, | |
| bias=True, | |
| ): | |
| super().__init__() | |
| if padding < 0: | |
| padding = (stride + 1) // 2 | |
| if output_padding < 0: | |
| output_padding = 1 if stride % 2 else 0 | |
| self.deconv = nn.ConvTranspose1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| output_padding=output_padding, | |
| groups=groups, | |
| bias=bias, | |
| ) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x (Tensor): Float tensor variable with the shape (B, C, T). | |
| Returns: | |
| Tensor: Float tensor variable with the shape (B, C', T'). | |
| """ | |
| x = self.deconv(x) | |
| return x | |
| class ResidualUnit(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size=3, | |
| dilation=1, | |
| bias=False, | |
| nonlinear_activation="ELU", | |
| nonlinear_activation_params={}, | |
| ): | |
| super().__init__() | |
| self.activation = getattr(nn, nonlinear_activation)( | |
| **nonlinear_activation_params | |
| ) | |
| self.conv1 = Conv1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| dilation=dilation, | |
| bias=bias, | |
| ) | |
| self.conv2 = Conv1d1x1(out_channels, out_channels, bias) | |
| def forward(self, x): | |
| y = self.conv1(self.activation(x)) | |
| y = self.conv2(self.activation(y)) | |
| return x + y | |
| class Projector(nn.Module): | |
| def __init__( | |
| self, input_channels: int, code_dim: int, kernel_size=3, stride=1, bias=False | |
| ): | |
| super().__init__() | |
| self.project = Conv1d( | |
| input_channels, code_dim, kernel_size=kernel_size, stride=stride, bias=bias | |
| ) | |
| def forward(self, x): | |
| return self.project(x) | |
| class EncoderBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| stride: int, | |
| dilations=(1, 1), | |
| unit_kernel_size=3, | |
| bias=True, | |
| ): | |
| super().__init__() | |
| self.res_units = torch.nn.ModuleList() | |
| for dilation in dilations: | |
| self.res_units += [ | |
| ResidualUnit( | |
| in_channels, | |
| in_channels, | |
| kernel_size=unit_kernel_size, | |
| dilation=dilation, | |
| ) | |
| ] | |
| self.num_res = len(self.res_units) | |
| self.conv = Conv1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=( | |
| 3 if stride == 1 else (2 * stride) | |
| ), # special case: stride=1, do not use kernel=2 | |
| stride=stride, | |
| bias=bias, | |
| ) | |
| def forward(self, x): | |
| for idx in range(self.num_res): | |
| x = self.res_units[idx](x) | |
| x = self.conv(x) | |
| return x | |
| class Encoder(nn.Module): | |
| def __init__( | |
| self, | |
| input_channels: int, | |
| encode_channels: int, | |
| channel_ratios=(1, 1), | |
| strides=(1, 1), | |
| kernel_size=3, | |
| bias=True, | |
| block_dilations=(1, 1), | |
| unit_kernel_size=3, | |
| ): | |
| super().__init__() | |
| assert len(channel_ratios) == len(strides) | |
| self.conv = Conv1d( | |
| in_channels=input_channels, | |
| out_channels=encode_channels, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| bias=False, | |
| ) | |
| self.conv_blocks = torch.nn.ModuleList() | |
| in_channels = encode_channels | |
| for idx, stride in enumerate(strides): | |
| out_channels = int(encode_channels * channel_ratios[idx]) # could be float | |
| self.conv_blocks += [ | |
| EncoderBlock( | |
| in_channels, | |
| out_channels, | |
| stride, | |
| dilations=block_dilations, | |
| unit_kernel_size=unit_kernel_size, | |
| bias=bias, | |
| ) | |
| ] | |
| in_channels = out_channels | |
| self.num_blocks = len(self.conv_blocks) | |
| self.out_channels = out_channels | |
| def forward(self, x): | |
| x = self.conv(x) | |
| for i in range(self.num_blocks): | |
| x = self.conv_blocks[i](x) | |
| return x | |
| class DecoderBlock(nn.Module): | |
| """Decoder block (no up-sampling)""" | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| stride: int, | |
| dilations=(1, 1), | |
| unit_kernel_size=3, | |
| bias=True, | |
| ): | |
| super().__init__() | |
| if stride == 1: | |
| self.conv = Conv1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=3, # fix kernel=3 when stride=1 for unchanged shape | |
| stride=stride, | |
| bias=bias, | |
| ) | |
| else: | |
| self.conv = ConvTranspose1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=(2 * stride), | |
| stride=stride, | |
| bias=bias, | |
| ) | |
| self.res_units = torch.nn.ModuleList() | |
| for idx, dilation in enumerate(dilations): | |
| self.res_units += [ | |
| ResidualUnit( | |
| out_channels, | |
| out_channels, | |
| kernel_size=unit_kernel_size, | |
| dilation=dilation, | |
| ) | |
| ] | |
| self.num_res = len(self.res_units) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| for idx in range(self.num_res): | |
| x = self.res_units[idx](x) | |
| return x | |
| class Decoder(nn.Module): | |
| def __init__( | |
| self, | |
| code_dim: int, | |
| output_channels: int, | |
| decode_channels: int, | |
| channel_ratios=(1, 1), | |
| strides=(1, 1), | |
| kernel_size=3, | |
| bias=True, | |
| block_dilations=(1, 1), | |
| unit_kernel_size=3, | |
| ): | |
| super().__init__() | |
| assert len(channel_ratios) == len(strides) | |
| self.conv1 = Conv1d( | |
| in_channels=code_dim, | |
| out_channels=int(decode_channels * channel_ratios[0]), | |
| kernel_size=kernel_size, | |
| stride=1, | |
| bias=False, | |
| ) | |
| self.conv_blocks = torch.nn.ModuleList() | |
| for idx, stride in enumerate(strides): | |
| in_channels = int(decode_channels * channel_ratios[idx]) | |
| if idx < (len(channel_ratios) - 1): | |
| out_channels = int(decode_channels * channel_ratios[idx + 1]) | |
| else: | |
| out_channels = decode_channels | |
| self.conv_blocks += [ | |
| DecoderBlock( | |
| in_channels, | |
| out_channels, | |
| stride, | |
| dilations=block_dilations, | |
| unit_kernel_size=unit_kernel_size, | |
| bias=bias, | |
| ) | |
| ] | |
| self.num_blocks = len(self.conv_blocks) | |
| self.conv2 = Conv1d(out_channels, output_channels, kernel_size, 1, bias=False) | |
| def forward(self, z): | |
| x = self.conv1(z) | |
| for i in range(self.num_blocks): | |
| x = self.conv_blocks[i](x) | |
| x = self.conv2(x) | |
| return x | |
| class VevoRepCodec(nn.Module): | |
| def __init__( | |
| self, | |
| input_channels=768, | |
| output_channels=768, | |
| encode_channels=768, | |
| decode_channels=768, | |
| code_dim=768, | |
| codebook_num=1, | |
| codebook_size=1024, | |
| bias=True, | |
| enc_ratios=(1, 1), | |
| dec_ratios=(1, 1), | |
| enc_strides=(1, 1), | |
| dec_strides=(1, 1), | |
| enc_kernel_size=3, | |
| dec_kernel_size=3, | |
| enc_block_dilations=(1, 1), | |
| enc_block_kernel_size=3, | |
| dec_block_dilations=(1, 1), | |
| dec_block_kernel_size=3, | |
| ): | |
| super().__init__() | |
| self.input_channels = input_channels | |
| self.encoder = Encoder( | |
| input_channels=input_channels, | |
| encode_channels=encode_channels, | |
| channel_ratios=enc_ratios, | |
| strides=enc_strides, | |
| kernel_size=enc_kernel_size, | |
| bias=bias, | |
| block_dilations=enc_block_dilations, | |
| unit_kernel_size=enc_block_kernel_size, | |
| ) | |
| self.decoder = Decoder( | |
| code_dim=code_dim, | |
| output_channels=output_channels, | |
| decode_channels=decode_channels, | |
| channel_ratios=dec_ratios, | |
| strides=dec_strides, | |
| kernel_size=dec_kernel_size, | |
| bias=bias, | |
| block_dilations=dec_block_dilations, | |
| unit_kernel_size=dec_block_kernel_size, | |
| ) | |
| self.projector = Projector( | |
| input_channels=self.encoder.out_channels, | |
| code_dim=code_dim, | |
| kernel_size=3, | |
| stride=1, | |
| bias=False, | |
| ) | |
| self.quantizer = Quantizer( | |
| code_dim=code_dim, codebook_num=codebook_num, codebook_size=codebook_size | |
| ) | |
| def forward(self, x): | |
| x = self.encoder(x) | |
| z = self.projector(x) | |
| zq, vqloss, perplexity = self.quantizer(z) | |
| y = self.decoder(zq) | |
| return y, zq, z, vqloss, perplexity | |