Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import yaml | |
| import torch | |
| import torchaudio.transforms as T | |
| from datasets import load_dataset | |
| from huggingface_hub import snapshot_download | |
| from snac import SNAC | |
| from transformers import AutoTokenizer | |
| def load_config(config_path): | |
| """ | |
| Load tokenizer configuration from YAML file. | |
| Args: | |
| config_path: Path to YAML config file | |
| Returns: | |
| Dictionary with configuration values | |
| """ | |
| with open(config_path, 'r') as f: | |
| config = yaml.safe_load(f) | |
| return config | |
| def tokenise_audio(waveform, snac_model, ds_sample_rate, target_sample_rate, audio_tokens_start): | |
| """ | |
| Tokenize audio waveform using SNAC codec. | |
| Args: | |
| waveform: Audio array from dataset | |
| snac_model: SNAC model instance | |
| ds_sample_rate: Original dataset sample rate | |
| target_sample_rate: Target sample rate (24000) | |
| audio_tokens_start: Offset for audio tokens | |
| Returns: | |
| List of audio token IDs with proper offsets applied | |
| """ | |
| # Convert to tensor and prepare for processing | |
| waveform = torch.from_numpy(waveform).unsqueeze(0) | |
| waveform = waveform.to(dtype=torch.float32) | |
| # Resample to target sample rate if needed | |
| resample_transform = T.Resample(orig_freq=ds_sample_rate, new_freq=target_sample_rate) | |
| waveform = resample_transform(waveform) | |
| waveform = waveform.unsqueeze(0).to("cuda") | |
| # Generate SNAC codes | |
| with torch.inference_mode(): | |
| codes = snac_model.encode(waveform) | |
| # Interleave codes from 3 codebooks with proper offsets | |
| # SNAC uses hierarchical vector quantization with 3 levels | |
| all_codes = [] | |
| num_frames = codes[0].shape[1] | |
| for i in range(num_frames): | |
| # Level 0: 1 code per frame | |
| all_codes.append(codes[0][0][i].item() + audio_tokens_start) | |
| # Level 1: 2 codes per frame | |
| all_codes.append(codes[1][0][2*i].item() + audio_tokens_start + 4096) | |
| # Level 2: 4 codes per frame | |
| all_codes.append(codes[2][0][4*i].item() + audio_tokens_start + (2 * 4096)) | |
| all_codes.append(codes[2][0][4*i + 1].item() + audio_tokens_start + (3 * 4096)) | |
| # Continue level 1 and 2 interleaving | |
| all_codes.append(codes[1][0][2*i + 1].item() + audio_tokens_start + (4 * 4096)) | |
| all_codes.append(codes[2][0][4*i + 2].item() + audio_tokens_start + (5 * 4096)) | |
| all_codes.append(codes[2][0][4*i + 3].item() + audio_tokens_start + (6 * 4096)) | |
| return all_codes | |
| def remove_duplicate_frames(codes_list): | |
| """ | |
| Remove consecutive duplicate audio frames to reduce redundancy. | |
| Each frame consists of 7 codes (1 + 2 + 4 from 3 SNAC codebook levels). | |
| Frames with identical first codes are considered duplicates. | |
| Args: | |
| codes_list: List of audio codes | |
| Returns: | |
| Deduplicated codes list | |
| """ | |
| if len(codes_list) % 7 != 0: | |
| raise ValueError("Input list length must be divisible by 7") | |
| # Keep first frame | |
| result = codes_list[:7] | |
| removed_frames = 0 | |
| # Check each subsequent frame | |
| for i in range(7, len(codes_list), 7): | |
| current_first_code = codes_list[i] | |
| previous_first_code = result[-7] | |
| if current_first_code != previous_first_code: | |
| result.extend(codes_list[i:i+7]) | |
| else: | |
| removed_frames += 1 | |
| return result | |
| def process_dataset( | |
| original_dataset, | |
| output_dataset, | |
| model_type="qwen3", | |
| text_field="text_scribe", | |
| target_sample_rate=24000 | |
| ): | |
| """ | |
| Process dataset: tokenize audio and text, create training sequences. | |
| Args: | |
| original_dataset: HuggingFace dataset path to process | |
| output_dataset: HuggingFace dataset path for output | |
| model_type: Model type - either "qwen3" or "lfm2" (default: "qwen3") | |
| text_field: Name of text field in dataset (default: "text_scribe") | |
| target_sample_rate: Target audio sample rate (default: 24000) | |
| """ | |
| try: | |
| # Set tokenizer and config based on model type | |
| print(f"Setting up configuration for model_type: {model_type}") | |
| if model_type == "qwen3": | |
| tokenizer_model = "Qwen/Qwen3-0.6B" | |
| config_path = "qwen3.yaml" | |
| elif model_type == "lfm2": | |
| tokenizer_model = "LiquidAI/LFM2-350M" | |
| config_path = "lfm2.yaml" | |
| else: | |
| raise ValueError(f"Invalid model_type: {model_type}. Must be 'qwen3' or 'lfm2'") | |
| print(f"Tokenizer Model: {tokenizer_model}") | |
| print(f"Config Path: {config_path}") | |
| # Load configuration | |
| print(f"Loading config from: {config_path}") | |
| config = load_config(config_path) | |
| print(f"Config loaded successfully. Type: {type(config)}") | |
| if not isinstance(config, dict): | |
| raise TypeError(f"Config must be a dictionary, got {type(config)}") | |
| except Exception as e: | |
| print(f"Error in initial setup: {str(e)}") | |
| raise | |
| try: | |
| print("Extracting config values...") | |
| TOKENIZER_LENGTH = config['TOKENIZER_LENGTH'] | |
| START_OF_TEXT = config['START_OF_TEXT'] | |
| END_OF_TEXT = config['END_OF_TEXT'] | |
| START_OF_SPEECH = config['START_OF_SPEECH'] | |
| END_OF_SPEECH = config['END_OF_SPEECH'] | |
| START_OF_HUMAN = config['START_OF_HUMAN'] | |
| END_OF_HUMAN = config['END_OF_HUMAN'] | |
| START_OF_AI = config['START_OF_AI'] | |
| END_OF_AI = config['END_OF_AI'] | |
| PAD_TOKEN = config['PAD_TOKEN'] | |
| AUDIO_TOKENS_START = config['AUDIO_TOKENS_START'] | |
| print("✓ All config values extracted successfully") | |
| except KeyError as e: | |
| print(f"Missing key in config: {e}") | |
| print(f"Available keys: {list(config.keys())}") | |
| raise | |
| except Exception as e: | |
| print(f"Error extracting config values: {str(e)}") | |
| raise | |
| # Download dataset | |
| print(f"Downloading dataset: {original_dataset}") | |
| snapshot_download( | |
| repo_id=original_dataset, | |
| repo_type="dataset", | |
| revision="main", | |
| max_workers=64, | |
| token=os.environ.get("HF_TOKEN") | |
| ) | |
| # Load dataset | |
| print("Loading dataset...") | |
| ds = load_dataset(original_dataset, split="train", token=os.environ.get("HF_TOKEN")) | |
| ds_sample_rate = ds[0]["audio"]["sampling_rate"] | |
| # Load SNAC model | |
| print("Loading SNAC model: hubertsiuzdak/snac_24khz") | |
| snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz") | |
| snac_model = snac_model.to("cuda") | |
| # Define processing functions | |
| def add_codes(example): | |
| """Add audio codes to dataset example.""" | |
| codes_list = None | |
| try: | |
| audio_data = example.get("audio") | |
| if audio_data and "array" in audio_data: | |
| audio_array = audio_data["array"] | |
| codes_list = tokenise_audio( | |
| audio_array, | |
| snac_model, | |
| ds_sample_rate, | |
| target_sample_rate, | |
| AUDIO_TOKENS_START | |
| ) | |
| except Exception as e: | |
| print(f"Skipping row due to error: {e}") | |
| example["codes_list"] = codes_list | |
| return example | |
| # Process dataset: tokenize audio | |
| print("Tokenizing audio...") | |
| ds = ds.map(add_codes, remove_columns=["audio"]) | |
| # Load text tokenizer | |
| print(f"Loading tokenizer: {tokenizer_model}") | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_model) | |
| num_proc = os.cpu_count() - 2 | |
| # Filter out failed tokenizations | |
| print("Filtering invalid examples...") | |
| ds = ds.filter(lambda x: x["codes_list"] is not None) | |
| ds = ds.filter(lambda x: len(x["codes_list"]) > 0) | |
| # Remove duplicate frames | |
| def remove_duplicate_frames_wrapper(example): | |
| """Wrapper for remove_duplicate_frames.""" | |
| example["codes_list"] = remove_duplicate_frames(example["codes_list"]) | |
| return example | |
| print("Removing duplicate frames...") | |
| ds = ds.map(remove_duplicate_frames_wrapper, num_proc=num_proc) | |
| print(f""" | |
| NOTE: Text prompt customization | |
| You can modify the text prompt in create_input_ids() below. | |
| For multispeaker models, ensure your dataset has a "source" field. | |
| - Single-speaker: uses example['{text_field}'] | |
| - Multi-speaker: uses example['source']: example['{text_field}'] | |
| """) | |
| def create_input_ids(example): | |
| """ | |
| Create training input sequence with proper formatting. | |
| Format: [HUMAN] text [/HUMAN] [AI] [SPEECH] audio_codes [/SPEECH] [/AI] | |
| """ | |
| # Determine whether to include the source field | |
| if "source" in example: | |
| text_prompt = f"{example['source']}: {example[text_field]}" | |
| else: | |
| text_prompt = example[text_field] | |
| # Tokenize text input | |
| text_ids = tokenizer.encode(text_prompt, add_special_tokens=True) | |
| text_ids.append(END_OF_TEXT) | |
| example["text_tokens"] = text_ids | |
| # Construct full sequence with special tokens | |
| input_ids = ( | |
| [START_OF_HUMAN] | |
| + example["text_tokens"] | |
| + [END_OF_HUMAN] | |
| + [START_OF_AI] | |
| + [START_OF_SPEECH] | |
| + example["codes_list"] | |
| + [END_OF_SPEECH] | |
| + [END_OF_AI] | |
| ) | |
| example["input_ids"] = input_ids | |
| example["labels"] = input_ids | |
| example["attention_mask"] = [1] * len(input_ids) | |
| return example | |
| # Create final training sequences | |
| print("Creating input sequences...") | |
| ds = ds.map( | |
| create_input_ids, | |
| num_proc=num_proc, | |
| remove_columns=[text_field, "codes_list"] | |
| ) | |
| # Keep only training columns | |
| columns_to_keep = ["input_ids", "labels", "attention_mask"] | |
| columns_to_remove = [col for col in ds.column_names if col not in columns_to_keep] | |
| ds = ds.remove_columns(columns_to_remove) | |
| # Upload processed dataset | |
| print(f"Pushing dataset to: {output_dataset}") | |
| ds.push_to_hub(output_dataset, token=os.environ.get("HF_TOKEN")) | |
| print("Done!") | |