File size: 5,386 Bytes
7e6946d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import atexit
from dataclasses import fields
from time import perf_counter

import torch.multiprocessing as mp
from tqdm.auto import tqdm
from transformers import AutoTokenizer

from flashcosyvoice.config import Config, SamplingParams
from flashcosyvoice.engine.model_runner import ModelRunner
from flashcosyvoice.engine.scheduler import Scheduler
from flashcosyvoice.engine.sequence import Sequence


class LLMEngine:

    def __init__(self, model, **kwargs):
        config_fields = {field.name for field in fields(Config)}
        config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
        config = Config(model, **config_kwargs)
        self.ps = []
        self.events = []
        ctx = mp.get_context("spawn")
        assert config.tensor_parallel_size == 1, "NOTE(xcsong): Currently only support tp=1"
        for i in range(1, config.tensor_parallel_size):
            event = ctx.Event()
            process = ctx.Process(target=ModelRunner, args=(config, i, event))
            process.start()
            self.ps.append(process)
            self.events.append(event)
        if hasattr(config.hf_config, "speech_vocab_size"):
            # NOTE: non-chat model, all these special tokens keep randomly initialized.
            special_tokens = {
                'eos_token': '<|endoftext|>',
                'pad_token': '<|endoftext|>',
                'additional_special_tokens': [
                    '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
                    '[breath]', '<strong>', '</strong>', '[noise]',
                    '[laughter]', '[cough]', '[clucking]', '[accent]',
                    '[quick_breath]',
                    "<laughter>", "</laughter>",
                    "[hissing]", "[sigh]", "[vocalized-noise]",
                    "[lipsmack]", "[mn]"
                ]
            }
            self.tokenizer = AutoTokenizer.from_pretrained(f"{config.model}/CosyVoice-BlankEN")
            self.tokenizer.add_special_tokens(special_tokens)
            self.skip_special_tokens = True
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
        if hasattr(config.hf_config, "eos_token_id"):
            config.eos = config.hf_config.eos_token_id
        else:
            config.eos = self.tokenizer.eos_token_id
        self.model_runner = ModelRunner(config, config.rank, self.events)
        self.scheduler = Scheduler(config)
        self.config = config
        atexit.register(self.exit)

    def exit(self):
        self.model_runner.call("exit")
        del self.model_runner
        for p in self.ps:
            p.join()

    def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
        if isinstance(prompt, str):
            prompt = self.tokenizer.encode(prompt)
        seq = Sequence(prompt, sampling_params)
        self.scheduler.add(seq)

    def step(self):
        seqs, is_prefill = self.scheduler.schedule()
        token_ids = self.model_runner.call("run", seqs, is_prefill)
        self.scheduler.postprocess(seqs, token_ids)
        outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
        num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
        return outputs, num_tokens

    def is_finished(self):
        return self.scheduler.is_finished()

    def generate(
        self,
        prompts: list[str] | list[list[int]],
        sampling_params: SamplingParams | list[SamplingParams],
        use_tqdm: bool = True,
    ) -> list[str]:
        if use_tqdm:
            pbar = tqdm(total=len(prompts), desc="Generating tokens (LLM)", leave=False,
                        dynamic_ncols=True, position=self.config.rank + 1)
        if not isinstance(sampling_params, list):
            sampling_params = [sampling_params] * len(prompts)
        for prompt, sp in zip(prompts, sampling_params):
            self.add_request(prompt, sp)
        outputs = {}
        prefill_throughput = decode_throughput = instant_decode_throughput = 0.
        total_decode_tokens = 0
        total_decode_time = 0.
        while not self.is_finished():
            t = perf_counter()
            output, num_tokens = self.step()
            step_time = perf_counter() - t
            if use_tqdm:
                if num_tokens > 0:
                    prefill_throughput = num_tokens / step_time
                else:
                    instant_decode_throughput = -num_tokens / step_time
                    total_decode_tokens += -num_tokens
                    total_decode_time += step_time
                    decode_throughput = total_decode_tokens / total_decode_time if total_decode_time > 0 else 0
                pbar.set_postfix({
                    "Prefill": f"{int(prefill_throughput)}tok/s",
                    "AvgDecode": f"{int(decode_throughput)}tok/s",
                    "InstDecode": f"{int(instant_decode_throughput)}tok/s",
                })
            for seq_id, token_ids in output:
                outputs[seq_id] = token_ids
                if use_tqdm:
                    pbar.update(1)
        outputs = [outputs[seq_id] for seq_id in sorted(outputs)]
        outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
        if use_tqdm:
            pbar.close()
        return outputs