Spaces:
Running
Running
| """Search a good noise schedule for WaveGrad for a given number of inference iterations""" | |
| import argparse | |
| from itertools import product as cartesian_product | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from TTS.config import load_config | |
| from TTS.utils.audio import AudioProcessor | |
| from TTS.vocoder.datasets.preprocess import load_wav_data | |
| from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset | |
| from TTS.vocoder.models import setup_model | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model_path", type=str, help="Path to model checkpoint.") | |
| parser.add_argument("--config_path", type=str, help="Path to model config file.") | |
| parser.add_argument("--data_path", type=str, help="Path to data directory.") | |
| parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.") | |
| parser.add_argument( | |
| "--num_iter", | |
| type=int, | |
| help="Number of model inference iterations that you like to optimize noise schedule for.", | |
| ) | |
| parser.add_argument("--use_cuda", action="store_true", help="enable CUDA.") | |
| parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.") | |
| parser.add_argument( | |
| "--search_depth", | |
| type=int, | |
| default=3, | |
| help="Search granularity. Increasing this increases the run-time exponentially.", | |
| ) | |
| # load config | |
| args = parser.parse_args() | |
| config = load_config(args.config_path) | |
| # setup audio processor | |
| ap = AudioProcessor(**config.audio) | |
| # load dataset | |
| _, train_data = load_wav_data(args.data_path, 0) | |
| train_data = train_data[: args.num_samples] | |
| dataset = WaveGradDataset( | |
| ap=ap, | |
| items=train_data, | |
| seq_len=-1, | |
| hop_len=ap.hop_length, | |
| pad_short=config.pad_short, | |
| conv_pad=config.conv_pad, | |
| is_training=True, | |
| return_segments=False, | |
| use_noise_augment=False, | |
| use_cache=False, | |
| verbose=True, | |
| ) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=1, | |
| shuffle=False, | |
| collate_fn=dataset.collate_full_clips, | |
| drop_last=False, | |
| num_workers=config.num_loader_workers, | |
| pin_memory=False, | |
| ) | |
| # setup the model | |
| model = setup_model(config) | |
| if args.use_cuda: | |
| model.cuda() | |
| # setup optimization parameters | |
| base_values = sorted(10 * np.random.uniform(size=args.search_depth)) | |
| print(f" > base values: {base_values}") | |
| exponents = 10 ** np.linspace(-6, -1, num=args.num_iter) | |
| best_error = float("inf") | |
| best_schedule = None # pylint: disable=C0103 | |
| total_search_iter = len(base_values) ** args.num_iter | |
| for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter): | |
| beta = exponents * base | |
| model.compute_noise_level(beta) | |
| for data in loader: | |
| mel, audio = data | |
| y_hat = model.inference(mel.cuda() if args.use_cuda else mel) | |
| if args.use_cuda: | |
| y_hat = y_hat.cpu() | |
| y_hat = y_hat.numpy() | |
| mel_hat = [] | |
| for i in range(y_hat.shape[0]): | |
| m = ap.melspectrogram(y_hat[i, 0])[:, :-1] | |
| mel_hat.append(torch.from_numpy(m)) | |
| mel_hat = torch.stack(mel_hat) | |
| mse = torch.sum((mel - mel_hat) ** 2).mean() | |
| if mse.item() < best_error: | |
| best_error = mse.item() | |
| best_schedule = {"beta": beta} | |
| print(f" > Found a better schedule. - MSE: {mse.item()}") | |
| np.save(args.output_path, best_schedule) | |