pardi-speech / pardi_speech.py
Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
import json
import string
from dataclasses import asdict, dataclass
from pathlib import Path
import torch
from safetensors.torch import load_file
from torch.nn.utils.rnn import pad_sequence
from codec.models import PatchVAE
from tts.model.cache_utils import FLACache
from tts.text_processor import BasicTextProcessor
from tts.tools import sequence_mask
from tts.tts import ARTTSModel
@dataclass
class VelocityHeadSamplingParams:
"""
Velocity head sampling parameters
Attributes:
cfg (float): CFG factor against unconditional prediction.
cfg_ref (float): CFG factor against a reference (to be used with a cache of size 2*batch_size and unfold).
temperature (float): scale factor of z0 ~ 𝒩(0,1)
num_steps (int): number of ODE steps
solver (str): parameter passed to NeuralODE
sensitivity (str): parameter passed to NeuralODE
"""
cfg: float = 1.3
cfg_ref: float = 1.5
temperature: float = 0.9
num_steps: int = 13
solver: str = "euler"
sensitivity: str = "adjoint"
@dataclass
class PatchVAESamplingParams:
"""
PatchVAE sampling parameters
Attributes:
cfg (float): CFG factor against unconditional prediction.
temperature (float): scale factor of z0 ~ 𝒩(0,1)
num_steps (int): number of ODE steps
solver (str): parameter passed to NeuralODE
sensitivity (str): parameter passed to NeuralODE
"""
cfg: float = 2.0
temperature: float = 1.0
num_steps: int = 10
solver: str = "euler"
sensitivity: str = "adjoint"
class PardiSpeech:
tts: ARTTSModel
patchvae: PatchVAE
text_processor: BasicTextProcessor
def __init__(
self,
tts: ARTTSModel,
patchvae: PatchVAE,
text_processor: BasicTextProcessor,
):
self.tts = tts
self.patchvae = patchvae
self.text_processor = text_processor
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
map_location: str = "cpu",
):
if Path(pretrained_model_name_or_path).exists():
path = pretrained_model_name_or_path
else:
from huggingface_hub import snapshot_download
path = snapshot_download(pretrained_model_name_or_path)
with open(Path(path) / "config.json", "r") as f:
config = json.load(f)
artts_model, artts_config = ARTTSModel.instantiate_from_config(config)
state_dict = load_file(
Path(path) / "model.st",
device=map_location,
)
artts_model.load_state_dict(state_dict, assign=True)
patchvae = PatchVAE.from_pretrained(
artts_config.patchvae_path,
map_location=map_location,
)
text_processor = BasicTextProcessor(
str(Path(path) / "pretrained_tokenizer.json")
)
return cls(artts_model, patchvae, text_processor)
def encode_reference(self, wav: torch.Tensor, sr: int):
import torchaudio
new_freq = self.patchvae.wavvae.sampling_rate
wav = torchaudio.functional.resample(wav, orig_freq=sr, new_freq=new_freq)
return self.patchvae.encode(wav)
@property
def sampling_rate(self):
return self.patchvae.wavvae.sampling_rate
def text_to_speech(
self,
text: str,
prefix: tuple[str, torch.Tensor] | None = None,
patchvae_sampling_params: PatchVAESamplingParams | None = None,
velocity_head_sampling_params: VelocityHeadSamplingParams | None = None,
prefix_separator: str = ". ",
max_seq_len: int = 600,
stop_threshold: float = 0.5,
cache: FLACache | None = None,
**kwargs,
):
"""
Parameters
----------
text: str
The text to synthesize.
prefix: tuple[str, torch.Tensor] | None
A pair (text, speech) consisting of a reference speech excerpt encoded (see encode_reference) and its corresponding text transcription. Synthesis is performed by continuing the prefix. If no prefix is given, the first frame is randomly sampled.
patchvae_sampling_params: PatchVAESamplingParams
PatchVAE sampling parameters
velocity_head_sampling_params: VelocityHeadSamplingParams
VelocityHead sampling parameters (AR sampling)
prefix_separator: str
The separator that joins the prefix text to the target text.
max_seq_len: int
The maximum number of latent to generate.
stop_threshold: float
Threshold value at which AR prediction stops.
"""
device = next(self.tts.parameters()).device
if type(text) is str:
text = [text]
if prefix is not None:
prefix_text, prefix_speech = prefix
prefix_text = prefix_text.strip().rstrip(string.punctuation)
if prefix_text != "":
text = [prefix_text + prefix_separator + t for t in text]
prefix_speech = prefix_speech.repeat(len(text), 1, 1)
else:
_, audio_latent_sz = self.tts.audio_embd.weight.shape
prefix_speech = torch.randn(len(text), 1, audio_latent_sz, device=device)
# if self.bos:
# text = "[BOS]" + text
# if self.eos:
# text = text + "[EOS]"
text_ids = [torch.LongTensor(self.text_processor(x + "[EOS]")) for x in text]
text_pre_mask = sequence_mask(torch.tensor([x.shape[0] for x in text_ids])).to(device)
text_mask = text_pre_mask[:, None] * text_pre_mask[..., None]
crossatt_mask = text_pre_mask[:, None,None]
text_ids = pad_sequence(text_ids, batch_first=True)
if velocity_head_sampling_params is None:
velocity_head_sampling_params = VelocityHeadSamplingParams()
if patchvae_sampling_params is None:
patchvae_sampling_params = PatchVAESamplingParams()
with torch.inference_mode():
_, predictions = self.tts.generate(
text_ids.to(device),
text_mask=text_mask,
crossatt_mask=crossatt_mask,
prefix=prefix_speech.to(device),
max_seq_len=max_seq_len,
sampling_params=asdict(velocity_head_sampling_params),
stop_threshold=stop_threshold,
cache=cache,
device=device,
**kwargs,
)
wavs = [self.patchvae.decode(
p,
**asdict(patchvae_sampling_params),
) for p in predictions]
return wavs, predictions