Spaces:
Running
Running
adds scheduler stuff and hopes for the best with track tonic
Browse files
config/train_gpt_oss_custom.py
CHANGED
|
@@ -59,7 +59,7 @@ class GPTOSSEnhancedCustomConfig:
|
|
| 59 |
# ============================================================================
|
| 60 |
# SCHEDULER CONFIGURATION
|
| 61 |
# ============================================================================
|
| 62 |
-
scheduler: str = "
|
| 63 |
lr_scheduler_kwargs: Optional[Dict] = None
|
| 64 |
|
| 65 |
# ============================================================================
|
|
@@ -299,7 +299,8 @@ class GPTOSSEnhancedCustomConfig:
|
|
| 299 |
# SCHEDULER CONFIGURATION DEFAULTS
|
| 300 |
# ============================================================================
|
| 301 |
if self.lr_scheduler_kwargs is None:
|
| 302 |
-
|
|
|
|
| 303 |
|
| 304 |
# ============================================================================
|
| 305 |
# CHAT TEMPLATE CONFIGURATION DEFAULTS (GPT-OSS Harmony Format)
|
|
|
|
| 59 |
# ============================================================================
|
| 60 |
# SCHEDULER CONFIGURATION
|
| 61 |
# ============================================================================
|
| 62 |
+
scheduler: str = "cosine" # Default to broadly compatible scheduler; TRL special is opt-in
|
| 63 |
lr_scheduler_kwargs: Optional[Dict] = None
|
| 64 |
|
| 65 |
# ============================================================================
|
|
|
|
| 299 |
# SCHEDULER CONFIGURATION DEFAULTS
|
| 300 |
# ============================================================================
|
| 301 |
if self.lr_scheduler_kwargs is None:
|
| 302 |
+
# Leave empty; training script will add TRL-specific keys only when needed
|
| 303 |
+
self.lr_scheduler_kwargs = {}
|
| 304 |
|
| 305 |
# ============================================================================
|
| 306 |
# CHAT TEMPLATE CONFIGURATION DEFAULTS (GPT-OSS Harmony Format)
|
config/train_gpt_oss_medical_o1_sft.py
CHANGED
|
@@ -65,6 +65,10 @@ config = GPTOSSEnhancedCustomConfig(
|
|
| 65 |
warmup_steps=50,
|
| 66 |
max_grad_norm=1.0,
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
# Sequence length
|
| 69 |
max_seq_length=2048,
|
| 70 |
|
|
|
|
| 65 |
warmup_steps=50,
|
| 66 |
max_grad_norm=1.0,
|
| 67 |
|
| 68 |
+
# Scheduler: use broadly compatible cosine by default to avoid TRL signature issues
|
| 69 |
+
scheduler="cosine",
|
| 70 |
+
lr_scheduler_kwargs={},
|
| 71 |
+
|
| 72 |
# Sequence length
|
| 73 |
max_seq_length=2048,
|
| 74 |
|
config/train_gpt_oss_openhermes_fr_memory_optimized.py
CHANGED
|
@@ -193,11 +193,9 @@ config = GPTOSSEnhancedCustomConfig(
|
|
| 193 |
beta2=0.95, # GPT-OSS optimized beta2
|
| 194 |
eps=1e-8,
|
| 195 |
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
"warmup_steps": None, # Use warmup_ratio instead
|
| 200 |
-
},
|
| 201 |
|
| 202 |
# Packing to increase token utilization per step (supported by TRL)
|
| 203 |
packing=True,
|
|
|
|
| 193 |
beta2=0.95, # GPT-OSS optimized beta2
|
| 194 |
eps=1e-8,
|
| 195 |
|
| 196 |
+
# Use standard cosine for broad compatibility; TRL min-lr scheduler is optional
|
| 197 |
+
scheduler="cosine",
|
| 198 |
+
lr_scheduler_kwargs={},
|
|
|
|
|
|
|
| 199 |
|
| 200 |
# Packing to increase token utilization per step (supported by TRL)
|
| 201 |
packing=True,
|
requirements/requirements.txt
CHANGED
|
@@ -19,7 +19,8 @@ numpy>=1.24.0
|
|
| 19 |
tqdm>=4.65.0
|
| 20 |
|
| 21 |
# Experiment tracking
|
| 22 |
-
trackio>=0.1.0
|
|
|
|
| 23 |
|
| 24 |
# Optional: for evaluation (commented out to reduce conflicts)
|
| 25 |
# lighteval>=0.1.0
|
|
|
|
| 19 |
tqdm>=4.65.0
|
| 20 |
|
| 21 |
# Experiment tracking
|
| 22 |
+
# trackio>=0.1.0
|
| 23 |
+
gradio>=5.0.0
|
| 24 |
|
| 25 |
# Optional: for evaluation (commented out to reduce conflicts)
|
| 26 |
# lighteval>=0.1.0
|
requirements/requirements_minimal.txt
CHANGED
|
@@ -10,5 +10,4 @@ tokenizers>=0.13.0
|
|
| 10 |
bitsandbytes>=0.41.0
|
| 11 |
numpy>=1.24.0
|
| 12 |
tqdm>=4.65.0
|
| 13 |
-
trackio>=0.1.0
|
| 14 |
psutil>=5.9.0
|
|
|
|
| 10 |
bitsandbytes>=0.41.0
|
| 11 |
numpy>=1.24.0
|
| 12 |
tqdm>=4.65.0
|
|
|
|
| 13 |
psutil>=5.9.0
|
scripts/training/train_gpt_oss.py
CHANGED
|
@@ -191,12 +191,26 @@ def load_dataset_from_config(config):
|
|
| 191 |
return dataset
|
| 192 |
|
| 193 |
def build_scheduler_kwargs(config):
|
| 194 |
-
"""Construct lr_scheduler_kwargs
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
| 196 |
"""
|
| 197 |
skw = getattr(config, 'lr_scheduler_kwargs', {}) or {}
|
| 198 |
if not isinstance(skw, dict):
|
| 199 |
skw = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
min_lr_cfg = getattr(config, 'min_lr', 1e-6)
|
| 201 |
if 'min_lr' not in skw and 'min_lr_rate' not in skw:
|
| 202 |
try:
|
|
@@ -206,6 +220,7 @@ def build_scheduler_kwargs(config):
|
|
| 206 |
skw['min_lr_rate'] = 0.1
|
| 207 |
except Exception:
|
| 208 |
skw['min_lr_rate'] = 0.001
|
|
|
|
| 209 |
# Remove warmup-related keys which conflict with some TRL schedulers
|
| 210 |
for k in ('warmup_steps', 'num_warmup_steps', 'warmup_ratio'):
|
| 211 |
if k in skw:
|
|
@@ -683,7 +698,8 @@ def create_sft_config(config, output_dir):
|
|
| 683 |
|
| 684 |
# Learning rate configuration
|
| 685 |
learning_rate = _as_float(getattr(config, 'learning_rate', 2e-4), 2e-4)
|
| 686 |
-
|
|
|
|
| 687 |
lr_scheduler_kwargs = build_scheduler_kwargs(config)
|
| 688 |
|
| 689 |
# Detect TRL scheduler signature incompatibilities and fall back gracefully
|
|
@@ -865,6 +881,57 @@ def train_gpt_oss(config_path, experiment_name, output_dir, trackio_url, trainer
|
|
| 865 |
config.experiment_name = experiment_name
|
| 866 |
config.trackio_url = trackio_url
|
| 867 |
config.trainer_type = trainer_type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 868 |
|
| 869 |
# Load model and tokenizer
|
| 870 |
model, tokenizer = load_gpt_oss_model_and_tokenizer(config)
|
|
@@ -1027,6 +1094,24 @@ def main():
|
|
| 1027 |
parser.add_argument("--output-dir", required=True, help="Output directory for checkpoints")
|
| 1028 |
parser.add_argument("--trackio-url", help="Trackio URL for monitoring")
|
| 1029 |
parser.add_argument("--trainer-type", default="sft", choices=["sft", "dpo"], help="Trainer type")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1030 |
|
| 1031 |
args = parser.parse_args()
|
| 1032 |
|
|
@@ -1039,7 +1124,16 @@ def main():
|
|
| 1039 |
os.makedirs(args.output_dir, exist_ok=True)
|
| 1040 |
|
| 1041 |
try:
|
| 1042 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1043 |
config_path=args.config,
|
| 1044 |
experiment_name=args.experiment_name,
|
| 1045 |
output_dir=args.output_dir,
|
|
|
|
| 191 |
return dataset
|
| 192 |
|
| 193 |
def build_scheduler_kwargs(config):
|
| 194 |
+
"""Construct lr_scheduler_kwargs compatibly across TRL/Transformers versions.
|
| 195 |
+
|
| 196 |
+
- For TRL's 'cosine_with_min_lr' scheduler, ensure a min_lr/min_lr_rate is set.
|
| 197 |
+
- For all other schedulers, strip TRL-specific keys to avoid unexpected kwargs
|
| 198 |
+
errors in Transformers' native schedulers.
|
| 199 |
"""
|
| 200 |
skw = getattr(config, 'lr_scheduler_kwargs', {}) or {}
|
| 201 |
if not isinstance(skw, dict):
|
| 202 |
skw = {}
|
| 203 |
+
|
| 204 |
+
scheduler_type = getattr(config, 'scheduler', None)
|
| 205 |
+
|
| 206 |
+
# If we're NOT using TRL's special scheduler, drop incompatible keys early
|
| 207 |
+
if scheduler_type != 'cosine_with_min_lr':
|
| 208 |
+
for k in ('min_lr', 'min_lr_rate', 'warmup_steps', 'num_warmup_steps', 'warmup_ratio'):
|
| 209 |
+
if k in skw:
|
| 210 |
+
skw.pop(k, None)
|
| 211 |
+
return skw
|
| 212 |
+
|
| 213 |
+
# TRL cosine-with-min-lr: ensure one of min_lr or min_lr_rate is provided
|
| 214 |
min_lr_cfg = getattr(config, 'min_lr', 1e-6)
|
| 215 |
if 'min_lr' not in skw and 'min_lr_rate' not in skw:
|
| 216 |
try:
|
|
|
|
| 220 |
skw['min_lr_rate'] = 0.1
|
| 221 |
except Exception:
|
| 222 |
skw['min_lr_rate'] = 0.001
|
| 223 |
+
|
| 224 |
# Remove warmup-related keys which conflict with some TRL schedulers
|
| 225 |
for k in ('warmup_steps', 'num_warmup_steps', 'warmup_ratio'):
|
| 226 |
if k in skw:
|
|
|
|
| 698 |
|
| 699 |
# Learning rate configuration
|
| 700 |
learning_rate = _as_float(getattr(config, 'learning_rate', 2e-4), 2e-4)
|
| 701 |
+
# Allow CLI/env override of scheduler
|
| 702 |
+
lr_scheduler_type = os.environ.get('GPT_OSS_SCHEDULER', getattr(config, 'scheduler', 'cosine'))
|
| 703 |
lr_scheduler_kwargs = build_scheduler_kwargs(config)
|
| 704 |
|
| 705 |
# Detect TRL scheduler signature incompatibilities and fall back gracefully
|
|
|
|
| 881 |
config.experiment_name = experiment_name
|
| 882 |
config.trackio_url = trackio_url
|
| 883 |
config.trainer_type = trainer_type
|
| 884 |
+
|
| 885 |
+
# Optional: scheduler overrides via environment variables set by CLI
|
| 886 |
+
try:
|
| 887 |
+
env_scheduler = os.environ.get("GPT_OSS_SCHEDULER")
|
| 888 |
+
if env_scheduler:
|
| 889 |
+
# Apply scheduler override
|
| 890 |
+
config.scheduler = env_scheduler
|
| 891 |
+
# Prepare/normalize lr scheduler kwargs container
|
| 892 |
+
if not hasattr(config, 'lr_scheduler_kwargs') or config.lr_scheduler_kwargs is None:
|
| 893 |
+
config.lr_scheduler_kwargs = {}
|
| 894 |
+
|
| 895 |
+
# Apply min lr overrides only when using TRL's special scheduler
|
| 896 |
+
if env_scheduler == 'cosine_with_min_lr':
|
| 897 |
+
env_min_lr = os.environ.get("GPT_OSS_MIN_LR")
|
| 898 |
+
env_min_lr_rate = os.environ.get("GPT_OSS_MIN_LR_RATE")
|
| 899 |
+
# Clear conflicting warmup keys to avoid signature issues
|
| 900 |
+
for k in ('warmup_steps', 'num_warmup_steps', 'warmup_ratio'):
|
| 901 |
+
if k in config.lr_scheduler_kwargs:
|
| 902 |
+
config.lr_scheduler_kwargs.pop(k, None)
|
| 903 |
+
# Prefer absolute min_lr if provided
|
| 904 |
+
if env_min_lr is not None:
|
| 905 |
+
try:
|
| 906 |
+
config.min_lr = float(env_min_lr)
|
| 907 |
+
config.lr_scheduler_kwargs['min_lr'] = config.min_lr
|
| 908 |
+
# Remove relative rate if present
|
| 909 |
+
config.lr_scheduler_kwargs.pop('min_lr_rate', None)
|
| 910 |
+
except Exception:
|
| 911 |
+
pass
|
| 912 |
+
elif env_min_lr_rate is not None:
|
| 913 |
+
try:
|
| 914 |
+
config.lr_scheduler_kwargs['min_lr_rate'] = float(env_min_lr_rate)
|
| 915 |
+
# Remove absolute min_lr if present in kwargs (leave config.min_lr untouched)
|
| 916 |
+
config.lr_scheduler_kwargs.pop('min_lr', None)
|
| 917 |
+
except Exception:
|
| 918 |
+
pass
|
| 919 |
+
else:
|
| 920 |
+
# Ensure at least one constraint exists; prefer absolute from config if valid
|
| 921 |
+
try:
|
| 922 |
+
if hasattr(config, 'min_lr') and config.min_lr is not None:
|
| 923 |
+
config.lr_scheduler_kwargs['min_lr'] = float(config.min_lr)
|
| 924 |
+
else:
|
| 925 |
+
config.lr_scheduler_kwargs.setdefault('min_lr_rate', 0.1)
|
| 926 |
+
except Exception:
|
| 927 |
+
config.lr_scheduler_kwargs.setdefault('min_lr_rate', 0.1)
|
| 928 |
+
else:
|
| 929 |
+
# Non-TRL scheduler: strip TRL-specific keys to avoid unexpected kwargs
|
| 930 |
+
if hasattr(config, 'lr_scheduler_kwargs') and isinstance(config.lr_scheduler_kwargs, dict):
|
| 931 |
+
for k in ('min_lr', 'min_lr_rate'):
|
| 932 |
+
config.lr_scheduler_kwargs.pop(k, None)
|
| 933 |
+
except Exception:
|
| 934 |
+
pass
|
| 935 |
|
| 936 |
# Load model and tokenizer
|
| 937 |
model, tokenizer = load_gpt_oss_model_and_tokenizer(config)
|
|
|
|
| 1094 |
parser.add_argument("--output-dir", required=True, help="Output directory for checkpoints")
|
| 1095 |
parser.add_argument("--trackio-url", help="Trackio URL for monitoring")
|
| 1096 |
parser.add_argument("--trainer-type", default="sft", choices=["sft", "dpo"], help="Trainer type")
|
| 1097 |
+
# Optional LR scheduler overrides (applied across any GPT-OSS config)
|
| 1098 |
+
parser.add_argument(
|
| 1099 |
+
"--scheduler",
|
| 1100 |
+
choices=["linear", "cosine", "cosine_with_min_lr", "constant"],
|
| 1101 |
+
help="Override LR scheduler for this run",
|
| 1102 |
+
)
|
| 1103 |
+
parser.add_argument(
|
| 1104 |
+
"--min-lr",
|
| 1105 |
+
type=float,
|
| 1106 |
+
dest="min_lr",
|
| 1107 |
+
help="Absolute floor for LR (used when scheduler is 'cosine_with_min_lr')",
|
| 1108 |
+
)
|
| 1109 |
+
parser.add_argument(
|
| 1110 |
+
"--min-lr-rate",
|
| 1111 |
+
type=float,
|
| 1112 |
+
dest="min_lr_rate",
|
| 1113 |
+
help="Relative LR floor rate in (0,1) for TRL scheduler (used when scheduler is 'cosine_with_min_lr')",
|
| 1114 |
+
)
|
| 1115 |
|
| 1116 |
args = parser.parse_args()
|
| 1117 |
|
|
|
|
| 1124 |
os.makedirs(args.output_dir, exist_ok=True)
|
| 1125 |
|
| 1126 |
try:
|
| 1127 |
+
# If provided, expose scheduler overrides via environment so they can be picked up consistently
|
| 1128 |
+
# across helper functions if needed.
|
| 1129 |
+
if args.scheduler:
|
| 1130 |
+
os.environ["GPT_OSS_SCHEDULER"] = args.scheduler
|
| 1131 |
+
if args.min_lr is not None:
|
| 1132 |
+
os.environ["GPT_OSS_MIN_LR"] = str(args.min_lr)
|
| 1133 |
+
if args.min_lr_rate is not None:
|
| 1134 |
+
os.environ["GPT_OSS_MIN_LR_RATE"] = str(args.min_lr_rate)
|
| 1135 |
+
|
| 1136 |
+
trainer = train_gpt_oss(
|
| 1137 |
config_path=args.config,
|
| 1138 |
experiment_name=args.experiment_name,
|
| 1139 |
output_dir=args.output_dir,
|