Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,662 Bytes
56cfa73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import json
import string
from dataclasses import asdict, dataclass
from pathlib import Path
import torch
from safetensors.torch import load_file
from torch.nn.utils.rnn import pad_sequence
from codec.models import PatchVAE
from tts.model.cache_utils import FLACache
from tts.text_processor import BasicTextProcessor
from tts.tools import sequence_mask
from tts.tts import ARTTSModel
@dataclass
class VelocityHeadSamplingParams:
"""
Velocity head sampling parameters
Attributes:
cfg (float): CFG factor against unconditional prediction.
cfg_ref (float): CFG factor against a reference (to be used with a cache of size 2*batch_size and unfold).
temperature (float): scale factor of z0 ~ π©(0,1)
num_steps (int): number of ODE steps
solver (str): parameter passed to NeuralODE
sensitivity (str): parameter passed to NeuralODE
"""
cfg: float = 1.3
cfg_ref: float = 1.5
temperature: float = 0.9
num_steps: int = 13
solver: str = "euler"
sensitivity: str = "adjoint"
@dataclass
class PatchVAESamplingParams:
"""
PatchVAE sampling parameters
Attributes:
cfg (float): CFG factor against unconditional prediction.
temperature (float): scale factor of z0 ~ π©(0,1)
num_steps (int): number of ODE steps
solver (str): parameter passed to NeuralODE
sensitivity (str): parameter passed to NeuralODE
"""
cfg: float = 2.0
temperature: float = 1.0
num_steps: int = 10
solver: str = "euler"
sensitivity: str = "adjoint"
class PardiSpeech:
tts: ARTTSModel
patchvae: PatchVAE
text_processor: BasicTextProcessor
def __init__(
self,
tts: ARTTSModel,
patchvae: PatchVAE,
text_processor: BasicTextProcessor,
):
self.tts = tts
self.patchvae = patchvae
self.text_processor = text_processor
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
map_location: str = "cpu",
):
if Path(pretrained_model_name_or_path).exists():
path = pretrained_model_name_or_path
else:
from huggingface_hub import snapshot_download
path = snapshot_download(pretrained_model_name_or_path)
with open(Path(path) / "config.json", "r") as f:
config = json.load(f)
artts_model, artts_config = ARTTSModel.instantiate_from_config(config)
state_dict = load_file(
Path(path) / "model.st",
device=map_location,
)
artts_model.load_state_dict(state_dict, assign=True)
patchvae = PatchVAE.from_pretrained(
artts_config.patchvae_path,
map_location=map_location,
)
text_processor = BasicTextProcessor(
str(Path(path) / "pretrained_tokenizer.json")
)
return cls(artts_model, patchvae, text_processor)
def encode_reference(self, wav: torch.Tensor, sr: int):
import torchaudio
new_freq = self.patchvae.wavvae.sampling_rate
wav = torchaudio.functional.resample(wav, orig_freq=sr, new_freq=new_freq)
return self.patchvae.encode(wav)
@property
def sampling_rate(self):
return self.patchvae.wavvae.sampling_rate
def text_to_speech(
self,
text: str,
prefix: tuple[str, torch.Tensor] | None = None,
patchvae_sampling_params: PatchVAESamplingParams | None = None,
velocity_head_sampling_params: VelocityHeadSamplingParams | None = None,
prefix_separator: str = ". ",
max_seq_len: int = 600,
stop_threshold: float = 0.5,
cache: FLACache | None = None,
**kwargs,
):
"""
Parameters
----------
text: str
The text to synthesize.
prefix: tuple[str, torch.Tensor] | None
A pair (text, speech) consisting of a reference speech excerpt encoded (see encode_reference) and its corresponding text transcription. Synthesis is performed by continuing the prefix. If no prefix is given, the first frame is randomly sampled.
patchvae_sampling_params: PatchVAESamplingParams
PatchVAE sampling parameters
velocity_head_sampling_params: VelocityHeadSamplingParams
VelocityHead sampling parameters (AR sampling)
prefix_separator: str
The separator that joins the prefix text to the target text.
max_seq_len: int
The maximum number of latent to generate.
stop_threshold: float
Threshold value at which AR prediction stops.
"""
device = next(self.tts.parameters()).device
if type(text) is str:
text = [text]
if prefix is not None:
prefix_text, prefix_speech = prefix
prefix_text = prefix_text.strip().rstrip(string.punctuation)
if prefix_text != "":
text = [prefix_text + prefix_separator + t for t in text]
prefix_speech = prefix_speech.repeat(len(text), 1, 1)
else:
_, audio_latent_sz = self.tts.audio_embd.weight.shape
prefix_speech = torch.randn(len(text), 1, audio_latent_sz, device=device)
# if self.bos:
# text = "[BOS]" + text
# if self.eos:
# text = text + "[EOS]"
text_ids = [torch.LongTensor(self.text_processor(x + "[EOS]")) for x in text]
text_pre_mask = sequence_mask(torch.tensor([x.shape[0] for x in text_ids])).to(device)
text_mask = text_pre_mask[:, None] * text_pre_mask[..., None]
crossatt_mask = text_pre_mask[:, None,None]
text_ids = pad_sequence(text_ids, batch_first=True)
if velocity_head_sampling_params is None:
velocity_head_sampling_params = VelocityHeadSamplingParams()
if patchvae_sampling_params is None:
patchvae_sampling_params = PatchVAESamplingParams()
with torch.inference_mode():
_, predictions = self.tts.generate(
text_ids.to(device),
text_mask=text_mask,
crossatt_mask=crossatt_mask,
prefix=prefix_speech.to(device),
max_seq_len=max_seq_len,
sampling_params=asdict(velocity_head_sampling_params),
stop_threshold=stop_threshold,
cache=cache,
device=device,
**kwargs,
)
wavs = [self.patchvae.decode(
p,
**asdict(patchvae_sampling_params),
) for p in predictions]
return wavs, predictions
|