Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import typing as tp | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from ...modules import NormConv2d | |
| from .base import MultiDiscriminator, MultiDiscriminatorOutputType | |
| def get_padding(kernel_size: int, dilation: int = 1) -> int: | |
| return int((kernel_size * dilation - dilation) / 2) | |
| class PeriodDiscriminator(nn.Module): | |
| """Period sub-discriminator. | |
| Args: | |
| period (int): Period between samples of audio. | |
| in_channels (int): Number of input channels. | |
| out_channels (int): Number of output channels. | |
| n_layers (int): Number of convolutional layers. | |
| kernel_sizes (list of int): Kernel sizes for convolutions. | |
| stride (int): Stride for convolutions. | |
| filters (int): Initial number of filters in convolutions. | |
| filters_scale (int): Multiplier of number of filters as we increase depth. | |
| max_filters (int): Maximum number of filters. | |
| norm (str): Normalization method. | |
| activation (str): Activation function. | |
| activation_params (dict): Parameters to provide to the activation function. | |
| """ | |
| def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1, | |
| n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3, | |
| filters: int = 8, filters_scale: int = 4, max_filters: int = 1024, | |
| norm: str = 'weight_norm', activation: str = 'LeakyReLU', | |
| activation_params: dict = {'negative_slope': 0.2}): | |
| super().__init__() | |
| self.period = period | |
| self.n_layers = n_layers | |
| self.activation = getattr(torch.nn, activation)(**activation_params) | |
| self.convs = nn.ModuleList() | |
| in_chs = in_channels | |
| for i in range(self.n_layers): | |
| out_chs = min(filters * (filters_scale ** (i + 1)), max_filters) | |
| eff_stride = 1 if i == self.n_layers - 1 else stride | |
| self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1), | |
| padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm)) | |
| in_chs = out_chs | |
| self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1, | |
| padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm) | |
| def forward(self, x: torch.Tensor): | |
| fmap = [] | |
| # 1d to 2d | |
| b, c, t = x.shape | |
| if t % self.period != 0: # pad first | |
| n_pad = self.period - (t % self.period) | |
| x = F.pad(x, (0, n_pad), 'reflect') | |
| t = t + n_pad | |
| x = x.view(b, c, t // self.period, self.period) | |
| for conv in self.convs: | |
| x = conv(x) | |
| x = self.activation(x) | |
| fmap.append(x) | |
| x = self.conv_post(x) | |
| fmap.append(x) | |
| # x = torch.flatten(x, 1, -1) | |
| return x, fmap | |
| class MultiPeriodDiscriminator(MultiDiscriminator): | |
| """Multi-Period (MPD) Discriminator. | |
| Args: | |
| in_channels (int): Number of input channels. | |
| out_channels (int): Number of output channels. | |
| periods (Sequence[int]): Periods between samples of audio for the sub-discriminators. | |
| **kwargs: Additional args for `PeriodDiscriminator` | |
| """ | |
| def __init__(self, in_channels: int = 1, out_channels: int = 1, | |
| periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs): | |
| super().__init__() | |
| self.discriminators = nn.ModuleList([ | |
| PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods | |
| ]) | |
| def num_discriminators(self): | |
| return len(self.discriminators) | |
| def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: | |
| logits = [] | |
| fmaps = [] | |
| for disc in self.discriminators: | |
| logit, fmap = disc(x) | |
| logits.append(logit) | |
| fmaps.append(fmap) | |
| return logits, fmaps | |