File size: 7,685 Bytes
7e6946d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# Copyright (c) 2025 Tsinghua Univ. (authors: Xingchen Song)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from datetime import datetime

import s3tokenizer
import torch
from tqdm import tqdm

from flashcosyvoice.config import Config, SamplingParams
from flashcosyvoice.engine.llm_engine import LLMEngine
from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec
from flashcosyvoice.modules.hifigan import HiFTGenerator


class CosyVoice2(torch.nn.Module):
    def __init__(self, config: Config = None):
        super().__init__()
        self.config = Config() if config is None else config

        self.audio_tokenizer = s3tokenizer.load_model("speech_tokenizer_v2_25hz").cuda().eval()

        self.llm = LLMEngine(**self.config.__dict__)

        self.use_tqdm = torch.distributed.get_node_local_rank() == 0

        self.flow = CausalMaskedDiffWithXvec()
        if self.config.hf_config.fp16_flow:
            timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
            tqdm.write(f"[{timestamp}] - [INFO] - Casting flow to fp16")
            self.flow.half()
        self.flow.load_state_dict(torch.load(f"{self.config.model}/flow.pt", map_location="cpu", weights_only=True), strict=True)
        self.flow.cuda().eval()

        self.hift = HiFTGenerator()
        hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{self.config.model}/hift.pt", map_location="cpu", weights_only=True).items()}
        self.hift.load_state_dict(hift_state_dict, strict=True)
        self.hift.cuda().eval()

    @torch.inference_mode()
    def forward(
        self, prompt_mels_for_llm: torch.Tensor, prompt_mels_lens_for_llm: torch.Tensor,
        prompt_text_tokens_for_llm: list[list[int]], text_tokens_for_llm: list[list[int]],
        prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor,
        spk_emb_for_flow: torch.Tensor,
        sampling_params: SamplingParams | list[SamplingParams],
        batch_size_flow: int,
        only_llm: bool,
        **kwargs,  # for compatibility
    ):
        timing_stats = {}

        # Audio tokenization
        start_time = time.time()
        prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
            prompt_mels_for_llm.cuda(), prompt_mels_lens_for_llm.cuda()
        )
        timing_stats['audio_tokenization'] = time.time() - start_time

        batch_size = prompt_speech_tokens.shape[0]
        assert len(prompt_text_tokens_for_llm) == batch_size

        # Prepare LLM inputs
        start_time = time.time()
        valid_prompt_speech_tokens = []
        inputs = []
        for i in range(batch_size):
            speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
            valid_prompt_speech_tokens.append(speech_tokens_i)
            inputs.append([self.config.hf_config.speech_vocab_size] + prompt_text_tokens_for_llm[i] + text_tokens_for_llm[i] + [self.config.hf_config.speech_vocab_size + 1] + speech_tokens_i)
        timing_stats['prepare_llm_inputs'] = time.time() - start_time

        # LLM generation
        start_time = time.time()
        llm_outputs = self.llm.generate(inputs, sampling_params, use_tqdm=self.use_tqdm)
        timing_stats['llm_generation'] = time.time() - start_time

        results_dict = {
            "prompt_speech_tokens": valid_prompt_speech_tokens,
            "generated_speech_tokens": [o['token_ids'][:-1] for o in llm_outputs],
        }
        if only_llm:
            return results_dict, timing_stats

        # Prepare Flow inputs
        start_time = time.time()
        flow_inputs = []
        flow_inputs_lens = []
        for i, o in enumerate(llm_outputs):
            generated_speech_tokens = o['token_ids'][:-1]  # ignore last eos
            prompt_speech_tokens = valid_prompt_speech_tokens[i]
            flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens))
            flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens))
        flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0)
        flow_inputs_lens = torch.tensor(flow_inputs_lens)
        timing_stats['prepare_flow_inputs'] = time.time() - start_time

        # Flow generation and HiFi-GAN generation (with batching)
        total_batch_size = flow_inputs.shape[0]
        generated_wavs = []
        flow_total_time = 0.0
        hifigan_total_time = 0.0

        # Process in batches according to batch_size_flow, batch_size_flow <= total_batch_size
        # NOTE(xcsong): When executing both LLM and Flow on the same GPU,
        #   Flow can easily fill up the SM and memory. Therefore, batch processing is required to avoid OOM.
        num_batches = (total_batch_size + batch_size_flow - 1) // batch_size_flow
        batch_iterator = range(0, total_batch_size, batch_size_flow)
        if self.use_tqdm:
            batch_iterator = tqdm(batch_iterator, desc="Generating wavs (Flow+HiFi-GAN)", leave=False, unit="batch",
                                  total=num_batches, dynamic_ncols=True, position=self.config.rank + 1)

        for start_idx in batch_iterator:
            end_idx = min(start_idx + batch_size_flow, total_batch_size)
            batch_flow_inputs = flow_inputs[start_idx:end_idx]
            batch_flow_inputs_lens = flow_inputs_lens[start_idx:end_idx]
            batch_prompt_mels = prompt_mels_for_flow[start_idx:end_idx]
            batch_prompt_mels_lens = prompt_mels_lens_for_flow[start_idx:end_idx]
            batch_spk_emb = spk_emb_for_flow[start_idx:end_idx]

            # Flow generation for this batch
            flow_start_time = time.time()
            with torch.amp.autocast("cuda", dtype=torch.float16 if self.config.hf_config.fp16_flow else torch.float32):
                batch_generated_mels, batch_generated_mels_lens = self.flow(
                    batch_flow_inputs.cuda(), batch_flow_inputs_lens.cuda(),
                    batch_prompt_mels.cuda(), batch_prompt_mels_lens.cuda(), batch_spk_emb.cuda(),
                    streaming=False, finalize=True
                )
            flow_total_time += time.time() - flow_start_time

            # HiFi-GAN generation for this batch
            hifigan_start_time = time.time()
            batch_size_current = end_idx - start_idx
            for i in range(batch_size_current):
                mel = batch_generated_mels[i, :, batch_prompt_mels_lens[i].item():batch_generated_mels_lens[i].item()].unsqueeze(0)
                wav, _ = self.hift(speech_feat=mel)
                generated_wavs.append(wav)
            hifigan_total_time += time.time() - hifigan_start_time

        timing_stats['flow_generation'] = flow_total_time
        timing_stats['hifigan_generation'] = hifigan_total_time

        # Calculate total time and batch statistics
        timing_stats['model.forward_total'] = sum(timing_stats.values())
        timing_stats['batch_size'] = len(generated_wavs)
        timing_stats['batch_size_flow'] = batch_size_flow

        results_dict['generated_wavs'] = generated_wavs
        return results_dict, timing_stats