|
|
""" |
|
|
Learning Rate Schedulers for v6.2.0 |
|
|
Advanced scheduling with warmup and phase-based adjustments |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import math |
|
|
from typing import Optional, Dict, List, Any |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
class WarmupCosineScheduler: |
|
|
""" |
|
|
Cosine annealing with linear warmup |
|
|
GPT-5 suggested: Essential for stable progressive splitting training |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
warmup_steps: int, |
|
|
total_steps: int, |
|
|
min_lr: float = 1e-6, |
|
|
max_lr: Optional[float] = None): |
|
|
self.optimizer = optimizer |
|
|
self.warmup_steps = warmup_steps |
|
|
self.total_steps = total_steps |
|
|
self.min_lr = min_lr |
|
|
self.max_lr = max_lr or optimizer.param_groups[0]['lr'] |
|
|
self.current_step = 0 |
|
|
|
|
|
def step(self): |
|
|
"""Update learning rate""" |
|
|
self.current_step += 1 |
|
|
|
|
|
if self.current_step <= self.warmup_steps: |
|
|
|
|
|
lr = self.max_lr * (self.current_step / self.warmup_steps) |
|
|
else: |
|
|
|
|
|
if self.total_steps <= self.warmup_steps: |
|
|
lr = self.min_lr |
|
|
else: |
|
|
progress = (self.current_step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps) |
|
|
progress = min(1.0, max(0.0, progress)) |
|
|
lr = self.min_lr + (self.max_lr - self.min_lr) * 0.5 * (1 + math.cos(math.pi * progress)) |
|
|
|
|
|
for param_group in self.optimizer.param_groups: |
|
|
param_group['lr'] = lr |
|
|
|
|
|
return lr |
|
|
|
|
|
def get_lr(self): |
|
|
"""Get current learning rate""" |
|
|
return self.optimizer.param_groups[0]['lr'] |
|
|
|
|
|
|
|
|
class PhaseBasedScheduler: |
|
|
""" |
|
|
Curriculum learning scheduler with phase transitions |
|
|
Adjusts learning rate based on training phases |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
phase_configs: List[Dict], |
|
|
current_epoch: int = 0): |
|
|
""" |
|
|
Args: |
|
|
optimizer: PyTorch optimizer |
|
|
phase_configs: List of phase configurations |
|
|
[{ |
|
|
'epochs': (start, end), |
|
|
'lr': learning_rate, |
|
|
'warmup_epochs': warmup_duration |
|
|
}, ...] |
|
|
""" |
|
|
self.optimizer = optimizer |
|
|
self.phase_configs = phase_configs |
|
|
self.current_epoch = current_epoch |
|
|
self.current_phase = 0 |
|
|
self.base_lr = optimizer.param_groups[0]['lr'] |
|
|
|
|
|
def step(self, epoch: Optional[int] = None): |
|
|
"""Update learning rate based on current phase""" |
|
|
if epoch is not None: |
|
|
self.current_epoch = epoch |
|
|
|
|
|
|
|
|
for i, phase in enumerate(self.phase_configs): |
|
|
start_epoch, end_epoch = phase['epochs'] |
|
|
if start_epoch <= self.current_epoch <= end_epoch: |
|
|
self.current_phase = i |
|
|
break |
|
|
|
|
|
phase = self.phase_configs[self.current_phase] |
|
|
target_lr = phase['lr'] |
|
|
warmup_epochs = phase.get('warmup_epochs', 0) |
|
|
start_epoch = phase['epochs'][0] |
|
|
|
|
|
|
|
|
if self.current_epoch - start_epoch < warmup_epochs: |
|
|
warmup_progress = (self.current_epoch - start_epoch + 1) / warmup_epochs |
|
|
lr = target_lr * warmup_progress |
|
|
else: |
|
|
lr = target_lr |
|
|
|
|
|
|
|
|
for param_group in self.optimizer.param_groups: |
|
|
param_group['lr'] = lr |
|
|
|
|
|
return lr |
|
|
|
|
|
|
|
|
class AdaptiveScheduler: |
|
|
""" |
|
|
Adaptive learning rate based on validation metrics |
|
|
Reduces LR when metrics plateau |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
mode: str = 'min', |
|
|
factor: float = 0.5, |
|
|
patience: int = 10, |
|
|
threshold: float = 1e-4, |
|
|
min_lr: float = 1e-7): |
|
|
""" |
|
|
Args: |
|
|
optimizer: PyTorch optimizer |
|
|
mode: 'min' or 'max' - whether to reduce LR when metric stops decreasing or increasing |
|
|
factor: Factor to reduce LR by |
|
|
patience: Number of epochs with no improvement to wait |
|
|
threshold: Minimum change to qualify as improvement |
|
|
min_lr: Minimum learning rate |
|
|
""" |
|
|
self.optimizer = optimizer |
|
|
self.mode = mode |
|
|
self.factor = factor |
|
|
self.patience = patience |
|
|
self.threshold = threshold |
|
|
self.min_lr = min_lr |
|
|
|
|
|
self.best_score = None |
|
|
self.num_bad_epochs = 0 |
|
|
self.last_reduction = 0 |
|
|
|
|
|
def step(self, metric: float, epoch: int = 0): |
|
|
"""Update learning rate based on metric""" |
|
|
current_lr = self.optimizer.param_groups[0]['lr'] |
|
|
|
|
|
if self.best_score is None: |
|
|
self.best_score = metric |
|
|
else: |
|
|
if self.mode == 'min': |
|
|
improved = metric < self.best_score - self.threshold |
|
|
else: |
|
|
improved = metric > self.best_score + self.threshold |
|
|
|
|
|
if improved: |
|
|
self.best_score = metric |
|
|
self.num_bad_epochs = 0 |
|
|
else: |
|
|
self.num_bad_epochs += 1 |
|
|
|
|
|
|
|
|
if self.num_bad_epochs >= self.patience: |
|
|
new_lr = max(current_lr * self.factor, self.min_lr) |
|
|
|
|
|
if new_lr < current_lr: |
|
|
print(f"Reducing learning rate from {current_lr:.2e} to {new_lr:.2e}") |
|
|
|
|
|
for param_group in self.optimizer.param_groups: |
|
|
param_group['lr'] = new_lr |
|
|
|
|
|
self.num_bad_epochs = 0 |
|
|
self.last_reduction = epoch |
|
|
|
|
|
return current_lr |
|
|
|
|
|
|
|
|
class ProgressiveSplittingScheduler: |
|
|
""" |
|
|
Adaptive scheduler for progressive splitting |
|
|
No fixed targets - adjusts based on quality feedback |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
initial_lr: float = 1e-4, |
|
|
min_reconstruction: float = 0.85, |
|
|
ema: float = 0.98, |
|
|
min_lr: float = 1e-7): |
|
|
self.optimizer = optimizer |
|
|
self.initial_lr = initial_lr |
|
|
self.min_reconstruction = min_reconstruction |
|
|
self.ema = ema |
|
|
self.min_lr = min_lr |
|
|
|
|
|
|
|
|
self.quality_multiplier = 1.0 |
|
|
|
|
|
|
|
|
self.current_state = 'learning' |
|
|
|
|
|
|
|
|
self._ema_comp = None |
|
|
self._ema_recon = None |
|
|
|
|
|
def step(self, metrics: Dict[str, float]): |
|
|
""" |
|
|
Update learning rate based on current metrics |
|
|
GPT fix: EMA smoothing and minimum floor |
|
|
|
|
|
Args: |
|
|
metrics: Dictionary containing: |
|
|
- compression_ratio: Current compression ratio |
|
|
- reconstruction_acc: Reconstruction accuracy |
|
|
""" |
|
|
compression_ratio = float(metrics.get('compression_ratio', 0.0)) |
|
|
reconstruction_acc = float(metrics.get('reconstruction_acc', 0.0)) |
|
|
|
|
|
|
|
|
if self._ema_comp is None: |
|
|
self._ema_comp = compression_ratio |
|
|
self._ema_recon = reconstruction_acc |
|
|
else: |
|
|
self._ema_comp = self.ema * self._ema_comp + (1 - self.ema) * compression_ratio |
|
|
self._ema_recon = self.ema * self._ema_recon + (1 - self.ema) * reconstruction_acc |
|
|
|
|
|
|
|
|
|
|
|
if self._ema_recon < self.min_reconstruction: |
|
|
|
|
|
self.quality_multiplier = max(0.5, self._ema_recon) |
|
|
else: |
|
|
|
|
|
self.quality_multiplier = 1.0 |
|
|
|
|
|
|
|
|
reconstruction_factor = max(0.1, self._ema_recon) |
|
|
|
|
|
|
|
|
lr = self.initial_lr * self.quality_multiplier * reconstruction_factor |
|
|
lr = max(lr, self.min_lr) |
|
|
|
|
|
|
|
|
for param_group in self.optimizer.param_groups: |
|
|
param_group['lr'] = lr |
|
|
|
|
|
return lr |
|
|
|
|
|
|
|
|
class GumbelTemperatureScheduler: |
|
|
""" |
|
|
Temperature annealing for Gumbel-Softmax |
|
|
GPT-5 suggestion: Critical for progressive splitting |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
initial_temp: float = 1.0, |
|
|
final_temp: float = 0.1, |
|
|
anneal_rate: float = 0.99995, |
|
|
anneal_steps: Optional[int] = None): |
|
|
self.initial_temp = initial_temp |
|
|
self.final_temp = final_temp |
|
|
self.anneal_rate = anneal_rate |
|
|
self.anneal_steps = anneal_steps |
|
|
self.current_step = 0 |
|
|
self.current_temp = initial_temp |
|
|
|
|
|
def step(self): |
|
|
"""Update temperature""" |
|
|
self.current_step += 1 |
|
|
|
|
|
if self.anneal_steps: |
|
|
|
|
|
progress = min(1.0, self.current_step / self.anneal_steps) |
|
|
self.current_temp = self.initial_temp + (self.final_temp - self.initial_temp) * progress |
|
|
else: |
|
|
|
|
|
self.current_temp = max( |
|
|
self.final_temp, |
|
|
self.initial_temp * (self.anneal_rate ** self.current_step) |
|
|
) |
|
|
|
|
|
return self.current_temp |
|
|
|
|
|
def get_temperature(self): |
|
|
"""Get current temperature""" |
|
|
return self.current_temp |
|
|
|
|
|
|
|
|
class CompressionRatioScheduler: |
|
|
""" |
|
|
Schedule target compression ratio during training |
|
|
Gradually increase compression requirements |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
initial_ratio: float = 8.0, |
|
|
target_ratio: float = 24.0, |
|
|
warmup_epochs: int = 10, |
|
|
total_epochs: int = 100): |
|
|
self.initial_ratio = initial_ratio |
|
|
self.target_ratio = target_ratio |
|
|
self.warmup_epochs = warmup_epochs |
|
|
self.total_epochs = total_epochs |
|
|
self.current_epoch = 0 |
|
|
|
|
|
def step(self, epoch: Optional[int] = None): |
|
|
"""Update target compression ratio""" |
|
|
if epoch is not None: |
|
|
self.current_epoch = epoch |
|
|
else: |
|
|
self.current_epoch += 1 |
|
|
|
|
|
if self.current_epoch < self.warmup_epochs: |
|
|
|
|
|
ratio = self.initial_ratio |
|
|
else: |
|
|
|
|
|
progress = (self.current_epoch - self.warmup_epochs) / (self.total_epochs - self.warmup_epochs) |
|
|
progress = min(1.0, progress) |
|
|
ratio = self.initial_ratio + (self.target_ratio - self.initial_ratio) * progress |
|
|
|
|
|
return ratio |
|
|
|
|
|
|
|
|
class MultiScheduler: |
|
|
""" |
|
|
Combine multiple schedulers for comprehensive training control |
|
|
""" |
|
|
|
|
|
def __init__(self, schedulers: Dict): |
|
|
""" |
|
|
Args: |
|
|
schedulers: Dictionary of schedulers |
|
|
{ |
|
|
'lr': learning_rate_scheduler, |
|
|
'gumbel': gumbel_temperature_scheduler, |
|
|
'compression': compression_ratio_scheduler, |
|
|
... |
|
|
} |
|
|
""" |
|
|
self.schedulers = schedulers |
|
|
|
|
|
def step(self, **kwargs): |
|
|
""" |
|
|
Step all schedulers |
|
|
GPT fix: unified input convention |
|
|
|
|
|
Returns: |
|
|
Dictionary with all scheduler outputs |
|
|
""" |
|
|
results = {} |
|
|
|
|
|
for name, scheduler in self.schedulers.items(): |
|
|
try: |
|
|
|
|
|
if hasattr(scheduler, '__class__'): |
|
|
class_name = scheduler.__class__.__name__ |
|
|
|
|
|
if class_name == 'AdaptiveScheduler' and 'metric' in kwargs: |
|
|
results[name] = scheduler.step(kwargs['metric'], kwargs.get('epoch', 0)) |
|
|
elif class_name == 'PhaseBasedScheduler' and 'epoch' in kwargs: |
|
|
results[name] = scheduler.step(kwargs['epoch']) |
|
|
elif class_name == 'CompressionRatioScheduler' and 'epoch' in kwargs: |
|
|
results[name] = scheduler.step(kwargs['epoch']) |
|
|
elif class_name == 'ProgressiveSplittingScheduler' and 'metrics' in kwargs: |
|
|
results[name] = scheduler.step(kwargs['metrics']) |
|
|
elif hasattr(scheduler, 'step'): |
|
|
|
|
|
results[name] = scheduler.step() |
|
|
else: |
|
|
if hasattr(scheduler, 'step'): |
|
|
results[name] = scheduler.step() |
|
|
except Exception as e: |
|
|
print(f"Warning: Scheduler '{name}' step failed: {e}") |
|
|
results[name] = None |
|
|
|
|
|
return results |
|
|
|
|
|
def get_current_values(self): |
|
|
"""Get current values from all schedulers""" |
|
|
values = {} |
|
|
|
|
|
for name, scheduler in self.schedulers.items(): |
|
|
if hasattr(scheduler, 'get_lr'): |
|
|
values[name] = scheduler.get_lr() |
|
|
elif hasattr(scheduler, 'get_temperature'): |
|
|
values[name] = scheduler.get_temperature() |
|
|
elif hasattr(scheduler, 'current_temp'): |
|
|
values[name] = scheduler.current_temp |
|
|
elif hasattr(scheduler, 'current_epoch'): |
|
|
values[name] = scheduler.current_epoch |
|
|
|
|
|
return values |
|
|
|
|
|
|
|
|
class GateWarmupScheduler: |
|
|
"""๊ฒ์ดํธ ํ๋ผ๋ฏธํฐ ์์
์ค์ผ์ค๋ฌ |
|
|
|
|
|
์ด๊ธฐ: ๋ชจ๋ ๋ ์ด์ด ๋๋ฑ ์ฌ์ฉ (gate=1.0) |
|
|
์์
: ์ ์ง์ ๊ฒ์ดํธ ํ์ต ์์ |
|
|
ํ๊ธฐ: ์ต์ ๊ฒ์ดํธ ๊ฐ์ผ๋ก ์๋ ด |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
warmup_steps: int = 1000, |
|
|
gate_param_group_name: str = 'gates', |
|
|
importance_param_group_name: str = 'importance' |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
optimizer: ์ตํฐ๋ง์ด์ |
|
|
warmup_steps: ์์
์คํ
์ |
|
|
gate_param_group_name: ๊ฒ์ดํธ ํ๋ผ๋ฏธํฐ ๊ทธ๋ฃน ์ด๋ฆ |
|
|
importance_param_group_name: ์ค์๋ ํ๋ผ๋ฏธํฐ ๊ทธ๋ฃน ์ด๋ฆ |
|
|
""" |
|
|
self.optimizer = optimizer |
|
|
self.warmup_steps = warmup_steps |
|
|
self.gate_group_name = gate_param_group_name |
|
|
self.importance_group_name = importance_param_group_name |
|
|
|
|
|
|
|
|
self.base_lrs = {} |
|
|
for group in optimizer.param_groups: |
|
|
if 'name' in group: |
|
|
self.base_lrs[group['name']] = group['lr'] |
|
|
|
|
|
def get_gate_factor(self, step: int) -> float: |
|
|
"""๊ฒ์ดํธ ํ์ต๋ฅ ๊ณ์ ๊ณ์ฐ |
|
|
|
|
|
์์
๊ธฐ๊ฐ ๋์์ ๋ฎ์ ํ์ต๋ฅ , |
|
|
์ดํ ์ ์ ํ์ต๋ฅ ๋ก ์ ํ |
|
|
""" |
|
|
if step < self.warmup_steps: |
|
|
|
|
|
return step / self.warmup_steps |
|
|
else: |
|
|
|
|
|
return 1.0 |
|
|
|
|
|
def get_importance_factor(self, step: int) -> float: |
|
|
"""์ค์๋ ํ์ต๋ฅ ๊ณ์ ๊ณ์ฐ |
|
|
|
|
|
๊ฒ์ดํธ๋ณด๋ค ๋๋ฆฌ๊ฒ ํ์ต ์์ |
|
|
""" |
|
|
delayed_warmup = self.warmup_steps * 1.5 |
|
|
if step < delayed_warmup: |
|
|
return step / delayed_warmup * 0.5 |
|
|
else: |
|
|
return 1.0 |
|
|
|
|
|
def step(self, current_step: int): |
|
|
"""์ค์ผ์ค๋ฌ ์คํ
|
|
|
|
|
|
Args: |
|
|
current_step: ํ์ฌ ๊ธ๋ก๋ฒ ์คํ
|
|
|
""" |
|
|
|
|
|
gate_factor = self.get_gate_factor(current_step) |
|
|
importance_factor = self.get_importance_factor(current_step) |
|
|
|
|
|
for group in self.optimizer.param_groups: |
|
|
if 'name' not in group: |
|
|
continue |
|
|
|
|
|
if group['name'] == self.gate_group_name: |
|
|
|
|
|
group['lr'] = self.base_lrs[self.gate_group_name] * gate_factor |
|
|
|
|
|
elif group['name'] == self.importance_group_name: |
|
|
|
|
|
group['lr'] = self.base_lrs[self.importance_group_name] * importance_factor |
|
|
|
|
|
def get_lr(self) -> Dict[str, float]: |
|
|
"""ํ์ฌ ํ์ต๋ฅ ๋ฐํ""" |
|
|
lrs = {} |
|
|
for group in self.optimizer.param_groups: |
|
|
if 'name' in group: |
|
|
lrs[group['name']] = group['lr'] |
|
|
return lrs |
|
|
|
|
|
|
|
|
class UniversalCosineScheduler: |
|
|
"""Universal Cosine Annealing ์ค์ผ์ค๋ฌ |
|
|
|
|
|
๋ชจ๋ ์ธ์ด์ ๋ํด ๋์ผํ ์ค์ผ์ค ์ ์ฉ |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
warmup_steps: int = 1000, |
|
|
total_steps: int = 10000, |
|
|
min_lr_ratio: float = 0.1 |
|
|
): |
|
|
self.optimizer = optimizer |
|
|
self.warmup_steps = warmup_steps |
|
|
self.total_steps = total_steps |
|
|
self.min_lr_ratio = min_lr_ratio |
|
|
self.current_step = 0 |
|
|
|
|
|
|
|
|
self.base_lrs = [group['lr'] for group in optimizer.param_groups] |
|
|
|
|
|
def step(self): |
|
|
"""์ค์ผ์ค๋ฌ ์คํ
""" |
|
|
self.current_step += 1 |
|
|
|
|
|
for idx, param_group in enumerate(self.optimizer.param_groups): |
|
|
if self.current_step < self.warmup_steps: |
|
|
|
|
|
lr = self.base_lrs[idx] * (self.current_step / self.warmup_steps) |
|
|
else: |
|
|
|
|
|
if self.total_steps <= self.warmup_steps: |
|
|
|
|
|
lr = self.base_lrs[idx] * self.min_lr_ratio |
|
|
else: |
|
|
progress = min(1.0, (self.current_step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)) |
|
|
lr = self.base_lrs[idx] * ( |
|
|
self.min_lr_ratio + (1 - self.min_lr_ratio) * 0.5 * (1 + math.cos(math.pi * progress)) |
|
|
) |
|
|
|
|
|
param_group['lr'] = lr |
|
|
|
|
|
def get_last_lr(self) -> List[float]: |
|
|
"""๋ง์ง๋ง ํ์ต๋ฅ ๋ฐํ""" |
|
|
return [group['lr'] for group in self.optimizer.param_groups] |
|
|
|
|
|
def state_dict(self) -> Dict[str, Any]: |
|
|
"""์ค์ผ์ค๋ฌ ์ํ ๋์
๋๋ฆฌ ๋ฐํ (์ฒดํฌํฌ์ธํธ ์ ์ฅ์ฉ)""" |
|
|
return { |
|
|
'current_step': self.current_step, |
|
|
'warmup_steps': self.warmup_steps, |
|
|
'total_steps': self.total_steps, |
|
|
'min_lr_ratio': self.min_lr_ratio, |
|
|
'base_lrs': self.base_lrs |
|
|
} |
|
|
|
|
|
def load_state_dict(self, state_dict: Dict[str, Any]): |
|
|
"""์ค์ผ์ค๋ฌ ์ํ ๋ก๋ (์ฒดํฌํฌ์ธํธ ์ฌ์์์ฉ)""" |
|
|
self.current_step = state_dict['current_step'] |
|
|
self.warmup_steps = state_dict['warmup_steps'] |
|
|
self.total_steps = state_dict['total_steps'] |
|
|
self.min_lr_ratio = state_dict['min_lr_ratio'] |
|
|
self.base_lrs = state_dict['base_lrs'] |
|
|
|
|
|
|
|
|
class AdaptiveLayerScheduler: |
|
|
"""๋ ์ด์ด๋ณ ์ ์์ ์ค์ผ์ค๋ฌ |
|
|
|
|
|
๊ฐ ๋ ์ด์ด์ ํ์ต ์งํ๋์ ๋ฐ๋ผ ๋์ ์ผ๋ก ์กฐ์ |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
layer_builder, |
|
|
threshold_active: float = 0.7, |
|
|
threshold_skip: float = 0.3 |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
layer_builder: LayerBuilder ์ธ์คํด์ค |
|
|
threshold_active: ํ์ฑ ๋ ์ด์ด ์๊ณ๊ฐ |
|
|
threshold_skip: ์คํต ๋ ์ด์ด ์๊ณ๊ฐ |
|
|
""" |
|
|
self.layer_builder = layer_builder |
|
|
self.threshold_active = threshold_active |
|
|
self.threshold_skip = threshold_skip |
|
|
|
|
|
|
|
|
self.layer_stats = { |
|
|
'usage_count': torch.zeros(5), |
|
|
'contribution': torch.zeros(5) |
|
|
} |
|
|
|
|
|
def update_stats(self, batch_output): |
|
|
"""๋ฐฐ์น ์ถ๋ ฅ์ผ๋ก ํต๊ณ ์
๋ฐ์ดํธ""" |
|
|
with torch.no_grad(): |
|
|
gates = torch.sigmoid(self.layer_builder.layer_gates) |
|
|
|
|
|
|
|
|
self.layer_stats['usage_count'] += (gates > self.threshold_skip).float() |
|
|
|
|
|
|
|
|
importance = torch.nn.functional.softmax( |
|
|
self.layer_builder.layer_importance, dim=0 |
|
|
) |
|
|
self.layer_stats['contribution'] += importance.detach() |
|
|
|
|
|
def get_layer_status(self) -> Dict[int, str]: |
|
|
"""๊ฐ ๋ ์ด์ด์ ์ํ ๋ฐํ""" |
|
|
gates = torch.sigmoid(self.layer_builder.layer_gates) |
|
|
status = {} |
|
|
|
|
|
for i in range(5): |
|
|
if gates[i] > self.threshold_active: |
|
|
status[i] = "ACTIVE" |
|
|
elif gates[i] > self.threshold_skip: |
|
|
status[i] = "PARTIAL" |
|
|
else: |
|
|
status[i] = "SKIP" |
|
|
|
|
|
return status |
|
|
|
|
|
def suggest_pruning(self) -> List[int]: |
|
|
"""ํ๋ฃจ๋ ๊ฐ๋ฅํ ๋ ์ด์ด ์ ์""" |
|
|
gates = torch.sigmoid(self.layer_builder.layer_gates) |
|
|
prunable = [] |
|
|
|
|
|
for i in range(5): |
|
|
if gates[i] < self.threshold_skip: |
|
|
|
|
|
if self.layer_stats['contribution'][i] < 0.1: |
|
|
prunable.append(i) |
|
|
|
|
|
return prunable |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
print("Testing Schedulers") |
|
|
|
|
|
|
|
|
model = torch.nn.Linear(10, 10) |
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) |
|
|
|
|
|
|
|
|
print("\n1. WarmupCosineScheduler:") |
|
|
scheduler = WarmupCosineScheduler(optimizer, warmup_steps=100, total_steps=1000) |
|
|
lrs = [] |
|
|
for step in range(200): |
|
|
lr = scheduler.step() |
|
|
if step % 20 == 0: |
|
|
print(f" Step {step}: LR = {lr:.6f}") |
|
|
lrs.append(lr) |
|
|
|
|
|
|
|
|
print("\n2. PhaseBasedScheduler:") |
|
|
phase_configs = [ |
|
|
{'epochs': (0, 30), 'lr': 1e-4, 'warmup_epochs': 5}, |
|
|
{'epochs': (31, 60), 'lr': 5e-5, 'warmup_epochs': 2}, |
|
|
{'epochs': (61, 100), 'lr': 1e-5, 'warmup_epochs': 0} |
|
|
] |
|
|
scheduler = PhaseBasedScheduler(optimizer, phase_configs) |
|
|
for epoch in [0, 5, 31, 35, 61, 80]: |
|
|
lr = scheduler.step(epoch) |
|
|
print(f" Epoch {epoch}: LR = {lr:.6f}") |
|
|
|
|
|
|
|
|
print("\n3. GumbelTemperatureScheduler:") |
|
|
scheduler = GumbelTemperatureScheduler() |
|
|
for step in [0, 100, 500, 1000, 5000]: |
|
|
for _ in range(step - scheduler.current_step): |
|
|
scheduler.step() |
|
|
temp = scheduler.get_temperature() |
|
|
print(f" Step {step}: Temperature = {temp:.4f}") |
|
|
|
|
|
|
|
|
print("\n4. CompressionRatioScheduler:") |
|
|
scheduler = CompressionRatioScheduler() |
|
|
for epoch in [0, 5, 10, 30, 50, 80, 100]: |
|
|
ratio = scheduler.step(epoch) |
|
|
print(f" Epoch {epoch}: Target ratio = {ratio:.1f}:1") |
|
|
|