ggunio's picture
Upload folder using huggingface_hub
ff85374 verified
"""
Learning Rate Schedulers for v6.2.0
Advanced scheduling with warmup and phase-based adjustments
"""
import torch
import math
from typing import Optional, Dict, List, Any
import numpy as np
class WarmupCosineScheduler:
"""
Cosine annealing with linear warmup
GPT-5 suggested: Essential for stable progressive splitting training
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
warmup_steps: int,
total_steps: int,
min_lr: float = 1e-6,
max_lr: Optional[float] = None):
self.optimizer = optimizer
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.min_lr = min_lr
self.max_lr = max_lr or optimizer.param_groups[0]['lr']
self.current_step = 0
def step(self):
"""Update learning rate"""
self.current_step += 1
if self.current_step <= self.warmup_steps:
# Linear warmup
lr = self.max_lr * (self.current_step / self.warmup_steps)
else:
# Cosine annealing (GPT fix: guard against division by zero)
if self.total_steps <= self.warmup_steps:
lr = self.min_lr
else:
progress = (self.current_step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)
progress = min(1.0, max(0.0, progress)) # Clamp to [0, 1]
lr = self.min_lr + (self.max_lr - self.min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
return lr
def get_lr(self):
"""Get current learning rate"""
return self.optimizer.param_groups[0]['lr']
class PhaseBasedScheduler:
"""
Curriculum learning scheduler with phase transitions
Adjusts learning rate based on training phases
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
phase_configs: List[Dict],
current_epoch: int = 0):
"""
Args:
optimizer: PyTorch optimizer
phase_configs: List of phase configurations
[{
'epochs': (start, end),
'lr': learning_rate,
'warmup_epochs': warmup_duration
}, ...]
"""
self.optimizer = optimizer
self.phase_configs = phase_configs
self.current_epoch = current_epoch
self.current_phase = 0
self.base_lr = optimizer.param_groups[0]['lr']
def step(self, epoch: Optional[int] = None):
"""Update learning rate based on current phase"""
if epoch is not None:
self.current_epoch = epoch
# Find current phase
for i, phase in enumerate(self.phase_configs):
start_epoch, end_epoch = phase['epochs']
if start_epoch <= self.current_epoch <= end_epoch:
self.current_phase = i
break
phase = self.phase_configs[self.current_phase]
target_lr = phase['lr']
warmup_epochs = phase.get('warmup_epochs', 0)
start_epoch = phase['epochs'][0]
# Apply warmup if in warmup period
if self.current_epoch - start_epoch < warmup_epochs:
warmup_progress = (self.current_epoch - start_epoch + 1) / warmup_epochs
lr = target_lr * warmup_progress
else:
lr = target_lr
# Update optimizer
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
return lr
class AdaptiveScheduler:
"""
Adaptive learning rate based on validation metrics
Reduces LR when metrics plateau
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
mode: str = 'min',
factor: float = 0.5,
patience: int = 10,
threshold: float = 1e-4,
min_lr: float = 1e-7):
"""
Args:
optimizer: PyTorch optimizer
mode: 'min' or 'max' - whether to reduce LR when metric stops decreasing or increasing
factor: Factor to reduce LR by
patience: Number of epochs with no improvement to wait
threshold: Minimum change to qualify as improvement
min_lr: Minimum learning rate
"""
self.optimizer = optimizer
self.mode = mode
self.factor = factor
self.patience = patience
self.threshold = threshold
self.min_lr = min_lr
self.best_score = None
self.num_bad_epochs = 0
self.last_reduction = 0
def step(self, metric: float, epoch: int = 0):
"""Update learning rate based on metric"""
current_lr = self.optimizer.param_groups[0]['lr']
if self.best_score is None:
self.best_score = metric
else:
if self.mode == 'min':
improved = metric < self.best_score - self.threshold
else:
improved = metric > self.best_score + self.threshold
if improved:
self.best_score = metric
self.num_bad_epochs = 0
else:
self.num_bad_epochs += 1
# Reduce LR if patience exceeded
if self.num_bad_epochs >= self.patience:
new_lr = max(current_lr * self.factor, self.min_lr)
if new_lr < current_lr:
print(f"Reducing learning rate from {current_lr:.2e} to {new_lr:.2e}")
for param_group in self.optimizer.param_groups:
param_group['lr'] = new_lr
self.num_bad_epochs = 0
self.last_reduction = epoch
return current_lr
class ProgressiveSplittingScheduler:
"""
Adaptive scheduler for progressive splitting
No fixed targets - adjusts based on quality feedback
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
initial_lr: float = 1e-4,
min_reconstruction: float = 0.85,
ema: float = 0.98,
min_lr: float = 1e-7):
self.optimizer = optimizer
self.initial_lr = initial_lr
self.min_reconstruction = min_reconstruction # Quality threshold
self.ema = ema
self.min_lr = min_lr
# Adaptive multipliers based on performance
self.quality_multiplier = 1.0 # Adjusts with reconstruction quality
# No phases - continuous adaptation
self.current_state = 'learning'
# EMA tracking for smooth transitions
self._ema_comp = None
self._ema_recon = None
def step(self, metrics: Dict[str, float]):
"""
Update learning rate based on current metrics
GPT fix: EMA smoothing and minimum floor
Args:
metrics: Dictionary containing:
- compression_ratio: Current compression ratio
- reconstruction_acc: Reconstruction accuracy
"""
compression_ratio = float(metrics.get('compression_ratio', 0.0))
reconstruction_acc = float(metrics.get('reconstruction_acc', 0.0))
# Update EMA (GPT fix: smooth transitions)
if self._ema_comp is None:
self._ema_comp = compression_ratio
self._ema_recon = reconstruction_acc
else:
self._ema_comp = self.ema * self._ema_comp + (1 - self.ema) * compression_ratio
self._ema_recon = self.ema * self._ema_recon + (1 - self.ema) * reconstruction_acc
# Adaptive adjustment based on reconstruction quality only
# No fixed compression target - emerges from quality
if self._ema_recon < self.min_reconstruction:
# Poor reconstruction - reduce LR for careful learning
self.quality_multiplier = max(0.5, self._ema_recon)
else:
# Good reconstruction - normal learning
self.quality_multiplier = 1.0
# Smooth LR changes
reconstruction_factor = max(0.1, self._ema_recon)
# Combined learning rate (adaptive, no phase multiplier)
lr = self.initial_lr * self.quality_multiplier * reconstruction_factor
lr = max(lr, self.min_lr) # Ensure minimum LR
# Update optimizer
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
return lr
class GumbelTemperatureScheduler:
"""
Temperature annealing for Gumbel-Softmax
GPT-5 suggestion: Critical for progressive splitting
"""
def __init__(self,
initial_temp: float = 1.0,
final_temp: float = 0.1,
anneal_rate: float = 0.99995,
anneal_steps: Optional[int] = None):
self.initial_temp = initial_temp
self.final_temp = final_temp
self.anneal_rate = anneal_rate
self.anneal_steps = anneal_steps
self.current_step = 0
self.current_temp = initial_temp
def step(self):
"""Update temperature"""
self.current_step += 1
if self.anneal_steps:
# Linear annealing
progress = min(1.0, self.current_step / self.anneal_steps)
self.current_temp = self.initial_temp + (self.final_temp - self.initial_temp) * progress
else:
# Exponential annealing
self.current_temp = max(
self.final_temp,
self.initial_temp * (self.anneal_rate ** self.current_step)
)
return self.current_temp
def get_temperature(self):
"""Get current temperature"""
return self.current_temp
class CompressionRatioScheduler:
"""
Schedule target compression ratio during training
Gradually increase compression requirements
"""
def __init__(self,
initial_ratio: float = 8.0,
target_ratio: float = 24.0,
warmup_epochs: int = 10,
total_epochs: int = 100):
self.initial_ratio = initial_ratio
self.target_ratio = target_ratio
self.warmup_epochs = warmup_epochs
self.total_epochs = total_epochs
self.current_epoch = 0
def step(self, epoch: Optional[int] = None):
"""Update target compression ratio"""
if epoch is not None:
self.current_epoch = epoch
else:
self.current_epoch += 1
if self.current_epoch < self.warmup_epochs:
# Start with lower compression requirement
ratio = self.initial_ratio
else:
# Gradually increase to target
progress = (self.current_epoch - self.warmup_epochs) / (self.total_epochs - self.warmup_epochs)
progress = min(1.0, progress)
ratio = self.initial_ratio + (self.target_ratio - self.initial_ratio) * progress
return ratio
class MultiScheduler:
"""
Combine multiple schedulers for comprehensive training control
"""
def __init__(self, schedulers: Dict):
"""
Args:
schedulers: Dictionary of schedulers
{
'lr': learning_rate_scheduler,
'gumbel': gumbel_temperature_scheduler,
'compression': compression_ratio_scheduler,
...
}
"""
self.schedulers = schedulers
def step(self, **kwargs):
"""
Step all schedulers
GPT fix: unified input convention
Returns:
Dictionary with all scheduler outputs
"""
results = {}
for name, scheduler in self.schedulers.items():
try:
# Check scheduler type and pass appropriate arguments
if hasattr(scheduler, '__class__'):
class_name = scheduler.__class__.__name__
if class_name == 'AdaptiveScheduler' and 'metric' in kwargs:
results[name] = scheduler.step(kwargs['metric'], kwargs.get('epoch', 0))
elif class_name == 'PhaseBasedScheduler' and 'epoch' in kwargs:
results[name] = scheduler.step(kwargs['epoch'])
elif class_name == 'CompressionRatioScheduler' and 'epoch' in kwargs:
results[name] = scheduler.step(kwargs['epoch'])
elif class_name == 'ProgressiveSplittingScheduler' and 'metrics' in kwargs:
results[name] = scheduler.step(kwargs['metrics'])
elif hasattr(scheduler, 'step'):
# Generic step (no arguments)
results[name] = scheduler.step()
else:
if hasattr(scheduler, 'step'):
results[name] = scheduler.step()
except Exception as e:
print(f"Warning: Scheduler '{name}' step failed: {e}")
results[name] = None
return results
def get_current_values(self):
"""Get current values from all schedulers"""
values = {}
for name, scheduler in self.schedulers.items():
if hasattr(scheduler, 'get_lr'):
values[name] = scheduler.get_lr()
elif hasattr(scheduler, 'get_temperature'):
values[name] = scheduler.get_temperature()
elif hasattr(scheduler, 'current_temp'):
values[name] = scheduler.current_temp
elif hasattr(scheduler, 'current_epoch'):
values[name] = scheduler.current_epoch
return values
class GateWarmupScheduler:
"""๊ฒŒ์ดํŠธ ํŒŒ๋ผ๋ฏธํ„ฐ ์›œ์—… ์Šค์ผ€์ค„๋Ÿฌ
์ดˆ๊ธฐ: ๋ชจ๋“  ๋ ˆ์ด์–ด ๋™๋“ฑ ์‚ฌ์šฉ (gate=1.0)
์›œ์—…: ์ ์ง„์  ๊ฒŒ์ดํŠธ ํ•™์Šต ์‹œ์ž‘
ํ›„๊ธฐ: ์ตœ์  ๊ฒŒ์ดํŠธ ๊ฐ’์œผ๋กœ ์ˆ˜๋ ด
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_steps: int = 1000,
gate_param_group_name: str = 'gates',
importance_param_group_name: str = 'importance'
):
"""
Args:
optimizer: ์˜ตํ‹ฐ๋งˆ์ด์ €
warmup_steps: ์›œ์—… ์Šคํ… ์ˆ˜
gate_param_group_name: ๊ฒŒ์ดํŠธ ํŒŒ๋ผ๋ฏธํ„ฐ ๊ทธ๋ฃน ์ด๋ฆ„
importance_param_group_name: ์ค‘์š”๋„ ํŒŒ๋ผ๋ฏธํ„ฐ ๊ทธ๋ฃน ์ด๋ฆ„
"""
self.optimizer = optimizer
self.warmup_steps = warmup_steps
self.gate_group_name = gate_param_group_name
self.importance_group_name = importance_param_group_name
# ์ดˆ๊ธฐ ํ•™์Šต๋ฅ  ์ €์žฅ
self.base_lrs = {}
for group in optimizer.param_groups:
if 'name' in group:
self.base_lrs[group['name']] = group['lr']
def get_gate_factor(self, step: int) -> float:
"""๊ฒŒ์ดํŠธ ํ•™์Šต๋ฅ  ๊ณ„์ˆ˜ ๊ณ„์‚ฐ
์›œ์—… ๊ธฐ๊ฐ„ ๋™์•ˆ์€ ๋‚ฎ์€ ํ•™์Šต๋ฅ ,
์ดํ›„ ์ •์ƒ ํ•™์Šต๋ฅ ๋กœ ์ „ํ™˜
"""
if step < self.warmup_steps:
# ์›œ์—… ๊ธฐ๊ฐ„: ์„ ํ˜• ์ฆ๊ฐ€
return step / self.warmup_steps
else:
# ์ •์ƒ ํ•™์Šต
return 1.0
def get_importance_factor(self, step: int) -> float:
"""์ค‘์š”๋„ ํ•™์Šต๋ฅ  ๊ณ„์ˆ˜ ๊ณ„์‚ฐ
๊ฒŒ์ดํŠธ๋ณด๋‹ค ๋А๋ฆฌ๊ฒŒ ํ•™์Šต ์‹œ์ž‘
"""
delayed_warmup = self.warmup_steps * 1.5
if step < delayed_warmup:
return step / delayed_warmup * 0.5
else:
return 1.0
def step(self, current_step: int):
"""์Šค์ผ€์ค„๋Ÿฌ ์Šคํ…
Args:
current_step: ํ˜„์žฌ ๊ธ€๋กœ๋ฒŒ ์Šคํ…
"""
# ๊ฒŒ์ดํŠธ ํŒŒ๋ผ๋ฏธํ„ฐ ๊ทธ๋ฃน ํ•™์Šต๋ฅ  ์กฐ์ •
gate_factor = self.get_gate_factor(current_step)
importance_factor = self.get_importance_factor(current_step)
for group in self.optimizer.param_groups:
if 'name' not in group:
continue
if group['name'] == self.gate_group_name:
# ๊ฒŒ์ดํŠธ ํ•™์Šต๋ฅ  ์กฐ์ •
group['lr'] = self.base_lrs[self.gate_group_name] * gate_factor
elif group['name'] == self.importance_group_name:
# ์ค‘์š”๋„ ํ•™์Šต๋ฅ  ์กฐ์ •
group['lr'] = self.base_lrs[self.importance_group_name] * importance_factor
def get_lr(self) -> Dict[str, float]:
"""ํ˜„์žฌ ํ•™์Šต๋ฅ  ๋ฐ˜ํ™˜"""
lrs = {}
for group in self.optimizer.param_groups:
if 'name' in group:
lrs[group['name']] = group['lr']
return lrs
class UniversalCosineScheduler:
"""Universal Cosine Annealing ์Šค์ผ€์ค„๋Ÿฌ
๋ชจ๋“  ์–ธ์–ด์— ๋Œ€ํ•ด ๋™์ผํ•œ ์Šค์ผ€์ค„ ์ ์šฉ
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_steps: int = 1000,
total_steps: int = 10000,
min_lr_ratio: float = 0.1
):
self.optimizer = optimizer
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.min_lr_ratio = min_lr_ratio
self.current_step = 0
# ์ดˆ๊ธฐ ํ•™์Šต๋ฅ  ์ €์žฅ
self.base_lrs = [group['lr'] for group in optimizer.param_groups]
def step(self):
"""์Šค์ผ€์ค„๋Ÿฌ ์Šคํ…"""
self.current_step += 1
for idx, param_group in enumerate(self.optimizer.param_groups):
if self.current_step < self.warmup_steps:
# Warmup ๋‹จ๊ณ„
lr = self.base_lrs[idx] * (self.current_step / self.warmup_steps)
else:
# Cosine annealing
if self.total_steps <= self.warmup_steps:
# warmup_steps๊ฐ€ total_steps๋ณด๋‹ค ํฌ๊ฑฐ๋‚˜ ๊ฐ™์€ ๊ฒฝ์šฐ
lr = self.base_lrs[idx] * self.min_lr_ratio
else:
progress = min(1.0, (self.current_step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps))
lr = self.base_lrs[idx] * (
self.min_lr_ratio + (1 - self.min_lr_ratio) * 0.5 * (1 + math.cos(math.pi * progress))
)
param_group['lr'] = lr
def get_last_lr(self) -> List[float]:
"""๋งˆ์ง€๋ง‰ ํ•™์Šต๋ฅ  ๋ฐ˜ํ™˜"""
return [group['lr'] for group in self.optimizer.param_groups]
def state_dict(self) -> Dict[str, Any]:
"""์Šค์ผ€์ค„๋Ÿฌ ์ƒํƒœ ๋”•์…”๋„ˆ๋ฆฌ ๋ฐ˜ํ™˜ (์ฒดํฌํฌ์ธํŠธ ์ €์žฅ์šฉ)"""
return {
'current_step': self.current_step,
'warmup_steps': self.warmup_steps,
'total_steps': self.total_steps,
'min_lr_ratio': self.min_lr_ratio,
'base_lrs': self.base_lrs
}
def load_state_dict(self, state_dict: Dict[str, Any]):
"""์Šค์ผ€์ค„๋Ÿฌ ์ƒํƒœ ๋กœ๋“œ (์ฒดํฌํฌ์ธํŠธ ์žฌ์‹œ์ž‘์šฉ)"""
self.current_step = state_dict['current_step']
self.warmup_steps = state_dict['warmup_steps']
self.total_steps = state_dict['total_steps']
self.min_lr_ratio = state_dict['min_lr_ratio']
self.base_lrs = state_dict['base_lrs']
class AdaptiveLayerScheduler:
"""๋ ˆ์ด์–ด๋ณ„ ์ ์‘์  ์Šค์ผ€์ค„๋Ÿฌ
๊ฐ ๋ ˆ์ด์–ด์˜ ํ•™์Šต ์ง„ํ–‰๋„์— ๋”ฐ๋ผ ๋™์ ์œผ๋กœ ์กฐ์ •
"""
def __init__(
self,
layer_builder,
threshold_active: float = 0.7,
threshold_skip: float = 0.3
):
"""
Args:
layer_builder: LayerBuilder ์ธ์Šคํ„ด์Šค
threshold_active: ํ™œ์„ฑ ๋ ˆ์ด์–ด ์ž„๊ณ„๊ฐ’
threshold_skip: ์Šคํ‚ต ๋ ˆ์ด์–ด ์ž„๊ณ„๊ฐ’
"""
self.layer_builder = layer_builder
self.threshold_active = threshold_active
self.threshold_skip = threshold_skip
# ๋ ˆ์ด์–ด๋ณ„ ํ†ต๊ณ„
self.layer_stats = {
'usage_count': torch.zeros(5),
'contribution': torch.zeros(5)
}
def update_stats(self, batch_output):
"""๋ฐฐ์น˜ ์ถœ๋ ฅ์œผ๋กœ ํ†ต๊ณ„ ์—…๋ฐ์ดํŠธ"""
with torch.no_grad():
gates = torch.sigmoid(self.layer_builder.layer_gates)
# ์‚ฌ์šฉ ํšŸ์ˆ˜ ์—…๋ฐ์ดํŠธ
self.layer_stats['usage_count'] += (gates > self.threshold_skip).float()
# ๊ธฐ์—ฌ๋„ ์ถ”์ • (๊ฐ„๋‹จํ•œ ๋ฒ„์ „)
importance = torch.nn.functional.softmax(
self.layer_builder.layer_importance, dim=0
)
self.layer_stats['contribution'] += importance.detach()
def get_layer_status(self) -> Dict[int, str]:
"""๊ฐ ๋ ˆ์ด์–ด์˜ ์ƒํƒœ ๋ฐ˜ํ™˜"""
gates = torch.sigmoid(self.layer_builder.layer_gates)
status = {}
for i in range(5):
if gates[i] > self.threshold_active:
status[i] = "ACTIVE"
elif gates[i] > self.threshold_skip:
status[i] = "PARTIAL"
else:
status[i] = "SKIP"
return status
def suggest_pruning(self) -> List[int]:
"""ํ”„๋ฃจ๋‹ ๊ฐ€๋Šฅํ•œ ๋ ˆ์ด์–ด ์ œ์•ˆ"""
gates = torch.sigmoid(self.layer_builder.layer_gates)
prunable = []
for i in range(5):
if gates[i] < self.threshold_skip:
# ๋‚ฎ์€ ๊ฒŒ์ดํŠธ ๊ฐ’ + ๋‚ฎ์€ ๊ธฐ์—ฌ๋„
if self.layer_stats['contribution'][i] < 0.1:
prunable.append(i)
return prunable
if __name__ == "__main__":
# Test schedulers
print("Testing Schedulers")
# Create dummy optimizer
model = torch.nn.Linear(10, 10)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Test WarmupCosineScheduler
print("\n1. WarmupCosineScheduler:")
scheduler = WarmupCosineScheduler(optimizer, warmup_steps=100, total_steps=1000)
lrs = []
for step in range(200):
lr = scheduler.step()
if step % 20 == 0:
print(f" Step {step}: LR = {lr:.6f}")
lrs.append(lr)
# Test PhaseBasedScheduler
print("\n2. PhaseBasedScheduler:")
phase_configs = [
{'epochs': (0, 30), 'lr': 1e-4, 'warmup_epochs': 5},
{'epochs': (31, 60), 'lr': 5e-5, 'warmup_epochs': 2},
{'epochs': (61, 100), 'lr': 1e-5, 'warmup_epochs': 0}
]
scheduler = PhaseBasedScheduler(optimizer, phase_configs)
for epoch in [0, 5, 31, 35, 61, 80]:
lr = scheduler.step(epoch)
print(f" Epoch {epoch}: LR = {lr:.6f}")
# Test GumbelTemperatureScheduler
print("\n3. GumbelTemperatureScheduler:")
scheduler = GumbelTemperatureScheduler()
for step in [0, 100, 500, 1000, 5000]:
for _ in range(step - scheduler.current_step):
scheduler.step()
temp = scheduler.get_temperature()
print(f" Step {step}: Temperature = {temp:.4f}")
# Test CompressionRatioScheduler
print("\n4. CompressionRatioScheduler:")
scheduler = CompressionRatioScheduler()
for epoch in [0, 5, 10, 30, 50, 80, 100]:
ratio = scheduler.step(epoch)
print(f" Epoch {epoch}: Target ratio = {ratio:.1f}:1")