LoRE / scripts /train.py
Charlie81's picture
change save limit to 1
5af07d9
#!/usr/bin/env python3
import torch
from torch.utils.data import DataLoader
from transformers import (
AutoTokenizer,
TrainingArguments,
Trainer,
default_data_collator,
)
from datasets import load_dataset
from myolmoe import MyOlmoeForCausalLM, OlmoeConfig
import os
from transformers import TrainerCallback
import subprocess
def main():
print("Starting my COOL OLMoE training script for small experts")
# Load config - first try from local file, then from pretrained
config_path = os.path.join("myolmoe", "config.json")
if os.path.exists(config_path):
config = OlmoeConfig.from_json_file(config_path)
else:
config = OlmoeConfig.from_pretrained("myolmoe")
# Load model
model = MyOlmoeForCausalLM.from_pretrained(
"myolmoe",
config=config,
torch_dtype=torch.bfloat16,
device_map="auto",
ignore_mismatched_sizes=True
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("myolmoe")
tokenizer.pad_token = tokenizer.eos_token
# Load dataset
dataset = load_dataset("allenai/tulu-v2-sft-mixture", split="train")
def tokenize_function(examples):
texts = []
for message_list in examples["messages"]:
formatted = ""
for msg in message_list:
role = msg["role"]
content = msg["content"]
if role == "user":
formatted += f"User: {content}\n"
elif role == "assistant":
formatted += f"Assistant: {content}\n"
else:
formatted += f"{role.capitalize()}: {content}\n"
texts.append(formatted)
tokenized = tokenizer(
texts,
truncation=True,
max_length=4096,
padding="max_length"
)
tokenized["labels"] = tokenized["input_ids"].copy()
return tokenized
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=dataset.column_names,
num_proc=4
)
# Training arguments
training_args = TrainingArguments(
output_dir="./checkpoints",
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
learning_rate=1e-4,
num_train_epochs=3,
logging_dir="./logs",
logging_steps=10,
save_steps=20,
save_total_limit=1,
bf16=True,
gradient_checkpointing=False, # Disabled for now
report_to="tensorboard",
optim="adamw_torch",
lr_scheduler_type="cosine",
warmup_ratio=0.1,
max_grad_norm=1.0,
)
# Freeze all parameters first
for param in model.parameters():
param.requires_grad = False
# Unfreeze only the small experts and their gating networks
trainable_params = []
for name, param in model.named_parameters():
if (
"small_experts" in name or
"small_gate" in name
):
param.requires_grad = True
trainable_params.append(name)
### ADDED: Check if small experts were found
if trainable_params:
print(f"[INFO] Found {len(trainable_params)} small_expert/small_gate parameters.")
else:
print("[WARNING] No small_expert or small_gate parameters found in model!")
# Verify gradient requirements
unfrozen = [name for name, param in model.named_parameters() if param.requires_grad]
if unfrozen:
print(f"[INFO] {len(unfrozen)} parameters are unfrozen and trainable.")
for name in unfrozen:
print(f" - {name}")
else:
print("[ERROR] No parameters were unfrozen! Training will not update anything.")
print(f"Total trainable parameters: {len(trainable_params)}")
# Verify gradient requirements
for name, param in model.named_parameters():
if param.requires_grad:
print(f"Parameter {name} requires grad: {param.requires_grad}")
# Custom data collator
def data_collator(features):
batch = default_data_collator(features)
batch["output_router_logits"] = True
return batch
# Fixed CustomTrainer class that handles all possible arguments
class CustomTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
# Remove any unexpected arguments
inputs = {k: v for k, v in inputs.items() if k not in ['num_items_in_batch']}
# Ensure we're in training mode
model.train()
# Forward pass with gradients
with torch.set_grad_enabled(True):
outputs = model(**inputs)
loss = outputs.loss
if not loss.requires_grad:
raise RuntimeError("Loss doesn't require gradients. Check model parameters.")
return (loss, outputs) if return_outputs else loss
class GitPushCallback(TrainerCallback):
def on_save(self, args, state, control, **kwargs):
try:
print("Saving checkpoint to Git repo...")
# Add all changes (you can scope this to ./checkpoints/ if desired)
subprocess.run(["git", "add", "."], check=True)
# Skip commit if no changes
result = subprocess.run(["git", "diff", "--cached", "--quiet"])
if result.returncode == 0:
print("No changes to commit.")
return
subprocess.run(["git", "commit", "-m", f'Checkpoint at step {state.global_step}'], check=True)
subprocess.run(["git", "push"], check=True)
print("Checkpoint pushed successfully.")
except subprocess.CalledProcessError as e:
print(f"Git push failed: {e}")
class SmallExpertSaveCallback(TrainerCallback):
def __init__(self, model, trainable_params):
self.model = model
self.trainable_params = trainable_params
def on_save(self, args, state, control, **kwargs):
# Define save path inside the checkpoint dir
checkpoint_dir = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
small_expert_path = os.path.join(checkpoint_dir, "small_experts_and_gates.bin")
small_expert_state_dict = {
name: param for name, param in self.model.named_parameters()
if name in self.trainable_params
}
if small_expert_state_dict:
os.makedirs(checkpoint_dir, exist_ok=True)
torch.save(small_expert_state_dict, small_expert_path)
print(f"[INFO] Saved {len(small_expert_state_dict)} small_expert/small_gate parameters "
f"to {small_expert_path}")
else:
print("[ERROR] No small_expert or small_gate parameters found to save!")
# Initialize trainer
trainer = CustomTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=data_collator,
callbacks=[
GitPushCallback(),
SmallExpertSaveCallback(model, trainable_params)
]
)
# Test forward/backward pass before training
print("Testing gradient flow...")
test_loader = DataLoader(tokenized_dataset, batch_size=1, collate_fn=data_collator)
test_batch = next(iter(test_loader))
# Move batch to model's device
device = next(model.parameters()).device
test_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in test_batch.items()}
model.train()
outputs = model(**test_batch)
loss = outputs.loss
print(f"Initial loss: {loss.item()}")
loss.backward()
print("Gradients computed successfully")
# Check which parameters received gradients
for name, param in model.named_parameters():
if param.grad is not None:
print(f"Parameter {name} received gradients")
# Reset gradients
model.zero_grad()
# Check for existing checkpoint
import re
checkpoint_dir = None
if os.path.isdir(training_args.output_dir):
checkpoints = [
os.path.join(training_args.output_dir, d)
for d in os.listdir(training_args.output_dir)
if re.match(r"checkpoint-\d+", d)
]
if checkpoints:
# Extract step numbers and find the highest
checkpoint_dir = max(checkpoints, key=lambda x: int(x.split('-')[-1]))
print(f"Resuming from checkpoint: {checkpoint_dir}")
# Train
print("Starting training...")
trainer.train(resume_from_checkpoint=checkpoint_dir)
# Save only the small experts and gates
print("Saving small experts and gates...")
small_expert_state_dict = {
name: param for name, param in model.named_parameters()
if name in trainable_params
}
os.makedirs("./final_model", exist_ok=True)
torch.save(small_expert_state_dict, "./final_model/small_experts_and_gates.bin")
config.save_pretrained("./final_model")
if __name__ == "__main__":
main()