Spaces:
Running
Running
adds defensive programming (boo) and adaptations based on transformer versions
Browse files
scripts/training/train_gpt_oss.py
CHANGED
|
@@ -8,6 +8,7 @@ Based on the GPT-OSS fine-tuning tutorial
|
|
| 8 |
import os
|
| 9 |
import sys
|
| 10 |
import argparse
|
|
|
|
| 11 |
import torch
|
| 12 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
|
| 13 |
from peft import LoraConfig, get_peft_model
|
|
@@ -386,62 +387,72 @@ def create_sft_config(config, output_dir):
|
|
| 386 |
print(f" • Gradient accumulation: {gradient_accumulation_steps}")
|
| 387 |
print(f" • Effective batch size: {per_device_train_batch_size * gradient_accumulation_steps}")
|
| 388 |
|
| 389 |
-
|
|
|
|
| 390 |
# Training duration
|
| 391 |
-
num_train_epochs
|
| 392 |
-
max_steps
|
| 393 |
-
|
| 394 |
# Learning rate
|
| 395 |
-
learning_rate
|
| 396 |
-
lr_scheduler_type
|
| 397 |
-
warmup_ratio
|
| 398 |
-
warmup_steps
|
| 399 |
-
|
| 400 |
# Batch configuration
|
| 401 |
-
per_device_train_batch_size
|
| 402 |
-
per_device_eval_batch_size
|
| 403 |
-
gradient_accumulation_steps
|
| 404 |
-
|
| 405 |
# Model configuration
|
| 406 |
-
gradient_checkpointing
|
| 407 |
-
|
| 408 |
# Mixed precision
|
| 409 |
-
fp16
|
| 410 |
-
bf16
|
| 411 |
-
|
| 412 |
# Regularization
|
| 413 |
-
weight_decay
|
| 414 |
-
max_grad_norm
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
eval_steps=eval_steps,
|
| 419 |
-
|
| 420 |
# Logging
|
| 421 |
-
logging_steps
|
| 422 |
-
|
| 423 |
# Saving
|
| 424 |
-
save_strategy
|
| 425 |
-
save_steps
|
| 426 |
-
save_total_limit
|
| 427 |
-
|
| 428 |
# Output
|
| 429 |
-
output_dir
|
| 430 |
-
|
| 431 |
# Data loading
|
| 432 |
-
dataloader_num_workers
|
| 433 |
-
dataloader_pin_memory
|
| 434 |
-
|
| 435 |
# Performance
|
| 436 |
-
group_by_length
|
| 437 |
-
remove_unused_columns
|
| 438 |
-
|
| 439 |
# HuggingFace Hub
|
| 440 |
-
push_to_hub
|
| 441 |
-
|
| 442 |
# Monitoring
|
| 443 |
-
report_to
|
| 444 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
|
| 446 |
return sft_config
|
| 447 |
|
|
|
|
| 8 |
import os
|
| 9 |
import sys
|
| 10 |
import argparse
|
| 11 |
+
import inspect
|
| 12 |
import torch
|
| 13 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
|
| 14 |
from peft import LoraConfig, get_peft_model
|
|
|
|
| 387 |
print(f" • Gradient accumulation: {gradient_accumulation_steps}")
|
| 388 |
print(f" • Effective batch size: {per_device_train_batch_size * gradient_accumulation_steps}")
|
| 389 |
|
| 390 |
+
# Build kwargs dynamically to be compatible across transformers versions
|
| 391 |
+
ta_kwargs = {
|
| 392 |
# Training duration
|
| 393 |
+
"num_train_epochs": num_train_epochs,
|
| 394 |
+
"max_steps": max_steps,
|
|
|
|
| 395 |
# Learning rate
|
| 396 |
+
"learning_rate": learning_rate,
|
| 397 |
+
"lr_scheduler_type": lr_scheduler_type,
|
| 398 |
+
"warmup_ratio": warmup_ratio,
|
| 399 |
+
"warmup_steps": warmup_steps,
|
|
|
|
| 400 |
# Batch configuration
|
| 401 |
+
"per_device_train_batch_size": per_device_train_batch_size,
|
| 402 |
+
"per_device_eval_batch_size": per_device_eval_batch_size,
|
| 403 |
+
"gradient_accumulation_steps": gradient_accumulation_steps,
|
|
|
|
| 404 |
# Model configuration
|
| 405 |
+
"gradient_checkpointing": getattr(config, 'use_gradient_checkpointing', True),
|
|
|
|
| 406 |
# Mixed precision
|
| 407 |
+
"fp16": fp16,
|
| 408 |
+
"bf16": bf16,
|
|
|
|
| 409 |
# Regularization
|
| 410 |
+
"weight_decay": weight_decay,
|
| 411 |
+
"max_grad_norm": max_grad_norm,
|
| 412 |
+
# Evaluation (name may vary across versions)
|
| 413 |
+
"evaluation_strategy": eval_strategy,
|
| 414 |
+
"eval_steps": eval_steps,
|
|
|
|
|
|
|
| 415 |
# Logging
|
| 416 |
+
"logging_steps": logging_steps,
|
|
|
|
| 417 |
# Saving
|
| 418 |
+
"save_strategy": save_strategy,
|
| 419 |
+
"save_steps": save_steps,
|
| 420 |
+
"save_total_limit": save_total_limit,
|
|
|
|
| 421 |
# Output
|
| 422 |
+
"output_dir": output_dir,
|
|
|
|
| 423 |
# Data loading
|
| 424 |
+
"dataloader_num_workers": getattr(config, 'dataloader_num_workers', 4),
|
| 425 |
+
"dataloader_pin_memory": getattr(config, 'dataloader_pin_memory', True),
|
|
|
|
| 426 |
# Performance
|
| 427 |
+
"group_by_length": getattr(config, 'group_by_length', True),
|
| 428 |
+
"remove_unused_columns": getattr(config, 'remove_unused_columns', True),
|
|
|
|
| 429 |
# HuggingFace Hub
|
| 430 |
+
"push_to_hub": push_to_hub,
|
|
|
|
| 431 |
# Monitoring
|
| 432 |
+
"report_to": ("trackio" if getattr(config, 'enable_tracking', False) else None),
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
# Adapt to transformers versions where 'evaluation_strategy' was renamed
|
| 436 |
+
try:
|
| 437 |
+
ta_sig = inspect.signature(TrainingArguments.__init__)
|
| 438 |
+
param_names = set(ta_sig.parameters.keys())
|
| 439 |
+
except Exception:
|
| 440 |
+
param_names = set()
|
| 441 |
+
|
| 442 |
+
if "evaluation_strategy" not in param_names and "eval_strategy" in param_names:
|
| 443 |
+
# Move value to 'eval_strategy'
|
| 444 |
+
ta_kwargs["eval_strategy"] = ta_kwargs.pop("evaluation_strategy")
|
| 445 |
+
elif "evaluation_strategy" not in param_names:
|
| 446 |
+
# If neither is supported, drop it
|
| 447 |
+
ta_kwargs.pop("evaluation_strategy", None)
|
| 448 |
+
|
| 449 |
+
# Remove any kwargs not supported by current transformers version
|
| 450 |
+
if param_names:
|
| 451 |
+
unsupported = [k for k in ta_kwargs.keys() if k not in param_names]
|
| 452 |
+
for k in unsupported:
|
| 453 |
+
ta_kwargs.pop(k, None)
|
| 454 |
+
|
| 455 |
+
sft_config = TrainingArguments(**ta_kwargs)
|
| 456 |
|
| 457 |
return sft_config
|
| 458 |
|