Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import numpy as np | |
| import torch | |
| from scipy.interpolate import interp1d | |
| import torchaudio | |
| from fairseq.tasks.text_to_speech import ( | |
| batch_compute_distortion, compute_rms_dist | |
| ) | |
| def batch_mel_spectral_distortion( | |
| y1, y2, sr, normalize_type="path", mel_fn=None | |
| ): | |
| """ | |
| https://arxiv.org/pdf/2011.03568.pdf | |
| Same as Mel Cepstral Distortion, but computed on log-mel spectrograms. | |
| """ | |
| if mel_fn is None or mel_fn.sample_rate != sr: | |
| mel_fn = torchaudio.transforms.MelSpectrogram( | |
| sr, n_fft=int(0.05 * sr), win_length=int(0.05 * sr), | |
| hop_length=int(0.0125 * sr), f_min=20, n_mels=80, | |
| window_fn=torch.hann_window | |
| ).to(y1[0].device) | |
| offset = 1e-6 | |
| return batch_compute_distortion( | |
| y1, y2, sr, lambda y: torch.log(mel_fn(y) + offset).transpose(-1, -2), | |
| compute_rms_dist, normalize_type | |
| ) | |
| # This code is based on | |
| # "https://github.com/bastibe/MAPS-Scripts/blob/master/helper.py" | |
| def _same_t_in_true_and_est(func): | |
| def new_func(true_t, true_f, est_t, est_f): | |
| assert type(true_t) is np.ndarray | |
| assert type(true_f) is np.ndarray | |
| assert type(est_t) is np.ndarray | |
| assert type(est_f) is np.ndarray | |
| interpolated_f = interp1d( | |
| est_t, est_f, bounds_error=False, kind='nearest', fill_value=0 | |
| )(true_t) | |
| return func(true_t, true_f, true_t, interpolated_f) | |
| return new_func | |
| def gross_pitch_error(true_t, true_f, est_t, est_f): | |
| """The relative frequency in percent of pitch estimates that are | |
| outside a threshold around the true pitch. Only frames that are | |
| considered pitched by both the ground truth and the estimator (if | |
| applicable) are considered. | |
| """ | |
| correct_frames = _true_voiced_frames(true_t, true_f, est_t, est_f) | |
| gross_pitch_error_frames = _gross_pitch_error_frames( | |
| true_t, true_f, est_t, est_f | |
| ) | |
| return np.sum(gross_pitch_error_frames) / np.sum(correct_frames) | |
| def _gross_pitch_error_frames(true_t, true_f, est_t, est_f, eps=1e-8): | |
| voiced_frames = _true_voiced_frames(true_t, true_f, est_t, est_f) | |
| true_f_p_eps = [x + eps for x in true_f] | |
| pitch_error_frames = np.abs(est_f / true_f_p_eps - 1) > 0.2 | |
| return voiced_frames & pitch_error_frames | |
| def _true_voiced_frames(true_t, true_f, est_t, est_f): | |
| return (est_f != 0) & (true_f != 0) | |
| def _voicing_decision_error_frames(true_t, true_f, est_t, est_f): | |
| return (est_f != 0) != (true_f != 0) | |
| def f0_frame_error(true_t, true_f, est_t, est_f): | |
| gross_pitch_error_frames = _gross_pitch_error_frames( | |
| true_t, true_f, est_t, est_f | |
| ) | |
| voicing_decision_error_frames = _voicing_decision_error_frames( | |
| true_t, true_f, est_t, est_f | |
| ) | |
| return (np.sum(gross_pitch_error_frames) + | |
| np.sum(voicing_decision_error_frames)) / (len(true_t)) | |
| def voicing_decision_error(true_t, true_f, est_t, est_f): | |
| voicing_decision_error_frames = _voicing_decision_error_frames( | |
| true_t, true_f, est_t, est_f | |
| ) | |
| return np.sum(voicing_decision_error_frames) / (len(true_t)) | |