# Copyright 2025-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import Optional import torch import transformers from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed from peft import ( WaveFTConfig, get_peft_model, ) def train( base_model: str, data_path: str = "yahma/alpaca-cleaned", output_dir: str = "waveft", batch_size: int = 16, num_epochs: int = 1, learning_rate: float = 3e-4, cutoff_len: int = 256, val_set_size: int = 16, eval_step: int = 100, save_step: int = 100, device_map: str = "auto", waveft_n_frequency: int = 2592, waveft_target_modules: list[str] = None, waveft_scaling: float = 25.0, waveft_wavelet_family: str = "db1", waveft_use_idwt: bool = True, torch_dtype: str = "float16", seed: Optional[int] = None, ): # Set device_map to the right place when enabling DDP. world_size = int(os.environ.get("WORLD_SIZE", 0)) or int(os.environ.get("PMI_SIZE", 0)) if world_size > 1 and device_map != "cpu": from accelerate import Accelerator device_map = {"": Accelerator().process_index} # Set seed if seed is not None: set_seed(seed) model_kwargs = {"dtype": getattr(torch, torch_dtype), "device_map": device_map} model = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs) tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) # For some tokenizer with no pad token like llama if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token def tokenize(prompt, add_eos_token=True): result = tokenizer( prompt, truncation=True, max_length=cutoff_len, padding=False, return_tensors=None, ) if ( result["input_ids"][-1] != tokenizer.eos_token_id and len(result["input_ids"]) < cutoff_len and add_eos_token ): result["input_ids"].append(tokenizer.eos_token_id) result["attention_mask"].append(1) result["labels"] = result["input_ids"].copy() return result def generate_and_tokenize_prompt(example): full_prompt = generate_prompt(example) tokenized_full_prompt = tokenize(full_prompt) return tokenized_full_prompt config = WaveFTConfig( n_frequency=waveft_n_frequency, scaling=waveft_scaling, wavelet_family=waveft_wavelet_family, use_idwt=waveft_use_idwt, target_modules=waveft_target_modules, task_type="CAUSAL_LM", ) model = get_peft_model(model, config) data = load_dataset(data_path) train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42) train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt) val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt) trainer = transformers.Trainer( model=model, train_dataset=train_data, eval_dataset=val_data, args=transformers.TrainingArguments( per_device_train_batch_size=batch_size, warmup_steps=100, num_train_epochs=num_epochs, learning_rate=learning_rate, logging_steps=100, optim="adamw_torch", eval_strategy="steps", save_strategy="steps", eval_steps=eval_step, save_steps=save_step, output_dir=output_dir, save_total_limit=3, load_best_model_at_end=True, ddp_find_unused_parameters=False if world_size > 1 else None, ), data_collator=transformers.DataCollatorForSeq2Seq( tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True ), ) trainer.train() model.save_pretrained(output_dir) def generate_prompt(example): return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: {example["instruction"]} ### Response: {example["output"]}""" if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--base_model", type=str) parser.add_argument("--data_path", type=str, default="yahma/alpaca-cleaned") parser.add_argument("--output_dir", type=str, default="waveft") parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--num_epochs", type=int, default=1) parser.add_argument("--learning_rate", type=float, default=3e-4) parser.add_argument("--cutoff_len", type=int, default=256) parser.add_argument("--val_set_size", type=int, default=16) parser.add_argument("--eval_step", type=int, default=100) parser.add_argument("--save_step", type=int, default=100) parser.add_argument("--device_map", type=str, default="auto") parser.add_argument("--waveft_n_frequency", type=int, default=2592) parser.add_argument("--waveft_target_modules", type=str, default=None) parser.add_argument("--waveft_scaling", type=float, default=25.0) parser.add_argument("--waveft_wavelet_family", type=str, default="db1") parser.add_argument("--waveft_use_idwt", action="store_true", default=True) parser.add_argument("--torch_dtype", type=str, default="float16") parser.add_argument("--seed", type=int, default=None) args = parser.parse_args() train( base_model=args.base_model, data_path=args.data_path, output_dir=args.output_dir, batch_size=args.batch_size, num_epochs=args.num_epochs, learning_rate=args.learning_rate, cutoff_len=args.cutoff_len, val_set_size=args.val_set_size, eval_step=args.eval_step, save_step=args.save_step, device_map=args.device_map, waveft_n_frequency=args.waveft_n_frequency, waveft_target_modules=args.waveft_target_modules, waveft_scaling=args.waveft_scaling, waveft_wavelet_family=args.waveft_wavelet_family, waveft_use_idwt=args.waveft_use_idwt, torch_dtype=args.torch_dtype, seed=args.seed, )