File size: 6,662 Bytes
56cfa73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import json
import string
from dataclasses import asdict, dataclass
from pathlib import Path

import torch
from safetensors.torch import load_file
from torch.nn.utils.rnn import pad_sequence

from codec.models import PatchVAE
from tts.model.cache_utils import FLACache
from tts.text_processor import BasicTextProcessor
from tts.tools import sequence_mask
from tts.tts import ARTTSModel


@dataclass
class VelocityHeadSamplingParams:
    """
    Velocity head sampling parameters

    Attributes:

    cfg (float): CFG factor against unconditional prediction.
    cfg_ref (float): CFG factor against a reference (to be used with a cache of size 2*batch_size and unfold).
    temperature (float): scale factor of z0 ~ 𝒩(0,1)
    num_steps (int): number of ODE steps
    solver (str): parameter passed to NeuralODE
    sensitivity (str): parameter passed to NeuralODE
    """
    cfg: float = 1.3
    cfg_ref: float = 1.5
    temperature: float = 0.9
    num_steps: int = 13
    solver: str = "euler"
    sensitivity: str = "adjoint"


@dataclass
class PatchVAESamplingParams:
    """
    PatchVAE sampling parameters

    Attributes:

    cfg (float): CFG factor against unconditional prediction.
    temperature (float): scale factor of z0 ~ 𝒩(0,1)
    num_steps (int): number of ODE steps
    solver (str): parameter passed to NeuralODE
    sensitivity (str): parameter passed to NeuralODE
    """

    cfg: float = 2.0
    temperature: float = 1.0
    num_steps: int = 10
    solver: str = "euler"
    sensitivity: str = "adjoint"


class PardiSpeech:
    tts: ARTTSModel
    patchvae: PatchVAE
    text_processor: BasicTextProcessor

    def __init__(
        self,
        tts: ARTTSModel,
        patchvae: PatchVAE,
        text_processor: BasicTextProcessor,
    ):
        self.tts = tts
        self.patchvae = patchvae
        self.text_processor = text_processor

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: str,
        map_location: str = "cpu",
    ):
        if Path(pretrained_model_name_or_path).exists():
            path = pretrained_model_name_or_path
        else:
            from huggingface_hub import snapshot_download

            path = snapshot_download(pretrained_model_name_or_path)

        with open(Path(path) / "config.json", "r") as f:
            config = json.load(f)

        artts_model, artts_config = ARTTSModel.instantiate_from_config(config)
        state_dict = load_file(
            Path(path) / "model.st",
            device=map_location,
        )
        artts_model.load_state_dict(state_dict, assign=True)
        patchvae = PatchVAE.from_pretrained(
            artts_config.patchvae_path,
            map_location=map_location,
        )
        text_processor = BasicTextProcessor(
            str(Path(path) / "pretrained_tokenizer.json")
        )

        return cls(artts_model, patchvae, text_processor)

    def encode_reference(self, wav: torch.Tensor, sr: int):
        import torchaudio

        new_freq = self.patchvae.wavvae.sampling_rate
        wav = torchaudio.functional.resample(wav, orig_freq=sr, new_freq=new_freq)
        return self.patchvae.encode(wav)

    @property
    def sampling_rate(self):
        return self.patchvae.wavvae.sampling_rate

    def text_to_speech(
        self,
        text: str,
        prefix: tuple[str, torch.Tensor] | None = None,
        patchvae_sampling_params: PatchVAESamplingParams | None = None,
        velocity_head_sampling_params: VelocityHeadSamplingParams | None = None,
        prefix_separator: str = ". ",
        max_seq_len: int = 600,
        stop_threshold: float = 0.5,
        cache: FLACache | None = None,
        **kwargs,
    ):
        """
        Parameters
        ----------
        text: str
            The text to synthesize.

        prefix: tuple[str, torch.Tensor] | None
            A pair (text, speech) consisting of a reference speech excerpt encoded (see encode_reference) and its corresponding text transcription. Synthesis is performed by continuing the prefix. If no prefix is given, the first frame is randomly sampled.

        patchvae_sampling_params: PatchVAESamplingParams
            PatchVAE sampling parameters

        velocity_head_sampling_params: VelocityHeadSamplingParams
            VelocityHead sampling parameters (AR sampling)

        prefix_separator: str
            The separator that joins the prefix text to the target text.

        max_seq_len: int
            The maximum number of latent to generate.

        stop_threshold: float
            Threshold value at which AR prediction stops.
        """

        device = next(self.tts.parameters()).device

        if type(text) is str:
            text = [text]

        if prefix is not None:
            prefix_text, prefix_speech = prefix
            prefix_text = prefix_text.strip().rstrip(string.punctuation)
            if prefix_text != "":
                text = [prefix_text + prefix_separator + t for t in text]

            prefix_speech = prefix_speech.repeat(len(text), 1, 1)
        else:
            _, audio_latent_sz = self.tts.audio_embd.weight.shape
            prefix_speech = torch.randn(len(text), 1, audio_latent_sz, device=device)

        # if self.bos:
        #    text = "[BOS]" + text

        # if self.eos:
        #    text = text + "[EOS]"
        text_ids = [torch.LongTensor(self.text_processor(x + "[EOS]")) for x in text]
        text_pre_mask = sequence_mask(torch.tensor([x.shape[0] for x in text_ids])).to(device)
        text_mask = text_pre_mask[:, None] * text_pre_mask[..., None]
        crossatt_mask = text_pre_mask[:, None,None]

        text_ids = pad_sequence(text_ids, batch_first=True)

        if velocity_head_sampling_params is None:
            velocity_head_sampling_params = VelocityHeadSamplingParams()

        if patchvae_sampling_params is None:
            patchvae_sampling_params = PatchVAESamplingParams()

        
        with torch.inference_mode():
            _, predictions = self.tts.generate(
                text_ids.to(device),
                text_mask=text_mask,
                crossatt_mask=crossatt_mask,
                prefix=prefix_speech.to(device),
                max_seq_len=max_seq_len,
                sampling_params=asdict(velocity_head_sampling_params),
                stop_threshold=stop_threshold,
                cache=cache,
                device=device,
                **kwargs,
            )
            wavs = [self.patchvae.decode(
                p,
                **asdict(patchvae_sampling_params),
            ) for p in predictions]

        return wavs, predictions