#!/usr/bin/env python3 import torch from torch.utils.data import DataLoader from transformers import ( AutoTokenizer, TrainingArguments, Trainer, default_data_collator, ) from datasets import load_dataset from myolmoe import MyOlmoeForCausalLM, OlmoeConfig import os from transformers import TrainerCallback import subprocess def main(): print("Starting my COOL OLMoE training script for small experts") # Load config - first try from local file, then from pretrained config_path = os.path.join("myolmoe", "config.json") if os.path.exists(config_path): config = OlmoeConfig.from_json_file(config_path) else: config = OlmoeConfig.from_pretrained("myolmoe") # Load model model = MyOlmoeForCausalLM.from_pretrained( "myolmoe", config=config, torch_dtype=torch.bfloat16, device_map="auto", ignore_mismatched_sizes=True ) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained("myolmoe") tokenizer.pad_token = tokenizer.eos_token # Load dataset dataset = load_dataset("allenai/tulu-v2-sft-mixture", split="train") def tokenize_function(examples): texts = [] for message_list in examples["messages"]: formatted = "" for msg in message_list: role = msg["role"] content = msg["content"] if role == "user": formatted += f"User: {content}\n" elif role == "assistant": formatted += f"Assistant: {content}\n" else: formatted += f"{role.capitalize()}: {content}\n" texts.append(formatted) tokenized = tokenizer( texts, truncation=True, max_length=4096, padding="max_length" ) tokenized["labels"] = tokenized["input_ids"].copy() return tokenized tokenized_dataset = dataset.map( tokenize_function, batched=True, remove_columns=dataset.column_names, num_proc=4 ) # Training arguments training_args = TrainingArguments( output_dir="./checkpoints", per_device_train_batch_size=2, gradient_accumulation_steps=8, learning_rate=1e-4, num_train_epochs=3, logging_dir="./logs", logging_steps=10, save_steps=20, save_total_limit=1, bf16=True, gradient_checkpointing=False, # Disabled for now report_to="tensorboard", optim="adamw_torch", lr_scheduler_type="cosine", warmup_ratio=0.1, max_grad_norm=1.0, ) # Freeze all parameters first for param in model.parameters(): param.requires_grad = False # Unfreeze only the small experts and their gating networks trainable_params = [] for name, param in model.named_parameters(): if ( "small_experts" in name or "small_gate" in name ): param.requires_grad = True trainable_params.append(name) ### ADDED: Check if small experts were found if trainable_params: print(f"[INFO] Found {len(trainable_params)} small_expert/small_gate parameters.") else: print("[WARNING] No small_expert or small_gate parameters found in model!") # Verify gradient requirements unfrozen = [name for name, param in model.named_parameters() if param.requires_grad] if unfrozen: print(f"[INFO] {len(unfrozen)} parameters are unfrozen and trainable.") for name in unfrozen: print(f" - {name}") else: print("[ERROR] No parameters were unfrozen! Training will not update anything.") print(f"Total trainable parameters: {len(trainable_params)}") # Verify gradient requirements for name, param in model.named_parameters(): if param.requires_grad: print(f"Parameter {name} requires grad: {param.requires_grad}") # Custom data collator def data_collator(features): batch = default_data_collator(features) batch["output_router_logits"] = True return batch # Fixed CustomTrainer class that handles all possible arguments class CustomTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False, **kwargs): # Remove any unexpected arguments inputs = {k: v for k, v in inputs.items() if k not in ['num_items_in_batch']} # Ensure we're in training mode model.train() # Forward pass with gradients with torch.set_grad_enabled(True): outputs = model(**inputs) loss = outputs.loss if not loss.requires_grad: raise RuntimeError("Loss doesn't require gradients. Check model parameters.") return (loss, outputs) if return_outputs else loss class GitPushCallback(TrainerCallback): def on_save(self, args, state, control, **kwargs): try: print("Saving checkpoint to Git repo...") # Add all changes (you can scope this to ./checkpoints/ if desired) subprocess.run(["git", "add", "."], check=True) # Skip commit if no changes result = subprocess.run(["git", "diff", "--cached", "--quiet"]) if result.returncode == 0: print("No changes to commit.") return subprocess.run(["git", "commit", "-m", f'Checkpoint at step {state.global_step}'], check=True) subprocess.run(["git", "push"], check=True) print("Checkpoint pushed successfully.") except subprocess.CalledProcessError as e: print(f"Git push failed: {e}") class SmallExpertSaveCallback(TrainerCallback): def __init__(self, model, trainable_params): self.model = model self.trainable_params = trainable_params def on_save(self, args, state, control, **kwargs): # Define save path inside the checkpoint dir checkpoint_dir = os.path.join(args.output_dir, f"checkpoint-{state.global_step}") small_expert_path = os.path.join(checkpoint_dir, "small_experts_and_gates.bin") small_expert_state_dict = { name: param for name, param in self.model.named_parameters() if name in self.trainable_params } if small_expert_state_dict: os.makedirs(checkpoint_dir, exist_ok=True) torch.save(small_expert_state_dict, small_expert_path) print(f"[INFO] Saved {len(small_expert_state_dict)} small_expert/small_gate parameters " f"to {small_expert_path}") else: print("[ERROR] No small_expert or small_gate parameters found to save!") # Initialize trainer trainer = CustomTrainer( model=model, args=training_args, train_dataset=tokenized_dataset, data_collator=data_collator, callbacks=[ GitPushCallback(), SmallExpertSaveCallback(model, trainable_params) ] ) # Test forward/backward pass before training print("Testing gradient flow...") test_loader = DataLoader(tokenized_dataset, batch_size=1, collate_fn=data_collator) test_batch = next(iter(test_loader)) # Move batch to model's device device = next(model.parameters()).device test_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in test_batch.items()} model.train() outputs = model(**test_batch) loss = outputs.loss print(f"Initial loss: {loss.item()}") loss.backward() print("Gradients computed successfully") # Check which parameters received gradients for name, param in model.named_parameters(): if param.grad is not None: print(f"Parameter {name} received gradients") # Reset gradients model.zero_grad() # Check for existing checkpoint import re checkpoint_dir = None if os.path.isdir(training_args.output_dir): checkpoints = [ os.path.join(training_args.output_dir, d) for d in os.listdir(training_args.output_dir) if re.match(r"checkpoint-\d+", d) ] if checkpoints: # Extract step numbers and find the highest checkpoint_dir = max(checkpoints, key=lambda x: int(x.split('-')[-1])) print(f"Resuming from checkpoint: {checkpoint_dir}") # Train print("Starting training...") trainer.train(resume_from_checkpoint=checkpoint_dir) # Save only the small experts and gates print("Saving small experts and gates...") small_expert_state_dict = { name: param for name, param in model.named_parameters() if name in trainable_params } os.makedirs("./final_model", exist_ok=True) torch.save(small_expert_state_dict, "./final_model/small_experts_and_gates.bin") config.save_pretrained("./final_model") if __name__ == "__main__": main()