hamedhub1 commited on
Commit
3d5dc72
·
verified ·
1 Parent(s): bc7091a

Remove

Files changed (1) hide show
  1. generator.py +0 -174
generator.py DELETED
@@ -1,174 +0,0 @@
1
- import os
2
- from dataclasses import dataclass
3
- from typing import List, Tuple
4
-
5
- import torch
6
- import torchaudio
7
- from huggingface_hub import hf_hub_download
8
- from models import Model
9
- from moshi.models import loaders
10
- from tokenizers.processors import TemplateProcessing
11
- from transformers import AutoTokenizer
12
- from watermarking import load_watermarker, watermark
13
-
14
- CSM_1B_HF_WATERMARK = list(map(int, os.getenv("WATERMARK_KEY").split(" ")))
15
-
16
-
17
- @dataclass
18
- class Segment:
19
- speaker: int
20
- text: str
21
- # (num_samples,), sample_rate = 24_000
22
- audio: torch.Tensor
23
-
24
-
25
- def load_llama3_tokenizer():
26
- """
27
- https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
28
- """
29
- tokenizer_name = "meta-llama/Llama-3.2-1B"
30
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
31
- bos = tokenizer.bos_token
32
- eos = tokenizer.eos_token
33
- tokenizer._tokenizer.post_processor = TemplateProcessing(
34
- single=f"{bos}:0 $A:0 {eos}:0",
35
- pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
36
- special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
37
- )
38
-
39
- return tokenizer
40
-
41
-
42
- class Generator:
43
- def __init__(
44
- self,
45
- model: Model,
46
- ):
47
- self._model = model
48
- self._model.setup_caches(1)
49
-
50
- self._text_tokenizer = load_llama3_tokenizer()
51
-
52
- device = next(model.parameters()).device
53
- mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
54
- mimi = loaders.get_mimi(mimi_weight, device=device)
55
- mimi.set_num_codebooks(32)
56
- self._audio_tokenizer = mimi
57
-
58
- self._watermarker = load_watermarker(device=device)
59
-
60
- self.sample_rate = mimi.sample_rate
61
- self.device = device
62
-
63
- def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]:
64
- frame_tokens = []
65
- frame_masks = []
66
-
67
- text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}")
68
- text_frame = torch.zeros(len(text_tokens), 33).long()
69
- text_frame_mask = torch.zeros(len(text_tokens), 33).bool()
70
- text_frame[:, -1] = torch.tensor(text_tokens)
71
- text_frame_mask[:, -1] = True
72
-
73
- frame_tokens.append(text_frame.to(self.device))
74
- frame_masks.append(text_frame_mask.to(self.device))
75
-
76
- return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
77
-
78
- def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
79
- frame_tokens = []
80
- frame_masks = []
81
-
82
- # (K, T)
83
- audio = audio.to(self.device)
84
- audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
85
- # add EOS frame
86
- eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
87
- audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
88
-
89
- audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device)
90
- audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device)
91
- audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
92
- audio_frame_mask[:, :-1] = True
93
-
94
- frame_tokens.append(audio_frame)
95
- frame_masks.append(audio_frame_mask)
96
-
97
- return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
98
-
99
- def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
100
- """
101
- Returns:
102
- (seq_len, 33), (seq_len, 33)
103
- """
104
- text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
105
- audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
106
-
107
- return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0)
108
-
109
- @torch.inference_mode()
110
- def generate(
111
- self,
112
- text: str,
113
- speaker: int,
114
- context: List[Segment],
115
- max_audio_length_ms: float = 90_000,
116
- temperature: float = 0.9,
117
- topk: int = 50,
118
- ) -> torch.Tensor:
119
- self._model.reset_caches()
120
-
121
- max_audio_frames = int(max_audio_length_ms / 80)
122
- tokens, tokens_mask = [], []
123
- for segment in context:
124
- segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
125
- tokens.append(segment_tokens)
126
- tokens_mask.append(segment_tokens_mask)
127
-
128
- gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker)
129
- tokens.append(gen_segment_tokens)
130
- tokens_mask.append(gen_segment_tokens_mask)
131
-
132
- prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
133
- prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
134
-
135
- samples = []
136
- curr_tokens = prompt_tokens.unsqueeze(0)
137
- curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
138
- curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
139
-
140
- max_seq_len = 2048 - max_audio_frames
141
- if curr_tokens.size(1) >= max_seq_len:
142
- raise ValueError(f"Inputs too long, must be below max_seq_len - max_audio_frames: {max_seq_len}")
143
-
144
- for _ in range(max_audio_frames):
145
- sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
146
- if torch.all(sample == 0):
147
- break # eos
148
-
149
- samples.append(sample)
150
-
151
- curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
152
- curr_tokens_mask = torch.cat(
153
- [torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
154
- ).unsqueeze(1)
155
- curr_pos = curr_pos[:, -1:] + 1
156
-
157
- audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
158
-
159
- # This applies an imperceptible watermark to identify audio as AI-generated.
160
- # Watermarking ensures transparency, dissuades misuse, and enables traceability.
161
- # Please be a responsible AI citizen and keep the watermarking in place.
162
- # If using CSM 1B in another application, use your own private key and keep it secret.
163
- audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_HF_WATERMARK)
164
- audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate)
165
-
166
- return audio
167
-
168
-
169
- def load_csm_1b(device: str = "cuda") -> Generator:
170
- model = Model.from_pretrained("sesame/csm-1b")
171
- model.to(device=device, dtype=torch.bfloat16)
172
-
173
- generator = Generator(model)
174
- return generator