|
|
import torch
|
|
|
|
|
|
import audiosr.hifigan as hifigan
|
|
|
|
|
|
|
|
|
def get_vocoder_config():
|
|
|
return {
|
|
|
"resblock": "1",
|
|
|
"num_gpus": 6,
|
|
|
"batch_size": 16,
|
|
|
"learning_rate": 0.0002,
|
|
|
"adam_b1": 0.8,
|
|
|
"adam_b2": 0.99,
|
|
|
"lr_decay": 0.999,
|
|
|
"seed": 1234,
|
|
|
"upsample_rates": [5, 4, 2, 2, 2],
|
|
|
"upsample_kernel_sizes": [16, 16, 8, 4, 4],
|
|
|
"upsample_initial_channel": 1024,
|
|
|
"resblock_kernel_sizes": [3, 7, 11],
|
|
|
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
|
|
"segment_size": 8192,
|
|
|
"num_mels": 64,
|
|
|
"num_freq": 1025,
|
|
|
"n_fft": 1024,
|
|
|
"hop_size": 160,
|
|
|
"win_size": 1024,
|
|
|
"sampling_rate": 16000,
|
|
|
"fmin": 0,
|
|
|
"fmax": 8000,
|
|
|
"fmax_for_loss": None,
|
|
|
"num_workers": 4,
|
|
|
"dist_config": {
|
|
|
"dist_backend": "nccl",
|
|
|
"dist_url": "tcp://localhost:54321",
|
|
|
"world_size": 1,
|
|
|
},
|
|
|
}
|
|
|
|
|
|
|
|
|
def get_vocoder_config_48k():
|
|
|
return {
|
|
|
"resblock": "1",
|
|
|
"num_gpus": 8,
|
|
|
"batch_size": 128,
|
|
|
"learning_rate": 0.0001,
|
|
|
"adam_b1": 0.8,
|
|
|
"adam_b2": 0.99,
|
|
|
"lr_decay": 0.999,
|
|
|
"seed": 1234,
|
|
|
"upsample_rates": [6, 5, 4, 2, 2],
|
|
|
"upsample_kernel_sizes": [12, 10, 8, 4, 4],
|
|
|
"upsample_initial_channel": 1536,
|
|
|
"resblock_kernel_sizes": [3, 7, 11, 15],
|
|
|
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
|
|
"segment_size": 15360,
|
|
|
"num_mels": 256,
|
|
|
"n_fft": 2048,
|
|
|
"hop_size": 480,
|
|
|
"win_size": 2048,
|
|
|
"sampling_rate": 48000,
|
|
|
"fmin": 20,
|
|
|
"fmax": 24000,
|
|
|
"fmax_for_loss": None,
|
|
|
"num_workers": 8,
|
|
|
"dist_config": {
|
|
|
"dist_backend": "nccl",
|
|
|
"dist_url": "tcp://localhost:18273",
|
|
|
"world_size": 1,
|
|
|
},
|
|
|
}
|
|
|
|
|
|
|
|
|
def get_available_checkpoint_keys(model, ckpt):
|
|
|
state_dict = torch.load(ckpt)["state_dict"]
|
|
|
current_state_dict = model.state_dict()
|
|
|
new_state_dict = {}
|
|
|
for k in state_dict.keys():
|
|
|
if (
|
|
|
k in current_state_dict.keys()
|
|
|
and current_state_dict[k].size() == state_dict[k].size()
|
|
|
):
|
|
|
new_state_dict[k] = state_dict[k]
|
|
|
else:
|
|
|
print("==> WARNING: Skipping %s" % k)
|
|
|
print(
|
|
|
"%s out of %s keys are matched"
|
|
|
% (len(new_state_dict.keys()), len(state_dict.keys()))
|
|
|
)
|
|
|
return new_state_dict
|
|
|
|
|
|
|
|
|
def get_param_num(model):
|
|
|
num_param = sum(param.numel() for param in model.parameters())
|
|
|
return num_param
|
|
|
|
|
|
|
|
|
def torch_version_orig_mod_remove(state_dict):
|
|
|
new_state_dict = {}
|
|
|
new_state_dict["generator"] = {}
|
|
|
for key in state_dict["generator"].keys():
|
|
|
if "_orig_mod." in key:
|
|
|
new_state_dict["generator"][key.replace("_orig_mod.", "")] = state_dict[
|
|
|
"generator"
|
|
|
][key]
|
|
|
else:
|
|
|
new_state_dict["generator"][key] = state_dict["generator"][key]
|
|
|
return new_state_dict
|
|
|
|
|
|
|
|
|
def get_vocoder(config, device, mel_bins):
|
|
|
name = "HiFi-GAN"
|
|
|
speaker = ""
|
|
|
if name == "MelGAN":
|
|
|
if speaker == "LJSpeech":
|
|
|
vocoder = torch.hub.load(
|
|
|
"descriptinc/melgan-neurips", "load_melgan", "linda_johnson"
|
|
|
)
|
|
|
elif speaker == "universal":
|
|
|
vocoder = torch.hub.load(
|
|
|
"descriptinc/melgan-neurips", "load_melgan", "multi_speaker"
|
|
|
)
|
|
|
vocoder.mel2wav.eval()
|
|
|
vocoder.mel2wav.to(device)
|
|
|
elif name == "HiFi-GAN":
|
|
|
if mel_bins == 64:
|
|
|
config = get_vocoder_config()
|
|
|
config = hifigan.AttrDict(config)
|
|
|
vocoder = hifigan.Generator_old(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vocoder.eval()
|
|
|
vocoder.remove_weight_norm()
|
|
|
vocoder.to(device)
|
|
|
else:
|
|
|
config = get_vocoder_config_48k()
|
|
|
config = hifigan.AttrDict(config)
|
|
|
vocoder = hifigan.Generator_old(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vocoder.eval()
|
|
|
vocoder.remove_weight_norm()
|
|
|
vocoder.to(device)
|
|
|
return vocoder
|
|
|
|
|
|
|
|
|
def vocoder_infer(mels, vocoder, lengths=None):
|
|
|
with torch.no_grad():
|
|
|
wavs = vocoder(mels).squeeze(1)
|
|
|
|
|
|
wavs = (wavs.cpu().numpy() * 32768).astype("int16")
|
|
|
|
|
|
if lengths is not None:
|
|
|
wavs = wavs[:, :lengths]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return wavs
|
|
|
|