Spaces:
Running
on
Zero
Running
on
Zero
حذف
Browse filesRemove
- generator.py +0 -174
generator.py
DELETED
|
@@ -1,174 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from dataclasses import dataclass
|
| 3 |
-
from typing import List, Tuple
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torchaudio
|
| 7 |
-
from huggingface_hub import hf_hub_download
|
| 8 |
-
from models import Model
|
| 9 |
-
from moshi.models import loaders
|
| 10 |
-
from tokenizers.processors import TemplateProcessing
|
| 11 |
-
from transformers import AutoTokenizer
|
| 12 |
-
from watermarking import load_watermarker, watermark
|
| 13 |
-
|
| 14 |
-
CSM_1B_HF_WATERMARK = list(map(int, os.getenv("WATERMARK_KEY").split(" ")))
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
@dataclass
|
| 18 |
-
class Segment:
|
| 19 |
-
speaker: int
|
| 20 |
-
text: str
|
| 21 |
-
# (num_samples,), sample_rate = 24_000
|
| 22 |
-
audio: torch.Tensor
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def load_llama3_tokenizer():
|
| 26 |
-
"""
|
| 27 |
-
https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
|
| 28 |
-
"""
|
| 29 |
-
tokenizer_name = "meta-llama/Llama-3.2-1B"
|
| 30 |
-
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| 31 |
-
bos = tokenizer.bos_token
|
| 32 |
-
eos = tokenizer.eos_token
|
| 33 |
-
tokenizer._tokenizer.post_processor = TemplateProcessing(
|
| 34 |
-
single=f"{bos}:0 $A:0 {eos}:0",
|
| 35 |
-
pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
|
| 36 |
-
special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
|
| 37 |
-
)
|
| 38 |
-
|
| 39 |
-
return tokenizer
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class Generator:
|
| 43 |
-
def __init__(
|
| 44 |
-
self,
|
| 45 |
-
model: Model,
|
| 46 |
-
):
|
| 47 |
-
self._model = model
|
| 48 |
-
self._model.setup_caches(1)
|
| 49 |
-
|
| 50 |
-
self._text_tokenizer = load_llama3_tokenizer()
|
| 51 |
-
|
| 52 |
-
device = next(model.parameters()).device
|
| 53 |
-
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
|
| 54 |
-
mimi = loaders.get_mimi(mimi_weight, device=device)
|
| 55 |
-
mimi.set_num_codebooks(32)
|
| 56 |
-
self._audio_tokenizer = mimi
|
| 57 |
-
|
| 58 |
-
self._watermarker = load_watermarker(device=device)
|
| 59 |
-
|
| 60 |
-
self.sample_rate = mimi.sample_rate
|
| 61 |
-
self.device = device
|
| 62 |
-
|
| 63 |
-
def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 64 |
-
frame_tokens = []
|
| 65 |
-
frame_masks = []
|
| 66 |
-
|
| 67 |
-
text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}")
|
| 68 |
-
text_frame = torch.zeros(len(text_tokens), 33).long()
|
| 69 |
-
text_frame_mask = torch.zeros(len(text_tokens), 33).bool()
|
| 70 |
-
text_frame[:, -1] = torch.tensor(text_tokens)
|
| 71 |
-
text_frame_mask[:, -1] = True
|
| 72 |
-
|
| 73 |
-
frame_tokens.append(text_frame.to(self.device))
|
| 74 |
-
frame_masks.append(text_frame_mask.to(self.device))
|
| 75 |
-
|
| 76 |
-
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
|
| 77 |
-
|
| 78 |
-
def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 79 |
-
frame_tokens = []
|
| 80 |
-
frame_masks = []
|
| 81 |
-
|
| 82 |
-
# (K, T)
|
| 83 |
-
audio = audio.to(self.device)
|
| 84 |
-
audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
|
| 85 |
-
# add EOS frame
|
| 86 |
-
eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
|
| 87 |
-
audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
|
| 88 |
-
|
| 89 |
-
audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device)
|
| 90 |
-
audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device)
|
| 91 |
-
audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
|
| 92 |
-
audio_frame_mask[:, :-1] = True
|
| 93 |
-
|
| 94 |
-
frame_tokens.append(audio_frame)
|
| 95 |
-
frame_masks.append(audio_frame_mask)
|
| 96 |
-
|
| 97 |
-
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
|
| 98 |
-
|
| 99 |
-
def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 100 |
-
"""
|
| 101 |
-
Returns:
|
| 102 |
-
(seq_len, 33), (seq_len, 33)
|
| 103 |
-
"""
|
| 104 |
-
text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
|
| 105 |
-
audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
|
| 106 |
-
|
| 107 |
-
return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0)
|
| 108 |
-
|
| 109 |
-
@torch.inference_mode()
|
| 110 |
-
def generate(
|
| 111 |
-
self,
|
| 112 |
-
text: str,
|
| 113 |
-
speaker: int,
|
| 114 |
-
context: List[Segment],
|
| 115 |
-
max_audio_length_ms: float = 90_000,
|
| 116 |
-
temperature: float = 0.9,
|
| 117 |
-
topk: int = 50,
|
| 118 |
-
) -> torch.Tensor:
|
| 119 |
-
self._model.reset_caches()
|
| 120 |
-
|
| 121 |
-
max_audio_frames = int(max_audio_length_ms / 80)
|
| 122 |
-
tokens, tokens_mask = [], []
|
| 123 |
-
for segment in context:
|
| 124 |
-
segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
|
| 125 |
-
tokens.append(segment_tokens)
|
| 126 |
-
tokens_mask.append(segment_tokens_mask)
|
| 127 |
-
|
| 128 |
-
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker)
|
| 129 |
-
tokens.append(gen_segment_tokens)
|
| 130 |
-
tokens_mask.append(gen_segment_tokens_mask)
|
| 131 |
-
|
| 132 |
-
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
|
| 133 |
-
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
|
| 134 |
-
|
| 135 |
-
samples = []
|
| 136 |
-
curr_tokens = prompt_tokens.unsqueeze(0)
|
| 137 |
-
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
|
| 138 |
-
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
|
| 139 |
-
|
| 140 |
-
max_seq_len = 2048 - max_audio_frames
|
| 141 |
-
if curr_tokens.size(1) >= max_seq_len:
|
| 142 |
-
raise ValueError(f"Inputs too long, must be below max_seq_len - max_audio_frames: {max_seq_len}")
|
| 143 |
-
|
| 144 |
-
for _ in range(max_audio_frames):
|
| 145 |
-
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
|
| 146 |
-
if torch.all(sample == 0):
|
| 147 |
-
break # eos
|
| 148 |
-
|
| 149 |
-
samples.append(sample)
|
| 150 |
-
|
| 151 |
-
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
|
| 152 |
-
curr_tokens_mask = torch.cat(
|
| 153 |
-
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
|
| 154 |
-
).unsqueeze(1)
|
| 155 |
-
curr_pos = curr_pos[:, -1:] + 1
|
| 156 |
-
|
| 157 |
-
audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
|
| 158 |
-
|
| 159 |
-
# This applies an imperceptible watermark to identify audio as AI-generated.
|
| 160 |
-
# Watermarking ensures transparency, dissuades misuse, and enables traceability.
|
| 161 |
-
# Please be a responsible AI citizen and keep the watermarking in place.
|
| 162 |
-
# If using CSM 1B in another application, use your own private key and keep it secret.
|
| 163 |
-
audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_HF_WATERMARK)
|
| 164 |
-
audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate)
|
| 165 |
-
|
| 166 |
-
return audio
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
def load_csm_1b(device: str = "cuda") -> Generator:
|
| 170 |
-
model = Model.from_pretrained("sesame/csm-1b")
|
| 171 |
-
model.to(device=device, dtype=torch.bfloat16)
|
| 172 |
-
|
| 173 |
-
generator = Generator(model)
|
| 174 |
-
return generator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|