|
|
""" |
|
|
MobiusNet Trainer with TensorBoard, SafeTensors, and HuggingFace Upload |
|
|
======================================================================= |
|
|
""" |
|
|
|
|
|
import os |
|
|
import re |
|
|
import json |
|
|
import math |
|
|
import shutil |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch import Tensor |
|
|
from typing import Tuple, Optional, Dict, Any |
|
|
from torchvision import datasets, transforms |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
from tqdm.auto import tqdm |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
from safetensors.torch import save_file as save_safetensors, load_file as load_safetensors |
|
|
from huggingface_hub import HfApi, login |
|
|
|
|
|
|
|
|
try: |
|
|
from google.colab import userdata |
|
|
token = userdata.get('HF_TOKEN') |
|
|
os.environ['HF_TOKEN'] = token |
|
|
login(token=token) |
|
|
print("Logged in to HuggingFace via Colab") |
|
|
except: |
|
|
|
|
|
pass |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
print(f"Device: {device}") |
|
|
|
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
torch.set_float32_matmul_precision('high') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MobiusLens(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
layer_idx: int, |
|
|
total_layers: int, |
|
|
scale_range: Tuple[float, float] = (1.0, 9.0), |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.dim = dim |
|
|
self.layer_idx = layer_idx |
|
|
self.total_layers = total_layers |
|
|
self.t = layer_idx / max(total_layers - 1, 1) |
|
|
|
|
|
scale_span = scale_range[1] - scale_range[0] |
|
|
step = scale_span / max(total_layers, 1) |
|
|
scale_low = scale_range[0] + self.t * scale_span |
|
|
scale_high = scale_low + step |
|
|
|
|
|
self.register_buffer('scales', torch.tensor([scale_low, scale_high])) |
|
|
|
|
|
self.twist_in_angle = nn.Parameter(torch.tensor(self.t * math.pi)) |
|
|
self.twist_in_proj = nn.Linear(dim, dim, bias=False) |
|
|
nn.init.orthogonal_(self.twist_in_proj.weight) |
|
|
|
|
|
self.omega = nn.Parameter(torch.tensor(math.pi)) |
|
|
self.alpha = nn.Parameter(torch.tensor(1.5)) |
|
|
|
|
|
self.phase_l = nn.Parameter(torch.zeros(2)) |
|
|
self.drift_l = nn.Parameter(torch.ones(2)) |
|
|
self.phase_m = nn.Parameter(torch.zeros(2)) |
|
|
self.drift_m = nn.Parameter(torch.zeros(2)) |
|
|
self.phase_r = nn.Parameter(torch.zeros(2)) |
|
|
self.drift_r = nn.Parameter(-torch.ones(2)) |
|
|
|
|
|
self.accum_weights = nn.Parameter(torch.tensor([0.4, 0.2, 0.4])) |
|
|
self.xor_weight = nn.Parameter(torch.tensor(0.7)) |
|
|
|
|
|
self.gate_norm = nn.LayerNorm(dim) |
|
|
|
|
|
self.twist_out_angle = nn.Parameter(torch.tensor(-self.t * math.pi)) |
|
|
self.twist_out_proj = nn.Linear(dim, dim, bias=False) |
|
|
nn.init.orthogonal_(self.twist_out_proj.weight) |
|
|
|
|
|
def _twist_in(self, x: Tensor) -> Tensor: |
|
|
cos_t = torch.cos(self.twist_in_angle) |
|
|
sin_t = torch.sin(self.twist_in_angle) |
|
|
return x * cos_t + self.twist_in_proj(x) * sin_t |
|
|
|
|
|
def _center_lens(self, x: Tensor) -> Tensor: |
|
|
x_norm = torch.tanh(x) |
|
|
t = x_norm.abs().mean(dim=-1, keepdim=True).unsqueeze(-2) |
|
|
|
|
|
x_exp = x_norm.unsqueeze(-2) |
|
|
s = self.scales.view(-1, 1) |
|
|
|
|
|
def wave(phase, drift): |
|
|
a = self.alpha.abs() + 0.1 |
|
|
pos = s * self.omega * (x_exp + drift.view(-1, 1) * t) + phase.view(-1, 1) |
|
|
return torch.exp(-a * torch.sin(pos).pow(2)).prod(dim=-2) |
|
|
|
|
|
L = wave(self.phase_l, self.drift_l) |
|
|
M = wave(self.phase_m, self.drift_m) |
|
|
R = wave(self.phase_r, self.drift_r) |
|
|
|
|
|
w = torch.softmax(self.accum_weights, dim=0) |
|
|
xor_w = torch.sigmoid(self.xor_weight) |
|
|
|
|
|
xor_comp = (L + R - 2 * L * R).abs() |
|
|
and_comp = L * R |
|
|
lr = xor_w * xor_comp + (1 - xor_w) * and_comp |
|
|
|
|
|
gate = w[0] * L + w[1] * M + w[2] * R |
|
|
gate = gate * (0.5 + 0.5 * lr) |
|
|
gate = torch.sigmoid(self.gate_norm(gate)) |
|
|
|
|
|
return x * gate |
|
|
|
|
|
def _twist_out(self, x: Tensor) -> Tensor: |
|
|
cos_t = torch.cos(self.twist_out_angle) |
|
|
sin_t = torch.sin(self.twist_out_angle) |
|
|
return x * cos_t + self.twist_out_proj(x) * sin_t |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
return self._twist_out(self._center_lens(self._twist_in(x))) |
|
|
|
|
|
def get_lens_stats(self) -> Dict[str, float]: |
|
|
"""Return lens parameters for logging.""" |
|
|
return { |
|
|
'omega': self.omega.item(), |
|
|
'alpha': self.alpha.item(), |
|
|
'twist_in_angle': self.twist_in_angle.item(), |
|
|
'twist_out_angle': self.twist_out_angle.item(), |
|
|
'xor_weight': torch.sigmoid(self.xor_weight).item(), |
|
|
'accum_weights_l': torch.softmax(self.accum_weights, dim=0)[0].item(), |
|
|
'accum_weights_m': torch.softmax(self.accum_weights, dim=0)[1].item(), |
|
|
'accum_weights_r': torch.softmax(self.accum_weights, dim=0)[2].item(), |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MobiusConvBlock(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
channels: int, |
|
|
layer_idx: int, |
|
|
total_layers: int, |
|
|
scale_range: Tuple[float, float] = (1.0, 9.0), |
|
|
reduction: float = 0.5, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.conv = nn.Sequential( |
|
|
nn.Conv2d(channels, channels, 3, padding=1, groups=channels, bias=False), |
|
|
nn.Conv2d(channels, channels, 1, bias=False), |
|
|
nn.BatchNorm2d(channels), |
|
|
) |
|
|
|
|
|
self.lens = MobiusLens(channels, layer_idx, total_layers, scale_range) |
|
|
|
|
|
third = channels // 3 |
|
|
which_third = layer_idx % 3 |
|
|
mask = torch.ones(channels) |
|
|
start = which_third * third |
|
|
end = start + third + (channels % 3 if which_third == 2 else 0) |
|
|
mask[start:end] = reduction |
|
|
self.register_buffer('thirds_mask', mask.view(1, -1, 1, 1)) |
|
|
|
|
|
self.residual_weight = nn.Parameter(torch.tensor(0.9)) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
identity = x |
|
|
|
|
|
h = self.conv(x) |
|
|
B, D, H, W = h.shape |
|
|
h = h.permute(0, 2, 3, 1) |
|
|
h = self.lens(h) |
|
|
h = h.permute(0, 3, 1, 2) |
|
|
h = h * self.thirds_mask |
|
|
|
|
|
rw = torch.sigmoid(self.residual_weight) |
|
|
return rw * identity + (1 - rw) * h |
|
|
|
|
|
def get_residual_weight(self) -> float: |
|
|
return torch.sigmoid(self.residual_weight).item() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MobiusNet(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_chans: int = 3, |
|
|
num_classes: int = 200, |
|
|
channels: Tuple[int, ...] = (64, 128, 256, 512), |
|
|
depths: Tuple[int, ...] = (2, 2, 2, 2), |
|
|
scale_range: Tuple[float, float] = (0.5, 2.5), |
|
|
use_integrator: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
num_stages = len(depths) |
|
|
total_layers = sum(depths) |
|
|
|
|
|
self.total_layers = total_layers |
|
|
self.scale_range = scale_range |
|
|
self.channels = tuple(channels) |
|
|
self.depths = tuple(depths) |
|
|
self.num_stages = num_stages |
|
|
self.use_integrator = use_integrator |
|
|
self.num_classes = num_classes |
|
|
self.in_chans = in_chans |
|
|
|
|
|
channels = list(channels) |
|
|
while len(channels) < num_stages: |
|
|
channels.append(channels[-1]) |
|
|
|
|
|
self.stem = nn.Sequential( |
|
|
nn.Conv2d(in_chans, channels[0], 3, stride=1, padding=1, bias=False), |
|
|
nn.BatchNorm2d(channels[0]), |
|
|
) |
|
|
|
|
|
layer_idx = 0 |
|
|
self.stages = nn.ModuleList() |
|
|
self.downsamples = nn.ModuleList() |
|
|
|
|
|
for stage_idx in range(num_stages): |
|
|
ch = channels[stage_idx] |
|
|
|
|
|
stage = nn.ModuleList() |
|
|
for _ in range(depths[stage_idx]): |
|
|
stage.append(MobiusConvBlock(ch, layer_idx, total_layers, scale_range)) |
|
|
layer_idx += 1 |
|
|
self.stages.append(stage) |
|
|
|
|
|
if stage_idx < num_stages - 1: |
|
|
ch_next = channels[stage_idx + 1] |
|
|
self.downsamples.append(nn.Sequential( |
|
|
nn.Conv2d(ch, ch_next, 3, stride=2, padding=1, bias=False), |
|
|
nn.BatchNorm2d(ch_next), |
|
|
)) |
|
|
|
|
|
final_ch = channels[num_stages - 1] |
|
|
if use_integrator: |
|
|
self.integrator = nn.Sequential( |
|
|
nn.Conv2d(final_ch, final_ch, 3, padding=1, bias=False), |
|
|
nn.BatchNorm2d(final_ch), |
|
|
nn.GELU(), |
|
|
) |
|
|
else: |
|
|
self.integrator = nn.Identity() |
|
|
|
|
|
self.pool = nn.AdaptiveAvgPool2d(1) |
|
|
self.head = nn.Linear(final_ch, num_classes) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
x = self.stem(x) |
|
|
|
|
|
for i, stage in enumerate(self.stages): |
|
|
for block in stage: |
|
|
x = block(x) |
|
|
if i < len(self.downsamples): |
|
|
x = self.downsamples[i](x) |
|
|
|
|
|
x = self.integrator(x) |
|
|
return self.head(self.pool(x).flatten(1)) |
|
|
|
|
|
def get_config(self) -> Dict[str, Any]: |
|
|
"""Return model configuration for saving.""" |
|
|
return { |
|
|
'in_chans': self.in_chans, |
|
|
'num_classes': self.num_classes, |
|
|
'channels': self.channels, |
|
|
'depths': self.depths, |
|
|
'scale_range': self.scale_range, |
|
|
'use_integrator': self.use_integrator, |
|
|
'total_layers': self.total_layers, |
|
|
'num_stages': self.num_stages, |
|
|
} |
|
|
|
|
|
def get_all_lens_stats(self) -> Dict[str, Dict[str, float]]: |
|
|
"""Return stats from all lenses for logging.""" |
|
|
stats = {} |
|
|
layer_idx = 0 |
|
|
for stage_idx, stage in enumerate(self.stages): |
|
|
for block_idx, block in enumerate(stage): |
|
|
key = f"stage{stage_idx}_block{block_idx}" |
|
|
stats[key] = block.lens.get_lens_stats() |
|
|
stats[key]['residual_weight'] = block.get_residual_weight() |
|
|
layer_idx += 1 |
|
|
return stats |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_tiny_imagenet_loaders(data_dir='./data/tiny-imagenet-200', batch_size=128): |
|
|
train_dir = os.path.join(data_dir, 'train') |
|
|
val_dir = os.path.join(data_dir, 'val') |
|
|
|
|
|
val_images_dir = os.path.join(val_dir, 'images') |
|
|
if os.path.exists(val_images_dir): |
|
|
print("Reorganizing validation folder...") |
|
|
reorganize_val_folder(val_dir) |
|
|
|
|
|
train_transform = transforms.Compose([ |
|
|
transforms.RandomCrop(64, padding=8), |
|
|
transforms.RandomHorizontalFlip(), |
|
|
transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
|
|
]) |
|
|
|
|
|
val_transform = transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
|
|
]) |
|
|
|
|
|
train_dataset = datasets.ImageFolder(train_dir, transform=train_transform) |
|
|
val_dataset = datasets.ImageFolder(val_dir, transform=val_transform) |
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, batch_size=batch_size, shuffle=True, |
|
|
num_workers=8, pin_memory=True, persistent_workers=True |
|
|
) |
|
|
val_loader = DataLoader( |
|
|
val_dataset, batch_size=256, shuffle=False, |
|
|
num_workers=4, pin_memory=True, persistent_workers=True |
|
|
) |
|
|
|
|
|
return train_loader, val_loader |
|
|
|
|
|
|
|
|
def reorganize_val_folder(val_dir): |
|
|
"""Reorganize Tiny ImageNet val folder into class subfolders.""" |
|
|
val_images_dir = os.path.join(val_dir, 'images') |
|
|
val_annotations = os.path.join(val_dir, 'val_annotations.txt') |
|
|
|
|
|
if not os.path.exists(val_images_dir): |
|
|
return |
|
|
|
|
|
with open(val_annotations, 'r') as f: |
|
|
for line in f: |
|
|
parts = line.strip().split('\t') |
|
|
img_name, class_id = parts[0], parts[1] |
|
|
|
|
|
class_dir = os.path.join(val_dir, class_id) |
|
|
os.makedirs(class_dir, exist_ok=True) |
|
|
|
|
|
src = os.path.join(val_images_dir, img_name) |
|
|
dst = os.path.join(class_dir, img_name) |
|
|
|
|
|
if os.path.exists(src): |
|
|
shutil.move(src, dst) |
|
|
|
|
|
if os.path.exists(val_images_dir): |
|
|
shutil.rmtree(val_images_dir) |
|
|
if os.path.exists(val_annotations): |
|
|
os.remove(val_annotations) |
|
|
|
|
|
print("Validation folder reorganized.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PRESETS = { |
|
|
'mobius_tiny_s': { |
|
|
'channels': (64, 128, 256), |
|
|
'depths': (2, 2, 2), |
|
|
'scale_range': (0.5, 2.5), |
|
|
}, |
|
|
'mobius_tiny_m': { |
|
|
'channels': (64, 128, 256, 512, 768), |
|
|
'depths': (2, 2, 4, 2, 2), |
|
|
'scale_range': (0.25, 2.75), |
|
|
}, |
|
|
'mobius_tiny_l': { |
|
|
'channels': (96, 192, 384, 768), |
|
|
'depths': (3, 3, 3, 3), |
|
|
'scale_range': (0.5, 3.5), |
|
|
}, |
|
|
'mobius_base': { |
|
|
'channels': (128, 256, 512, 768, 1024), |
|
|
'depths': (2, 2, 2, 2, 2), |
|
|
'scale_range': (0.25, 2.75), |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CheckpointManager: |
|
|
def __init__( |
|
|
self, |
|
|
base_dir: str, |
|
|
variant_name: str, |
|
|
dataset_name: str, |
|
|
hf_repo: str = "AbstractPhil/mobiusnet", |
|
|
upload_every_n_epochs: int = 10, |
|
|
save_every_n_epochs: int = 10, |
|
|
timestamp: Optional[str] = None, |
|
|
): |
|
|
self.timestamp = timestamp or datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
self.variant_name = variant_name |
|
|
self.dataset_name = dataset_name |
|
|
self.hf_repo = hf_repo |
|
|
self.upload_every_n_epochs = upload_every_n_epochs |
|
|
self.save_every_n_epochs = save_every_n_epochs |
|
|
|
|
|
|
|
|
self.run_name = f"{variant_name}_{dataset_name}" |
|
|
self.run_dir = Path(base_dir) / "checkpoints" / self.run_name / self.timestamp |
|
|
self.checkpoints_dir = self.run_dir / "checkpoints" |
|
|
self.tensorboard_dir = self.run_dir / "tensorboard" |
|
|
|
|
|
|
|
|
self.checkpoints_dir.mkdir(parents=True, exist_ok=True) |
|
|
self.tensorboard_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
self.writer = SummaryWriter(log_dir=str(self.tensorboard_dir)) |
|
|
|
|
|
|
|
|
self.hf_api = HfApi() |
|
|
self.uploaded_files = set() |
|
|
|
|
|
|
|
|
self.best_acc = 0.0 |
|
|
self.best_epoch = 0 |
|
|
self.best_changed_since_upload = False |
|
|
|
|
|
print(f"Checkpoint directory: {self.run_dir}") |
|
|
|
|
|
@staticmethod |
|
|
def extract_timestamp(checkpoint_path: str) -> Optional[str]: |
|
|
"""Extract timestamp from checkpoint path.""" |
|
|
|
|
|
match = re.search(r'(\d{8}_\d{6})', checkpoint_path) |
|
|
if match: |
|
|
return match.group(1) |
|
|
return None |
|
|
|
|
|
def save_config(self, config: Dict[str, Any], training_config: Dict[str, Any]): |
|
|
"""Save model and training configuration.""" |
|
|
full_config = { |
|
|
'model': config, |
|
|
'training': training_config, |
|
|
'timestamp': self.timestamp, |
|
|
'variant_name': self.variant_name, |
|
|
'dataset_name': self.dataset_name, |
|
|
} |
|
|
|
|
|
config_path = self.run_dir / "config.json" |
|
|
with open(config_path, 'w') as f: |
|
|
json.dump(full_config, f, indent=2) |
|
|
|
|
|
return config_path |
|
|
|
|
|
def save_checkpoint( |
|
|
self, |
|
|
model: nn.Module, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
scheduler: Any, |
|
|
epoch: int, |
|
|
train_acc: float, |
|
|
val_acc: float, |
|
|
train_loss: float, |
|
|
is_best: bool = False, |
|
|
): |
|
|
"""Save checkpoint every N epochs, always save best (overwriting).""" |
|
|
|
|
|
|
|
|
raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model |
|
|
|
|
|
|
|
|
checkpoint = { |
|
|
'epoch': epoch, |
|
|
'train_acc': train_acc, |
|
|
'val_acc': val_acc, |
|
|
'train_loss': train_loss, |
|
|
'best_acc': self.best_acc, |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'scheduler_state_dict': scheduler.state_dict(), |
|
|
} |
|
|
|
|
|
|
|
|
if epoch % self.save_every_n_epochs == 0: |
|
|
epoch_pt_path = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.pt" |
|
|
torch.save({**checkpoint, 'model_state_dict': raw_model.state_dict()}, epoch_pt_path) |
|
|
|
|
|
epoch_st_path = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.safetensors" |
|
|
save_safetensors(raw_model.state_dict(), str(epoch_st_path)) |
|
|
|
|
|
|
|
|
if is_best: |
|
|
self.best_acc = val_acc |
|
|
self.best_epoch = epoch |
|
|
self.best_changed_since_upload = True |
|
|
|
|
|
|
|
|
best_pt_path = self.checkpoints_dir / "best_model.pt" |
|
|
torch.save({**checkpoint, 'model_state_dict': raw_model.state_dict()}, best_pt_path) |
|
|
|
|
|
|
|
|
best_st_path = self.checkpoints_dir / "best_model.safetensors" |
|
|
save_safetensors(raw_model.state_dict(), str(best_st_path)) |
|
|
|
|
|
|
|
|
acc_path = self.run_dir / "best_accuracy.json" |
|
|
with open(acc_path, 'w') as f: |
|
|
json.dump({ |
|
|
'best_acc': val_acc, |
|
|
'best_epoch': epoch, |
|
|
'train_acc': train_acc, |
|
|
'train_loss': train_loss, |
|
|
}, f, indent=2) |
|
|
|
|
|
def save_final(self, model: nn.Module, final_acc: float, final_epoch: int): |
|
|
"""Save final model.""" |
|
|
raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model |
|
|
|
|
|
|
|
|
final_st_path = self.checkpoints_dir / "final_model.safetensors" |
|
|
save_safetensors(raw_model.state_dict(), str(final_st_path)) |
|
|
|
|
|
|
|
|
final_pt_path = self.checkpoints_dir / "final_model.pt" |
|
|
torch.save({ |
|
|
'model_state_dict': raw_model.state_dict(), |
|
|
'final_acc': final_acc, |
|
|
'final_epoch': final_epoch, |
|
|
'best_acc': self.best_acc, |
|
|
'best_epoch': self.best_epoch, |
|
|
}, final_pt_path) |
|
|
|
|
|
|
|
|
acc_path = self.run_dir / "final_accuracy.json" |
|
|
with open(acc_path, 'w') as f: |
|
|
json.dump({ |
|
|
'final_acc': final_acc, |
|
|
'final_epoch': final_epoch, |
|
|
'best_acc': self.best_acc, |
|
|
'best_epoch': self.best_epoch, |
|
|
}, f, indent=2) |
|
|
|
|
|
return final_st_path, final_pt_path |
|
|
|
|
|
def log_scalars(self, epoch: int, scalars: Dict[str, float], prefix: str = ""): |
|
|
"""Log scalars to TensorBoard.""" |
|
|
for name, value in scalars.items(): |
|
|
tag = f"{prefix}/{name}" if prefix else name |
|
|
self.writer.add_scalar(tag, value, epoch) |
|
|
|
|
|
def log_lens_stats(self, epoch: int, model: nn.Module): |
|
|
"""Log lens statistics to TensorBoard.""" |
|
|
raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model |
|
|
stats = raw_model.get_all_lens_stats() |
|
|
|
|
|
for block_name, block_stats in stats.items(): |
|
|
for stat_name, value in block_stats.items(): |
|
|
self.writer.add_scalar(f"lens/{block_name}/{stat_name}", value, epoch) |
|
|
|
|
|
def log_histograms(self, epoch: int, model: nn.Module): |
|
|
"""Log weight histograms to TensorBoard.""" |
|
|
raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model |
|
|
|
|
|
for name, param in raw_model.named_parameters(): |
|
|
if param.requires_grad: |
|
|
self.writer.add_histogram(f"weights/{name}", param.data, epoch) |
|
|
if param.grad is not None: |
|
|
self.writer.add_histogram(f"gradients/{name}", param.grad, epoch) |
|
|
|
|
|
def upload_to_hf(self, epoch: int, force: bool = False): |
|
|
"""Upload checkpoint every N epochs. Best uploads only on upload epochs if changed.""" |
|
|
if not force and epoch % self.upload_every_n_epochs != 0: |
|
|
return |
|
|
|
|
|
try: |
|
|
hf_base_path = f"checkpoints/{self.run_name}/{self.timestamp}" |
|
|
|
|
|
files_to_upload = [] |
|
|
|
|
|
|
|
|
config_path = self.run_dir / "config.json" |
|
|
if config_path.exists(): |
|
|
files_to_upload.append(config_path) |
|
|
|
|
|
|
|
|
if epoch % self.save_every_n_epochs == 0: |
|
|
ckpt_st = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.safetensors" |
|
|
ckpt_pt = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.pt" |
|
|
if ckpt_st.exists(): |
|
|
files_to_upload.append(ckpt_st) |
|
|
if ckpt_pt.exists(): |
|
|
files_to_upload.append(ckpt_pt) |
|
|
|
|
|
|
|
|
if self.best_changed_since_upload: |
|
|
best_files = [ |
|
|
self.checkpoints_dir / "best_model.safetensors", |
|
|
self.checkpoints_dir / "best_model.pt", |
|
|
self.run_dir / "best_accuracy.json", |
|
|
] |
|
|
for f in best_files: |
|
|
if f.exists(): |
|
|
files_to_upload.append(f) |
|
|
self.best_changed_since_upload = False |
|
|
|
|
|
|
|
|
for local_path in files_to_upload: |
|
|
rel_path = local_path.relative_to(self.run_dir) |
|
|
hf_path = f"{hf_base_path}/{rel_path}" |
|
|
|
|
|
try: |
|
|
self.hf_api.upload_file( |
|
|
path_or_fileobj=str(local_path), |
|
|
path_in_repo=hf_path, |
|
|
repo_id=self.hf_repo, |
|
|
repo_type="model", |
|
|
) |
|
|
print(f"Uploaded: {hf_path}") |
|
|
except Exception as e: |
|
|
print(f"Failed to upload {rel_path}: {e}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"HuggingFace upload error: {e}") |
|
|
|
|
|
def close(self): |
|
|
"""Close TensorBoard writer.""" |
|
|
self.writer.close() |
|
|
|
|
|
@staticmethod |
|
|
def load_checkpoint( |
|
|
checkpoint_path: str, |
|
|
model: nn.Module, |
|
|
optimizer: Optional[torch.optim.Optimizer] = None, |
|
|
scheduler: Optional[Any] = None, |
|
|
hf_repo: str = "AbstractPhil/mobiusnet", |
|
|
device: torch.device = torch.device('cpu'), |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Load checkpoint from local path or HuggingFace repo. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Either: |
|
|
- Local file path to .pt checkpoint |
|
|
- Local directory containing checkpoints |
|
|
- HuggingFace path like "checkpoints/variant_dataset/timestamp" |
|
|
model: Model to load weights into |
|
|
optimizer: Optional optimizer to restore state |
|
|
scheduler: Optional scheduler to restore state |
|
|
hf_repo: HuggingFace repo ID |
|
|
device: Device to load tensors to |
|
|
|
|
|
Returns: |
|
|
Dict with checkpoint info (epoch, best_acc, etc.) |
|
|
""" |
|
|
from huggingface_hub import hf_hub_download, list_repo_files |
|
|
|
|
|
checkpoint_file = None |
|
|
|
|
|
|
|
|
if os.path.isfile(checkpoint_path): |
|
|
checkpoint_file = checkpoint_path |
|
|
|
|
|
|
|
|
elif os.path.isdir(checkpoint_path): |
|
|
|
|
|
best_path = os.path.join(checkpoint_path, "checkpoints", "best_model.pt") |
|
|
if os.path.exists(best_path): |
|
|
checkpoint_file = best_path |
|
|
else: |
|
|
|
|
|
ckpt_dir = os.path.join(checkpoint_path, "checkpoints") |
|
|
if os.path.isdir(ckpt_dir): |
|
|
pt_files = sorted([f for f in os.listdir(ckpt_dir) if f.startswith("checkpoint_epoch_") and f.endswith(".pt")]) |
|
|
if pt_files: |
|
|
checkpoint_file = os.path.join(ckpt_dir, pt_files[-1]) |
|
|
|
|
|
|
|
|
if checkpoint_file is None: |
|
|
print(f"Attempting to download from HuggingFace: {hf_repo}/{checkpoint_path}") |
|
|
try: |
|
|
|
|
|
if not checkpoint_path.endswith(".pt"): |
|
|
|
|
|
try: |
|
|
checkpoint_file = hf_hub_download( |
|
|
repo_id=hf_repo, |
|
|
filename=f"{checkpoint_path}/checkpoints/best_model.pt", |
|
|
repo_type="model", |
|
|
) |
|
|
print(f"Downloaded best_model.pt from {hf_repo}") |
|
|
except: |
|
|
|
|
|
files = list_repo_files(repo_id=hf_repo, repo_type="model") |
|
|
ckpt_files = sorted([f for f in files if checkpoint_path in f and f.endswith(".pt") and "checkpoint_epoch_" in f]) |
|
|
if ckpt_files: |
|
|
checkpoint_file = hf_hub_download( |
|
|
repo_id=hf_repo, |
|
|
filename=ckpt_files[-1], |
|
|
repo_type="model", |
|
|
) |
|
|
print(f"Downloaded {ckpt_files[-1]} from {hf_repo}") |
|
|
else: |
|
|
|
|
|
checkpoint_file = hf_hub_download( |
|
|
repo_id=hf_repo, |
|
|
filename=checkpoint_path, |
|
|
repo_type="model", |
|
|
) |
|
|
print(f"Downloaded {checkpoint_path} from {hf_repo}") |
|
|
except Exception as e: |
|
|
raise FileNotFoundError(f"Could not find or download checkpoint: {checkpoint_path}. Error: {e}") |
|
|
|
|
|
if checkpoint_file is None: |
|
|
raise FileNotFoundError(f"Could not find checkpoint: {checkpoint_path}") |
|
|
|
|
|
print(f"Loading checkpoint from: {checkpoint_file}") |
|
|
checkpoint = torch.load(checkpoint_file, map_location=device, weights_only=False) |
|
|
|
|
|
|
|
|
raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model |
|
|
raw_model.load_state_dict(checkpoint['model_state_dict']) |
|
|
print(f"Loaded model weights") |
|
|
|
|
|
|
|
|
if optimizer is not None and 'optimizer_state_dict' in checkpoint: |
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
print(f"Loaded optimizer state") |
|
|
|
|
|
|
|
|
if scheduler is not None and 'scheduler_state_dict' in checkpoint: |
|
|
scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
|
|
print(f"Loaded scheduler state") |
|
|
|
|
|
info = { |
|
|
'epoch': checkpoint.get('epoch', 0), |
|
|
'best_acc': checkpoint.get('best_acc', 0.0), |
|
|
'train_acc': checkpoint.get('train_acc', 0.0), |
|
|
'val_acc': checkpoint.get('val_acc', 0.0), |
|
|
'train_loss': checkpoint.get('train_loss', 0.0), |
|
|
} |
|
|
|
|
|
print(f"Resuming from epoch {info['epoch']} (best_acc: {info['best_acc']:.4f})") |
|
|
|
|
|
return info |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_tiny_imagenet( |
|
|
preset: str = 'mobius_tiny_m', |
|
|
epochs: int = 100, |
|
|
lr: float = 1e-3, |
|
|
batch_size: int = 128, |
|
|
use_integrator: bool = True, |
|
|
data_dir: str = './data/tiny-imagenet-200', |
|
|
output_dir: str = './outputs', |
|
|
hf_repo: str = "AbstractPhil/mobiusnet", |
|
|
save_every_n_epochs: int = 10, |
|
|
upload_every_n_epochs: int = 10, |
|
|
log_histograms_every: int = 10, |
|
|
use_compile: bool = True, |
|
|
continue_from: Optional[str] = None, |
|
|
): |
|
|
""" |
|
|
Train MobiusNet on Tiny ImageNet. |
|
|
|
|
|
Args: |
|
|
preset: Model preset name |
|
|
epochs: Total epochs to train |
|
|
lr: Learning rate |
|
|
batch_size: Batch size |
|
|
use_integrator: Whether to use integrator layer |
|
|
data_dir: Path to Tiny ImageNet data |
|
|
output_dir: Output directory for checkpoints |
|
|
hf_repo: HuggingFace repo for uploads/downloads |
|
|
save_every_n_epochs: Save checkpoint every N epochs |
|
|
upload_every_n_epochs: Upload to HF every N epochs |
|
|
log_histograms_every: Log weight histograms every N epochs |
|
|
use_compile: Whether to use torch.compile |
|
|
continue_from: Resume from checkpoint. Can be: |
|
|
- Local .pt file path |
|
|
- Local checkpoint directory |
|
|
- HuggingFace path (e.g., "checkpoints/mobius_base_tiny_imagenet/20240101_120000") |
|
|
""" |
|
|
config = PRESETS[preset] |
|
|
dataset_name = "tiny_imagenet" |
|
|
|
|
|
print("=" * 70) |
|
|
print(f"MÖBIUS NET - {preset.upper()} - TINY IMAGENET") |
|
|
print("=" * 70) |
|
|
print(f"Device: {device}") |
|
|
print(f"Channels: {config['channels']}") |
|
|
print(f"Depths: {config['depths']}") |
|
|
print(f"Scale range: {config['scale_range']}") |
|
|
print(f"Integrator: {use_integrator}") |
|
|
if continue_from: |
|
|
print(f"Continuing from: {continue_from}") |
|
|
print() |
|
|
|
|
|
|
|
|
resume_timestamp = None |
|
|
if continue_from: |
|
|
resume_timestamp = CheckpointManager.extract_timestamp(continue_from) |
|
|
if resume_timestamp: |
|
|
print(f"Using original timestamp: {resume_timestamp}") |
|
|
|
|
|
|
|
|
ckpt_manager = CheckpointManager( |
|
|
base_dir=output_dir, |
|
|
variant_name=preset, |
|
|
dataset_name=dataset_name, |
|
|
hf_repo=hf_repo, |
|
|
upload_every_n_epochs=upload_every_n_epochs, |
|
|
save_every_n_epochs=save_every_n_epochs, |
|
|
timestamp=resume_timestamp, |
|
|
) |
|
|
|
|
|
|
|
|
train_loader, val_loader = get_tiny_imagenet_loaders(data_dir, batch_size) |
|
|
|
|
|
|
|
|
model = MobiusNet( |
|
|
in_chans=3, |
|
|
num_classes=200, |
|
|
use_integrator=use_integrator, |
|
|
**config |
|
|
).to(device) |
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
print(f"Total params: {total_params:,}") |
|
|
print() |
|
|
|
|
|
|
|
|
training_config = { |
|
|
'epochs': epochs, |
|
|
'lr': lr, |
|
|
'batch_size': batch_size, |
|
|
'optimizer': 'AdamW', |
|
|
'weight_decay': 0.05, |
|
|
'scheduler': 'CosineAnnealingLR', |
|
|
'total_params': total_params, |
|
|
} |
|
|
ckpt_manager.save_config(model.get_config(), training_config) |
|
|
|
|
|
|
|
|
if use_compile: |
|
|
model = torch.compile(model, mode='reduce-overhead') |
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05) |
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) |
|
|
|
|
|
|
|
|
start_epoch = 1 |
|
|
best_acc = 0.0 |
|
|
|
|
|
if continue_from: |
|
|
ckpt_info = CheckpointManager.load_checkpoint( |
|
|
checkpoint_path=continue_from, |
|
|
model=model, |
|
|
optimizer=optimizer, |
|
|
scheduler=scheduler, |
|
|
hf_repo=hf_repo, |
|
|
device=device, |
|
|
) |
|
|
start_epoch = ckpt_info['epoch'] + 1 |
|
|
best_acc = ckpt_info['best_acc'] |
|
|
ckpt_manager.best_acc = best_acc |
|
|
ckpt_manager.best_epoch = ckpt_info['epoch'] |
|
|
print(f"Resuming training from epoch {start_epoch}") |
|
|
|
|
|
for epoch in range(start_epoch, epochs + 1): |
|
|
|
|
|
model.train() |
|
|
train_loss, train_correct, train_total = 0, 0, 0 |
|
|
|
|
|
pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}") |
|
|
for x, y in pbar: |
|
|
x, y = x.to(device), y.to(device) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
logits = model(x) |
|
|
loss = F.cross_entropy(logits, y) |
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
optimizer.step() |
|
|
|
|
|
train_loss += loss.item() * x.size(0) |
|
|
train_correct += (logits.argmax(1) == y).sum().item() |
|
|
train_total += x.size(0) |
|
|
|
|
|
pbar.set_postfix(loss=f"{loss.item():.4f}") |
|
|
|
|
|
scheduler.step() |
|
|
|
|
|
|
|
|
model.eval() |
|
|
val_correct, val_total = 0, 0 |
|
|
with torch.no_grad(): |
|
|
for x, y in val_loader: |
|
|
x, y = x.to(device), y.to(device) |
|
|
logits = model(x) |
|
|
val_correct += (logits.argmax(1) == y).sum().item() |
|
|
val_total += x.size(0) |
|
|
|
|
|
|
|
|
train_acc = train_correct / train_total |
|
|
val_acc = val_correct / val_total |
|
|
avg_loss = train_loss / train_total |
|
|
current_lr = scheduler.get_last_lr()[0] |
|
|
|
|
|
is_best = val_acc > best_acc |
|
|
if is_best: |
|
|
best_acc = val_acc |
|
|
|
|
|
marker = " ★" if is_best else "" |
|
|
print(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | " |
|
|
f"Train: {train_acc:.4f} | Val: {val_acc:.4f} | Best: {best_acc:.4f}{marker}") |
|
|
|
|
|
|
|
|
ckpt_manager.log_scalars(epoch, { |
|
|
'loss': avg_loss, |
|
|
'train_acc': train_acc, |
|
|
'val_acc': val_acc, |
|
|
'best_acc': best_acc, |
|
|
'learning_rate': current_lr, |
|
|
}, prefix="train") |
|
|
|
|
|
|
|
|
ckpt_manager.log_lens_stats(epoch, model) |
|
|
|
|
|
|
|
|
if epoch % log_histograms_every == 0: |
|
|
ckpt_manager.log_histograms(epoch, model) |
|
|
|
|
|
|
|
|
ckpt_manager.save_checkpoint( |
|
|
model=model, |
|
|
optimizer=optimizer, |
|
|
scheduler=scheduler, |
|
|
epoch=epoch, |
|
|
train_acc=train_acc, |
|
|
val_acc=val_acc, |
|
|
train_loss=avg_loss, |
|
|
is_best=is_best, |
|
|
) |
|
|
|
|
|
|
|
|
ckpt_manager.upload_to_hf(epoch) |
|
|
|
|
|
|
|
|
ckpt_manager.save_final(model, val_acc, epochs) |
|
|
|
|
|
|
|
|
ckpt_manager.upload_to_hf(epochs, force=True) |
|
|
ckpt_manager.close() |
|
|
|
|
|
print() |
|
|
print("=" * 70) |
|
|
print("FINAL RESULTS") |
|
|
print("=" * 70) |
|
|
print(f"Preset: {preset}") |
|
|
print(f"Best accuracy: {best_acc:.4f}") |
|
|
print(f"Total params: {total_params:,}") |
|
|
print(f"Checkpoints: {ckpt_manager.run_dir}") |
|
|
print("=" * 70) |
|
|
|
|
|
return model, best_acc |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
model, best_acc = train_tiny_imagenet( |
|
|
preset='mobius_base', |
|
|
epochs=200, |
|
|
lr=3e-4, |
|
|
batch_size=128, |
|
|
use_integrator=True, |
|
|
data_dir='./data/tiny-imagenet-200', |
|
|
output_dir='./outputs', |
|
|
hf_repo='AbstractPhil/mobiusnet', |
|
|
save_every_n_epochs=10, |
|
|
upload_every_n_epochs=10, |
|
|
log_histograms_every=10, |
|
|
use_compile=True, |
|
|
continue_from='/content/outputs/checkpoints/mobius_base_tiny_imagenet/20260110_132436/checkpoints/best_model.pt', |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
) |