Spaces:
Running
on
Zero
Running
on
Zero
| import json | |
| import string | |
| from dataclasses import asdict, dataclass | |
| from pathlib import Path | |
| import torch | |
| from safetensors.torch import load_file | |
| from torch.nn.utils.rnn import pad_sequence | |
| from codec.models import PatchVAE | |
| from tts.model.cache_utils import FLACache | |
| from tts.text_processor import BasicTextProcessor | |
| from tts.tools import sequence_mask | |
| from tts.tts import ARTTSModel | |
| class VelocityHeadSamplingParams: | |
| """ | |
| Velocity head sampling parameters | |
| Attributes: | |
| cfg (float): CFG factor against unconditional prediction. | |
| cfg_ref (float): CFG factor against a reference (to be used with a cache of size 2*batch_size and unfold). | |
| temperature (float): scale factor of z0 ~ π©(0,1) | |
| num_steps (int): number of ODE steps | |
| solver (str): parameter passed to NeuralODE | |
| sensitivity (str): parameter passed to NeuralODE | |
| """ | |
| cfg: float = 1.3 | |
| cfg_ref: float = 1.5 | |
| temperature: float = 0.9 | |
| num_steps: int = 13 | |
| solver: str = "euler" | |
| sensitivity: str = "adjoint" | |
| class PatchVAESamplingParams: | |
| """ | |
| PatchVAE sampling parameters | |
| Attributes: | |
| cfg (float): CFG factor against unconditional prediction. | |
| temperature (float): scale factor of z0 ~ π©(0,1) | |
| num_steps (int): number of ODE steps | |
| solver (str): parameter passed to NeuralODE | |
| sensitivity (str): parameter passed to NeuralODE | |
| """ | |
| cfg: float = 2.0 | |
| temperature: float = 1.0 | |
| num_steps: int = 10 | |
| solver: str = "euler" | |
| sensitivity: str = "adjoint" | |
| class PardiSpeech: | |
| tts: ARTTSModel | |
| patchvae: PatchVAE | |
| text_processor: BasicTextProcessor | |
| def __init__( | |
| self, | |
| tts: ARTTSModel, | |
| patchvae: PatchVAE, | |
| text_processor: BasicTextProcessor, | |
| ): | |
| self.tts = tts | |
| self.patchvae = patchvae | |
| self.text_processor = text_processor | |
| def from_pretrained( | |
| cls, | |
| pretrained_model_name_or_path: str, | |
| map_location: str = "cpu", | |
| ): | |
| if Path(pretrained_model_name_or_path).exists(): | |
| path = pretrained_model_name_or_path | |
| else: | |
| from huggingface_hub import snapshot_download | |
| path = snapshot_download(pretrained_model_name_or_path) | |
| with open(Path(path) / "config.json", "r") as f: | |
| config = json.load(f) | |
| artts_model, artts_config = ARTTSModel.instantiate_from_config(config) | |
| state_dict = load_file( | |
| Path(path) / "model.st", | |
| device=map_location, | |
| ) | |
| artts_model.load_state_dict(state_dict, assign=True) | |
| patchvae = PatchVAE.from_pretrained( | |
| artts_config.patchvae_path, | |
| map_location=map_location, | |
| ) | |
| text_processor = BasicTextProcessor( | |
| str(Path(path) / "pretrained_tokenizer.json") | |
| ) | |
| return cls(artts_model, patchvae, text_processor) | |
| def encode_reference(self, wav: torch.Tensor, sr: int): | |
| import torchaudio | |
| new_freq = self.patchvae.wavvae.sampling_rate | |
| wav = torchaudio.functional.resample(wav, orig_freq=sr, new_freq=new_freq) | |
| return self.patchvae.encode(wav) | |
| def sampling_rate(self): | |
| return self.patchvae.wavvae.sampling_rate | |
| def text_to_speech( | |
| self, | |
| text: str, | |
| prefix: tuple[str, torch.Tensor] | None = None, | |
| patchvae_sampling_params: PatchVAESamplingParams | None = None, | |
| velocity_head_sampling_params: VelocityHeadSamplingParams | None = None, | |
| prefix_separator: str = ". ", | |
| max_seq_len: int = 600, | |
| stop_threshold: float = 0.5, | |
| cache: FLACache | None = None, | |
| **kwargs, | |
| ): | |
| """ | |
| Parameters | |
| ---------- | |
| text: str | |
| The text to synthesize. | |
| prefix: tuple[str, torch.Tensor] | None | |
| A pair (text, speech) consisting of a reference speech excerpt encoded (see encode_reference) and its corresponding text transcription. Synthesis is performed by continuing the prefix. If no prefix is given, the first frame is randomly sampled. | |
| patchvae_sampling_params: PatchVAESamplingParams | |
| PatchVAE sampling parameters | |
| velocity_head_sampling_params: VelocityHeadSamplingParams | |
| VelocityHead sampling parameters (AR sampling) | |
| prefix_separator: str | |
| The separator that joins the prefix text to the target text. | |
| max_seq_len: int | |
| The maximum number of latent to generate. | |
| stop_threshold: float | |
| Threshold value at which AR prediction stops. | |
| """ | |
| device = next(self.tts.parameters()).device | |
| if type(text) is str: | |
| text = [text] | |
| if prefix is not None: | |
| prefix_text, prefix_speech = prefix | |
| prefix_text = prefix_text.strip().rstrip(string.punctuation) | |
| if prefix_text != "": | |
| text = [prefix_text + prefix_separator + t for t in text] | |
| prefix_speech = prefix_speech.repeat(len(text), 1, 1) | |
| else: | |
| _, audio_latent_sz = self.tts.audio_embd.weight.shape | |
| prefix_speech = torch.randn(len(text), 1, audio_latent_sz, device=device) | |
| # if self.bos: | |
| # text = "[BOS]" + text | |
| # if self.eos: | |
| # text = text + "[EOS]" | |
| text_ids = [torch.LongTensor(self.text_processor(x + "[EOS]")) for x in text] | |
| text_pre_mask = sequence_mask(torch.tensor([x.shape[0] for x in text_ids])).to(device) | |
| text_mask = text_pre_mask[:, None] * text_pre_mask[..., None] | |
| crossatt_mask = text_pre_mask[:, None,None] | |
| text_ids = pad_sequence(text_ids, batch_first=True) | |
| if velocity_head_sampling_params is None: | |
| velocity_head_sampling_params = VelocityHeadSamplingParams() | |
| if patchvae_sampling_params is None: | |
| patchvae_sampling_params = PatchVAESamplingParams() | |
| with torch.inference_mode(): | |
| _, predictions = self.tts.generate( | |
| text_ids.to(device), | |
| text_mask=text_mask, | |
| crossatt_mask=crossatt_mask, | |
| prefix=prefix_speech.to(device), | |
| max_seq_len=max_seq_len, | |
| sampling_params=asdict(velocity_head_sampling_params), | |
| stop_threshold=stop_threshold, | |
| cache=cache, | |
| device=device, | |
| **kwargs, | |
| ) | |
| wavs = [self.patchvae.decode( | |
| p, | |
| **asdict(patchvae_sampling_params), | |
| ) for p in predictions] | |
| return wavs, predictions | |