Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import typing as tp | |
| import numpy as np | |
| from torchaudio.transforms import MelSpectrogram | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from ..modules import pad_for_conv1d | |
| class MelSpectrogramWrapper(nn.Module): | |
| """Wrapper around MelSpectrogram torchaudio transform providing proper padding | |
| and additional post-processing including log scaling. | |
| Args: | |
| n_mels (int): Number of mel bins. | |
| n_fft (int): Number of fft. | |
| hop_length (int): Hop size. | |
| win_length (int): Window length. | |
| n_mels (int): Number of mel bins. | |
| sample_rate (int): Sample rate. | |
| f_min (float or None): Minimum frequency. | |
| f_max (float or None): Maximum frequency. | |
| log (bool): Whether to scale with log. | |
| normalized (bool): Whether to normalize the melspectrogram. | |
| floor_level (float): Floor level based on human perception (default=1e-5). | |
| """ | |
| def __init__(self, n_fft: int = 1024, hop_length: int = 256, win_length: tp.Optional[int] = None, | |
| n_mels: int = 80, sample_rate: float = 22050, f_min: float = 0.0, f_max: tp.Optional[float] = None, | |
| log: bool = True, normalized: bool = False, floor_level: float = 1e-5): | |
| super().__init__() | |
| self.n_fft = n_fft | |
| hop_length = int(hop_length) | |
| self.hop_length = hop_length | |
| self.mel_transform = MelSpectrogram(n_mels=n_mels, sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, | |
| win_length=win_length, f_min=f_min, f_max=f_max, normalized=normalized, | |
| window_fn=torch.hann_window, center=False) | |
| self.floor_level = floor_level | |
| self.log = log | |
| def forward(self, x): | |
| p = int((self.n_fft - self.hop_length) // 2) | |
| if len(x.shape) == 2: | |
| x = x.unsqueeze(1) | |
| x = F.pad(x, (p, p), "reflect") | |
| # Make sure that all the frames are full. | |
| # The combination of `pad_for_conv1d` and the above padding | |
| # will make the output of size ceil(T / hop). | |
| x = pad_for_conv1d(x, self.n_fft, self.hop_length) | |
| self.mel_transform.to(x.device) | |
| mel_spec = self.mel_transform(x) | |
| B, C, freqs, frame = mel_spec.shape | |
| if self.log: | |
| mel_spec = torch.log10(self.floor_level + mel_spec) | |
| return mel_spec.reshape(B, C * freqs, frame) | |
| class MelSpectrogramL1Loss(torch.nn.Module): | |
| """L1 Loss on MelSpectrogram. | |
| Args: | |
| sample_rate (int): Sample rate. | |
| n_fft (int): Number of fft. | |
| hop_length (int): Hop size. | |
| win_length (int): Window length. | |
| n_mels (int): Number of mel bins. | |
| f_min (float or None): Minimum frequency. | |
| f_max (float or None): Maximum frequency. | |
| log (bool): Whether to scale with log. | |
| normalized (bool): Whether to normalize the melspectrogram. | |
| floor_level (float): Floor level value based on human perception (default=1e-5). | |
| """ | |
| def __init__(self, sample_rate: int, n_fft: int = 1024, hop_length: int = 256, win_length: int = 1024, | |
| n_mels: int = 80, f_min: float = 0.0, f_max: tp.Optional[float] = None, | |
| log: bool = True, normalized: bool = False, floor_level: float = 1e-5): | |
| super().__init__() | |
| self.l1 = torch.nn.L1Loss() | |
| self.melspec = MelSpectrogramWrapper(n_fft=n_fft, hop_length=hop_length, win_length=win_length, | |
| n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max, | |
| log=log, normalized=normalized, floor_level=floor_level) | |
| def forward(self, x, y): | |
| self.melspec.to(x.device) | |
| s_x = self.melspec(x) | |
| s_y = self.melspec(y) | |
| return self.l1(s_x, s_y) | |
| class MultiScaleMelSpectrogramLoss(nn.Module): | |
| """Multi-Scale spectrogram loss (msspec). | |
| Args: | |
| sample_rate (int): Sample rate. | |
| range_start (int): Power of 2 to use for the first scale. | |
| range_stop (int): Power of 2 to use for the last scale. | |
| n_mels (int): Number of mel bins. | |
| f_min (float): Minimum frequency. | |
| f_max (float or None): Maximum frequency. | |
| normalized (bool): Whether to normalize the melspectrogram. | |
| alphas (bool): Whether to use alphas as coefficients or not. | |
| floor_level (float): Floor level value based on human perception (default=1e-5). | |
| """ | |
| def __init__(self, sample_rate: int, range_start: int = 6, range_end: int = 11, | |
| n_mels: int = 64, f_min: float = 0.0, f_max: tp.Optional[float] = None, | |
| normalized: bool = False, alphas: bool = True, floor_level: float = 1e-5): | |
| super().__init__() | |
| l1s = list() | |
| l2s = list() | |
| self.alphas = list() | |
| self.total = 0 | |
| self.normalized = normalized | |
| for i in range(range_start, range_end): | |
| l1s.append( | |
| MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i, | |
| n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max, | |
| log=False, normalized=normalized, floor_level=floor_level)) | |
| l2s.append( | |
| MelSpectrogramWrapper(n_fft=2 ** i, hop_length=(2 ** i) / 4, win_length=2 ** i, | |
| n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max, | |
| log=True, normalized=normalized, floor_level=floor_level)) | |
| if alphas: | |
| self.alphas.append(np.sqrt(2 ** i - 1)) | |
| else: | |
| self.alphas.append(1) | |
| self.total += self.alphas[-1] + 1 | |
| self.l1s = nn.ModuleList(l1s) | |
| self.l2s = nn.ModuleList(l2s) | |
| def forward(self, x, y): | |
| loss = 0.0 | |
| self.l1s.to(x.device) | |
| self.l2s.to(x.device) | |
| for i in range(len(self.alphas)): | |
| s_x_1 = self.l1s[i](x) | |
| s_y_1 = self.l1s[i](y) | |
| s_x_2 = self.l2s[i](x) | |
| s_y_2 = self.l2s[i](y) | |
| loss += F.l1_loss(s_x_1, s_y_1) + self.alphas[i] * F.mse_loss(s_x_2, s_y_2) | |
| if self.normalized: | |
| loss = loss / self.total | |
| return loss | |