kadirnar commited on
Commit
5f66bcf
·
verified ·
1 Parent(s): 32f28d2

Create audio_tokenizer.py

Browse files
Files changed (1) hide show
  1. audio_tokenizer.py +276 -0
audio_tokenizer.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ import torchaudio.transforms as T
5
+ from datasets import load_dataset
6
+ from huggingface_hub import snapshot_download
7
+ from snac import SNAC
8
+ from transformers import AutoTokenizer
9
+
10
+
11
+ def load_config(config_path):
12
+ """
13
+ Load tokenizer configuration from YAML file.
14
+
15
+ Args:
16
+ config_path: Path to YAML config file
17
+
18
+ Returns:
19
+ Dictionary with configuration values
20
+ """
21
+ with open(config_path, 'r') as f:
22
+ config = yaml.safe_load(f)
23
+ return config
24
+
25
+
26
+ def tokenise_audio(waveform, snac_model, ds_sample_rate, target_sample_rate, audio_tokens_start):
27
+ """
28
+ Tokenize audio waveform using SNAC codec.
29
+
30
+ Args:
31
+ waveform: Audio array from dataset
32
+ snac_model: SNAC model instance
33
+ ds_sample_rate: Original dataset sample rate
34
+ target_sample_rate: Target sample rate (24000)
35
+ audio_tokens_start: Offset for audio tokens
36
+
37
+ Returns:
38
+ List of audio token IDs with proper offsets applied
39
+ """
40
+ # Convert to tensor and prepare for processing
41
+ waveform = torch.from_numpy(waveform).unsqueeze(0)
42
+ waveform = waveform.to(dtype=torch.float32)
43
+
44
+ # Resample to target sample rate if needed
45
+ resample_transform = T.Resample(orig_freq=ds_sample_rate, new_freq=target_sample_rate)
46
+ waveform = resample_transform(waveform)
47
+ waveform = waveform.unsqueeze(0).to("cuda")
48
+
49
+ # Generate SNAC codes
50
+ with torch.inference_mode():
51
+ codes = snac_model.encode(waveform)
52
+
53
+ # Interleave codes from 3 codebooks with proper offsets
54
+ # SNAC uses hierarchical vector quantization with 3 levels
55
+ all_codes = []
56
+ num_frames = codes[0].shape[1]
57
+
58
+ for i in range(num_frames):
59
+ # Level 0: 1 code per frame
60
+ all_codes.append(codes[0][0][i].item() + audio_tokens_start)
61
+
62
+ # Level 1: 2 codes per frame
63
+ all_codes.append(codes[1][0][2*i].item() + audio_tokens_start + 4096)
64
+
65
+ # Level 2: 4 codes per frame
66
+ all_codes.append(codes[2][0][4*i].item() + audio_tokens_start + (2 * 4096))
67
+ all_codes.append(codes[2][0][4*i + 1].item() + audio_tokens_start + (3 * 4096))
68
+
69
+ # Continue level 1 and 2 interleaving
70
+ all_codes.append(codes[1][0][2*i + 1].item() + audio_tokens_start + (4 * 4096))
71
+ all_codes.append(codes[2][0][4*i + 2].item() + audio_tokens_start + (5 * 4096))
72
+ all_codes.append(codes[2][0][4*i + 3].item() + audio_tokens_start + (6 * 4096))
73
+
74
+ return all_codes
75
+
76
+
77
+ def remove_duplicate_frames(codes_list):
78
+ """
79
+ Remove consecutive duplicate audio frames to reduce redundancy.
80
+
81
+ Each frame consists of 7 codes (1 + 2 + 4 from 3 SNAC codebook levels).
82
+ Frames with identical first codes are considered duplicates.
83
+
84
+ Args:
85
+ codes_list: List of audio codes
86
+
87
+ Returns:
88
+ Deduplicated codes list
89
+ """
90
+ if len(codes_list) % 7 != 0:
91
+ raise ValueError("Input list length must be divisible by 7")
92
+
93
+ # Keep first frame
94
+ result = codes_list[:7]
95
+ removed_frames = 0
96
+
97
+ # Check each subsequent frame
98
+ for i in range(7, len(codes_list), 7):
99
+ current_first_code = codes_list[i]
100
+ previous_first_code = result[-7]
101
+
102
+ if current_first_code != previous_first_code:
103
+ result.extend(codes_list[i:i+7])
104
+ else:
105
+ removed_frames += 1
106
+
107
+ return result
108
+
109
+
110
+ def process_dataset(
111
+ original_dataset,
112
+ output_dataset,
113
+ model_type="qwen3",
114
+ text_field="text_scribe",
115
+ target_sample_rate=24000
116
+ ):
117
+ """
118
+ Process dataset: tokenize audio and text, create training sequences.
119
+
120
+ Args:
121
+ original_dataset: HuggingFace dataset path to process
122
+ output_dataset: HuggingFace dataset path for output
123
+ model_type: Model type - either "qwen3" or "lfm2" (default: "qwen3")
124
+ text_field: Name of text field in dataset (default: "text_scribe")
125
+ target_sample_rate: Target audio sample rate (default: 24000)
126
+ """
127
+ # Set tokenizer and config based on model type
128
+ if model_type == "qwen3":
129
+ tokenizer_model = "Qwen/Qwen3-0.6B"
130
+ config_path = "vyvotts/configs/inference/qwen3.yaml"
131
+ elif model_type == "lfm2":
132
+ tokenizer_model = "LiquidAI/LFM2-350M"
133
+ config_path = "vyvotts/configs/inference/lfm2.yaml"
134
+ else:
135
+ raise ValueError(f"Invalid model_type: {model_type}. Must be 'qwen3' or 'lfm2'")
136
+
137
+ # Load configuration
138
+ print(f"Loading config from: {config_path}")
139
+ config = load_config(config_path)
140
+
141
+ TOKENIZER_LENGTH = config['TOKENIZER_LENGTH']
142
+ START_OF_TEXT = config['START_OF_TEXT']
143
+ END_OF_TEXT = config['END_OF_TEXT']
144
+ START_OF_SPEECH = config['START_OF_SPEECH']
145
+ END_OF_SPEECH = config['END_OF_SPEECH']
146
+ START_OF_HUMAN = config['START_OF_HUMAN']
147
+ END_OF_HUMAN = config['END_OF_HUMAN']
148
+ START_OF_AI = config['START_OF_AI']
149
+ END_OF_AI = config['END_OF_AI']
150
+ PAD_TOKEN = config['PAD_TOKEN']
151
+ AUDIO_TOKENS_START = config['AUDIO_TOKENS_START']
152
+
153
+ # Download dataset
154
+ print(f"Downloading dataset: {original_dataset}")
155
+ snapshot_download(
156
+ repo_id=original_dataset,
157
+ repo_type="dataset",
158
+ revision="main",
159
+ max_workers=64,
160
+ )
161
+
162
+ # Load dataset
163
+ print("Loading dataset...")
164
+ ds = load_dataset(original_dataset, split="train")
165
+ ds_sample_rate = ds[0]["audio"]["sampling_rate"]
166
+
167
+ # Load SNAC model
168
+ print("Loading SNAC model: hubertsiuzdak/snac_24khz")
169
+ snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
170
+ snac_model = snac_model.to("cuda")
171
+
172
+ # Define processing functions
173
+ def add_codes(example):
174
+ """Add audio codes to dataset example."""
175
+ codes_list = None
176
+
177
+ try:
178
+ audio_data = example.get("audio")
179
+ if audio_data and "array" in audio_data:
180
+ audio_array = audio_data["array"]
181
+ codes_list = tokenise_audio(
182
+ audio_array,
183
+ snac_model,
184
+ ds_sample_rate,
185
+ target_sample_rate,
186
+ AUDIO_TOKENS_START
187
+ )
188
+ except Exception as e:
189
+ print(f"Skipping row due to error: {e}")
190
+
191
+ example["codes_list"] = codes_list
192
+ return example
193
+
194
+ # Process dataset: tokenize audio
195
+ print("Tokenizing audio...")
196
+ ds = ds.map(add_codes, remove_columns=["audio"])
197
+
198
+ # Load text tokenizer
199
+ print(f"Loading tokenizer: {tokenizer_model}")
200
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_model)
201
+ num_proc = os.cpu_count() - 2
202
+
203
+ # Filter out failed tokenizations
204
+ print("Filtering invalid examples...")
205
+ ds = ds.filter(lambda x: x["codes_list"] is not None)
206
+ ds = ds.filter(lambda x: len(x["codes_list"]) > 0)
207
+
208
+ # Remove duplicate frames
209
+ def remove_duplicate_frames_wrapper(example):
210
+ """Wrapper for remove_duplicate_frames."""
211
+ example["codes_list"] = remove_duplicate_frames(example["codes_list"])
212
+ return example
213
+
214
+ print("Removing duplicate frames...")
215
+ ds = ds.map(remove_duplicate_frames_wrapper, num_proc=num_proc)
216
+
217
+ print(f"""
218
+ NOTE: Text prompt customization
219
+ You can modify the text prompt in create_input_ids() below.
220
+ For multispeaker models, ensure your dataset has a "source" field.
221
+ - Single-speaker: uses example['{text_field}']
222
+ - Multi-speaker: uses example['source']: example['{text_field}']
223
+ """)
224
+
225
+ def create_input_ids(example):
226
+ """
227
+ Create training input sequence with proper formatting.
228
+
229
+ Format: [HUMAN] text [/HUMAN] [AI] [SPEECH] audio_codes [/SPEECH] [/AI]
230
+ """
231
+ # Determine whether to include the source field
232
+ if "source" in example:
233
+ text_prompt = f"{example['source']}: {example[text_field]}"
234
+ else:
235
+ text_prompt = example[text_field]
236
+
237
+ # Tokenize text input
238
+ text_ids = tokenizer.encode(text_prompt, add_special_tokens=True)
239
+ text_ids.append(END_OF_TEXT)
240
+ example["text_tokens"] = text_ids
241
+
242
+ # Construct full sequence with special tokens
243
+ input_ids = (
244
+ [START_OF_HUMAN]
245
+ + example["text_tokens"]
246
+ + [END_OF_HUMAN]
247
+ + [START_OF_AI]
248
+ + [START_OF_SPEECH]
249
+ + example["codes_list"]
250
+ + [END_OF_SPEECH]
251
+ + [END_OF_AI]
252
+ )
253
+
254
+ example["input_ids"] = input_ids
255
+ example["labels"] = input_ids
256
+ example["attention_mask"] = [1] * len(input_ids)
257
+
258
+ return example
259
+
260
+ # Create final training sequences
261
+ print("Creating input sequences...")
262
+ ds = ds.map(
263
+ create_input_ids,
264
+ num_proc=num_proc,
265
+ remove_columns=[text_field, "codes_list"]
266
+ )
267
+
268
+ # Keep only training columns
269
+ columns_to_keep = ["input_ids", "labels", "attention_mask"]
270
+ columns_to_remove = [col for col in ds.column_names if col not in columns_to_keep]
271
+ ds = ds.remove_columns(columns_to_remove)
272
+
273
+ # Upload processed dataset
274
+ print(f"Pushing dataset to: {output_dataset}")
275
+ ds.push_to_hub(output_dataset)
276
+ print("Done!")