""" MobiusNet Trainer with TensorBoard, SafeTensors, and HuggingFace Upload ======================================================================= """ import os import re import json import math import shutil import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from typing import Tuple, Optional, Dict, Any from torchvision import datasets, transforms from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm.auto import tqdm from datetime import datetime from pathlib import Path from safetensors.torch import save_file as save_safetensors, load_file as load_safetensors from huggingface_hub import HfApi, login # Colab HF login try: from google.colab import userdata token = userdata.get('HF_TOKEN') os.environ['HF_TOKEN'] = token login(token=token) print("Logged in to HuggingFace via Colab") except: # Not in Colab or token not set pass device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") # Enable TF32 for faster computation on Ampere+ GPUs torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.set_float32_matmul_precision('high') # ============================================================================ # MÖBIUS LENS # ============================================================================ class MobiusLens(nn.Module): def __init__( self, dim: int, layer_idx: int, total_layers: int, scale_range: Tuple[float, float] = (1.0, 9.0), ): super().__init__() self.dim = dim self.layer_idx = layer_idx self.total_layers = total_layers self.t = layer_idx / max(total_layers - 1, 1) scale_span = scale_range[1] - scale_range[0] step = scale_span / max(total_layers, 1) scale_low = scale_range[0] + self.t * scale_span scale_high = scale_low + step self.register_buffer('scales', torch.tensor([scale_low, scale_high])) self.twist_in_angle = nn.Parameter(torch.tensor(self.t * math.pi)) self.twist_in_proj = nn.Linear(dim, dim, bias=False) nn.init.orthogonal_(self.twist_in_proj.weight) self.omega = nn.Parameter(torch.tensor(math.pi)) self.alpha = nn.Parameter(torch.tensor(1.5)) self.phase_l = nn.Parameter(torch.zeros(2)) self.drift_l = nn.Parameter(torch.ones(2)) self.phase_m = nn.Parameter(torch.zeros(2)) self.drift_m = nn.Parameter(torch.zeros(2)) self.phase_r = nn.Parameter(torch.zeros(2)) self.drift_r = nn.Parameter(-torch.ones(2)) self.accum_weights = nn.Parameter(torch.tensor([0.4, 0.2, 0.4])) self.xor_weight = nn.Parameter(torch.tensor(0.7)) self.gate_norm = nn.LayerNorm(dim) self.twist_out_angle = nn.Parameter(torch.tensor(-self.t * math.pi)) self.twist_out_proj = nn.Linear(dim, dim, bias=False) nn.init.orthogonal_(self.twist_out_proj.weight) def _twist_in(self, x: Tensor) -> Tensor: cos_t = torch.cos(self.twist_in_angle) sin_t = torch.sin(self.twist_in_angle) return x * cos_t + self.twist_in_proj(x) * sin_t def _center_lens(self, x: Tensor) -> Tensor: x_norm = torch.tanh(x) t = x_norm.abs().mean(dim=-1, keepdim=True).unsqueeze(-2) x_exp = x_norm.unsqueeze(-2) s = self.scales.view(-1, 1) def wave(phase, drift): a = self.alpha.abs() + 0.1 pos = s * self.omega * (x_exp + drift.view(-1, 1) * t) + phase.view(-1, 1) return torch.exp(-a * torch.sin(pos).pow(2)).prod(dim=-2) L = wave(self.phase_l, self.drift_l) M = wave(self.phase_m, self.drift_m) R = wave(self.phase_r, self.drift_r) w = torch.softmax(self.accum_weights, dim=0) xor_w = torch.sigmoid(self.xor_weight) xor_comp = (L + R - 2 * L * R).abs() and_comp = L * R lr = xor_w * xor_comp + (1 - xor_w) * and_comp gate = w[0] * L + w[1] * M + w[2] * R gate = gate * (0.5 + 0.5 * lr) gate = torch.sigmoid(self.gate_norm(gate)) return x * gate def _twist_out(self, x: Tensor) -> Tensor: cos_t = torch.cos(self.twist_out_angle) sin_t = torch.sin(self.twist_out_angle) return x * cos_t + self.twist_out_proj(x) * sin_t def forward(self, x: Tensor) -> Tensor: return self._twist_out(self._center_lens(self._twist_in(x))) def get_lens_stats(self) -> Dict[str, float]: """Return lens parameters for logging.""" return { 'omega': self.omega.item(), 'alpha': self.alpha.item(), 'twist_in_angle': self.twist_in_angle.item(), 'twist_out_angle': self.twist_out_angle.item(), 'xor_weight': torch.sigmoid(self.xor_weight).item(), 'accum_weights_l': torch.softmax(self.accum_weights, dim=0)[0].item(), 'accum_weights_m': torch.softmax(self.accum_weights, dim=0)[1].item(), 'accum_weights_r': torch.softmax(self.accum_weights, dim=0)[2].item(), } # ============================================================================ # MÖBIUS CONV BLOCK # ============================================================================ class MobiusConvBlock(nn.Module): def __init__( self, channels: int, layer_idx: int, total_layers: int, scale_range: Tuple[float, float] = (1.0, 9.0), reduction: float = 0.5, ): super().__init__() self.conv = nn.Sequential( nn.Conv2d(channels, channels, 3, padding=1, groups=channels, bias=False), nn.Conv2d(channels, channels, 1, bias=False), nn.BatchNorm2d(channels), ) self.lens = MobiusLens(channels, layer_idx, total_layers, scale_range) third = channels // 3 which_third = layer_idx % 3 mask = torch.ones(channels) start = which_third * third end = start + third + (channels % 3 if which_third == 2 else 0) mask[start:end] = reduction self.register_buffer('thirds_mask', mask.view(1, -1, 1, 1)) self.residual_weight = nn.Parameter(torch.tensor(0.9)) def forward(self, x: Tensor) -> Tensor: identity = x h = self.conv(x) B, D, H, W = h.shape h = h.permute(0, 2, 3, 1) h = self.lens(h) h = h.permute(0, 3, 1, 2) h = h * self.thirds_mask rw = torch.sigmoid(self.residual_weight) return rw * identity + (1 - rw) * h def get_residual_weight(self) -> float: return torch.sigmoid(self.residual_weight).item() # ============================================================================ # MÖBIUS NET # ============================================================================ class MobiusNet(nn.Module): def __init__( self, in_chans: int = 3, num_classes: int = 200, channels: Tuple[int, ...] = (64, 128, 256, 512), depths: Tuple[int, ...] = (2, 2, 2, 2), scale_range: Tuple[float, float] = (0.5, 2.5), use_integrator: bool = True, ): super().__init__() num_stages = len(depths) total_layers = sum(depths) self.total_layers = total_layers self.scale_range = scale_range self.channels = tuple(channels) self.depths = tuple(depths) self.num_stages = num_stages self.use_integrator = use_integrator self.num_classes = num_classes self.in_chans = in_chans channels = list(channels) while len(channels) < num_stages: channels.append(channels[-1]) self.stem = nn.Sequential( nn.Conv2d(in_chans, channels[0], 3, stride=1, padding=1, bias=False), nn.BatchNorm2d(channels[0]), ) layer_idx = 0 self.stages = nn.ModuleList() self.downsamples = nn.ModuleList() for stage_idx in range(num_stages): ch = channels[stage_idx] stage = nn.ModuleList() for _ in range(depths[stage_idx]): stage.append(MobiusConvBlock(ch, layer_idx, total_layers, scale_range)) layer_idx += 1 self.stages.append(stage) if stage_idx < num_stages - 1: ch_next = channels[stage_idx + 1] self.downsamples.append(nn.Sequential( nn.Conv2d(ch, ch_next, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(ch_next), )) final_ch = channels[num_stages - 1] if use_integrator: self.integrator = nn.Sequential( nn.Conv2d(final_ch, final_ch, 3, padding=1, bias=False), nn.BatchNorm2d(final_ch), nn.GELU(), ) else: self.integrator = nn.Identity() self.pool = nn.AdaptiveAvgPool2d(1) self.head = nn.Linear(final_ch, num_classes) def forward(self, x: Tensor) -> Tensor: x = self.stem(x) for i, stage in enumerate(self.stages): for block in stage: x = block(x) if i < len(self.downsamples): x = self.downsamples[i](x) x = self.integrator(x) return self.head(self.pool(x).flatten(1)) def get_config(self) -> Dict[str, Any]: """Return model configuration for saving.""" return { 'in_chans': self.in_chans, 'num_classes': self.num_classes, 'channels': self.channels, 'depths': self.depths, 'scale_range': self.scale_range, 'use_integrator': self.use_integrator, 'total_layers': self.total_layers, 'num_stages': self.num_stages, } def get_all_lens_stats(self) -> Dict[str, Dict[str, float]]: """Return stats from all lenses for logging.""" stats = {} layer_idx = 0 for stage_idx, stage in enumerate(self.stages): for block_idx, block in enumerate(stage): key = f"stage{stage_idx}_block{block_idx}" stats[key] = block.lens.get_lens_stats() stats[key]['residual_weight'] = block.get_residual_weight() layer_idx += 1 return stats # ============================================================================ # TINY IMAGENET DATASET # ============================================================================ def get_tiny_imagenet_loaders(data_dir='./data/tiny-imagenet-200', batch_size=128): train_dir = os.path.join(data_dir, 'train') val_dir = os.path.join(data_dir, 'val') val_images_dir = os.path.join(val_dir, 'images') if os.path.exists(val_images_dir): print("Reorganizing validation folder...") reorganize_val_folder(val_dir) train_transform = transforms.Compose([ transforms.RandomCrop(64, padding=8), transforms.RandomHorizontalFlip(), transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) val_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) train_dataset = datasets.ImageFolder(train_dir, transform=train_transform) val_dataset = datasets.ImageFolder(val_dir, transform=val_transform) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True ) val_loader = DataLoader( val_dataset, batch_size=256, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True ) return train_loader, val_loader def reorganize_val_folder(val_dir): """Reorganize Tiny ImageNet val folder into class subfolders.""" val_images_dir = os.path.join(val_dir, 'images') val_annotations = os.path.join(val_dir, 'val_annotations.txt') if not os.path.exists(val_images_dir): return with open(val_annotations, 'r') as f: for line in f: parts = line.strip().split('\t') img_name, class_id = parts[0], parts[1] class_dir = os.path.join(val_dir, class_id) os.makedirs(class_dir, exist_ok=True) src = os.path.join(val_images_dir, img_name) dst = os.path.join(class_dir, img_name) if os.path.exists(src): shutil.move(src, dst) if os.path.exists(val_images_dir): shutil.rmtree(val_images_dir) if os.path.exists(val_annotations): os.remove(val_annotations) print("Validation folder reorganized.") # ============================================================================ # PRESETS # ============================================================================ PRESETS = { 'mobius_tiny_s': { 'channels': (64, 128, 256), 'depths': (2, 2, 2), 'scale_range': (0.5, 2.5), }, 'mobius_tiny_m': { 'channels': (64, 128, 256, 512, 768), 'depths': (2, 2, 4, 2, 2), 'scale_range': (0.25, 2.75), }, 'mobius_tiny_l': { 'channels': (96, 192, 384, 768), 'depths': (3, 3, 3, 3), 'scale_range': (0.5, 3.5), }, 'mobius_base': { 'channels': (128, 256, 512, 768, 1024), 'depths': (2, 2, 2, 2, 2), 'scale_range': (0.25, 2.75), }, } # ============================================================================ # CHECKPOINT MANAGER # ============================================================================ class CheckpointManager: def __init__( self, base_dir: str, variant_name: str, dataset_name: str, hf_repo: str = "AbstractPhil/mobiusnet", upload_every_n_epochs: int = 10, save_every_n_epochs: int = 10, timestamp: Optional[str] = None, ): self.timestamp = timestamp or datetime.now().strftime("%Y%m%d_%H%M%S") self.variant_name = variant_name self.dataset_name = dataset_name self.hf_repo = hf_repo self.upload_every_n_epochs = upload_every_n_epochs self.save_every_n_epochs = save_every_n_epochs # Directory structure self.run_name = f"{variant_name}_{dataset_name}" self.run_dir = Path(base_dir) / "checkpoints" / self.run_name / self.timestamp self.checkpoints_dir = self.run_dir / "checkpoints" self.tensorboard_dir = self.run_dir / "tensorboard" # Create directories self.checkpoints_dir.mkdir(parents=True, exist_ok=True) self.tensorboard_dir.mkdir(parents=True, exist_ok=True) # TensorBoard writer self.writer = SummaryWriter(log_dir=str(self.tensorboard_dir)) # HuggingFace API self.hf_api = HfApi() self.uploaded_files = set() # Track best self.best_acc = 0.0 self.best_epoch = 0 self.best_changed_since_upload = False print(f"Checkpoint directory: {self.run_dir}") @staticmethod def extract_timestamp(checkpoint_path: str) -> Optional[str]: """Extract timestamp from checkpoint path.""" # Match YYYYMMDD_HHMMSS pattern match = re.search(r'(\d{8}_\d{6})', checkpoint_path) if match: return match.group(1) return None def save_config(self, config: Dict[str, Any], training_config: Dict[str, Any]): """Save model and training configuration.""" full_config = { 'model': config, 'training': training_config, 'timestamp': self.timestamp, 'variant_name': self.variant_name, 'dataset_name': self.dataset_name, } config_path = self.run_dir / "config.json" with open(config_path, 'w') as f: json.dump(full_config, f, indent=2) return config_path def save_checkpoint( self, model: nn.Module, optimizer: torch.optim.Optimizer, scheduler: Any, epoch: int, train_acc: float, val_acc: float, train_loss: float, is_best: bool = False, ): """Save checkpoint every N epochs, always save best (overwriting).""" # Unwrap compiled model if necessary raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model # Checkpoint data checkpoint = { 'epoch': epoch, 'train_acc': train_acc, 'val_acc': val_acc, 'train_loss': train_loss, 'best_acc': self.best_acc, 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), } # Save epoch checkpoint every N epochs if epoch % self.save_every_n_epochs == 0: epoch_pt_path = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.pt" torch.save({**checkpoint, 'model_state_dict': raw_model.state_dict()}, epoch_pt_path) epoch_st_path = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.safetensors" save_safetensors(raw_model.state_dict(), str(epoch_st_path)) # Save best model (overwrites previous best) if is_best: self.best_acc = val_acc self.best_epoch = epoch self.best_changed_since_upload = True # PyTorch best best_pt_path = self.checkpoints_dir / "best_model.pt" torch.save({**checkpoint, 'model_state_dict': raw_model.state_dict()}, best_pt_path) # SafeTensors best best_st_path = self.checkpoints_dir / "best_model.safetensors" save_safetensors(raw_model.state_dict(), str(best_st_path)) # Save accuracy info acc_path = self.run_dir / "best_accuracy.json" with open(acc_path, 'w') as f: json.dump({ 'best_acc': val_acc, 'best_epoch': epoch, 'train_acc': train_acc, 'train_loss': train_loss, }, f, indent=2) def save_final(self, model: nn.Module, final_acc: float, final_epoch: int): """Save final model.""" raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model # SafeTensors final final_st_path = self.checkpoints_dir / "final_model.safetensors" save_safetensors(raw_model.state_dict(), str(final_st_path)) # PyTorch final final_pt_path = self.checkpoints_dir / "final_model.pt" torch.save({ 'model_state_dict': raw_model.state_dict(), 'final_acc': final_acc, 'final_epoch': final_epoch, 'best_acc': self.best_acc, 'best_epoch': self.best_epoch, }, final_pt_path) # Final accuracy info acc_path = self.run_dir / "final_accuracy.json" with open(acc_path, 'w') as f: json.dump({ 'final_acc': final_acc, 'final_epoch': final_epoch, 'best_acc': self.best_acc, 'best_epoch': self.best_epoch, }, f, indent=2) return final_st_path, final_pt_path def log_scalars(self, epoch: int, scalars: Dict[str, float], prefix: str = ""): """Log scalars to TensorBoard.""" for name, value in scalars.items(): tag = f"{prefix}/{name}" if prefix else name self.writer.add_scalar(tag, value, epoch) def log_lens_stats(self, epoch: int, model: nn.Module): """Log lens statistics to TensorBoard.""" raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model stats = raw_model.get_all_lens_stats() for block_name, block_stats in stats.items(): for stat_name, value in block_stats.items(): self.writer.add_scalar(f"lens/{block_name}/{stat_name}", value, epoch) def log_histograms(self, epoch: int, model: nn.Module): """Log weight histograms to TensorBoard.""" raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model for name, param in raw_model.named_parameters(): if param.requires_grad: self.writer.add_histogram(f"weights/{name}", param.data, epoch) if param.grad is not None: self.writer.add_histogram(f"gradients/{name}", param.grad, epoch) def upload_to_hf(self, epoch: int, force: bool = False): """Upload checkpoint every N epochs. Best uploads only on upload epochs if changed.""" if not force and epoch % self.upload_every_n_epochs != 0: return try: hf_base_path = f"checkpoints/{self.run_name}/{self.timestamp}" files_to_upload = [] # Always upload config config_path = self.run_dir / "config.json" if config_path.exists(): files_to_upload.append(config_path) # Upload checkpoint if saved this epoch if epoch % self.save_every_n_epochs == 0: ckpt_st = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.safetensors" ckpt_pt = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.pt" if ckpt_st.exists(): files_to_upload.append(ckpt_st) if ckpt_pt.exists(): files_to_upload.append(ckpt_pt) # Upload best if it changed since last upload if self.best_changed_since_upload: best_files = [ self.checkpoints_dir / "best_model.safetensors", self.checkpoints_dir / "best_model.pt", self.run_dir / "best_accuracy.json", ] for f in best_files: if f.exists(): files_to_upload.append(f) self.best_changed_since_upload = False # Upload files for local_path in files_to_upload: rel_path = local_path.relative_to(self.run_dir) hf_path = f"{hf_base_path}/{rel_path}" try: self.hf_api.upload_file( path_or_fileobj=str(local_path), path_in_repo=hf_path, repo_id=self.hf_repo, repo_type="model", ) print(f"Uploaded: {hf_path}") except Exception as e: print(f"Failed to upload {rel_path}: {e}") except Exception as e: print(f"HuggingFace upload error: {e}") def close(self): """Close TensorBoard writer.""" self.writer.close() @staticmethod def load_checkpoint( checkpoint_path: str, model: nn.Module, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[Any] = None, hf_repo: str = "AbstractPhil/mobiusnet", device: torch.device = torch.device('cpu'), ) -> Dict[str, Any]: """ Load checkpoint from local path or HuggingFace repo. Args: checkpoint_path: Either: - Local file path to .pt checkpoint - Local directory containing checkpoints - HuggingFace path like "checkpoints/variant_dataset/timestamp" model: Model to load weights into optimizer: Optional optimizer to restore state scheduler: Optional scheduler to restore state hf_repo: HuggingFace repo ID device: Device to load tensors to Returns: Dict with checkpoint info (epoch, best_acc, etc.) """ from huggingface_hub import hf_hub_download, list_repo_files checkpoint_file = None # Check if it's a local file if os.path.isfile(checkpoint_path): checkpoint_file = checkpoint_path # Check if it's a local directory elif os.path.isdir(checkpoint_path): # Look for best_model.pt or latest checkpoint best_path = os.path.join(checkpoint_path, "checkpoints", "best_model.pt") if os.path.exists(best_path): checkpoint_file = best_path else: # Find latest epoch checkpoint ckpt_dir = os.path.join(checkpoint_path, "checkpoints") if os.path.isdir(ckpt_dir): pt_files = sorted([f for f in os.listdir(ckpt_dir) if f.startswith("checkpoint_epoch_") and f.endswith(".pt")]) if pt_files: checkpoint_file = os.path.join(ckpt_dir, pt_files[-1]) # Try HuggingFace download if checkpoint_file is None: print(f"Attempting to download from HuggingFace: {hf_repo}/{checkpoint_path}") try: # If checkpoint_path is a directory path in the repo if not checkpoint_path.endswith(".pt"): # Try to download best_model.pt try: checkpoint_file = hf_hub_download( repo_id=hf_repo, filename=f"{checkpoint_path}/checkpoints/best_model.pt", repo_type="model", ) print(f"Downloaded best_model.pt from {hf_repo}") except: # List files and find latest checkpoint files = list_repo_files(repo_id=hf_repo, repo_type="model") ckpt_files = sorted([f for f in files if checkpoint_path in f and f.endswith(".pt") and "checkpoint_epoch_" in f]) if ckpt_files: checkpoint_file = hf_hub_download( repo_id=hf_repo, filename=ckpt_files[-1], repo_type="model", ) print(f"Downloaded {ckpt_files[-1]} from {hf_repo}") else: # Direct file path checkpoint_file = hf_hub_download( repo_id=hf_repo, filename=checkpoint_path, repo_type="model", ) print(f"Downloaded {checkpoint_path} from {hf_repo}") except Exception as e: raise FileNotFoundError(f"Could not find or download checkpoint: {checkpoint_path}. Error: {e}") if checkpoint_file is None: raise FileNotFoundError(f"Could not find checkpoint: {checkpoint_path}") print(f"Loading checkpoint from: {checkpoint_file}") checkpoint = torch.load(checkpoint_file, map_location=device, weights_only=False) # Load model weights raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model raw_model.load_state_dict(checkpoint['model_state_dict']) print(f"Loaded model weights") # Load optimizer state if optimizer is not None and 'optimizer_state_dict' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) print(f"Loaded optimizer state") # Load scheduler state if scheduler is not None and 'scheduler_state_dict' in checkpoint: scheduler.load_state_dict(checkpoint['scheduler_state_dict']) print(f"Loaded scheduler state") info = { 'epoch': checkpoint.get('epoch', 0), 'best_acc': checkpoint.get('best_acc', 0.0), 'train_acc': checkpoint.get('train_acc', 0.0), 'val_acc': checkpoint.get('val_acc', 0.0), 'train_loss': checkpoint.get('train_loss', 0.0), } print(f"Resuming from epoch {info['epoch']} (best_acc: {info['best_acc']:.4f})") return info # ============================================================================ # TRAINING # ============================================================================ def train_tiny_imagenet( preset: str = 'mobius_tiny_m', epochs: int = 100, lr: float = 1e-3, batch_size: int = 128, use_integrator: bool = True, data_dir: str = './data/tiny-imagenet-200', output_dir: str = './outputs', hf_repo: str = "AbstractPhil/mobiusnet", save_every_n_epochs: int = 10, upload_every_n_epochs: int = 10, log_histograms_every: int = 10, use_compile: bool = True, continue_from: Optional[str] = None, ): """ Train MobiusNet on Tiny ImageNet. Args: preset: Model preset name epochs: Total epochs to train lr: Learning rate batch_size: Batch size use_integrator: Whether to use integrator layer data_dir: Path to Tiny ImageNet data output_dir: Output directory for checkpoints hf_repo: HuggingFace repo for uploads/downloads save_every_n_epochs: Save checkpoint every N epochs upload_every_n_epochs: Upload to HF every N epochs log_histograms_every: Log weight histograms every N epochs use_compile: Whether to use torch.compile continue_from: Resume from checkpoint. Can be: - Local .pt file path - Local checkpoint directory - HuggingFace path (e.g., "checkpoints/mobius_base_tiny_imagenet/20240101_120000") """ config = PRESETS[preset] dataset_name = "tiny_imagenet" print("=" * 70) print(f"MÖBIUS NET - {preset.upper()} - TINY IMAGENET") print("=" * 70) print(f"Device: {device}") print(f"Channels: {config['channels']}") print(f"Depths: {config['depths']}") print(f"Scale range: {config['scale_range']}") print(f"Integrator: {use_integrator}") if continue_from: print(f"Continuing from: {continue_from}") print() # Extract timestamp from checkpoint path if continuing resume_timestamp = None if continue_from: resume_timestamp = CheckpointManager.extract_timestamp(continue_from) if resume_timestamp: print(f"Using original timestamp: {resume_timestamp}") # Initialize checkpoint manager ckpt_manager = CheckpointManager( base_dir=output_dir, variant_name=preset, dataset_name=dataset_name, hf_repo=hf_repo, upload_every_n_epochs=upload_every_n_epochs, save_every_n_epochs=save_every_n_epochs, timestamp=resume_timestamp, ) # Data train_loader, val_loader = get_tiny_imagenet_loaders(data_dir, batch_size) # Model model = MobiusNet( in_chans=3, num_classes=200, use_integrator=use_integrator, **config ).to(device) total_params = sum(p.numel() for p in model.parameters()) print(f"Total params: {total_params:,}") print() # Save config training_config = { 'epochs': epochs, 'lr': lr, 'batch_size': batch_size, 'optimizer': 'AdamW', 'weight_decay': 0.05, 'scheduler': 'CosineAnnealingLR', 'total_params': total_params, } ckpt_manager.save_config(model.get_config(), training_config) # Compile model if use_compile: model = torch.compile(model, mode='reduce-overhead') # Optimizer and scheduler optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) # Load checkpoint if continuing start_epoch = 1 best_acc = 0.0 if continue_from: ckpt_info = CheckpointManager.load_checkpoint( checkpoint_path=continue_from, model=model, optimizer=optimizer, scheduler=scheduler, hf_repo=hf_repo, device=device, ) start_epoch = ckpt_info['epoch'] + 1 best_acc = ckpt_info['best_acc'] ckpt_manager.best_acc = best_acc ckpt_manager.best_epoch = ckpt_info['epoch'] print(f"Resuming training from epoch {start_epoch}") for epoch in range(start_epoch, epochs + 1): # Training model.train() train_loss, train_correct, train_total = 0, 0, 0 pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}") for x, y in pbar: x, y = x.to(device), y.to(device) optimizer.zero_grad() logits = model(x) loss = F.cross_entropy(logits, y) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() train_loss += loss.item() * x.size(0) train_correct += (logits.argmax(1) == y).sum().item() train_total += x.size(0) pbar.set_postfix(loss=f"{loss.item():.4f}") scheduler.step() # Validation model.eval() val_correct, val_total = 0, 0 with torch.no_grad(): for x, y in val_loader: x, y = x.to(device), y.to(device) logits = model(x) val_correct += (logits.argmax(1) == y).sum().item() val_total += x.size(0) # Metrics train_acc = train_correct / train_total val_acc = val_correct / val_total avg_loss = train_loss / train_total current_lr = scheduler.get_last_lr()[0] is_best = val_acc > best_acc if is_best: best_acc = val_acc marker = " ★" if is_best else "" print(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | " f"Train: {train_acc:.4f} | Val: {val_acc:.4f} | Best: {best_acc:.4f}{marker}") # TensorBoard logging ckpt_manager.log_scalars(epoch, { 'loss': avg_loss, 'train_acc': train_acc, 'val_acc': val_acc, 'best_acc': best_acc, 'learning_rate': current_lr, }, prefix="train") # Log lens stats ckpt_manager.log_lens_stats(epoch, model) # Log histograms periodically if epoch % log_histograms_every == 0: ckpt_manager.log_histograms(epoch, model) # Save checkpoint ckpt_manager.save_checkpoint( model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, train_acc=train_acc, val_acc=val_acc, train_loss=avg_loss, is_best=is_best, ) # Upload to HuggingFace (handles both checkpoint and best) ckpt_manager.upload_to_hf(epoch) # Save final model ckpt_manager.save_final(model, val_acc, epochs) # Final upload ckpt_manager.upload_to_hf(epochs, force=True) ckpt_manager.close() print() print("=" * 70) print("FINAL RESULTS") print("=" * 70) print(f"Preset: {preset}") print(f"Best accuracy: {best_acc:.4f}") print(f"Total params: {total_params:,}") print(f"Checkpoints: {ckpt_manager.run_dir}") print("=" * 70) return model, best_acc # ============================================================================ # RUN # ============================================================================ if __name__ == '__main__': model, best_acc = train_tiny_imagenet( preset='mobius_base', epochs=200, lr=3e-4, batch_size=128, use_integrator=True, data_dir='./data/tiny-imagenet-200', output_dir='./outputs', hf_repo='AbstractPhil/mobiusnet', save_every_n_epochs=10, upload_every_n_epochs=10, log_histograms_every=10, use_compile=True, continue_from='/content/outputs/checkpoints/mobius_base_tiny_imagenet/20260110_132436/checkpoints/best_model.pt', # Set to path or HF checkpoint to resume # Examples: # continue_from="./outputs/checkpoints/mobius_base_tiny_imagenet/20240101_120000" # continue_from="./outputs/checkpoints/mobius_base_tiny_imagenet/20240101_120000/checkpoints/best_model.pt" # continue_from="checkpoints/mobius_base_tiny_imagenet/20240101_120000" # downloads from HF )