|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
import transformers |
|
|
from datasets import load_dataset |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed |
|
|
|
|
|
from peft import ( |
|
|
LoraConfig, |
|
|
get_peft_model, |
|
|
) |
|
|
|
|
|
|
|
|
def train( |
|
|
base_model: str = "path/to/model", |
|
|
data_path: str = "yahma/alpaca-cleaned", |
|
|
output_dir: str = "olora", |
|
|
batch_size: int = 16, |
|
|
num_epochs: int = 1, |
|
|
learning_rate: float = 3e-4, |
|
|
cutoff_len: int = 256, |
|
|
val_set_size: int = 16, |
|
|
quantize: bool = False, |
|
|
eval_step: int = 100, |
|
|
save_step: int = 100, |
|
|
device_map: str = "auto", |
|
|
lora_r: int = 32, |
|
|
lora_alpha: int = 16, |
|
|
lora_dropout: float = 0.05, |
|
|
lora_target_modules: list[str] = None, |
|
|
torch_dtype: str = "float16", |
|
|
init_lora_weights="olora", |
|
|
seed: Optional[int] = None, |
|
|
): |
|
|
|
|
|
world_size = int(os.environ.get("WORLD_SIZE", 0)) or int(os.environ.get("PMI_SIZE", 0)) |
|
|
if world_size > 1 and device_map != "cpu": |
|
|
from accelerate import Accelerator |
|
|
|
|
|
device_map = {"": Accelerator().process_index} |
|
|
|
|
|
if seed is not None: |
|
|
set_seed(seed) |
|
|
model_kwargs = {"torch_dtype": getattr(torch, torch_dtype), "device_map": device_map} |
|
|
if quantize: |
|
|
model_kwargs["quantization_config"] = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
) |
|
|
model = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
def tokenize(prompt, add_eos_token=True): |
|
|
result = tokenizer( |
|
|
prompt, |
|
|
truncation=True, |
|
|
max_length=cutoff_len, |
|
|
padding=False, |
|
|
return_tensors=None, |
|
|
) |
|
|
if ( |
|
|
result["input_ids"][-1] != tokenizer.eos_token_id |
|
|
and len(result["input_ids"]) < cutoff_len |
|
|
and add_eos_token |
|
|
): |
|
|
result["input_ids"].append(tokenizer.eos_token_id) |
|
|
result["attention_mask"].append(1) |
|
|
|
|
|
result["labels"] = result["input_ids"].copy() |
|
|
|
|
|
return result |
|
|
|
|
|
def generate_and_tokenize_prompt(example): |
|
|
full_prompt = generate_prompt(example) |
|
|
tokenized_full_prompt = tokenize(full_prompt) |
|
|
return tokenized_full_prompt |
|
|
|
|
|
config = LoraConfig( |
|
|
r=lora_r, |
|
|
lora_alpha=lora_alpha, |
|
|
target_modules=lora_target_modules, |
|
|
lora_dropout=lora_dropout, |
|
|
bias="none", |
|
|
task_type="CAUSAL_LM", |
|
|
init_lora_weights=init_lora_weights, |
|
|
) |
|
|
model = get_peft_model(model, config) |
|
|
|
|
|
data = load_dataset(data_path) |
|
|
|
|
|
train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42) |
|
|
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt) |
|
|
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt) |
|
|
|
|
|
trainer = transformers.Trainer( |
|
|
model=model, |
|
|
train_dataset=train_data, |
|
|
eval_dataset=val_data, |
|
|
args=transformers.TrainingArguments( |
|
|
per_device_train_batch_size=batch_size, |
|
|
warmup_steps=100, |
|
|
num_train_epochs=num_epochs, |
|
|
learning_rate=learning_rate, |
|
|
logging_steps=100, |
|
|
optim="adamw_torch", |
|
|
eval_strategy="steps", |
|
|
save_strategy="steps", |
|
|
eval_steps=eval_step, |
|
|
save_steps=save_step, |
|
|
output_dir=output_dir, |
|
|
save_total_limit=3, |
|
|
load_best_model_at_end=True, |
|
|
ddp_find_unused_parameters=False if world_size > 1 else None, |
|
|
), |
|
|
data_collator=transformers.DataCollatorForSeq2Seq( |
|
|
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True |
|
|
), |
|
|
) |
|
|
trainer.train() |
|
|
model.save_pretrained(output_dir) |
|
|
|
|
|
|
|
|
def generate_prompt(example): |
|
|
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. |
|
|
### Instruction: |
|
|
{example["instruction"]} |
|
|
### Response: |
|
|
{example["output"]}""" |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--base_model", type=str, default="path/to/model") |
|
|
parser.add_argument("--data_path", type=str, default="yahma/alpaca-cleaned") |
|
|
parser.add_argument("--output_dir", type=str, default="olora") |
|
|
parser.add_argument("--batch_size", type=int, default=16) |
|
|
parser.add_argument("--num_epochs", type=int, default=1) |
|
|
parser.add_argument("--learning_rate", type=float, default=3e-4) |
|
|
parser.add_argument("--cutoff_len", type=int, default=256) |
|
|
parser.add_argument("--val_set_size", type=int, default=16) |
|
|
parser.add_argument("--quantize", action="store_true") |
|
|
parser.add_argument("--eval_step", type=int, default=100) |
|
|
parser.add_argument("--save_step", type=int, default=100) |
|
|
parser.add_argument("--device_map", type=str, default="auto") |
|
|
parser.add_argument("--lora_r", type=int, default=32) |
|
|
parser.add_argument("--lora_alpha", type=int, default=16) |
|
|
parser.add_argument("--lora_dropout", type=float, default=0.05) |
|
|
parser.add_argument("--lora_target_modules", type=str, default=None) |
|
|
parser.add_argument("--torch_dtype", type=str, default="float16") |
|
|
parser.add_argument("--init_lora_weights", type=str, default="olora") |
|
|
parser.add_argument("--seed", type=int, default=None) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
train( |
|
|
base_model=args.base_model, |
|
|
data_path=args.data_path, |
|
|
output_dir=args.output_dir, |
|
|
batch_size=args.batch_size, |
|
|
num_epochs=args.num_epochs, |
|
|
learning_rate=args.learning_rate, |
|
|
cutoff_len=args.cutoff_len, |
|
|
val_set_size=args.val_set_size, |
|
|
quantize=args.quantize, |
|
|
eval_step=args.eval_step, |
|
|
save_step=args.save_step, |
|
|
device_map=args.device_map, |
|
|
lora_r=args.lora_r, |
|
|
lora_alpha=args.lora_alpha, |
|
|
lora_dropout=args.lora_dropout, |
|
|
lora_target_modules=args.lora_target_modules, |
|
|
torch_dtype=args.torch_dtype, |
|
|
init_lora_weights=args.init_lora_weights, |
|
|
seed=args.seed, |
|
|
) |
|
|
|