|
|
|
|
|
import torch |
|
|
from torch.utils.data import DataLoader |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
TrainingArguments, |
|
|
Trainer, |
|
|
default_data_collator, |
|
|
) |
|
|
from datasets import load_dataset |
|
|
from myolmoe import MyOlmoeForCausalLM, OlmoeConfig |
|
|
import os |
|
|
from transformers import TrainerCallback |
|
|
import subprocess |
|
|
|
|
|
def main(): |
|
|
print("Starting my COOL OLMoE training script for small experts") |
|
|
|
|
|
config_path = os.path.join("myolmoe", "config.json") |
|
|
if os.path.exists(config_path): |
|
|
config = OlmoeConfig.from_json_file(config_path) |
|
|
else: |
|
|
config = OlmoeConfig.from_pretrained("myolmoe") |
|
|
|
|
|
|
|
|
model = MyOlmoeForCausalLM.from_pretrained( |
|
|
"myolmoe", |
|
|
config=config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto", |
|
|
ignore_mismatched_sizes=True |
|
|
) |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("myolmoe") |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
dataset = load_dataset("allenai/tulu-v2-sft-mixture", split="train") |
|
|
|
|
|
def tokenize_function(examples): |
|
|
texts = [] |
|
|
for message_list in examples["messages"]: |
|
|
formatted = "" |
|
|
for msg in message_list: |
|
|
role = msg["role"] |
|
|
content = msg["content"] |
|
|
if role == "user": |
|
|
formatted += f"User: {content}\n" |
|
|
elif role == "assistant": |
|
|
formatted += f"Assistant: {content}\n" |
|
|
else: |
|
|
formatted += f"{role.capitalize()}: {content}\n" |
|
|
texts.append(formatted) |
|
|
|
|
|
tokenized = tokenizer( |
|
|
texts, |
|
|
truncation=True, |
|
|
max_length=4096, |
|
|
padding="max_length" |
|
|
) |
|
|
tokenized["labels"] = tokenized["input_ids"].copy() |
|
|
return tokenized |
|
|
|
|
|
tokenized_dataset = dataset.map( |
|
|
tokenize_function, |
|
|
batched=True, |
|
|
remove_columns=dataset.column_names, |
|
|
num_proc=4 |
|
|
) |
|
|
|
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir="./checkpoints", |
|
|
per_device_train_batch_size=2, |
|
|
gradient_accumulation_steps=8, |
|
|
learning_rate=1e-4, |
|
|
num_train_epochs=3, |
|
|
logging_dir="./logs", |
|
|
logging_steps=10, |
|
|
save_steps=20, |
|
|
save_total_limit=1, |
|
|
bf16=True, |
|
|
gradient_checkpointing=False, |
|
|
report_to="tensorboard", |
|
|
optim="adamw_torch", |
|
|
lr_scheduler_type="cosine", |
|
|
warmup_ratio=0.1, |
|
|
max_grad_norm=1.0, |
|
|
) |
|
|
|
|
|
|
|
|
for param in model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
trainable_params = [] |
|
|
for name, param in model.named_parameters(): |
|
|
if ( |
|
|
"small_experts" in name or |
|
|
"small_gate" in name |
|
|
): |
|
|
param.requires_grad = True |
|
|
trainable_params.append(name) |
|
|
|
|
|
if trainable_params: |
|
|
print(f"[INFO] Found {len(trainable_params)} small_expert/small_gate parameters.") |
|
|
else: |
|
|
print("[WARNING] No small_expert or small_gate parameters found in model!") |
|
|
|
|
|
|
|
|
unfrozen = [name for name, param in model.named_parameters() if param.requires_grad] |
|
|
if unfrozen: |
|
|
print(f"[INFO] {len(unfrozen)} parameters are unfrozen and trainable.") |
|
|
for name in unfrozen: |
|
|
print(f" - {name}") |
|
|
else: |
|
|
print("[ERROR] No parameters were unfrozen! Training will not update anything.") |
|
|
|
|
|
print(f"Total trainable parameters: {len(trainable_params)}") |
|
|
|
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
if param.requires_grad: |
|
|
print(f"Parameter {name} requires grad: {param.requires_grad}") |
|
|
|
|
|
|
|
|
def data_collator(features): |
|
|
batch = default_data_collator(features) |
|
|
batch["output_router_logits"] = True |
|
|
return batch |
|
|
|
|
|
|
|
|
class CustomTrainer(Trainer): |
|
|
def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
|
|
|
|
|
inputs = {k: v for k, v in inputs.items() if k not in ['num_items_in_batch']} |
|
|
|
|
|
|
|
|
model.train() |
|
|
|
|
|
|
|
|
with torch.set_grad_enabled(True): |
|
|
outputs = model(**inputs) |
|
|
loss = outputs.loss |
|
|
|
|
|
if not loss.requires_grad: |
|
|
raise RuntimeError("Loss doesn't require gradients. Check model parameters.") |
|
|
|
|
|
return (loss, outputs) if return_outputs else loss |
|
|
|
|
|
class GitPushCallback(TrainerCallback): |
|
|
def on_save(self, args, state, control, **kwargs): |
|
|
try: |
|
|
print("Saving checkpoint to Git repo...") |
|
|
|
|
|
|
|
|
subprocess.run(["git", "add", "."], check=True) |
|
|
|
|
|
|
|
|
result = subprocess.run(["git", "diff", "--cached", "--quiet"]) |
|
|
if result.returncode == 0: |
|
|
print("No changes to commit.") |
|
|
return |
|
|
|
|
|
subprocess.run(["git", "commit", "-m", f'Checkpoint at step {state.global_step}'], check=True) |
|
|
subprocess.run(["git", "push"], check=True) |
|
|
print("Checkpoint pushed successfully.") |
|
|
except subprocess.CalledProcessError as e: |
|
|
print(f"Git push failed: {e}") |
|
|
class SmallExpertSaveCallback(TrainerCallback): |
|
|
def __init__(self, model, trainable_params): |
|
|
self.model = model |
|
|
self.trainable_params = trainable_params |
|
|
|
|
|
def on_save(self, args, state, control, **kwargs): |
|
|
|
|
|
checkpoint_dir = os.path.join(args.output_dir, f"checkpoint-{state.global_step}") |
|
|
small_expert_path = os.path.join(checkpoint_dir, "small_experts_and_gates.bin") |
|
|
|
|
|
small_expert_state_dict = { |
|
|
name: param for name, param in self.model.named_parameters() |
|
|
if name in self.trainable_params |
|
|
} |
|
|
|
|
|
if small_expert_state_dict: |
|
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
torch.save(small_expert_state_dict, small_expert_path) |
|
|
print(f"[INFO] Saved {len(small_expert_state_dict)} small_expert/small_gate parameters " |
|
|
f"to {small_expert_path}") |
|
|
else: |
|
|
print("[ERROR] No small_expert or small_gate parameters found to save!") |
|
|
|
|
|
|
|
|
trainer = CustomTrainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=tokenized_dataset, |
|
|
data_collator=data_collator, |
|
|
callbacks=[ |
|
|
GitPushCallback(), |
|
|
SmallExpertSaveCallback(model, trainable_params) |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
print("Testing gradient flow...") |
|
|
test_loader = DataLoader(tokenized_dataset, batch_size=1, collate_fn=data_collator) |
|
|
test_batch = next(iter(test_loader)) |
|
|
|
|
|
|
|
|
device = next(model.parameters()).device |
|
|
test_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in test_batch.items()} |
|
|
|
|
|
model.train() |
|
|
outputs = model(**test_batch) |
|
|
loss = outputs.loss |
|
|
print(f"Initial loss: {loss.item()}") |
|
|
|
|
|
loss.backward() |
|
|
print("Gradients computed successfully") |
|
|
|
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
if param.grad is not None: |
|
|
print(f"Parameter {name} received gradients") |
|
|
|
|
|
|
|
|
model.zero_grad() |
|
|
|
|
|
|
|
|
import re |
|
|
|
|
|
checkpoint_dir = None |
|
|
if os.path.isdir(training_args.output_dir): |
|
|
checkpoints = [ |
|
|
os.path.join(training_args.output_dir, d) |
|
|
for d in os.listdir(training_args.output_dir) |
|
|
if re.match(r"checkpoint-\d+", d) |
|
|
] |
|
|
if checkpoints: |
|
|
|
|
|
checkpoint_dir = max(checkpoints, key=lambda x: int(x.split('-')[-1])) |
|
|
print(f"Resuming from checkpoint: {checkpoint_dir}") |
|
|
|
|
|
|
|
|
|
|
|
print("Starting training...") |
|
|
trainer.train(resume_from_checkpoint=checkpoint_dir) |
|
|
|
|
|
|
|
|
print("Saving small experts and gates...") |
|
|
small_expert_state_dict = { |
|
|
name: param for name, param in model.named_parameters() |
|
|
if name in trainable_params |
|
|
} |
|
|
|
|
|
os.makedirs("./final_model", exist_ok=True) |
|
|
torch.save(small_expert_state_dict, "./final_model/small_experts_and_gates.bin") |
|
|
config.save_pretrained("./final_model") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |