Tonic commited on
Commit
976e218
·
1 Parent(s): 7f45871

adapt lr_scheduler according to trl version

Browse files
Files changed (1) hide show
  1. scripts/training/train_gpt_oss.py +34 -0
scripts/training/train_gpt_oss.py CHANGED
@@ -155,6 +155,10 @@ def build_scheduler_kwargs(config):
155
  skw['min_lr_rate'] = 0.1
156
  except Exception:
157
  skw['min_lr_rate'] = 0.001
 
 
 
 
158
  return skw
159
 
160
  def apply_dataset_filtering(dataset, config):
@@ -509,6 +513,36 @@ def create_sft_config(config, output_dir):
509
  learning_rate = _as_float(getattr(config, 'learning_rate', 2e-4), 2e-4)
510
  lr_scheduler_type = getattr(config, 'scheduler', 'cosine_with_min_lr')
511
  lr_scheduler_kwargs = build_scheduler_kwargs(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
 
513
  # Batch configuration
514
  per_device_train_batch_size = _as_int(getattr(config, 'batch_size', 2), 2)
 
155
  skw['min_lr_rate'] = 0.1
156
  except Exception:
157
  skw['min_lr_rate'] = 0.001
158
+ # Remove warmup-related keys which conflict with some TRL schedulers
159
+ for k in ('warmup_steps', 'num_warmup_steps', 'warmup_ratio'):
160
+ if k in skw:
161
+ skw.pop(k, None)
162
  return skw
163
 
164
  def apply_dataset_filtering(dataset, config):
 
513
  learning_rate = _as_float(getattr(config, 'learning_rate', 2e-4), 2e-4)
514
  lr_scheduler_type = getattr(config, 'scheduler', 'cosine_with_min_lr')
515
  lr_scheduler_kwargs = build_scheduler_kwargs(config)
516
+
517
+ # Detect TRL scheduler signature incompatibilities and fall back gracefully
518
+ # Some TRL versions call get_cosine_with_min_lr_schedule_with_warmup with
519
+ # 'warmup_steps' instead of 'num_warmup_steps', which raises:
520
+ # get_cosine_with_min_lr_schedule_with_warmup() got an unexpected keyword
521
+ # argument 'warmup_steps'
522
+ # To avoid this, we fallback to the standard 'cosine' scheduler and strip
523
+ # incompatible kwargs when the incompatible signature is detected.
524
+ if lr_scheduler_type == 'cosine_with_min_lr':
525
+ try:
526
+ from trl.trainer import utils as trl_utils # type: ignore
527
+ import inspect as _inspect
528
+ if hasattr(trl_utils, 'get_cosine_with_min_lr_schedule_with_warmup'):
529
+ _sig = _inspect.signature(trl_utils.get_cosine_with_min_lr_schedule_with_warmup)
530
+ # If the function does NOT accept 'warmup_steps' explicitly, some TRL versions
531
+ # still pass it internally as a kwarg, causing a TypeError. Fallback to 'cosine'.
532
+ if 'warmup_steps' not in _sig.parameters:
533
+ print("Warning: Incompatible TRL scheduler signature detected; falling back to 'cosine'.")
534
+ lr_scheduler_type = 'cosine'
535
+ lr_scheduler_kwargs = {}
536
+ else:
537
+ # Function missing; fallback
538
+ print("Warning: TRL min-lr cosine scheduler not available; falling back to 'cosine'.")
539
+ lr_scheduler_type = 'cosine'
540
+ lr_scheduler_kwargs = {}
541
+ except Exception:
542
+ # Any import/signature issues -> safe fallback
543
+ print("Warning: Unable to verify TRL scheduler; falling back to 'cosine'.")
544
+ lr_scheduler_type = 'cosine'
545
+ lr_scheduler_kwargs = {}
546
 
547
  # Batch configuration
548
  per_device_train_batch_size = _as_int(getattr(config, 'batch_size', 2), 2)