|
|
import os |
|
|
|
|
|
import torch |
|
|
from datasets import load_dataset |
|
|
from transformers import ( |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
BitsAndBytesConfig, |
|
|
DataCollatorForLanguageModeling, |
|
|
Trainer, |
|
|
TrainingArguments, |
|
|
) |
|
|
|
|
|
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training |
|
|
|
|
|
|
|
|
def train_model( |
|
|
base_model: str, |
|
|
data_path: str, |
|
|
output_dir: str, |
|
|
batch_size: int, |
|
|
num_epochs: int, |
|
|
learning_rate: float, |
|
|
cutoff_len: int, |
|
|
val_set_size: int, |
|
|
invocation_string: str, |
|
|
quantize: bool, |
|
|
eval_step: int, |
|
|
save_step: int, |
|
|
device: str, |
|
|
lora_r: int, |
|
|
lora_alpha: int, |
|
|
lora_dropout: float, |
|
|
lora_target_modules: str, |
|
|
hub_model_id: str, |
|
|
push_to_hub: bool, |
|
|
): |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
|
|
|
device = torch.device(device) |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model, token=hf_token) |
|
|
tokenizer.pad_token = tokenizer.unk_token |
|
|
invocation_tokens = tokenizer.encode(invocation_string, add_special_tokens=False) |
|
|
|
|
|
if quantize: |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
base_model, |
|
|
token=hf_token, |
|
|
quantization_config=BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_compute_dtype=( |
|
|
torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 |
|
|
), |
|
|
bnb_4bit_use_double_quant=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
), |
|
|
) |
|
|
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True) |
|
|
else: |
|
|
model = AutoModelForCausalLM.from_pretrained(base_model, token=hf_token) |
|
|
|
|
|
lora_config = LoraConfig( |
|
|
task_type="CAUSAL_LM", |
|
|
alora_invocation_tokens=invocation_tokens, |
|
|
r=lora_r, |
|
|
lora_alpha=lora_alpha, |
|
|
target_modules=(lora_target_modules.split(",") if lora_target_modules else ["q_proj", "k_proj", "v_proj"]), |
|
|
lora_dropout=lora_dropout, |
|
|
bias="none", |
|
|
) |
|
|
|
|
|
model = get_peft_model(model, lora_config) |
|
|
|
|
|
model.to(device) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
dataset = load_dataset(data_path) |
|
|
|
|
|
def tokenize_function(examples): |
|
|
formatted_texts = [ |
|
|
tokenizer.apply_chat_template( |
|
|
[ |
|
|
{"role": "user", "content": user_msg}, |
|
|
{"role": "assistant", "content": assistant_msg}, |
|
|
], |
|
|
tokenize=False, |
|
|
add_generation_prompt=False, |
|
|
) |
|
|
for user_msg, assistant_msg in zip(examples["input"], examples["output"]) |
|
|
] |
|
|
|
|
|
|
|
|
model_inputs = tokenizer( |
|
|
formatted_texts, |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
max_length=cutoff_len, |
|
|
) |
|
|
|
|
|
labels = [] |
|
|
for ids in model_inputs["input_ids"]: |
|
|
labels.append([(token_id if token_id != tokenizer.pad_token_id else -100) for token_id in ids]) |
|
|
model_inputs["labels"] = labels |
|
|
|
|
|
return model_inputs |
|
|
|
|
|
|
|
|
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names) |
|
|
|
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) |
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=output_dir, |
|
|
num_train_epochs=num_epochs, |
|
|
per_device_train_batch_size=batch_size, |
|
|
per_device_eval_batch_size=batch_size, |
|
|
warmup_steps=100, |
|
|
weight_decay=0.01, |
|
|
logging_dir="./logs", |
|
|
logging_steps=eval_step, |
|
|
save_steps=save_step, |
|
|
save_total_limit=2, |
|
|
push_to_hub=push_to_hub, |
|
|
hub_model_id=hub_model_id, |
|
|
gradient_accumulation_steps=16, |
|
|
fp16=True, |
|
|
learning_rate=learning_rate, |
|
|
hub_token=hf_token, |
|
|
) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=tokenized_datasets["train"], |
|
|
eval_dataset=tokenized_datasets["test"], |
|
|
data_collator=data_collator, |
|
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
if push_to_hub: |
|
|
trainer.push_to_hub(commit_message="Fine-tuned model") |
|
|
|
|
|
model.save_pretrained(output_dir) |
|
|
tokenizer.save_pretrained(output_dir) |
|
|
|
|
|
|
|
|
def model_inference(model_path: str, adapter_path: str, prompt: str = None, data_path: str = None): |
|
|
""" |
|
|
Simple inference with the tuned aLoRA adapter. Optionally (reuse_cache = True) demonstrates |
|
|
that the aLoRA adapter can (but does not need to) use KV cache created by the base model, |
|
|
perhaps during a prior generation turn. |
|
|
|
|
|
Purely for demonstration purposes. See the [paper](https://huggingface.co/papers/2504.12397) |
|
|
for realistic multiturn cache reuse examples. |
|
|
""" |
|
|
if prompt is None: |
|
|
|
|
|
dataset = load_dataset(data_path) |
|
|
prompt = dataset["test"][0]["input"] |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
base_model = AutoModelForCausalLM.from_pretrained(model_path) |
|
|
alora_model = PeftModel.from_pretrained(base_model, adapter_path) |
|
|
chat = [{"role": "user", "content": prompt}] |
|
|
text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) |
|
|
inputs = tokenizer(text, return_tensors="pt").to(base_model.device) |
|
|
|
|
|
|
|
|
output_dict = alora_model.generate(**inputs, return_dict_in_generate=True, max_new_tokens=20) |
|
|
alora_outputs = output_dict.sequences |
|
|
|
|
|
|
|
|
print(f"Prompt: {text}") |
|
|
response = tokenizer.decode(alora_outputs[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True) |
|
|
print(f"Trained adapter response: {response}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Fine-tune Mistral with Activated LoRA") |
|
|
parser.add_argument( |
|
|
"--base_model", type=str, default="mistralai/Mistral-7B-Instruct-v0.3", help="Base model path or name" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--data_path", |
|
|
type=str, |
|
|
default="Lots-of-LoRAs/task1660_super_glue_question_generation", |
|
|
help="Dataset path or name", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output_dir", type=str, default="path/to/output", help="Output directory for the fine-tuned model" |
|
|
) |
|
|
parser.add_argument("--batch_size", type=int, default=2, help="Batch size") |
|
|
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") |
|
|
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate") |
|
|
parser.add_argument("--cutoff_len", type=int, default=2048, help="Cutoff length for tokenization") |
|
|
parser.add_argument("--val_set_size", type=int, default=500, help="Validation set size") |
|
|
parser.add_argument( |
|
|
"--invocation_string", |
|
|
type=str, |
|
|
default="[/INST]", |
|
|
help="String that activates the aLoRA adapter. Model dependent.", |
|
|
) |
|
|
parser.add_argument("--quantize", action="store_true", help="Use quantization") |
|
|
parser.add_argument("--eval_step", type=int, default=10, help="Evaluation step interval") |
|
|
parser.add_argument("--save_step", type=int, default=100, help="Save step interval") |
|
|
parser.add_argument("--device", type=str, default="cuda:0", help="Device to use for training") |
|
|
parser.add_argument("--lora_r", type=int, default=32, help="LoRA rank") |
|
|
parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha") |
|
|
parser.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout rate") |
|
|
parser.add_argument( |
|
|
"--lora_target_modules", type=str, default=None, help="Comma-separated list of target modules for LoRA" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--hub_model_id", |
|
|
type=str, |
|
|
default="path/to/repo", |
|
|
help="Repository name to push the model on the Hugging Face Hub", |
|
|
) |
|
|
parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to Hugging Face Hub") |
|
|
args = parser.parse_args() |
|
|
train_model( |
|
|
base_model=args.base_model, |
|
|
data_path=args.data_path, |
|
|
output_dir=args.output_dir, |
|
|
batch_size=args.batch_size, |
|
|
num_epochs=args.num_epochs, |
|
|
learning_rate=args.learning_rate, |
|
|
cutoff_len=args.cutoff_len, |
|
|
val_set_size=args.val_set_size, |
|
|
invocation_string=args.invocation_string, |
|
|
quantize=args.quantize, |
|
|
eval_step=args.eval_step, |
|
|
save_step=args.save_step, |
|
|
device=args.device, |
|
|
lora_r=args.lora_r, |
|
|
lora_alpha=args.lora_alpha, |
|
|
lora_dropout=args.lora_dropout, |
|
|
lora_target_modules=args.lora_target_modules, |
|
|
hub_model_id=args.hub_model_id, |
|
|
push_to_hub=args.push_to_hub, |
|
|
) |
|
|
print("Model trained. Running test inference.") |
|
|
model_inference(model_path=args.base_model, adapter_path=args.output_dir, data_path=args.data_path) |
|
|
|