""" Learning Rate Schedulers for v6.2.0 Advanced scheduling with warmup and phase-based adjustments """ import torch import math from typing import Optional, Dict, List, Any import numpy as np class WarmupCosineScheduler: """ Cosine annealing with linear warmup GPT-5 suggested: Essential for stable progressive splitting training """ def __init__(self, optimizer: torch.optim.Optimizer, warmup_steps: int, total_steps: int, min_lr: float = 1e-6, max_lr: Optional[float] = None): self.optimizer = optimizer self.warmup_steps = warmup_steps self.total_steps = total_steps self.min_lr = min_lr self.max_lr = max_lr or optimizer.param_groups[0]['lr'] self.current_step = 0 def step(self): """Update learning rate""" self.current_step += 1 if self.current_step <= self.warmup_steps: # Linear warmup lr = self.max_lr * (self.current_step / self.warmup_steps) else: # Cosine annealing (GPT fix: guard against division by zero) if self.total_steps <= self.warmup_steps: lr = self.min_lr else: progress = (self.current_step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps) progress = min(1.0, max(0.0, progress)) # Clamp to [0, 1] lr = self.min_lr + (self.max_lr - self.min_lr) * 0.5 * (1 + math.cos(math.pi * progress)) for param_group in self.optimizer.param_groups: param_group['lr'] = lr return lr def get_lr(self): """Get current learning rate""" return self.optimizer.param_groups[0]['lr'] class PhaseBasedScheduler: """ Curriculum learning scheduler with phase transitions Adjusts learning rate based on training phases """ def __init__(self, optimizer: torch.optim.Optimizer, phase_configs: List[Dict], current_epoch: int = 0): """ Args: optimizer: PyTorch optimizer phase_configs: List of phase configurations [{ 'epochs': (start, end), 'lr': learning_rate, 'warmup_epochs': warmup_duration }, ...] """ self.optimizer = optimizer self.phase_configs = phase_configs self.current_epoch = current_epoch self.current_phase = 0 self.base_lr = optimizer.param_groups[0]['lr'] def step(self, epoch: Optional[int] = None): """Update learning rate based on current phase""" if epoch is not None: self.current_epoch = epoch # Find current phase for i, phase in enumerate(self.phase_configs): start_epoch, end_epoch = phase['epochs'] if start_epoch <= self.current_epoch <= end_epoch: self.current_phase = i break phase = self.phase_configs[self.current_phase] target_lr = phase['lr'] warmup_epochs = phase.get('warmup_epochs', 0) start_epoch = phase['epochs'][0] # Apply warmup if in warmup period if self.current_epoch - start_epoch < warmup_epochs: warmup_progress = (self.current_epoch - start_epoch + 1) / warmup_epochs lr = target_lr * warmup_progress else: lr = target_lr # Update optimizer for param_group in self.optimizer.param_groups: param_group['lr'] = lr return lr class AdaptiveScheduler: """ Adaptive learning rate based on validation metrics Reduces LR when metrics plateau """ def __init__(self, optimizer: torch.optim.Optimizer, mode: str = 'min', factor: float = 0.5, patience: int = 10, threshold: float = 1e-4, min_lr: float = 1e-7): """ Args: optimizer: PyTorch optimizer mode: 'min' or 'max' - whether to reduce LR when metric stops decreasing or increasing factor: Factor to reduce LR by patience: Number of epochs with no improvement to wait threshold: Minimum change to qualify as improvement min_lr: Minimum learning rate """ self.optimizer = optimizer self.mode = mode self.factor = factor self.patience = patience self.threshold = threshold self.min_lr = min_lr self.best_score = None self.num_bad_epochs = 0 self.last_reduction = 0 def step(self, metric: float, epoch: int = 0): """Update learning rate based on metric""" current_lr = self.optimizer.param_groups[0]['lr'] if self.best_score is None: self.best_score = metric else: if self.mode == 'min': improved = metric < self.best_score - self.threshold else: improved = metric > self.best_score + self.threshold if improved: self.best_score = metric self.num_bad_epochs = 0 else: self.num_bad_epochs += 1 # Reduce LR if patience exceeded if self.num_bad_epochs >= self.patience: new_lr = max(current_lr * self.factor, self.min_lr) if new_lr < current_lr: print(f"Reducing learning rate from {current_lr:.2e} to {new_lr:.2e}") for param_group in self.optimizer.param_groups: param_group['lr'] = new_lr self.num_bad_epochs = 0 self.last_reduction = epoch return current_lr class ProgressiveSplittingScheduler: """ Adaptive scheduler for progressive splitting No fixed targets - adjusts based on quality feedback """ def __init__(self, optimizer: torch.optim.Optimizer, initial_lr: float = 1e-4, min_reconstruction: float = 0.85, ema: float = 0.98, min_lr: float = 1e-7): self.optimizer = optimizer self.initial_lr = initial_lr self.min_reconstruction = min_reconstruction # Quality threshold self.ema = ema self.min_lr = min_lr # Adaptive multipliers based on performance self.quality_multiplier = 1.0 # Adjusts with reconstruction quality # No phases - continuous adaptation self.current_state = 'learning' # EMA tracking for smooth transitions self._ema_comp = None self._ema_recon = None def step(self, metrics: Dict[str, float]): """ Update learning rate based on current metrics GPT fix: EMA smoothing and minimum floor Args: metrics: Dictionary containing: - compression_ratio: Current compression ratio - reconstruction_acc: Reconstruction accuracy """ compression_ratio = float(metrics.get('compression_ratio', 0.0)) reconstruction_acc = float(metrics.get('reconstruction_acc', 0.0)) # Update EMA (GPT fix: smooth transitions) if self._ema_comp is None: self._ema_comp = compression_ratio self._ema_recon = reconstruction_acc else: self._ema_comp = self.ema * self._ema_comp + (1 - self.ema) * compression_ratio self._ema_recon = self.ema * self._ema_recon + (1 - self.ema) * reconstruction_acc # Adaptive adjustment based on reconstruction quality only # No fixed compression target - emerges from quality if self._ema_recon < self.min_reconstruction: # Poor reconstruction - reduce LR for careful learning self.quality_multiplier = max(0.5, self._ema_recon) else: # Good reconstruction - normal learning self.quality_multiplier = 1.0 # Smooth LR changes reconstruction_factor = max(0.1, self._ema_recon) # Combined learning rate (adaptive, no phase multiplier) lr = self.initial_lr * self.quality_multiplier * reconstruction_factor lr = max(lr, self.min_lr) # Ensure minimum LR # Update optimizer for param_group in self.optimizer.param_groups: param_group['lr'] = lr return lr class GumbelTemperatureScheduler: """ Temperature annealing for Gumbel-Softmax GPT-5 suggestion: Critical for progressive splitting """ def __init__(self, initial_temp: float = 1.0, final_temp: float = 0.1, anneal_rate: float = 0.99995, anneal_steps: Optional[int] = None): self.initial_temp = initial_temp self.final_temp = final_temp self.anneal_rate = anneal_rate self.anneal_steps = anneal_steps self.current_step = 0 self.current_temp = initial_temp def step(self): """Update temperature""" self.current_step += 1 if self.anneal_steps: # Linear annealing progress = min(1.0, self.current_step / self.anneal_steps) self.current_temp = self.initial_temp + (self.final_temp - self.initial_temp) * progress else: # Exponential annealing self.current_temp = max( self.final_temp, self.initial_temp * (self.anneal_rate ** self.current_step) ) return self.current_temp def get_temperature(self): """Get current temperature""" return self.current_temp class CompressionRatioScheduler: """ Schedule target compression ratio during training Gradually increase compression requirements """ def __init__(self, initial_ratio: float = 8.0, target_ratio: float = 24.0, warmup_epochs: int = 10, total_epochs: int = 100): self.initial_ratio = initial_ratio self.target_ratio = target_ratio self.warmup_epochs = warmup_epochs self.total_epochs = total_epochs self.current_epoch = 0 def step(self, epoch: Optional[int] = None): """Update target compression ratio""" if epoch is not None: self.current_epoch = epoch else: self.current_epoch += 1 if self.current_epoch < self.warmup_epochs: # Start with lower compression requirement ratio = self.initial_ratio else: # Gradually increase to target progress = (self.current_epoch - self.warmup_epochs) / (self.total_epochs - self.warmup_epochs) progress = min(1.0, progress) ratio = self.initial_ratio + (self.target_ratio - self.initial_ratio) * progress return ratio class MultiScheduler: """ Combine multiple schedulers for comprehensive training control """ def __init__(self, schedulers: Dict): """ Args: schedulers: Dictionary of schedulers { 'lr': learning_rate_scheduler, 'gumbel': gumbel_temperature_scheduler, 'compression': compression_ratio_scheduler, ... } """ self.schedulers = schedulers def step(self, **kwargs): """ Step all schedulers GPT fix: unified input convention Returns: Dictionary with all scheduler outputs """ results = {} for name, scheduler in self.schedulers.items(): try: # Check scheduler type and pass appropriate arguments if hasattr(scheduler, '__class__'): class_name = scheduler.__class__.__name__ if class_name == 'AdaptiveScheduler' and 'metric' in kwargs: results[name] = scheduler.step(kwargs['metric'], kwargs.get('epoch', 0)) elif class_name == 'PhaseBasedScheduler' and 'epoch' in kwargs: results[name] = scheduler.step(kwargs['epoch']) elif class_name == 'CompressionRatioScheduler' and 'epoch' in kwargs: results[name] = scheduler.step(kwargs['epoch']) elif class_name == 'ProgressiveSplittingScheduler' and 'metrics' in kwargs: results[name] = scheduler.step(kwargs['metrics']) elif hasattr(scheduler, 'step'): # Generic step (no arguments) results[name] = scheduler.step() else: if hasattr(scheduler, 'step'): results[name] = scheduler.step() except Exception as e: print(f"Warning: Scheduler '{name}' step failed: {e}") results[name] = None return results def get_current_values(self): """Get current values from all schedulers""" values = {} for name, scheduler in self.schedulers.items(): if hasattr(scheduler, 'get_lr'): values[name] = scheduler.get_lr() elif hasattr(scheduler, 'get_temperature'): values[name] = scheduler.get_temperature() elif hasattr(scheduler, 'current_temp'): values[name] = scheduler.current_temp elif hasattr(scheduler, 'current_epoch'): values[name] = scheduler.current_epoch return values class GateWarmupScheduler: """게이트 파라미터 웜업 스케줄러 초기: 모든 레이어 동등 사용 (gate=1.0) 웜업: 점진적 게이트 학습 시작 후기: 최적 게이트 값으로 수렴 """ def __init__( self, optimizer: torch.optim.Optimizer, warmup_steps: int = 1000, gate_param_group_name: str = 'gates', importance_param_group_name: str = 'importance' ): """ Args: optimizer: 옵티마이저 warmup_steps: 웜업 스텝 수 gate_param_group_name: 게이트 파라미터 그룹 이름 importance_param_group_name: 중요도 파라미터 그룹 이름 """ self.optimizer = optimizer self.warmup_steps = warmup_steps self.gate_group_name = gate_param_group_name self.importance_group_name = importance_param_group_name # 초기 학습률 저장 self.base_lrs = {} for group in optimizer.param_groups: if 'name' in group: self.base_lrs[group['name']] = group['lr'] def get_gate_factor(self, step: int) -> float: """게이트 학습률 계수 계산 웜업 기간 동안은 낮은 학습률, 이후 정상 학습률로 전환 """ if step < self.warmup_steps: # 웜업 기간: 선형 증가 return step / self.warmup_steps else: # 정상 학습 return 1.0 def get_importance_factor(self, step: int) -> float: """중요도 학습률 계수 계산 게이트보다 느리게 학습 시작 """ delayed_warmup = self.warmup_steps * 1.5 if step < delayed_warmup: return step / delayed_warmup * 0.5 else: return 1.0 def step(self, current_step: int): """스케줄러 스텝 Args: current_step: 현재 글로벌 스텝 """ # 게이트 파라미터 그룹 학습률 조정 gate_factor = self.get_gate_factor(current_step) importance_factor = self.get_importance_factor(current_step) for group in self.optimizer.param_groups: if 'name' not in group: continue if group['name'] == self.gate_group_name: # 게이트 학습률 조정 group['lr'] = self.base_lrs[self.gate_group_name] * gate_factor elif group['name'] == self.importance_group_name: # 중요도 학습률 조정 group['lr'] = self.base_lrs[self.importance_group_name] * importance_factor def get_lr(self) -> Dict[str, float]: """현재 학습률 반환""" lrs = {} for group in self.optimizer.param_groups: if 'name' in group: lrs[group['name']] = group['lr'] return lrs class UniversalCosineScheduler: """Universal Cosine Annealing 스케줄러 모든 언어에 대해 동일한 스케줄 적용 """ def __init__( self, optimizer: torch.optim.Optimizer, warmup_steps: int = 1000, total_steps: int = 10000, min_lr_ratio: float = 0.1 ): self.optimizer = optimizer self.warmup_steps = warmup_steps self.total_steps = total_steps self.min_lr_ratio = min_lr_ratio self.current_step = 0 # 초기 학습률 저장 self.base_lrs = [group['lr'] for group in optimizer.param_groups] def step(self): """스케줄러 스텝""" self.current_step += 1 for idx, param_group in enumerate(self.optimizer.param_groups): if self.current_step < self.warmup_steps: # Warmup 단계 lr = self.base_lrs[idx] * (self.current_step / self.warmup_steps) else: # Cosine annealing if self.total_steps <= self.warmup_steps: # warmup_steps가 total_steps보다 크거나 같은 경우 lr = self.base_lrs[idx] * self.min_lr_ratio else: progress = min(1.0, (self.current_step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)) lr = self.base_lrs[idx] * ( self.min_lr_ratio + (1 - self.min_lr_ratio) * 0.5 * (1 + math.cos(math.pi * progress)) ) param_group['lr'] = lr def get_last_lr(self) -> List[float]: """마지막 학습률 반환""" return [group['lr'] for group in self.optimizer.param_groups] def state_dict(self) -> Dict[str, Any]: """스케줄러 상태 딕셔너리 반환 (체크포인트 저장용)""" return { 'current_step': self.current_step, 'warmup_steps': self.warmup_steps, 'total_steps': self.total_steps, 'min_lr_ratio': self.min_lr_ratio, 'base_lrs': self.base_lrs } def load_state_dict(self, state_dict: Dict[str, Any]): """스케줄러 상태 로드 (체크포인트 재시작용)""" self.current_step = state_dict['current_step'] self.warmup_steps = state_dict['warmup_steps'] self.total_steps = state_dict['total_steps'] self.min_lr_ratio = state_dict['min_lr_ratio'] self.base_lrs = state_dict['base_lrs'] class AdaptiveLayerScheduler: """레이어별 적응적 스케줄러 각 레이어의 학습 진행도에 따라 동적으로 조정 """ def __init__( self, layer_builder, threshold_active: float = 0.7, threshold_skip: float = 0.3 ): """ Args: layer_builder: LayerBuilder 인스턴스 threshold_active: 활성 레이어 임계값 threshold_skip: 스킵 레이어 임계값 """ self.layer_builder = layer_builder self.threshold_active = threshold_active self.threshold_skip = threshold_skip # 레이어별 통계 self.layer_stats = { 'usage_count': torch.zeros(5), 'contribution': torch.zeros(5) } def update_stats(self, batch_output): """배치 출력으로 통계 업데이트""" with torch.no_grad(): gates = torch.sigmoid(self.layer_builder.layer_gates) # 사용 횟수 업데이트 self.layer_stats['usage_count'] += (gates > self.threshold_skip).float() # 기여도 추정 (간단한 버전) importance = torch.nn.functional.softmax( self.layer_builder.layer_importance, dim=0 ) self.layer_stats['contribution'] += importance.detach() def get_layer_status(self) -> Dict[int, str]: """각 레이어의 상태 반환""" gates = torch.sigmoid(self.layer_builder.layer_gates) status = {} for i in range(5): if gates[i] > self.threshold_active: status[i] = "ACTIVE" elif gates[i] > self.threshold_skip: status[i] = "PARTIAL" else: status[i] = "SKIP" return status def suggest_pruning(self) -> List[int]: """프루닝 가능한 레이어 제안""" gates = torch.sigmoid(self.layer_builder.layer_gates) prunable = [] for i in range(5): if gates[i] < self.threshold_skip: # 낮은 게이트 값 + 낮은 기여도 if self.layer_stats['contribution'][i] < 0.1: prunable.append(i) return prunable if __name__ == "__main__": # Test schedulers print("Testing Schedulers") # Create dummy optimizer model = torch.nn.Linear(10, 10) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # Test WarmupCosineScheduler print("\n1. WarmupCosineScheduler:") scheduler = WarmupCosineScheduler(optimizer, warmup_steps=100, total_steps=1000) lrs = [] for step in range(200): lr = scheduler.step() if step % 20 == 0: print(f" Step {step}: LR = {lr:.6f}") lrs.append(lr) # Test PhaseBasedScheduler print("\n2. PhaseBasedScheduler:") phase_configs = [ {'epochs': (0, 30), 'lr': 1e-4, 'warmup_epochs': 5}, {'epochs': (31, 60), 'lr': 5e-5, 'warmup_epochs': 2}, {'epochs': (61, 100), 'lr': 1e-5, 'warmup_epochs': 0} ] scheduler = PhaseBasedScheduler(optimizer, phase_configs) for epoch in [0, 5, 31, 35, 61, 80]: lr = scheduler.step(epoch) print(f" Epoch {epoch}: LR = {lr:.6f}") # Test GumbelTemperatureScheduler print("\n3. GumbelTemperatureScheduler:") scheduler = GumbelTemperatureScheduler() for step in [0, 100, 500, 1000, 5000]: for _ in range(step - scheduler.current_step): scheduler.step() temp = scheduler.get_temperature() print(f" Step {step}: Temperature = {temp:.4f}") # Test CompressionRatioScheduler print("\n4. CompressionRatioScheduler:") scheduler = CompressionRatioScheduler() for epoch in [0, 5, 10, 30, 50, 80, 100]: ratio = scheduler.step(epoch) print(f" Epoch {epoch}: Target ratio = {ratio:.1f}:1")