Commit
·
04be12f
1
Parent(s):
ff769a6
disable deepspeed and cuda kernel
Browse files- indextts/infer_v2.py +2 -3
- indextts/s2mel/modules/.ipynb_checkpoints/audio-checkpoint.py +82 -0
- indextts/s2mel/modules/.ipynb_checkpoints/commons-checkpoint.py +610 -0
- indextts/s2mel/modules/.ipynb_checkpoints/diffusion_transformer-checkpoint.py +258 -0
- indextts/s2mel/modules/.ipynb_checkpoints/flow_matching-checkpoint.py +171 -0
- indextts/s2mel/modules/.ipynb_checkpoints/length_regulator-checkpoint.py +141 -0
- webui.py +3 -1
indextts/infer_v2.py
CHANGED
|
@@ -35,7 +35,7 @@ import torch.nn.functional as F
|
|
| 35 |
class IndexTTS2:
|
| 36 |
def __init__(
|
| 37 |
self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=False, device=None,
|
| 38 |
-
use_cuda_kernel=None,
|
| 39 |
):
|
| 40 |
"""
|
| 41 |
Args:
|
|
@@ -83,14 +83,13 @@ class IndexTTS2:
|
|
| 83 |
try:
|
| 84 |
import deepspeed
|
| 85 |
|
| 86 |
-
use_deepspeed = True
|
| 87 |
except (ImportError, OSError, CalledProcessError) as e:
|
| 88 |
use_deepspeed = False
|
| 89 |
print(f">> DeepSpeed加载失败,回退到标准推理: {e}")
|
| 90 |
|
| 91 |
self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=True)
|
| 92 |
else:
|
| 93 |
-
self.gpt.post_init_gpt2_config(use_deepspeed=
|
| 94 |
|
| 95 |
if self.use_cuda_kernel:
|
| 96 |
# preload the CUDA kernel for BigVGAN
|
|
|
|
| 35 |
class IndexTTS2:
|
| 36 |
def __init__(
|
| 37 |
self, cfg_path="checkpoints/config.yaml", model_dir="checkpoints", is_fp16=False, device=None,
|
| 38 |
+
use_cuda_kernel=None,use_deepspeed=False
|
| 39 |
):
|
| 40 |
"""
|
| 41 |
Args:
|
|
|
|
| 83 |
try:
|
| 84 |
import deepspeed
|
| 85 |
|
|
|
|
| 86 |
except (ImportError, OSError, CalledProcessError) as e:
|
| 87 |
use_deepspeed = False
|
| 88 |
print(f">> DeepSpeed加载失败,回退到标准推理: {e}")
|
| 89 |
|
| 90 |
self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=True)
|
| 91 |
else:
|
| 92 |
+
self.gpt.post_init_gpt2_config(use_deepspeed=use_deepspeed, kv_cache=True, half=False)
|
| 93 |
|
| 94 |
if self.use_cuda_kernel:
|
| 95 |
# preload the CUDA kernel for BigVGAN
|
indextts/s2mel/modules/.ipynb_checkpoints/audio-checkpoint.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.utils.data
|
| 4 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 5 |
+
from scipy.io.wavfile import read
|
| 6 |
+
|
| 7 |
+
MAX_WAV_VALUE = 32768.0
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def load_wav(full_path):
|
| 11 |
+
sampling_rate, data = read(full_path)
|
| 12 |
+
return data, sampling_rate
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
| 16 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def dynamic_range_decompression(x, C=1):
|
| 20 |
+
return np.exp(x) / C
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
| 24 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def dynamic_range_decompression_torch(x, C=1):
|
| 28 |
+
return torch.exp(x) / C
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def spectral_normalize_torch(magnitudes):
|
| 32 |
+
output = dynamic_range_compression_torch(magnitudes)
|
| 33 |
+
return output
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def spectral_de_normalize_torch(magnitudes):
|
| 37 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
| 38 |
+
return output
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
mel_basis = {}
|
| 42 |
+
hann_window = {}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
| 46 |
+
# if torch.min(y) < -1.0:
|
| 47 |
+
# print("min value is ", torch.min(y))
|
| 48 |
+
# if torch.max(y) > 1.0:
|
| 49 |
+
# print("max value is ", torch.max(y))
|
| 50 |
+
|
| 51 |
+
global mel_basis, hann_window # pylint: disable=global-statement
|
| 52 |
+
if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis:
|
| 53 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
| 54 |
+
mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
| 55 |
+
hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device)
|
| 56 |
+
|
| 57 |
+
y = torch.nn.functional.pad(
|
| 58 |
+
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
| 59 |
+
)
|
| 60 |
+
y = y.squeeze(1)
|
| 61 |
+
|
| 62 |
+
spec = torch.view_as_real(
|
| 63 |
+
torch.stft(
|
| 64 |
+
y,
|
| 65 |
+
n_fft,
|
| 66 |
+
hop_length=hop_size,
|
| 67 |
+
win_length=win_size,
|
| 68 |
+
window=hann_window[str(sampling_rate) + "_" + str(y.device)],
|
| 69 |
+
center=center,
|
| 70 |
+
pad_mode="reflect",
|
| 71 |
+
normalized=False,
|
| 72 |
+
onesided=True,
|
| 73 |
+
return_complex=True,
|
| 74 |
+
)
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
| 78 |
+
|
| 79 |
+
spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec)
|
| 80 |
+
spec = spectral_normalize_torch(spec)
|
| 81 |
+
|
| 82 |
+
return spec
|
indextts/s2mel/modules/.ipynb_checkpoints/commons-checkpoint.py
ADDED
|
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from munch import Munch
|
| 7 |
+
import json
|
| 8 |
+
import argparse
|
| 9 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 10 |
+
|
| 11 |
+
def str2bool(v):
|
| 12 |
+
if isinstance(v, bool):
|
| 13 |
+
return v
|
| 14 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 15 |
+
return True
|
| 16 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 17 |
+
return False
|
| 18 |
+
else:
|
| 19 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 20 |
+
|
| 21 |
+
class AttrDict(dict):
|
| 22 |
+
def __init__(self, *args, **kwargs):
|
| 23 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
| 24 |
+
self.__dict__ = self
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def init_weights(m, mean=0.0, std=0.01):
|
| 28 |
+
classname = m.__class__.__name__
|
| 29 |
+
if classname.find("Conv") != -1:
|
| 30 |
+
m.weight.data.normal_(mean, std)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_padding(kernel_size, dilation=1):
|
| 34 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def convert_pad_shape(pad_shape):
|
| 38 |
+
l = pad_shape[::-1]
|
| 39 |
+
pad_shape = [item for sublist in l for item in sublist]
|
| 40 |
+
return pad_shape
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def intersperse(lst, item):
|
| 44 |
+
result = [item] * (len(lst) * 2 + 1)
|
| 45 |
+
result[1::2] = lst
|
| 46 |
+
return result
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
| 50 |
+
"""KL(P||Q)"""
|
| 51 |
+
kl = (logs_q - logs_p) - 0.5
|
| 52 |
+
kl += (
|
| 53 |
+
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
| 54 |
+
)
|
| 55 |
+
return kl
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def rand_gumbel(shape):
|
| 59 |
+
"""Sample from the Gumbel distribution, protect from overflows."""
|
| 60 |
+
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
| 61 |
+
return -torch.log(-torch.log(uniform_samples))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def rand_gumbel_like(x):
|
| 65 |
+
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
| 66 |
+
return g
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def slice_segments(x, ids_str, segment_size=4):
|
| 70 |
+
ret = torch.zeros_like(x[:, :, :segment_size])
|
| 71 |
+
for i in range(x.size(0)):
|
| 72 |
+
idx_str = ids_str[i]
|
| 73 |
+
idx_end = idx_str + segment_size
|
| 74 |
+
ret[i] = x[i, :, idx_str:idx_end]
|
| 75 |
+
return ret
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def slice_segments_audio(x, ids_str, segment_size=4):
|
| 79 |
+
ret = torch.zeros_like(x[:, :segment_size])
|
| 80 |
+
for i in range(x.size(0)):
|
| 81 |
+
idx_str = ids_str[i]
|
| 82 |
+
idx_end = idx_str + segment_size
|
| 83 |
+
ret[i] = x[i, idx_str:idx_end]
|
| 84 |
+
return ret
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
| 88 |
+
b, d, t = x.size()
|
| 89 |
+
if x_lengths is None:
|
| 90 |
+
x_lengths = t
|
| 91 |
+
ids_str_max = x_lengths - segment_size + 1
|
| 92 |
+
ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
|
| 93 |
+
dtype=torch.long
|
| 94 |
+
)
|
| 95 |
+
ret = slice_segments(x, ids_str, segment_size)
|
| 96 |
+
return ret, ids_str
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
| 100 |
+
position = torch.arange(length, dtype=torch.float)
|
| 101 |
+
num_timescales = channels // 2
|
| 102 |
+
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
|
| 103 |
+
num_timescales - 1
|
| 104 |
+
)
|
| 105 |
+
inv_timescales = min_timescale * torch.exp(
|
| 106 |
+
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
|
| 107 |
+
)
|
| 108 |
+
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
| 109 |
+
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
| 110 |
+
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
| 111 |
+
signal = signal.view(1, channels, length)
|
| 112 |
+
return signal
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
|
| 116 |
+
b, channels, length = x.size()
|
| 117 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
| 118 |
+
return x + signal.to(dtype=x.dtype, device=x.device)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
|
| 122 |
+
b, channels, length = x.size()
|
| 123 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
| 124 |
+
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def subsequent_mask(length):
|
| 128 |
+
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
| 129 |
+
return mask
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@torch.jit.script
|
| 133 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
| 134 |
+
n_channels_int = n_channels[0]
|
| 135 |
+
in_act = input_a + input_b
|
| 136 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
| 137 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
| 138 |
+
acts = t_act * s_act
|
| 139 |
+
return acts
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def convert_pad_shape(pad_shape):
|
| 143 |
+
l = pad_shape[::-1]
|
| 144 |
+
pad_shape = [item for sublist in l for item in sublist]
|
| 145 |
+
return pad_shape
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def shift_1d(x):
|
| 149 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
| 150 |
+
return x
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def sequence_mask(length, max_length=None):
|
| 154 |
+
if max_length is None:
|
| 155 |
+
max_length = length.max()
|
| 156 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
| 157 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def avg_with_mask(x, mask):
|
| 161 |
+
assert mask.dtype == torch.float, "Mask should be float"
|
| 162 |
+
|
| 163 |
+
if mask.ndim == 2:
|
| 164 |
+
mask = mask.unsqueeze(1)
|
| 165 |
+
|
| 166 |
+
if mask.shape[1] == 1:
|
| 167 |
+
mask = mask.expand_as(x)
|
| 168 |
+
|
| 169 |
+
return (x * mask).sum() / mask.sum()
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def generate_path(duration, mask):
|
| 173 |
+
"""
|
| 174 |
+
duration: [b, 1, t_x]
|
| 175 |
+
mask: [b, 1, t_y, t_x]
|
| 176 |
+
"""
|
| 177 |
+
device = duration.device
|
| 178 |
+
|
| 179 |
+
b, _, t_y, t_x = mask.shape
|
| 180 |
+
cum_duration = torch.cumsum(duration, -1)
|
| 181 |
+
|
| 182 |
+
cum_duration_flat = cum_duration.view(b * t_x)
|
| 183 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
| 184 |
+
path = path.view(b, t_x, t_y)
|
| 185 |
+
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
| 186 |
+
path = path.unsqueeze(1).transpose(2, 3) * mask
|
| 187 |
+
return path
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
| 191 |
+
if isinstance(parameters, torch.Tensor):
|
| 192 |
+
parameters = [parameters]
|
| 193 |
+
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
| 194 |
+
norm_type = float(norm_type)
|
| 195 |
+
if clip_value is not None:
|
| 196 |
+
clip_value = float(clip_value)
|
| 197 |
+
|
| 198 |
+
total_norm = 0
|
| 199 |
+
for p in parameters:
|
| 200 |
+
param_norm = p.grad.data.norm(norm_type)
|
| 201 |
+
total_norm += param_norm.item() ** norm_type
|
| 202 |
+
if clip_value is not None:
|
| 203 |
+
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
| 204 |
+
total_norm = total_norm ** (1.0 / norm_type)
|
| 205 |
+
return total_norm
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def log_norm(x, mean=-4, std=4, dim=2):
|
| 209 |
+
"""
|
| 210 |
+
normalized log mel -> mel -> norm -> log(norm)
|
| 211 |
+
"""
|
| 212 |
+
x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
|
| 213 |
+
return x
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def load_F0_models(path):
|
| 217 |
+
# load F0 model
|
| 218 |
+
from .JDC.model import JDCNet
|
| 219 |
+
|
| 220 |
+
F0_model = JDCNet(num_class=1, seq_len=192)
|
| 221 |
+
params = torch.load(path, map_location="cpu")["net"]
|
| 222 |
+
F0_model.load_state_dict(params)
|
| 223 |
+
_ = F0_model.train()
|
| 224 |
+
|
| 225 |
+
return F0_model
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def modify_w2v_forward(self, output_layer=15):
|
| 229 |
+
"""
|
| 230 |
+
change forward method of w2v encoder to get its intermediate layer output
|
| 231 |
+
:param self:
|
| 232 |
+
:param layer:
|
| 233 |
+
:return:
|
| 234 |
+
"""
|
| 235 |
+
from transformers.modeling_outputs import BaseModelOutput
|
| 236 |
+
|
| 237 |
+
def forward(
|
| 238 |
+
hidden_states,
|
| 239 |
+
attention_mask=None,
|
| 240 |
+
output_attentions=False,
|
| 241 |
+
output_hidden_states=False,
|
| 242 |
+
return_dict=True,
|
| 243 |
+
):
|
| 244 |
+
all_hidden_states = () if output_hidden_states else None
|
| 245 |
+
all_self_attentions = () if output_attentions else None
|
| 246 |
+
|
| 247 |
+
conv_attention_mask = attention_mask
|
| 248 |
+
if attention_mask is not None:
|
| 249 |
+
# make sure padded tokens output 0
|
| 250 |
+
hidden_states = hidden_states.masked_fill(
|
| 251 |
+
~attention_mask.bool().unsqueeze(-1), 0.0
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# extend attention_mask
|
| 255 |
+
attention_mask = 1.0 - attention_mask[:, None, None, :].to(
|
| 256 |
+
dtype=hidden_states.dtype
|
| 257 |
+
)
|
| 258 |
+
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
|
| 259 |
+
attention_mask = attention_mask.expand(
|
| 260 |
+
attention_mask.shape[0],
|
| 261 |
+
1,
|
| 262 |
+
attention_mask.shape[-1],
|
| 263 |
+
attention_mask.shape[-1],
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
hidden_states = self.dropout(hidden_states)
|
| 267 |
+
|
| 268 |
+
if self.embed_positions is not None:
|
| 269 |
+
relative_position_embeddings = self.embed_positions(hidden_states)
|
| 270 |
+
else:
|
| 271 |
+
relative_position_embeddings = None
|
| 272 |
+
|
| 273 |
+
deepspeed_zero3_is_enabled = False
|
| 274 |
+
|
| 275 |
+
for i, layer in enumerate(self.layers):
|
| 276 |
+
if output_hidden_states:
|
| 277 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 278 |
+
|
| 279 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
| 280 |
+
dropout_probability = torch.rand([])
|
| 281 |
+
|
| 282 |
+
skip_the_layer = (
|
| 283 |
+
True
|
| 284 |
+
if self.training and (dropout_probability < self.config.layerdrop)
|
| 285 |
+
else False
|
| 286 |
+
)
|
| 287 |
+
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
| 288 |
+
# under deepspeed zero3 all gpus must run in sync
|
| 289 |
+
if self.gradient_checkpointing and self.training:
|
| 290 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 291 |
+
layer.__call__,
|
| 292 |
+
hidden_states,
|
| 293 |
+
attention_mask,
|
| 294 |
+
relative_position_embeddings,
|
| 295 |
+
output_attentions,
|
| 296 |
+
conv_attention_mask,
|
| 297 |
+
)
|
| 298 |
+
else:
|
| 299 |
+
layer_outputs = layer(
|
| 300 |
+
hidden_states,
|
| 301 |
+
attention_mask=attention_mask,
|
| 302 |
+
relative_position_embeddings=relative_position_embeddings,
|
| 303 |
+
output_attentions=output_attentions,
|
| 304 |
+
conv_attention_mask=conv_attention_mask,
|
| 305 |
+
)
|
| 306 |
+
hidden_states = layer_outputs[0]
|
| 307 |
+
|
| 308 |
+
if skip_the_layer:
|
| 309 |
+
layer_outputs = (None, None)
|
| 310 |
+
|
| 311 |
+
if output_attentions:
|
| 312 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 313 |
+
|
| 314 |
+
if i == output_layer - 1:
|
| 315 |
+
break
|
| 316 |
+
|
| 317 |
+
if output_hidden_states:
|
| 318 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 319 |
+
|
| 320 |
+
if not return_dict:
|
| 321 |
+
return tuple(
|
| 322 |
+
v
|
| 323 |
+
for v in [hidden_states, all_hidden_states, all_self_attentions]
|
| 324 |
+
if v is not None
|
| 325 |
+
)
|
| 326 |
+
return BaseModelOutput(
|
| 327 |
+
last_hidden_state=hidden_states,
|
| 328 |
+
hidden_states=all_hidden_states,
|
| 329 |
+
attentions=all_self_attentions,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
return forward
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
MATPLOTLIB_FLAG = False
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def plot_spectrogram_to_numpy(spectrogram):
|
| 339 |
+
global MATPLOTLIB_FLAG
|
| 340 |
+
if not MATPLOTLIB_FLAG:
|
| 341 |
+
import matplotlib
|
| 342 |
+
import logging
|
| 343 |
+
|
| 344 |
+
matplotlib.use("Agg")
|
| 345 |
+
MATPLOTLIB_FLAG = True
|
| 346 |
+
mpl_logger = logging.getLogger("matplotlib")
|
| 347 |
+
mpl_logger.setLevel(logging.WARNING)
|
| 348 |
+
import matplotlib.pylab as plt
|
| 349 |
+
import numpy as np
|
| 350 |
+
|
| 351 |
+
fig, ax = plt.subplots(figsize=(10, 2))
|
| 352 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
| 353 |
+
plt.colorbar(im, ax=ax)
|
| 354 |
+
plt.xlabel("Frames")
|
| 355 |
+
plt.ylabel("Channels")
|
| 356 |
+
plt.tight_layout()
|
| 357 |
+
|
| 358 |
+
fig.canvas.draw()
|
| 359 |
+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
| 360 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
| 361 |
+
plt.close()
|
| 362 |
+
return data
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def normalize_f0(f0_sequence):
|
| 366 |
+
# Remove unvoiced frames (replace with -1)
|
| 367 |
+
voiced_indices = np.where(f0_sequence > 0)[0]
|
| 368 |
+
f0_voiced = f0_sequence[voiced_indices]
|
| 369 |
+
|
| 370 |
+
# Convert to log scale
|
| 371 |
+
log_f0 = np.log2(f0_voiced)
|
| 372 |
+
|
| 373 |
+
# Calculate mean and standard deviation
|
| 374 |
+
mean_f0 = np.mean(log_f0)
|
| 375 |
+
std_f0 = np.std(log_f0)
|
| 376 |
+
|
| 377 |
+
# Normalize the F0 sequence
|
| 378 |
+
normalized_f0 = (log_f0 - mean_f0) / std_f0
|
| 379 |
+
|
| 380 |
+
# Create the normalized F0 sequence with unvoiced frames
|
| 381 |
+
normalized_sequence = np.zeros_like(f0_sequence)
|
| 382 |
+
normalized_sequence[voiced_indices] = normalized_f0
|
| 383 |
+
normalized_sequence[f0_sequence <= 0] = -1 # Assign -1 to unvoiced frames
|
| 384 |
+
|
| 385 |
+
return normalized_sequence
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
class MyModel(nn.Module):
|
| 389 |
+
def __init__(self,args):
|
| 390 |
+
super(MyModel, self).__init__()
|
| 391 |
+
from modules.flow_matching import CFM
|
| 392 |
+
from modules.length_regulator import InterpolateRegulator
|
| 393 |
+
|
| 394 |
+
length_regulator = InterpolateRegulator(
|
| 395 |
+
channels=args.length_regulator.channels,
|
| 396 |
+
sampling_ratios=args.length_regulator.sampling_ratios,
|
| 397 |
+
is_discrete=args.length_regulator.is_discrete,
|
| 398 |
+
in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None,
|
| 399 |
+
vector_quantize=args.length_regulator.vector_quantize if hasattr(args.length_regulator, "vector_quantize") else False,
|
| 400 |
+
codebook_size=args.length_regulator.content_codebook_size,
|
| 401 |
+
n_codebooks=args.length_regulator.n_codebooks if hasattr(args.length_regulator, "n_codebooks") else 1,
|
| 402 |
+
quantizer_dropout=args.length_regulator.quantizer_dropout if hasattr(args.length_regulator, "quantizer_dropout") else 0.0,
|
| 403 |
+
f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False,
|
| 404 |
+
n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
self.models = nn.ModuleDict({
|
| 408 |
+
'cfm': CFM(args),
|
| 409 |
+
'length_regulator': length_regulator
|
| 410 |
+
})
|
| 411 |
+
|
| 412 |
+
def forward(self, x, target_lengths, prompt_len, cond, y):
|
| 413 |
+
x = self.models['cfm'](x, target_lengths, prompt_len, cond, y)
|
| 414 |
+
return x
|
| 415 |
+
|
| 416 |
+
def forward2(self, S_ori,target_lengths,F0_ori):
|
| 417 |
+
x = self.models['length_regulator'](S_ori, ylens=target_lengths, f0=F0_ori)
|
| 418 |
+
return x
|
| 419 |
+
|
| 420 |
+
def build_model(args, stage="DiT"):
|
| 421 |
+
if stage == "DiT":
|
| 422 |
+
from modules.flow_matching import CFM
|
| 423 |
+
from modules.length_regulator import InterpolateRegulator
|
| 424 |
+
|
| 425 |
+
length_regulator = InterpolateRegulator(
|
| 426 |
+
channels=args.length_regulator.channels,
|
| 427 |
+
sampling_ratios=args.length_regulator.sampling_ratios,
|
| 428 |
+
is_discrete=args.length_regulator.is_discrete,
|
| 429 |
+
in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None,
|
| 430 |
+
vector_quantize=args.length_regulator.vector_quantize if hasattr(args.length_regulator, "vector_quantize") else False,
|
| 431 |
+
codebook_size=args.length_regulator.content_codebook_size,
|
| 432 |
+
n_codebooks=args.length_regulator.n_codebooks if hasattr(args.length_regulator, "n_codebooks") else 1,
|
| 433 |
+
quantizer_dropout=args.length_regulator.quantizer_dropout if hasattr(args.length_regulator, "quantizer_dropout") else 0.0,
|
| 434 |
+
f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False,
|
| 435 |
+
n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512,
|
| 436 |
+
)
|
| 437 |
+
cfm = CFM(args)
|
| 438 |
+
nets = Munch(
|
| 439 |
+
cfm=cfm,
|
| 440 |
+
length_regulator=length_regulator,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
elif stage == 'codec':
|
| 444 |
+
from dac.model.dac import Encoder
|
| 445 |
+
from modules.quantize import (
|
| 446 |
+
FAquantizer,
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
encoder = Encoder(
|
| 450 |
+
d_model=args.DAC.encoder_dim,
|
| 451 |
+
strides=args.DAC.encoder_rates,
|
| 452 |
+
d_latent=1024,
|
| 453 |
+
causal=args.causal,
|
| 454 |
+
lstm=args.lstm,
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
quantizer = FAquantizer(
|
| 458 |
+
in_dim=1024,
|
| 459 |
+
n_p_codebooks=1,
|
| 460 |
+
n_c_codebooks=args.n_c_codebooks,
|
| 461 |
+
n_t_codebooks=2,
|
| 462 |
+
n_r_codebooks=3,
|
| 463 |
+
codebook_size=1024,
|
| 464 |
+
codebook_dim=8,
|
| 465 |
+
quantizer_dropout=0.5,
|
| 466 |
+
causal=args.causal,
|
| 467 |
+
separate_prosody_encoder=args.separate_prosody_encoder,
|
| 468 |
+
timbre_norm=args.timbre_norm,
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
nets = Munch(
|
| 472 |
+
encoder=encoder,
|
| 473 |
+
quantizer=quantizer,
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
elif stage == "mel_vocos":
|
| 477 |
+
from modules.vocos import Vocos
|
| 478 |
+
decoder = Vocos(args)
|
| 479 |
+
nets = Munch(
|
| 480 |
+
decoder=decoder,
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
else:
|
| 484 |
+
raise ValueError(f"Unknown stage: {stage}")
|
| 485 |
+
|
| 486 |
+
return nets
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def load_checkpoint(
|
| 490 |
+
model,
|
| 491 |
+
optimizer,
|
| 492 |
+
path,
|
| 493 |
+
load_only_params=True,
|
| 494 |
+
ignore_modules=[],
|
| 495 |
+
is_distributed=False,
|
| 496 |
+
load_ema=False,
|
| 497 |
+
):
|
| 498 |
+
state = torch.load(path, map_location="cpu")
|
| 499 |
+
params = state["net"]
|
| 500 |
+
if load_ema and "ema" in state:
|
| 501 |
+
print("Loading EMA")
|
| 502 |
+
for key in model:
|
| 503 |
+
i = 0
|
| 504 |
+
for param_name in params[key]:
|
| 505 |
+
if "input_pos" in param_name:
|
| 506 |
+
continue
|
| 507 |
+
assert params[key][param_name].shape == state["ema"][key][0][i].shape
|
| 508 |
+
params[key][param_name] = state["ema"][key][0][i].clone()
|
| 509 |
+
i += 1
|
| 510 |
+
for key in model:
|
| 511 |
+
if key in params and key not in ignore_modules:
|
| 512 |
+
if not is_distributed:
|
| 513 |
+
# strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
|
| 514 |
+
for k in list(params[key].keys()):
|
| 515 |
+
if k.startswith("module."):
|
| 516 |
+
params[key][k[len("module.") :]] = params[key][k]
|
| 517 |
+
del params[key][k]
|
| 518 |
+
model_state_dict = model[key].state_dict()
|
| 519 |
+
# 过滤出形状匹配的键值对
|
| 520 |
+
filtered_state_dict = {
|
| 521 |
+
k: v
|
| 522 |
+
for k, v in params[key].items()
|
| 523 |
+
if k in model_state_dict and v.shape == model_state_dict[k].shape
|
| 524 |
+
}
|
| 525 |
+
skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys())
|
| 526 |
+
if skipped_keys:
|
| 527 |
+
print(
|
| 528 |
+
f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
|
| 529 |
+
)
|
| 530 |
+
print("%s loaded" % key)
|
| 531 |
+
model[key].load_state_dict(filtered_state_dict, strict=False)
|
| 532 |
+
_ = [model[key].eval() for key in model]
|
| 533 |
+
|
| 534 |
+
if not load_only_params:
|
| 535 |
+
epoch = state["epoch"] + 1
|
| 536 |
+
iters = state["iters"]
|
| 537 |
+
optimizer.load_state_dict(state["optimizer"])
|
| 538 |
+
optimizer.load_scheduler_state_dict(state["scheduler"])
|
| 539 |
+
|
| 540 |
+
else:
|
| 541 |
+
epoch = 0
|
| 542 |
+
iters = 0
|
| 543 |
+
|
| 544 |
+
return model, optimizer, epoch, iters
|
| 545 |
+
|
| 546 |
+
def load_checkpoint2(
|
| 547 |
+
model,
|
| 548 |
+
optimizer,
|
| 549 |
+
path,
|
| 550 |
+
load_only_params=True,
|
| 551 |
+
ignore_modules=[],
|
| 552 |
+
is_distributed=False,
|
| 553 |
+
load_ema=False,
|
| 554 |
+
):
|
| 555 |
+
state = torch.load(path, map_location="cpu")
|
| 556 |
+
params = state["net"]
|
| 557 |
+
if load_ema and "ema" in state:
|
| 558 |
+
print("Loading EMA")
|
| 559 |
+
for key in model.models:
|
| 560 |
+
i = 0
|
| 561 |
+
for param_name in params[key]:
|
| 562 |
+
if "input_pos" in param_name:
|
| 563 |
+
continue
|
| 564 |
+
assert params[key][param_name].shape == state["ema"][key][0][i].shape
|
| 565 |
+
params[key][param_name] = state["ema"][key][0][i].clone()
|
| 566 |
+
i += 1
|
| 567 |
+
for key in model.models:
|
| 568 |
+
if key in params and key not in ignore_modules:
|
| 569 |
+
if not is_distributed:
|
| 570 |
+
# strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
|
| 571 |
+
for k in list(params[key].keys()):
|
| 572 |
+
if k.startswith("module."):
|
| 573 |
+
params[key][k[len("module.") :]] = params[key][k]
|
| 574 |
+
del params[key][k]
|
| 575 |
+
model_state_dict = model.models[key].state_dict()
|
| 576 |
+
# 过滤出形状匹配的键值对
|
| 577 |
+
filtered_state_dict = {
|
| 578 |
+
k: v
|
| 579 |
+
for k, v in params[key].items()
|
| 580 |
+
if k in model_state_dict and v.shape == model_state_dict[k].shape
|
| 581 |
+
}
|
| 582 |
+
skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys())
|
| 583 |
+
if skipped_keys:
|
| 584 |
+
print(
|
| 585 |
+
f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
|
| 586 |
+
)
|
| 587 |
+
print("%s loaded" % key)
|
| 588 |
+
model.models[key].load_state_dict(filtered_state_dict, strict=False)
|
| 589 |
+
model.eval()
|
| 590 |
+
# _ = [model[key].eval() for key in model]
|
| 591 |
+
|
| 592 |
+
if not load_only_params:
|
| 593 |
+
epoch = state["epoch"] + 1
|
| 594 |
+
iters = state["iters"]
|
| 595 |
+
optimizer.load_state_dict(state["optimizer"])
|
| 596 |
+
optimizer.load_scheduler_state_dict(state["scheduler"])
|
| 597 |
+
|
| 598 |
+
else:
|
| 599 |
+
epoch = 0
|
| 600 |
+
iters = 0
|
| 601 |
+
|
| 602 |
+
return model, optimizer, epoch, iters
|
| 603 |
+
|
| 604 |
+
def recursive_munch(d):
|
| 605 |
+
if isinstance(d, dict):
|
| 606 |
+
return Munch((k, recursive_munch(v)) for k, v in d.items())
|
| 607 |
+
elif isinstance(d, list):
|
| 608 |
+
return [recursive_munch(v) for v in d]
|
| 609 |
+
else:
|
| 610 |
+
return d
|
indextts/s2mel/modules/.ipynb_checkpoints/diffusion_transformer-checkpoint.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
from modules.gpt_fast.model import ModelArgs, Transformer
|
| 6 |
+
# from modules.torchscript_modules.gpt_fast_model import ModelArgs, Transformer
|
| 7 |
+
from modules.wavenet import WN
|
| 8 |
+
from modules.commons import sequence_mask
|
| 9 |
+
|
| 10 |
+
from torch.nn.utils import weight_norm
|
| 11 |
+
|
| 12 |
+
def modulate(x, shift, scale):
|
| 13 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
#################################################################################
|
| 17 |
+
# Embedding Layers for Timesteps and Class Labels #
|
| 18 |
+
#################################################################################
|
| 19 |
+
|
| 20 |
+
class TimestepEmbedder(nn.Module):
|
| 21 |
+
"""
|
| 22 |
+
Embeds scalar timesteps into vector representations.
|
| 23 |
+
"""
|
| 24 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.mlp = nn.Sequential(
|
| 27 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 28 |
+
nn.SiLU(),
|
| 29 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 30 |
+
)
|
| 31 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 32 |
+
self.max_period = 10000
|
| 33 |
+
self.scale = 1000
|
| 34 |
+
|
| 35 |
+
half = frequency_embedding_size // 2
|
| 36 |
+
freqs = torch.exp(
|
| 37 |
+
-math.log(self.max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 38 |
+
)
|
| 39 |
+
self.register_buffer("freqs", freqs)
|
| 40 |
+
|
| 41 |
+
def timestep_embedding(self, t):
|
| 42 |
+
"""
|
| 43 |
+
Create sinusoidal timestep embeddings.
|
| 44 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 45 |
+
These may be fractional.
|
| 46 |
+
:param dim: the dimension of the output.
|
| 47 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 48 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 49 |
+
"""
|
| 50 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 51 |
+
|
| 52 |
+
args = self.scale * t[:, None].float() * self.freqs[None]
|
| 53 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 54 |
+
if self.frequency_embedding_size % 2:
|
| 55 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 56 |
+
return embedding
|
| 57 |
+
|
| 58 |
+
def forward(self, t):
|
| 59 |
+
t_freq = self.timestep_embedding(t)
|
| 60 |
+
t_emb = self.mlp(t_freq)
|
| 61 |
+
return t_emb
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class StyleEmbedder(nn.Module):
|
| 65 |
+
"""
|
| 66 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
| 67 |
+
"""
|
| 68 |
+
def __init__(self, input_size, hidden_size, dropout_prob):
|
| 69 |
+
super().__init__()
|
| 70 |
+
use_cfg_embedding = dropout_prob > 0
|
| 71 |
+
self.embedding_table = nn.Embedding(int(use_cfg_embedding), hidden_size)
|
| 72 |
+
self.style_in = weight_norm(nn.Linear(input_size, hidden_size, bias=True))
|
| 73 |
+
self.input_size = input_size
|
| 74 |
+
self.dropout_prob = dropout_prob
|
| 75 |
+
|
| 76 |
+
def forward(self, labels, train, force_drop_ids=None):
|
| 77 |
+
use_dropout = self.dropout_prob > 0
|
| 78 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
| 79 |
+
labels = self.token_drop(labels, force_drop_ids)
|
| 80 |
+
else:
|
| 81 |
+
labels = self.style_in(labels)
|
| 82 |
+
embeddings = labels
|
| 83 |
+
return embeddings
|
| 84 |
+
|
| 85 |
+
class FinalLayer(nn.Module):
|
| 86 |
+
"""
|
| 87 |
+
The final layer of DiT.
|
| 88 |
+
"""
|
| 89 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 92 |
+
self.linear = weight_norm(nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True))
|
| 93 |
+
self.adaLN_modulation = nn.Sequential(
|
| 94 |
+
nn.SiLU(),
|
| 95 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def forward(self, x, c):
|
| 99 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
| 100 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 101 |
+
x = self.linear(x)
|
| 102 |
+
return x
|
| 103 |
+
|
| 104 |
+
class DiT(torch.nn.Module):
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
args
|
| 108 |
+
):
|
| 109 |
+
super(DiT, self).__init__()
|
| 110 |
+
self.time_as_token = args.DiT.time_as_token if hasattr(args.DiT, 'time_as_token') else False
|
| 111 |
+
self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False
|
| 112 |
+
self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False
|
| 113 |
+
model_args = ModelArgs(
|
| 114 |
+
block_size=16384,#args.DiT.block_size,
|
| 115 |
+
n_layer=args.DiT.depth,
|
| 116 |
+
n_head=args.DiT.num_heads,
|
| 117 |
+
dim=args.DiT.hidden_dim,
|
| 118 |
+
head_dim=args.DiT.hidden_dim // args.DiT.num_heads,
|
| 119 |
+
vocab_size=1024,
|
| 120 |
+
uvit_skip_connection=self.uvit_skip_connection,
|
| 121 |
+
time_as_token=self.time_as_token,
|
| 122 |
+
)
|
| 123 |
+
self.transformer = Transformer(model_args)
|
| 124 |
+
self.in_channels = args.DiT.in_channels
|
| 125 |
+
self.out_channels = args.DiT.in_channels
|
| 126 |
+
self.num_heads = args.DiT.num_heads
|
| 127 |
+
|
| 128 |
+
self.x_embedder = weight_norm(nn.Linear(args.DiT.in_channels, args.DiT.hidden_dim, bias=True))
|
| 129 |
+
|
| 130 |
+
self.content_type = args.DiT.content_type # 'discrete' or 'continuous'
|
| 131 |
+
self.content_codebook_size = args.DiT.content_codebook_size # for discrete content
|
| 132 |
+
self.content_dim = args.DiT.content_dim # for continuous content
|
| 133 |
+
self.cond_embedder = nn.Embedding(args.DiT.content_codebook_size, args.DiT.hidden_dim) # discrete content
|
| 134 |
+
self.cond_projection = nn.Linear(args.DiT.content_dim, args.DiT.hidden_dim, bias=True) # continuous content
|
| 135 |
+
|
| 136 |
+
self.is_causal = args.DiT.is_causal
|
| 137 |
+
|
| 138 |
+
self.t_embedder = TimestepEmbedder(args.DiT.hidden_dim)
|
| 139 |
+
|
| 140 |
+
# self.style_embedder1 = weight_norm(nn.Linear(1024, args.DiT.hidden_dim, bias=True))
|
| 141 |
+
# self.style_embedder2 = weight_norm(nn.Linear(1024, args.style_encoder.dim, bias=True))
|
| 142 |
+
|
| 143 |
+
input_pos = torch.arange(16384)
|
| 144 |
+
self.register_buffer("input_pos", input_pos)
|
| 145 |
+
|
| 146 |
+
self.final_layer_type = args.DiT.final_layer_type # mlp or wavenet
|
| 147 |
+
if self.final_layer_type == 'wavenet':
|
| 148 |
+
self.t_embedder2 = TimestepEmbedder(args.wavenet.hidden_dim)
|
| 149 |
+
self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)
|
| 150 |
+
self.conv2 = nn.Conv1d(args.wavenet.hidden_dim, args.DiT.in_channels, 1)
|
| 151 |
+
self.wavenet = WN(hidden_channels=args.wavenet.hidden_dim,
|
| 152 |
+
kernel_size=args.wavenet.kernel_size,
|
| 153 |
+
dilation_rate=args.wavenet.dilation_rate,
|
| 154 |
+
n_layers=args.wavenet.num_layers,
|
| 155 |
+
gin_channels=args.wavenet.hidden_dim,
|
| 156 |
+
p_dropout=args.wavenet.p_dropout,
|
| 157 |
+
causal=False)
|
| 158 |
+
self.final_layer = FinalLayer(args.wavenet.hidden_dim, 1, args.wavenet.hidden_dim)
|
| 159 |
+
self.res_projection = nn.Linear(args.DiT.hidden_dim,
|
| 160 |
+
args.wavenet.hidden_dim) # residual connection from tranformer output to final output
|
| 161 |
+
self.wavenet_style_condition = args.wavenet.style_condition
|
| 162 |
+
assert args.DiT.style_condition == args.wavenet.style_condition
|
| 163 |
+
else:
|
| 164 |
+
self.final_mlp = nn.Sequential(
|
| 165 |
+
nn.Linear(args.DiT.hidden_dim, args.DiT.hidden_dim),
|
| 166 |
+
nn.SiLU(),
|
| 167 |
+
nn.Linear(args.DiT.hidden_dim, args.DiT.in_channels),
|
| 168 |
+
)
|
| 169 |
+
self.transformer_style_condition = args.DiT.style_condition
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
self.class_dropout_prob = args.DiT.class_dropout_prob
|
| 173 |
+
self.content_mask_embedder = nn.Embedding(1, args.DiT.hidden_dim)
|
| 174 |
+
|
| 175 |
+
self.long_skip_connection = args.DiT.long_skip_connection
|
| 176 |
+
self.skip_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels, args.DiT.hidden_dim)
|
| 177 |
+
|
| 178 |
+
self.cond_x_merge_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels * 2 +
|
| 179 |
+
args.style_encoder.dim * self.transformer_style_condition * (not self.style_as_token),
|
| 180 |
+
args.DiT.hidden_dim)
|
| 181 |
+
if self.style_as_token:
|
| 182 |
+
self.style_in = nn.Linear(args.style_encoder.dim, args.DiT.hidden_dim)
|
| 183 |
+
|
| 184 |
+
def setup_caches(self, max_batch_size, max_seq_length):
|
| 185 |
+
self.transformer.setup_caches(max_batch_size, max_seq_length, use_kv_cache=False)
|
| 186 |
+
|
| 187 |
+
def forward(self, x, prompt_x, x_lens, t, style, cond, mask_content=False):
|
| 188 |
+
"""
|
| 189 |
+
x (torch.Tensor): random noise
|
| 190 |
+
prompt_x (torch.Tensor): reference mel + zero mel
|
| 191 |
+
shape: (batch_size, 80, 795+1068)
|
| 192 |
+
x_lens (torch.Tensor): mel frames output
|
| 193 |
+
shape: (batch_size, mel_timesteps)
|
| 194 |
+
t (torch.Tensor): radshape:
|
| 195 |
+
shape: (batch_size)
|
| 196 |
+
style (torch.Tensor): reference global style
|
| 197 |
+
shape: (batch_size, 192)
|
| 198 |
+
cond (torch.Tensor): semantic info of reference audio and altered audio
|
| 199 |
+
shape: (batch_size, mel_timesteps(795+1069), 512)
|
| 200 |
+
|
| 201 |
+
"""
|
| 202 |
+
class_dropout = False
|
| 203 |
+
if self.training and torch.rand(1) < self.class_dropout_prob:
|
| 204 |
+
class_dropout = True
|
| 205 |
+
if not self.training and mask_content:
|
| 206 |
+
class_dropout = True
|
| 207 |
+
# cond_in_module = self.cond_embedder if self.content_type == 'discrete' else self.cond_projection
|
| 208 |
+
cond_in_module = self.cond_projection
|
| 209 |
+
|
| 210 |
+
B, _, T = x.size()
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
t1 = self.t_embedder(t) # (N, D) # t1 [2, 512]
|
| 214 |
+
cond = cond_in_module(cond) # cond [2,1863,512]->[2,1863,512]
|
| 215 |
+
|
| 216 |
+
x = x.transpose(1, 2) # [2,1863,80]
|
| 217 |
+
prompt_x = prompt_x.transpose(1, 2) # [2,1863,80]
|
| 218 |
+
|
| 219 |
+
x_in = torch.cat([x, prompt_x, cond], dim=-1) # 80+80+512=672 [2, 1863, 672]
|
| 220 |
+
|
| 221 |
+
if self.transformer_style_condition and not self.style_as_token: # True and True
|
| 222 |
+
x_in = torch.cat([x_in, style[:, None, :].repeat(1, T, 1)], dim=-1) #[2, 1863, 864]
|
| 223 |
+
|
| 224 |
+
if class_dropout: #False
|
| 225 |
+
x_in[..., self.in_channels:] = x_in[..., self.in_channels:] * 0 # 80维后全置为0
|
| 226 |
+
|
| 227 |
+
x_in = self.cond_x_merge_linear(x_in) # (N, T, D) [2, 1863, 512]
|
| 228 |
+
|
| 229 |
+
if self.style_as_token: # False
|
| 230 |
+
style = self.style_in(style)
|
| 231 |
+
style = torch.zeros_like(style) if class_dropout else style
|
| 232 |
+
x_in = torch.cat([style.unsqueeze(1), x_in], dim=1)
|
| 233 |
+
|
| 234 |
+
if self.time_as_token: # False
|
| 235 |
+
x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
|
| 236 |
+
|
| 237 |
+
x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1) #torch.Size([1, 1, 1863])True
|
| 238 |
+
input_pos = self.input_pos[:x_in.size(1)] # (T,) range(0,1863)
|
| 239 |
+
x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None # torch.Size([1, 1, 1863, 1863]
|
| 240 |
+
x_res = self.transformer(x_in, t1.unsqueeze(1), input_pos, x_mask_expanded) # [2, 1863, 512]
|
| 241 |
+
x_res = x_res[:, 1:] if self.time_as_token else x_res
|
| 242 |
+
x_res = x_res[:, 1:] if self.style_as_token else x_res
|
| 243 |
+
|
| 244 |
+
if self.long_skip_connection: #True
|
| 245 |
+
x_res = self.skip_linear(torch.cat([x_res, x], dim=-1))
|
| 246 |
+
if self.final_layer_type == 'wavenet':
|
| 247 |
+
x = self.conv1(x_res)
|
| 248 |
+
x = x.transpose(1, 2)
|
| 249 |
+
t2 = self.t_embedder2(t)
|
| 250 |
+
x = self.wavenet(x, x_mask, g=t2.unsqueeze(2)).transpose(1, 2) + self.res_projection(
|
| 251 |
+
x_res) # long residual connection
|
| 252 |
+
x = self.final_layer(x, t1).transpose(1, 2)
|
| 253 |
+
x = self.conv2(x)
|
| 254 |
+
else:
|
| 255 |
+
x = self.final_mlp(x_res)
|
| 256 |
+
x = x.transpose(1, 2)
|
| 257 |
+
# x [2,80,1863]
|
| 258 |
+
return x
|
indextts/s2mel/modules/.ipynb_checkpoints/flow_matching-checkpoint.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from modules.diffusion_transformer import DiT
|
| 7 |
+
from modules.commons import sequence_mask
|
| 8 |
+
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
class BASECFM(torch.nn.Module, ABC):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
args,
|
| 15 |
+
):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.sigma_min = 1e-6
|
| 18 |
+
|
| 19 |
+
self.estimator = None
|
| 20 |
+
|
| 21 |
+
self.in_channels = args.DiT.in_channels
|
| 22 |
+
|
| 23 |
+
self.criterion = torch.nn.MSELoss() if args.reg_loss_type == "l2" else torch.nn.L1Loss()
|
| 24 |
+
|
| 25 |
+
if hasattr(args.DiT, 'zero_prompt_speech_token'):
|
| 26 |
+
self.zero_prompt_speech_token = args.DiT.zero_prompt_speech_token
|
| 27 |
+
else:
|
| 28 |
+
self.zero_prompt_speech_token = False
|
| 29 |
+
|
| 30 |
+
@torch.inference_mode()
|
| 31 |
+
def inference(self, mu, x_lens, prompt, style, f0, n_timesteps, temperature=1.0, inference_cfg_rate=0.5):
|
| 32 |
+
"""Forward diffusion
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
mu (torch.Tensor): semantic info of reference audio and altered audio
|
| 36 |
+
shape: (batch_size, mel_timesteps(795+1069), 512)
|
| 37 |
+
x_lens (torch.Tensor): mel frames output
|
| 38 |
+
shape: (batch_size, mel_timesteps)
|
| 39 |
+
prompt (torch.Tensor): reference mel
|
| 40 |
+
shape: (batch_size, 80, 795)
|
| 41 |
+
style (torch.Tensor): reference global style
|
| 42 |
+
shape: (batch_size, 192)
|
| 43 |
+
f0: None
|
| 44 |
+
n_timesteps (int): number of diffusion steps
|
| 45 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
sample: generated mel-spectrogram
|
| 49 |
+
shape: (batch_size, 80, mel_timesteps)
|
| 50 |
+
"""
|
| 51 |
+
B, T = mu.size(0), mu.size(1)
|
| 52 |
+
z = torch.randn([B, self.in_channels, T], device=mu.device) * temperature
|
| 53 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
| 54 |
+
# t_span = t_span + (-1) * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
|
| 55 |
+
return self.solve_euler(z, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate)
|
| 56 |
+
|
| 57 |
+
def solve_euler(self, x, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate=0.5):
|
| 58 |
+
"""
|
| 59 |
+
Fixed euler solver for ODEs.
|
| 60 |
+
Args:
|
| 61 |
+
x (torch.Tensor): random noise
|
| 62 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
| 63 |
+
shape: (n_timesteps + 1,)
|
| 64 |
+
mu (torch.Tensor): semantic info of reference audio and altered audio
|
| 65 |
+
shape: (batch_size, mel_timesteps(795+1069), 512)
|
| 66 |
+
x_lens (torch.Tensor): mel frames output
|
| 67 |
+
shape: (batch_size, mel_timesteps)
|
| 68 |
+
prompt (torch.Tensor): reference mel
|
| 69 |
+
shape: (batch_size, 80, 795)
|
| 70 |
+
style (torch.Tensor): reference global style
|
| 71 |
+
shape: (batch_size, 192)
|
| 72 |
+
"""
|
| 73 |
+
t, _, _ = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
| 74 |
+
|
| 75 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
| 76 |
+
# Or in future might add like a return_all_steps flag
|
| 77 |
+
sol = []
|
| 78 |
+
# apply prompt
|
| 79 |
+
prompt_len = prompt.size(-1)
|
| 80 |
+
prompt_x = torch.zeros_like(x)
|
| 81 |
+
prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
|
| 82 |
+
x[..., :prompt_len] = 0
|
| 83 |
+
if self.zero_prompt_speech_token:
|
| 84 |
+
mu[..., :prompt_len] = 0
|
| 85 |
+
for step in tqdm(range(1, len(t_span))):
|
| 86 |
+
dt = t_span[step] - t_span[step - 1]
|
| 87 |
+
if inference_cfg_rate > 0:
|
| 88 |
+
# Stack original and CFG (null) inputs for batched processing
|
| 89 |
+
stacked_prompt_x = torch.cat([prompt_x, torch.zeros_like(prompt_x)], dim=0)
|
| 90 |
+
stacked_style = torch.cat([style, torch.zeros_like(style)], dim=0)
|
| 91 |
+
stacked_mu = torch.cat([mu, torch.zeros_like(mu)], dim=0)
|
| 92 |
+
stacked_x = torch.cat([x, x], dim=0)
|
| 93 |
+
stacked_t = torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0)
|
| 94 |
+
|
| 95 |
+
# Perform a single forward pass for both original and CFG inputs
|
| 96 |
+
stacked_dphi_dt = self.estimator(
|
| 97 |
+
stacked_x, stacked_prompt_x, x_lens, stacked_t, stacked_style, stacked_mu,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Split the output back into the original and CFG components
|
| 101 |
+
dphi_dt, cfg_dphi_dt = stacked_dphi_dt.chunk(2, dim=0)
|
| 102 |
+
|
| 103 |
+
# Apply CFG formula
|
| 104 |
+
dphi_dt = (1.0 + inference_cfg_rate) * dphi_dt - inference_cfg_rate * cfg_dphi_dt
|
| 105 |
+
else:
|
| 106 |
+
dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu)
|
| 107 |
+
|
| 108 |
+
x = x + dt * dphi_dt
|
| 109 |
+
t = t + dt
|
| 110 |
+
sol.append(x)
|
| 111 |
+
if step < len(t_span) - 1:
|
| 112 |
+
dt = t_span[step + 1] - t
|
| 113 |
+
x[:, :, :prompt_len] = 0
|
| 114 |
+
|
| 115 |
+
return sol[-1]
|
| 116 |
+
def forward(self, x1, x_lens, prompt_lens, mu, style):
|
| 117 |
+
"""Computes diffusion loss
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
mu (torch.Tensor): semantic info of reference audio and altered audio
|
| 121 |
+
shape: (batch_size, mel_timesteps(795+1069), 512)
|
| 122 |
+
x1: mel
|
| 123 |
+
x_lens (torch.Tensor): mel frames output
|
| 124 |
+
shape: (batch_size, mel_timesteps)
|
| 125 |
+
prompt (torch.Tensor): reference mel
|
| 126 |
+
shape: (batch_size, 80, 795)
|
| 127 |
+
style (torch.Tensor): reference global style
|
| 128 |
+
shape: (batch_size, 192)
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
loss: conditional flow matching loss
|
| 132 |
+
y: conditional flow
|
| 133 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
| 134 |
+
"""
|
| 135 |
+
b, _, t = x1.shape
|
| 136 |
+
|
| 137 |
+
# random timestep
|
| 138 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=x1.dtype)
|
| 139 |
+
# sample noise p(x_0)
|
| 140 |
+
z = torch.randn_like(x1)
|
| 141 |
+
|
| 142 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
| 143 |
+
u = x1 - (1 - self.sigma_min) * z
|
| 144 |
+
|
| 145 |
+
prompt = torch.zeros_like(x1)
|
| 146 |
+
for bib in range(b):
|
| 147 |
+
prompt[bib, :, :prompt_lens[bib]] = x1[bib, :, :prompt_lens[bib]]
|
| 148 |
+
# range covered by prompt are set to 0
|
| 149 |
+
y[bib, :, :prompt_lens[bib]] = 0
|
| 150 |
+
if self.zero_prompt_speech_token:
|
| 151 |
+
mu[bib, :, :prompt_lens[bib]] = 0
|
| 152 |
+
|
| 153 |
+
estimator_out = self.estimator(y, prompt, x_lens, t.squeeze(1).squeeze(1), style, mu, prompt_lens)
|
| 154 |
+
loss = 0
|
| 155 |
+
for bib in range(b):
|
| 156 |
+
loss += self.criterion(estimator_out[bib, :, prompt_lens[bib]:x_lens[bib]], u[bib, :, prompt_lens[bib]:x_lens[bib]])
|
| 157 |
+
loss /= b
|
| 158 |
+
|
| 159 |
+
return loss, estimator_out + (1 - self.sigma_min) * z
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class CFM(BASECFM):
|
| 164 |
+
def __init__(self, args):
|
| 165 |
+
super().__init__(
|
| 166 |
+
args
|
| 167 |
+
)
|
| 168 |
+
if args.dit_type == "DiT":
|
| 169 |
+
self.estimator = DiT(args)
|
| 170 |
+
else:
|
| 171 |
+
raise NotImplementedError(f"Unknown diffusion type {args.dit_type}")
|
indextts/s2mel/modules/.ipynb_checkpoints/length_regulator-checkpoint.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
from modules.commons import sequence_mask
|
| 6 |
+
import numpy as np
|
| 7 |
+
from dac.nn.quantize import VectorQuantize
|
| 8 |
+
|
| 9 |
+
# f0_bin = 256
|
| 10 |
+
f0_max = 1100.0
|
| 11 |
+
f0_min = 50.0
|
| 12 |
+
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
|
| 13 |
+
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
|
| 14 |
+
|
| 15 |
+
def f0_to_coarse(f0, f0_bin):
|
| 16 |
+
f0_mel = 1127 * (1 + f0 / 700).log()
|
| 17 |
+
a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
|
| 18 |
+
b = f0_mel_min * a - 1.
|
| 19 |
+
f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
|
| 20 |
+
# torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
|
| 21 |
+
f0_coarse = torch.round(f0_mel).long()
|
| 22 |
+
f0_coarse = f0_coarse * (f0_coarse > 0)
|
| 23 |
+
f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
|
| 24 |
+
f0_coarse = f0_coarse * (f0_coarse < f0_bin)
|
| 25 |
+
f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1))
|
| 26 |
+
return f0_coarse
|
| 27 |
+
|
| 28 |
+
class InterpolateRegulator(nn.Module):
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
channels: int,
|
| 32 |
+
sampling_ratios: Tuple,
|
| 33 |
+
is_discrete: bool = False,
|
| 34 |
+
in_channels: int = None, # only applies to continuous input
|
| 35 |
+
vector_quantize: bool = False, # whether to use vector quantization, only applies to continuous input
|
| 36 |
+
codebook_size: int = 1024, # for discrete only
|
| 37 |
+
out_channels: int = None,
|
| 38 |
+
groups: int = 1,
|
| 39 |
+
n_codebooks: int = 1, # number of codebooks
|
| 40 |
+
quantizer_dropout: float = 0.0, # dropout for quantizer
|
| 41 |
+
f0_condition: bool = False,
|
| 42 |
+
n_f0_bins: int = 512,
|
| 43 |
+
):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.sampling_ratios = sampling_ratios
|
| 46 |
+
out_channels = out_channels or channels
|
| 47 |
+
model = nn.ModuleList([])
|
| 48 |
+
if len(sampling_ratios) > 0:
|
| 49 |
+
self.interpolate = True
|
| 50 |
+
for _ in sampling_ratios:
|
| 51 |
+
module = nn.Conv1d(channels, channels, 3, 1, 1)
|
| 52 |
+
norm = nn.GroupNorm(groups, channels)
|
| 53 |
+
act = nn.Mish()
|
| 54 |
+
model.extend([module, norm, act])
|
| 55 |
+
else:
|
| 56 |
+
self.interpolate = False
|
| 57 |
+
model.append(
|
| 58 |
+
nn.Conv1d(channels, out_channels, 1, 1)
|
| 59 |
+
)
|
| 60 |
+
self.model = nn.Sequential(*model)
|
| 61 |
+
self.embedding = nn.Embedding(codebook_size, channels)
|
| 62 |
+
self.is_discrete = is_discrete
|
| 63 |
+
|
| 64 |
+
self.mask_token = nn.Parameter(torch.zeros(1, channels))
|
| 65 |
+
|
| 66 |
+
self.n_codebooks = n_codebooks
|
| 67 |
+
if n_codebooks > 1:
|
| 68 |
+
self.extra_codebooks = nn.ModuleList([
|
| 69 |
+
nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1)
|
| 70 |
+
])
|
| 71 |
+
self.extra_codebook_mask_tokens = nn.ParameterList([
|
| 72 |
+
nn.Parameter(torch.zeros(1, channels)) for _ in range(n_codebooks - 1)
|
| 73 |
+
])
|
| 74 |
+
self.quantizer_dropout = quantizer_dropout
|
| 75 |
+
|
| 76 |
+
if f0_condition:
|
| 77 |
+
self.f0_embedding = nn.Embedding(n_f0_bins, channels)
|
| 78 |
+
self.f0_condition = f0_condition
|
| 79 |
+
self.n_f0_bins = n_f0_bins
|
| 80 |
+
self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
|
| 81 |
+
self.f0_mask = nn.Parameter(torch.zeros(1, channels))
|
| 82 |
+
else:
|
| 83 |
+
self.f0_condition = False
|
| 84 |
+
|
| 85 |
+
if not is_discrete:
|
| 86 |
+
self.content_in_proj = nn.Linear(in_channels, channels)
|
| 87 |
+
if vector_quantize:
|
| 88 |
+
self.vq = VectorQuantize(channels, codebook_size, 8)
|
| 89 |
+
|
| 90 |
+
def forward(self, x, ylens=None, n_quantizers=None, f0=None):
|
| 91 |
+
# apply token drop
|
| 92 |
+
if self.training:
|
| 93 |
+
n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks
|
| 94 |
+
dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],))
|
| 95 |
+
n_dropout = int(x.shape[0] * self.quantizer_dropout)
|
| 96 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
| 97 |
+
n_quantizers = n_quantizers.to(x.device)
|
| 98 |
+
# decide whether to drop for each sample in batch
|
| 99 |
+
else:
|
| 100 |
+
n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers)
|
| 101 |
+
if self.is_discrete:
|
| 102 |
+
if self.n_codebooks > 1:
|
| 103 |
+
assert len(x.size()) == 3
|
| 104 |
+
x_emb = self.embedding(x[:, 0])
|
| 105 |
+
for i, emb in enumerate(self.extra_codebooks):
|
| 106 |
+
x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1])
|
| 107 |
+
# add mask token if not using this codebook
|
| 108 |
+
# x_emb = x_emb + (n_quantizers <= i+1)[..., None, None] * self.extra_codebook_mask_tokens[i]
|
| 109 |
+
x = x_emb
|
| 110 |
+
elif self.n_codebooks == 1:
|
| 111 |
+
if len(x.size()) == 2:
|
| 112 |
+
x = self.embedding(x)
|
| 113 |
+
else:
|
| 114 |
+
x = self.embedding(x[:, 0])
|
| 115 |
+
else:
|
| 116 |
+
x = self.content_in_proj(x)
|
| 117 |
+
# x in (B, T, D)
|
| 118 |
+
mask = sequence_mask(ylens).unsqueeze(-1)
|
| 119 |
+
if self.interpolate:
|
| 120 |
+
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
|
| 121 |
+
else:
|
| 122 |
+
x = x.transpose(1, 2).contiguous()
|
| 123 |
+
mask = mask[:, :x.size(2), :]
|
| 124 |
+
ylens = ylens.clamp(max=x.size(2)).long()
|
| 125 |
+
if self.f0_condition:
|
| 126 |
+
if f0 is None:
|
| 127 |
+
x = x + self.f0_mask.unsqueeze(-1)
|
| 128 |
+
else:
|
| 129 |
+
#quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
|
| 130 |
+
quantized_f0 = f0_to_coarse(f0, self.n_f0_bins)
|
| 131 |
+
quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long()
|
| 132 |
+
f0_emb = self.f0_embedding(quantized_f0)
|
| 133 |
+
f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
|
| 134 |
+
x = x + f0_emb
|
| 135 |
+
out = self.model(x).transpose(1, 2).contiguous()
|
| 136 |
+
if hasattr(self, 'vq'):
|
| 137 |
+
out_q, commitment_loss, codebook_loss, codes, out, = self.vq(out.transpose(1, 2))
|
| 138 |
+
out_q = out_q.transpose(1, 2)
|
| 139 |
+
return out_q * mask, ylens, codes, commitment_loss, codebook_loss
|
| 140 |
+
olens = ylens
|
| 141 |
+
return out * mask, olens, None, None, None
|
webui.py
CHANGED
|
@@ -38,7 +38,9 @@ from modelscope.hub import api
|
|
| 38 |
|
| 39 |
i18n = I18nAuto(language="Auto")
|
| 40 |
MODE = 'local'
|
| 41 |
-
tts = IndexTTS2(model_dir=cmd_args.model_dir,
|
|
|
|
|
|
|
| 42 |
|
| 43 |
# 支持的语言列表
|
| 44 |
LANGUAGES = {
|
|
|
|
| 38 |
|
| 39 |
i18n = I18nAuto(language="Auto")
|
| 40 |
MODE = 'local'
|
| 41 |
+
tts = IndexTTS2(model_dir=cmd_args.model_dir,
|
| 42 |
+
cfg_path=os.path.join(cmd_args.model_dir, "config.yaml"),
|
| 43 |
+
is_fp16=False,use_cuda_kernel=False)
|
| 44 |
|
| 45 |
# 支持的语言列表
|
| 46 |
LANGUAGES = {
|