Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import random | |
| import torch | |
| from audiocraft.losses import ( | |
| MelSpectrogramL1Loss, | |
| MultiScaleMelSpectrogramLoss, | |
| MRSTFTLoss, | |
| SISNR, | |
| STFTLoss, | |
| ) | |
| from audiocraft.losses.loudnessloss import TFLoudnessRatio | |
| from audiocraft.losses.wmloss import WMMbLoss | |
| from tests.common_utils.wav_utils import get_white_noise | |
| def test_mel_l1_loss(): | |
| N, C, T = 2, 2, random.randrange(1000, 100_000) | |
| t1 = torch.randn(N, C, T) | |
| t2 = torch.randn(N, C, T) | |
| mel_l1 = MelSpectrogramL1Loss(sample_rate=22_050) | |
| loss = mel_l1(t1, t2) | |
| loss_same = mel_l1(t1, t1) | |
| assert isinstance(loss, torch.Tensor) | |
| assert isinstance(loss_same, torch.Tensor) | |
| assert loss_same.item() == 0.0 | |
| def test_msspec_loss(): | |
| N, C, T = 2, 2, random.randrange(1000, 100_000) | |
| t1 = torch.randn(N, C, T) | |
| t2 = torch.randn(N, C, T) | |
| msspec = MultiScaleMelSpectrogramLoss(sample_rate=22_050) | |
| loss = msspec(t1, t2) | |
| loss_same = msspec(t1, t1) | |
| assert isinstance(loss, torch.Tensor) | |
| assert isinstance(loss_same, torch.Tensor) | |
| assert loss_same.item() == 0.0 | |
| def test_mrstft_loss(): | |
| N, C, T = 2, 2, random.randrange(1000, 100_000) | |
| t1 = torch.randn(N, C, T) | |
| t2 = torch.randn(N, C, T) | |
| mrstft = MRSTFTLoss() | |
| loss = mrstft(t1, t2) | |
| assert isinstance(loss, torch.Tensor) | |
| def test_sisnr_loss(): | |
| N, C, T = 2, 2, random.randrange(1000, 100_000) | |
| t1 = torch.randn(N, C, T) | |
| t2 = torch.randn(N, C, T) | |
| sisnr = SISNR() | |
| loss = sisnr(t1, t2) | |
| assert isinstance(loss, torch.Tensor) | |
| def test_stft_loss(): | |
| N, C, T = 2, 2, random.randrange(1000, 100_000) | |
| t1 = torch.randn(N, C, T) | |
| t2 = torch.randn(N, C, T) | |
| mrstft = STFTLoss() | |
| loss = mrstft(t1, t2) | |
| assert isinstance(loss, torch.Tensor) | |
| def test_wm_loss(): | |
| N, nbits, T = 2, 16, random.randrange(1000, 100_000) | |
| positive = torch.randn(N, 2 + nbits, T) | |
| t2 = torch.randn(N, 1, T) | |
| message = torch.randn(N, nbits) | |
| wmloss = WMMbLoss(0.3, "mse") | |
| loss = wmloss(positive, None, t2, message) | |
| assert isinstance(loss, torch.Tensor) | |
| def test_loudness_loss(): | |
| sr = 16_000 | |
| duration = 1.0 | |
| wav = get_white_noise(1, int(sr * duration)).unsqueeze(0) | |
| tflrloss = TFLoudnessRatio(sample_rate=sr, n_bands=1) | |
| loss = tflrloss(wav, wav) | |
| assert isinstance(loss, torch.Tensor) | |