Spaces:
Running
Running
adapt lr_scheduler according to trl version
Browse files
scripts/training/train_gpt_oss.py
CHANGED
|
@@ -155,6 +155,10 @@ def build_scheduler_kwargs(config):
|
|
| 155 |
skw['min_lr_rate'] = 0.1
|
| 156 |
except Exception:
|
| 157 |
skw['min_lr_rate'] = 0.001
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
return skw
|
| 159 |
|
| 160 |
def apply_dataset_filtering(dataset, config):
|
|
@@ -509,6 +513,36 @@ def create_sft_config(config, output_dir):
|
|
| 509 |
learning_rate = _as_float(getattr(config, 'learning_rate', 2e-4), 2e-4)
|
| 510 |
lr_scheduler_type = getattr(config, 'scheduler', 'cosine_with_min_lr')
|
| 511 |
lr_scheduler_kwargs = build_scheduler_kwargs(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 512 |
|
| 513 |
# Batch configuration
|
| 514 |
per_device_train_batch_size = _as_int(getattr(config, 'batch_size', 2), 2)
|
|
|
|
| 155 |
skw['min_lr_rate'] = 0.1
|
| 156 |
except Exception:
|
| 157 |
skw['min_lr_rate'] = 0.001
|
| 158 |
+
# Remove warmup-related keys which conflict with some TRL schedulers
|
| 159 |
+
for k in ('warmup_steps', 'num_warmup_steps', 'warmup_ratio'):
|
| 160 |
+
if k in skw:
|
| 161 |
+
skw.pop(k, None)
|
| 162 |
return skw
|
| 163 |
|
| 164 |
def apply_dataset_filtering(dataset, config):
|
|
|
|
| 513 |
learning_rate = _as_float(getattr(config, 'learning_rate', 2e-4), 2e-4)
|
| 514 |
lr_scheduler_type = getattr(config, 'scheduler', 'cosine_with_min_lr')
|
| 515 |
lr_scheduler_kwargs = build_scheduler_kwargs(config)
|
| 516 |
+
|
| 517 |
+
# Detect TRL scheduler signature incompatibilities and fall back gracefully
|
| 518 |
+
# Some TRL versions call get_cosine_with_min_lr_schedule_with_warmup with
|
| 519 |
+
# 'warmup_steps' instead of 'num_warmup_steps', which raises:
|
| 520 |
+
# get_cosine_with_min_lr_schedule_with_warmup() got an unexpected keyword
|
| 521 |
+
# argument 'warmup_steps'
|
| 522 |
+
# To avoid this, we fallback to the standard 'cosine' scheduler and strip
|
| 523 |
+
# incompatible kwargs when the incompatible signature is detected.
|
| 524 |
+
if lr_scheduler_type == 'cosine_with_min_lr':
|
| 525 |
+
try:
|
| 526 |
+
from trl.trainer import utils as trl_utils # type: ignore
|
| 527 |
+
import inspect as _inspect
|
| 528 |
+
if hasattr(trl_utils, 'get_cosine_with_min_lr_schedule_with_warmup'):
|
| 529 |
+
_sig = _inspect.signature(trl_utils.get_cosine_with_min_lr_schedule_with_warmup)
|
| 530 |
+
# If the function does NOT accept 'warmup_steps' explicitly, some TRL versions
|
| 531 |
+
# still pass it internally as a kwarg, causing a TypeError. Fallback to 'cosine'.
|
| 532 |
+
if 'warmup_steps' not in _sig.parameters:
|
| 533 |
+
print("Warning: Incompatible TRL scheduler signature detected; falling back to 'cosine'.")
|
| 534 |
+
lr_scheduler_type = 'cosine'
|
| 535 |
+
lr_scheduler_kwargs = {}
|
| 536 |
+
else:
|
| 537 |
+
# Function missing; fallback
|
| 538 |
+
print("Warning: TRL min-lr cosine scheduler not available; falling back to 'cosine'.")
|
| 539 |
+
lr_scheduler_type = 'cosine'
|
| 540 |
+
lr_scheduler_kwargs = {}
|
| 541 |
+
except Exception:
|
| 542 |
+
# Any import/signature issues -> safe fallback
|
| 543 |
+
print("Warning: Unable to verify TRL scheduler; falling back to 'cosine'.")
|
| 544 |
+
lr_scheduler_type = 'cosine'
|
| 545 |
+
lr_scheduler_kwargs = {}
|
| 546 |
|
| 547 |
# Batch configuration
|
| 548 |
per_device_train_batch_size = _as_int(getattr(config, 'batch_size', 2), 2)
|