Spaces:
Runtime error
Runtime error
| """https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/vqgan/modules/wavenet.py""" | |
| import math | |
| from typing import Optional | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| class Mish(nn.Module): | |
| def forward(self, x): | |
| return x * torch.tanh(F.softplus(x)) | |
| class DiffusionEmbedding(nn.Module): | |
| """Diffusion Step Embedding""" | |
| def __init__(self, d_denoiser): | |
| super(DiffusionEmbedding, self).__init__() | |
| self.dim = d_denoiser | |
| def forward(self, x): | |
| device = x.device | |
| half_dim = self.dim // 2 | |
| emb = math.log(10000) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | |
| emb = x[:, None] * emb[None, :] | |
| emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
| return emb | |
| class LinearNorm(nn.Module): | |
| """LinearNorm Projection""" | |
| def __init__(self, in_features, out_features, bias=False): | |
| super(LinearNorm, self).__init__() | |
| self.linear = nn.Linear(in_features, out_features, bias) | |
| nn.init.xavier_uniform_(self.linear.weight) | |
| if bias: | |
| nn.init.constant_(self.linear.bias, 0.0) | |
| def forward(self, x): | |
| x = self.linear(x) | |
| return x | |
| class ConvNorm(nn.Module): | |
| """1D Convolution""" | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size=1, | |
| stride=1, | |
| padding=None, | |
| dilation=1, | |
| bias=True, | |
| w_init_gain="linear", | |
| ): | |
| super(ConvNorm, self).__init__() | |
| if padding is None: | |
| assert kernel_size % 2 == 1 | |
| padding = int(dilation * (kernel_size - 1) / 2) | |
| self.conv = nn.Conv1d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding, | |
| dilation=dilation, | |
| bias=bias, | |
| ) | |
| nn.init.kaiming_normal_(self.conv.weight) | |
| def forward(self, signal): | |
| conv_signal = self.conv(signal) | |
| return conv_signal | |
| class ResidualBlock(nn.Module): | |
| """Residual Block""" | |
| def __init__( | |
| self, | |
| residual_channels, | |
| use_linear_bias=False, | |
| dilation=1, | |
| condition_channels=None, | |
| ): | |
| super(ResidualBlock, self).__init__() | |
| self.conv_layer = ConvNorm( | |
| residual_channels, | |
| 2 * residual_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=dilation, | |
| dilation=dilation, | |
| ) | |
| if condition_channels is not None: | |
| self.diffusion_projection = LinearNorm( | |
| residual_channels, residual_channels, use_linear_bias | |
| ) | |
| self.condition_projection = ConvNorm( | |
| condition_channels, 2 * residual_channels, kernel_size=1 | |
| ) | |
| self.output_projection = ConvNorm( | |
| residual_channels, 2 * residual_channels, kernel_size=1 | |
| ) | |
| def forward(self, x, condition=None, diffusion_step=None): | |
| y = x | |
| if diffusion_step is not None: | |
| diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) | |
| y = y + diffusion_step | |
| y = self.conv_layer(y) | |
| if condition is not None: | |
| condition = self.condition_projection(condition) | |
| y = y + condition | |
| gate, filter = torch.chunk(y, 2, dim=1) | |
| y = torch.sigmoid(gate) * torch.tanh(filter) | |
| y = self.output_projection(y) | |
| residual, skip = torch.chunk(y, 2, dim=1) | |
| return (x + residual) / math.sqrt(2.0), skip | |
| class WaveNet(nn.Module): | |
| def __init__( | |
| self, | |
| input_channels: Optional[int] = None, | |
| output_channels: Optional[int] = None, | |
| residual_channels: int = 512, | |
| residual_layers: int = 20, | |
| dilation_cycle: Optional[int] = 4, | |
| is_diffusion: bool = False, | |
| condition_channels: Optional[int] = None, | |
| ): | |
| super().__init__() | |
| # Input projection | |
| self.input_projection = None | |
| if input_channels is not None and input_channels != residual_channels: | |
| self.input_projection = ConvNorm( | |
| input_channels, residual_channels, kernel_size=1 | |
| ) | |
| if input_channels is None: | |
| input_channels = residual_channels | |
| self.input_channels = input_channels | |
| # Residual layers | |
| self.residual_layers = nn.ModuleList( | |
| [ | |
| ResidualBlock( | |
| residual_channels=residual_channels, | |
| use_linear_bias=False, | |
| dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1, | |
| condition_channels=condition_channels, | |
| ) | |
| for i in range(residual_layers) | |
| ] | |
| ) | |
| # Skip projection | |
| self.skip_projection = ConvNorm( | |
| residual_channels, residual_channels, kernel_size=1 | |
| ) | |
| # Output projection | |
| self.output_projection = None | |
| if output_channels is not None and output_channels != residual_channels: | |
| self.output_projection = ConvNorm( | |
| residual_channels, output_channels, kernel_size=1 | |
| ) | |
| if is_diffusion: | |
| self.diffusion_embedding = DiffusionEmbedding(residual_channels) | |
| self.mlp = nn.Sequential( | |
| LinearNorm(residual_channels, residual_channels * 4, False), | |
| Mish(), | |
| LinearNorm(residual_channels * 4, residual_channels, False), | |
| ) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, (nn.Conv1d, nn.Linear)): | |
| nn.init.trunc_normal_(m.weight, std=0.02) | |
| if getattr(m, "bias", None) is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x, t=None, condition=None): | |
| if self.input_projection is not None: | |
| x = self.input_projection(x) | |
| x = F.silu(x) | |
| if t is not None: | |
| t = self.diffusion_embedding(t) | |
| t = self.mlp(t) | |
| skip = [] | |
| for layer in self.residual_layers: | |
| x, skip_connection = layer(x, condition, t) | |
| skip.append(skip_connection) | |
| x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers)) | |
| x = self.skip_projection(x) | |
| if self.output_projection is not None: | |
| x = F.silu(x) | |
| x = self.output_projection(x) | |
| return x | |