penta-vit-experiments / legacy /penta_vit_model_v1.py
AbstractPhil's picture
Rename penta_vit_model_v1.py to legacy/penta_vit_model_v1.py
2e10481 verified
"""
PentachoraViT: Vision Transformer with Pentachoron Geometric Structure
Enhanced with Geometric Attention for improved head cohesion and generalization
FIXED: All parameters initialized at module creation time (no lazy init)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange, repeat
import math
from typing import Optional, Dict, Tuple, List, Any
from dataclasses import dataclass
import warnings
# ============================================
# CONFIGURATION CLASSES
# ============================================
@dataclass
class PentachoraConfig:
"""Configuration for PentachoraViT models."""
img_size: int = 32
patch_size: int = 4
num_classes: int = 100
dim: int = 512
vocab_dim: Optional[int] = None # Vocabulary dimension (can differ from model dim)
depth: int = 12
heads: int = 8
mlp_ratio: float = 4.0
use_mesh_attention: bool = True
preserve_structure_until_layer: int = 6
dropout_rate: float = 0.0
drop_path_rate: float = 0.0
aux_loss_weight: float = 0.0
geo_loss_weight: float = 0.0
vocab: Optional[Any] = None
@property
def num_patches(self) -> int:
return (self.img_size // self.patch_size) ** 2
# ============================================
# GEOMETRIC ATTENTION COMPONENTS (FIXED INIT)
# ============================================
def perfect_4simplex(device):
"""Create perfect 4-simplex (pentachoron) vertices in 4D."""
sqrt5 = math.sqrt(5)
vertices = torch.tensor([
[1, 1, 1, -1/sqrt5],
[1, -1, -1, -1/sqrt5],
[-1, 1, -1, -1/sqrt5],
[-1, -1, 1, -1/sqrt5],
[0, 0, 0, 4/sqrt5]
], device=device, dtype=torch.float32)
return vertices / 2 # Normalize scale
def softmin_over_last(distances, tau):
"""Softmin over last dimension."""
return F.softmax(-distances / tau, dim=-1).sum(dim=-1)
@dataclass
class GeometricConfig:
"""Configuration for geometric attention."""
softmin_tau: float = 0.05
fuse_alpha: float = 0.7
phases: Tuple[float, ...] = (0.0, math.pi/2, math.pi, 3*math.pi/2)
jitter: float = 0.02
shift: float = 0.71
rotate_cycle: int = 11
use_phase_variance: bool = False
geometry_type: str = "pentachoron"
class GeometricNavigator(nn.Module):
"""Maps inputs to geometric regions in 4D space - FIXED with immediate initialization."""
def __init__(self, input_dim: int, num_regions: int, config: GeometricConfig, num_heads: int = 1, device=None):
super().__init__()
self.input_dim = input_dim
self.num_regions = num_regions
self.config = config
self.num_heads = num_heads
# Use CPU by default if device not specified
if device is None:
device = torch.device('cpu')
# Create separate parameters for each head if num_heads > 1
if num_heads > 1:
self.to_nav = nn.Parameter(torch.randn(num_heads, input_dim, 4, device=device) * 0.02)
self.vertex_w = nn.Parameter(torch.zeros(num_heads, num_regions, 5, device=device))
else:
self.to_nav = nn.Linear(input_dim, 4, bias=False)
self.vertex_w = nn.Parameter(torch.zeros(num_regions, 5, device=device))
# Pre-compute phase tensors for vectorization
self.register_buffer('phase_cos', torch.cos(torch.tensor(config.phases, dtype=torch.float32, device=device)))
self.register_buffer('phase_sin', torch.sin(torch.tensor(config.phases, dtype=torch.float32, device=device)))
# Initialize geometry immediately at creation time
self._init_geometry(device)
def _init_geometry(self, device):
"""Initialize geometry at module creation time."""
base = perfect_4simplex(device)
if self.num_heads > 1:
D = torch.zeros(self.num_heads, self.num_regions, 5, 4, device=device)
S = torch.zeros(self.num_heads, self.num_regions, 5, 4, device=device)
for h in range(self.num_heads):
for r in range(self.num_regions):
D[h, r] = base + self.config.jitter * torch.randn_like(base)
theta = torch.tensor(0.27 + 0.05 * (r % self.config.rotate_cycle), device=device)
rot = torch.eye(4, device=device)
c, s_val = torch.cos(theta), torch.sin(theta)
rot[0, 0] = c; rot[0, 1] = -s_val
rot[1, 0] = s_val; rot[1, 1] = c
S[h, r] = (base @ rot) + self.config.shift
S[h, r] += self.config.jitter * torch.randn_like(S[h, r])
else:
D = torch.zeros(self.num_regions, 5, 4, device=device)
S = torch.zeros(self.num_regions, 5, 4, device=device)
for r in range(self.num_regions):
D[r] = base + self.config.jitter * torch.randn_like(base)
theta = torch.tensor(0.27 + 0.05 * (r % self.config.rotate_cycle), device=device)
rot = torch.eye(4, device=device)
c, s_val = torch.cos(theta), torch.sin(theta)
rot[0, 0] = c; rot[0, 1] = -s_val
rot[1, 0] = s_val; rot[1, 1] = c
S[r] = (base @ rot) + self.config.shift
S[r] += self.config.jitter * torch.randn_like(S[r])
self.D = nn.Parameter(D)
self.S = nn.Parameter(S)
def navigate(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Navigate inputs through geometric space - OPTIMIZED with vectorized phase computation."""
if self.num_heads > 1:
# Batched navigation for multiple heads
BT, H, head_dim = x.shape
# Batched transformation
nav_x = torch.einsum('bhi,hio->bho', x, self.to_nav) # [BT, H, 4]
# Dispatcher scores
nav_x_disp = nav_x.view(BT, H, 1, 1, 4)
D_exp = self.D.unsqueeze(0) # [1, H, regions, 5, 4]
d_disp = torch.norm(nav_x_disp - D_exp, dim=-1)
s_disp = -softmin_over_last(d_disp, self.config.softmin_tau)
# OPTIMIZED: Vectorized phase computation (no loop)
cos_phases = self.phase_cos.view(-1, 1, 1, 1, 1)
sin_phases = self.phase_sin.view(-1, 1, 1, 1, 1)
# Compute all phase variants at once [phases, H, regions, 5, 4]
Vt_all = cos_phases * self.D.unsqueeze(0) + sin_phases * self.S.unsqueeze(0)
# Apply vertex weighting to all phases
w = F.softmax(self.vertex_w, dim=-1)
w_exp = w.unsqueeze(0).unsqueeze(-1) # [1, H, regions, 5, 1]
Vt_mean = Vt_all.mean(dim=3, keepdim=True)
Vt_all = (1.0 - w_exp) * Vt_all + w_exp * Vt_mean
# Compute all ribbon distances at once
nav_x_ribbon = nav_x.view(BT, 1, H, 1, 1, 4)
Vt_exp = Vt_all.unsqueeze(0) # [1, phases, H, regions, 5, 4]
d_ribbon_all = torch.norm(nav_x_ribbon - Vt_exp, dim=-1)
s_ribbon_all = -softmin_over_last(d_ribbon_all, self.config.softmin_tau)
s_ribbon = s_ribbon_all.mean(dim=1) # Average over phases
scores = self.config.fuse_alpha * s_ribbon + (1 - self.config.fuse_alpha) * s_disp
scores = scores.reshape(BT * H, self.num_regions)
else:
# Original single-head navigation
nav_x = self.to_nav(x)
nav_x_exp = nav_x[:, None, None, :]
D_exp = self.D[None, :, :, :]
d_disp = torch.norm(nav_x_exp - D_exp, dim=-1)
s_disp = -softmin_over_last(d_disp, self.config.softmin_tau)
w = F.softmax(self.vertex_w, dim=1)
# OPTIMIZED: Vectorized phase computation for single head
cos_phases = self.phase_cos.view(-1, 1, 1, 1)
sin_phases = self.phase_sin.view(-1, 1, 1, 1)
Vt_all = cos_phases * self.D.unsqueeze(0) + sin_phases * self.S.unsqueeze(0)
w_expanded = w.unsqueeze(0).unsqueeze(-1)
Vt_mean = Vt_all.mean(dim=2, keepdim=True)
Vt_all = (1.0 - w_expanded) * Vt_all + w_expanded * Vt_mean
nav_x_phase = nav_x[:, None, None, None, :]
Vt_exp = Vt_all.unsqueeze(0)
d_ribbon_all = torch.norm(nav_x_phase - Vt_exp, dim=-1)
s_ribbon_all = -softmin_over_last(d_ribbon_all, self.config.softmin_tau)
s_ribbon = s_ribbon_all.mean(dim=1)
scores = self.config.fuse_alpha * s_ribbon + (1 - self.config.fuse_alpha) * s_disp
diagnostics = {
'dispatcher_scores': s_disp.detach() if self.num_heads == 1 else s_disp.reshape(BT * H, -1).detach(),
'ribbon_scores': s_ribbon.detach() if self.num_heads == 1 else s_ribbon.reshape(BT * H, -1).detach()
}
return {'scores': scores, 'diagnostics': diagnostics}
class GeometricAttention(nn.Module):
"""Multi-head geometric attention with Q-K alignment - FIXED with proper device handling."""
def __init__(self, dim: int, num_heads: int = 8, num_regions: Optional[int] = None,
config: Optional[GeometricConfig] = None, dropout: float = 0.0, device=None):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
if num_regions is None:
num_regions = min(self.head_dim, 16)
if config is None:
config = GeometricConfig()
self.config = config
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
# Create batched navigators with device
self.q_navigator = GeometricNavigator(self.head_dim, num_regions, config, num_heads=num_heads, device=device)
self.k_navigator = GeometricNavigator(self.head_dim, num_regions, config, num_heads=num_heads, device=device)
self.out_proj = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None,
return_diagnostics: bool = False) -> Tuple[torch.Tensor, Optional[Dict]]:
B, T, D = x.shape
qkv = self.to_qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
k = k.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
v = v.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
# Prepare for batched navigation
q_batched = q.transpose(1, 2).reshape(B * T, self.num_heads, self.head_dim)
k_batched = k.transpose(1, 2).reshape(B * T, self.num_heads, self.head_dim)
# Navigate all heads at once
q_nav = self.q_navigator.navigate(q_batched)
k_nav = self.k_navigator.navigate(k_batched)
# Reshape scores back
q_scores = q_nav['scores'].reshape(B, T, self.num_heads, -1).transpose(1, 2)
k_scores = k_nav['scores'].reshape(B, T, self.num_heads, -1).transpose(1, 2)
# OPTIMIZED: Compute attention for all heads at once using einsum
scale = math.sqrt(q_scores.size(-1))
attn = torch.einsum('bhqr,bhkr->bhqk', q_scores, k_scores) / scale
if mask is not None:
mask_expanded = mask.unsqueeze(1).unsqueeze(2)
attn = attn.masked_fill(mask_expanded == 0, -1e9)
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
# Apply attention to values
out = torch.einsum('bhqk,bhkd->bhqd', attn, v)
out = out.transpose(1, 2).reshape(B, T, D)
output = self.out_proj(out)
output = self.dropout(output)
if return_diagnostics:
return output, {'q_diagnostics': q_nav['diagnostics'], 'k_diagnostics': k_nav['diagnostics']}
return output, None
# ============================================
# DROP PATH (STOCHASTIC DEPTH)
# ============================================
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample."""
def __init__(self, drop_prob: float = 0.):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.drop_prob == 0. or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
output = x.div(keep_prob) * random_tensor
return output
# ============================================
# HIERARCHICAL CLS WITH PENTACHORA
# ============================================
class HierarchicalPentachoronCLS(nn.Module):
"""
Hierarchical CLS structure with pentachoron geometry.
Uses vocabulary embeddings for CLS tokens.
"""
def __init__(self, dim: int, vocab_dim: int, num_classes: int = 100):
super().__init__()
self.dim = dim
self.vocab_dim = vocab_dim
self.num_classes = num_classes
# Class-specific pentachora from vocabulary
self.register_buffer('class_pentachora', torch.randn(num_classes, 5, vocab_dim) * 0.02)
# Projection from vocabulary dimension to model dimension
if vocab_dim != dim:
self.vocab_to_model = nn.Linear(vocab_dim, dim)
else:
self.vocab_to_model = nn.Identity()
# Learnable aggregation weights
self.vertex_weights = nn.Parameter(torch.ones(5) / 5)
# Optional learnable offset
self.global_offset = nn.Parameter(torch.zeros(1, 1, dim))
# Layer norms
self.vertex_norm = nn.LayerNorm(dim)
self.global_norm = nn.LayerNorm(dim)
def forward(self, batch_size: int, class_indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""Generate CLS tokens for batch."""
# Get class-specific pentachora
class_pentachora = self.class_pentachora # This is now a computed property
if class_indices is not None and class_indices.shape[0] == batch_size:
vertex_cls_vocab = class_pentachora[class_indices]
else:
vertex_cls_vocab = class_pentachora.mean(dim=0, keepdim=True)
vertex_cls_vocab = vertex_cls_vocab.expand(batch_size, -1, -1)
# Project from vocabulary dimension to model dimension
vertex_cls = self.vocab_to_model(vertex_cls_vocab)
vertex_cls = self.vertex_norm(vertex_cls)
# Create global CLS as weighted combination
weights = F.softmax(self.vertex_weights, dim=0)
global_cls = torch.einsum('bvd,v->bd', vertex_cls, weights).unsqueeze(1)
global_cls = global_cls + self.global_offset
global_cls = self.global_norm(global_cls)
return global_cls, vertex_cls
def get_class_prototypes(self) -> torch.Tensor:
"""Get class prototypes in model dimension."""
class_pentachora = self.class_pentachora # Get computed pentachora
pentachora_model = self.vocab_to_model(class_pentachora)
weights = F.softmax(self.vertex_weights, dim=0)
prototypes = torch.einsum('cvd,v->cd', pentachora_model, weights)
return prototypes
# ============================================
# GEOMETRIC PROJECTION LAYER
# ============================================
class GeometricProjection(nn.Module):
"""Project patches onto pentachoron geometry."""
def __init__(self, dim: int, vocab_dim: int, num_classes: int = 100, dropout: float = 0.1):
super().__init__()
self.dim = dim
self.vocab_dim = vocab_dim
self.num_classes = num_classes
# Projection from model dim to vocab dim
self.to_vocab_space = nn.Linear(dim, vocab_dim)
# Vertex-specific projections
self.vertex_projections = nn.ModuleList([
nn.Linear(vocab_dim, vocab_dim, bias=False) for _ in range(5)
])
# Temperature for alignment scores
self.temperature = nn.Parameter(torch.ones(1))
self.norm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
def forward(self, patches: torch.Tensor, pentachora: torch.Tensor) -> torch.Tensor:
"""Compute alignment between patches and class pentachora."""
B, N, D = patches.shape
C = pentachora.shape[0]
# Normalize patches
patches = self.norm(patches)
# Project patches to vocabulary space
patches_vocab = self.to_vocab_space(patches)
patches_vocab = F.normalize(patches_vocab, dim=-1)
# Compute alignment with each vertex
alignments = []
for v in range(5):
patches_v = self.vertex_projections[v](patches_vocab)
patches_v = F.normalize(patches_v, dim=-1)
vertex_v = F.normalize(pentachora[:, v, :], dim=-1)
alignment = torch.matmul(patches_v, vertex_v.T) / self.temperature
alignments.append(alignment)
# Average alignments across vertices
alignments = torch.stack(alignments, dim=-1).mean(dim=-1)
return self.dropout(alignments)
# ============================================
# MLP BLOCK
# ============================================
class MLP(nn.Module):
"""MLP block with GELU activation."""
def __init__(self, in_features: int, hidden_features: Optional[int] = None,
out_features: Optional[int] = None, dropout: float = 0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU()
self.drop1 = nn.Dropout(dropout)
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop2 = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
# ============================================
# VIT BLOCK WITH GEOMETRIC ATTENTION
# ============================================
class PentachoronViTBlock(nn.Module):
"""ViT block with geometric attention for structured layers."""
def __init__(self, dim: int, heads: int = 8, mlp_ratio: float = 4.0,
use_mesh: bool = True, dropout: float = 0., attn_dropout: float = 0.,
drop_path: float = 0., device=None):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
# Use GeometricAttention for structured layers, standard for others
if use_mesh:
self.attn = GeometricAttention(
dim=dim,
num_heads=heads,
num_regions=min(dim // heads, 16),
config=GeometricConfig(),
dropout=attn_dropout,
device=device
)
else:
# Standard multi-head attention for later layers
self.attn = nn.MultiheadAttention(dim, heads, dropout=attn_dropout, batch_first=True)
self.use_mesh = use_mesh
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = nn.LayerNorm(dim)
mlp_hidden = int(dim * mlp_ratio)
self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden, dropout=dropout)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x: torch.Tensor, preserve_structure: bool = True) -> torch.Tensor:
if self.use_mesh:
# GeometricAttention
attn_out, _ = self.attn(self.norm1(x))
x = x + self.drop_path1(attn_out)
else:
# Standard attention
normalized = self.norm1(x)
attn_out, _ = self.attn(normalized, normalized, normalized)
x = x + self.drop_path1(attn_out)
x = x + self.drop_path2(self.mlp(self.norm2(x)))
return x
# ============================================
# PATCH EMBEDDING
# ============================================
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding."""
def __init__(self, img_size: int = 32, patch_size: int = 4,
in_chans: int = 3, embed_dim: int = 512):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
x = rearrange(x, 'b c h w -> b (h w) c')
x = self.norm(x)
return x
# ============================================
# PENTACHORA VISION TRANSFORMER
# ============================================
class PentachoraViT(nn.Module):
"""
Vision Transformer with pentachoron-based hierarchical CLS tokens
and geometric vocabulary integration.
"""
def __init__(self, config: Optional[PentachoraConfig] = None, **kwargs):
super().__init__()
# Use config or kwargs
if config is not None:
cfg = config
else:
cfg = PentachoraConfig(**kwargs)
self.config = cfg
self.num_classes = cfg.num_classes
self.dim = cfg.dim
self.depth = cfg.depth
self.preserve_structure_until_layer = cfg.preserve_structure_until_layer
# Set vocabulary dimension
if cfg.vocab_dim is not None:
self.vocab_dim = cfg.vocab_dim
elif 'vocab_dim' in kwargs:
self.vocab_dim = kwargs['vocab_dim']
else:
self.vocab_dim = cfg.dim
# Patch embedding
self.patch_embed = PatchEmbed(
cfg.img_size, cfg.patch_size, 3, cfg.dim
)
num_patches = self.patch_embed.num_patches
# Positional embedding
self.pos_embed = nn.Parameter(torch.randn(1, num_patches, cfg.dim) * 0.02)
self.pos_drop = nn.Dropout(cfg.dropout_rate)
# CLS tokens with pentachoron structure
self.cls_tokens = HierarchicalPentachoronCLS(cfg.dim, self.vocab_dim, cfg.num_classes)
# Geometric projection layer
self.geometric_proj = GeometricProjection(cfg.dim, self.vocab_dim, cfg.num_classes, cfg.dropout_rate)
# Initialize from vocabulary if provided
if cfg.vocab is not None:
self._init_from_vocab(cfg.vocab)
# Stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, cfg.drop_path_rate, cfg.depth)]
# Transformer blocks with geometric attention
self.blocks = nn.ModuleList([
PentachoronViTBlock(
dim=cfg.dim,
heads=cfg.heads,
mlp_ratio=cfg.mlp_ratio,
use_mesh=(cfg.use_mesh_attention and i < cfg.preserve_structure_until_layer),
dropout=cfg.dropout_rate,
attn_dropout=cfg.dropout_rate,
drop_path=dpr[i],
device=torch.device('cpu') # Initialize on CPU, will be moved later
)
for i in range(cfg.depth)
])
# Final norm
self.norm = nn.LayerNorm(cfg.dim)
# Classification heads
self.use_prototype_classifier = True
if self.use_prototype_classifier:
self.head = None
else:
self.head = nn.Linear(cfg.dim, cfg.num_classes)
# Auxiliary head for vertex tokens
self.head_aux = nn.Linear(cfg.dim * 5, cfg.num_classes)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, m: nn.Module):
"""Initialize model weights."""
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _init_from_vocab(self, vocab):
"""Initialize class pentachora from geometric vocabulary."""
try:
print("Initializing pentachora from vocabulary...")
if not hasattr(vocab, 'encode_batch'):
print("Vocabulary provided but encode_batch method not found, using random initialization")
return
# Get CIFAR-100 class names
class_names = self._get_cifar100_classes()
# Generate pentachora for all classes
pentachora_list = vocab.encode_batch(class_names[:self.num_classes], generate=True)
pentachora = np.stack(pentachora_list, axis=0)
# Get actual dimensions from the encoded data
actual_vocab_dim = pentachora.shape[-1]
print(f"Encoded pentachora shape: {pentachora.shape}")
print(f"Detected vocabulary dimension: {actual_vocab_dim}")
# Validate basic shape requirements
if pentachora.shape[0] != self.num_classes or pentachora.shape[1] != 5:
print(f"Invalid shape: expected ({self.num_classes}, 5, ?), got {pentachora.shape}")
print("Using random initialization")
return
# Update vocabulary dimension
self.vocab_dim = actual_vocab_dim
self.cls_tokens.vocab_dim = actual_vocab_dim
self.geometric_proj.vocab_dim = actual_vocab_dim
# Replace class_pentachora with the loaded vocabulary
self.cls_tokens.class_pentachora = torch.tensor(pentachora, dtype=torch.float32)
# Update/create projection layer if dimensions differ
if actual_vocab_dim != self.dim:
self.cls_tokens.vocab_to_model = nn.Linear(actual_vocab_dim, self.dim)
else:
self.cls_tokens.vocab_to_model = nn.Identity()
# Rebuild geometric projection components
self.geometric_proj.to_vocab_space = nn.Linear(self.dim, actual_vocab_dim)
self.geometric_proj.vertex_projections = nn.ModuleList([
nn.Linear(actual_vocab_dim, actual_vocab_dim, bias=False) for _ in range(5)
])
# Re-initialize the new layers
nn.init.xavier_uniform_(self.geometric_proj.to_vocab_space.weight)
for proj in self.geometric_proj.vertex_projections:
nn.init.xavier_uniform_(proj.weight)
if actual_vocab_dim != self.dim:
nn.init.xavier_uniform_(self.cls_tokens.vocab_to_model.weight)
print(f"✓ Successfully initialized {self.num_classes} class pentachora from vocabulary")
print(f" Vocabulary dimension: {actual_vocab_dim}")
print(f" Model internal dimension: {self.dim}")
except Exception as e:
print(f"Error initializing from vocabulary: {e}")
print("Using random initialization")
def _get_cifar100_classes(self):
"""Get CIFAR-100 class names."""
return [
'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
]
def forward_features(self, x: torch.Tensor, class_indices: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
"""Extract features from input."""
B = x.shape[0]
# Patch embedding
x = self.patch_embed(x)
x = x + self.pos_embed
x = self.pos_drop(x)
# Get hierarchical CLS tokens
global_cls, vertex_cls = self.cls_tokens(B, class_indices)
# Concatenate CLS tokens with patches
x = torch.cat([global_cls, vertex_cls, x], dim=1)
# Apply transformer blocks
for i, block in enumerate(self.blocks):
preserve = i < self.preserve_structure_until_layer
x = block(x, preserve_structure=preserve)
# Apply final norm
x = self.norm(x)
# Split tokens
global_cls = x[:, 0]
vertex_cls = x[:, 1:6]
patches = x[:, 6:]
return {
'global_cls': global_cls,
'vertex_cls': vertex_cls,
'patches': patches
}
def forward(self, x: torch.Tensor, targets: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
"""Forward pass through the model."""
# During training, use target labels for class-specific CLS initialization
class_indices = targets if self.training and targets is not None else None
features = self.forward_features(x, class_indices)
# Primary classification using prototype matching
if self.use_prototype_classifier:
prototypes = self.cls_tokens.get_class_prototypes()
prototypes = F.normalize(prototypes, dim=-1)
global_cls_norm = F.normalize(features['global_cls'], dim=-1)
logits = torch.matmul(global_cls_norm, prototypes.T) * 20.0
else:
logits = self.head(features['global_cls'])
# Auxiliary classification using vertex tokens
B = features['vertex_cls'].shape[0]
vertex_flat = features['vertex_cls'].reshape(B, -1)
aux_logits = self.head_aux(vertex_flat)
# Geometric alignment scores
geometric_alignments = self.geometric_proj(
features['patches'],
self.cls_tokens.class_pentachora
)
return {
'logits': logits,
'aux_logits': aux_logits,
'geometric_alignments': geometric_alignments,
'vertex_cls': features['vertex_cls'],
'global_cls': features['global_cls'],
'patches': features['patches']
}
# ============================================
# LOSS FUNCTIONS
# ============================================
class PentachoraLoss(nn.Module):
"""Combined loss for PentachoraViT training."""
def __init__(self, aux_weight: float = 0.3, geo_weight: float = 0.1,
smoothing: float = 0.0):
super().__init__()
self.aux_weight = aux_weight
self.geo_weight = geo_weight
self.criterion = nn.CrossEntropyLoss(label_smoothing=smoothing)
def forward(self, outputs: Dict[str, torch.Tensor], targets: torch.Tensor) -> torch.Tensor:
"""Compute combined loss."""
# Primary classification loss
loss = self.criterion(outputs['logits'], targets)
# Auxiliary loss from vertex tokens
if 'aux_logits' in outputs and self.aux_weight > 0:
aux_loss = self.criterion(outputs['aux_logits'], targets)
loss = loss + self.aux_weight * aux_loss
# Geometric alignment loss
if 'geometric_alignments' in outputs and self.geo_weight > 0:
geo_logits = outputs['geometric_alignments'].mean(dim=1)
geo_loss = self.criterion(geo_logits, targets)
loss = loss + self.geo_weight * geo_loss
return loss
# ============================================
# MODEL REGISTRY AND BUILDERS
# ============================================
MODEL_CONFIGS = {
'pentachora_spark_xs': PentachoraConfig(
dim=100, depth=2, heads=10, mlp_ratio=4.0,
preserve_structure_until_layer=1,
dropout_rate=0.0, drop_path_rate=0.0
),
'pentachora_spark': PentachoraConfig(
dim=100, depth=5, heads=4, mlp_ratio=4.0,
preserve_structure_until_layer=1,
dropout_rate=0.0, drop_path_rate=0.0
),
'pentachora_shock': PentachoraConfig(
dim=100, depth=10, heads=5, mlp_ratio=4.0,
patch_size=5, preserve_structure_until_layer=5,
dropout_rate=0.0, drop_path_rate=0.0
),
'pentachora_shock_xs_32d': PentachoraConfig(
dim=32, depth=2, heads=8, mlp_ratio=4.0,
preserve_structure_until_layer=4,
dropout_rate=0.0, drop_path_rate=0.0
),
'pentachora_shock_xs_64d': PentachoraConfig(
dim=64, depth=2, heads=8, mlp_ratio=1.0,
preserve_structure_until_layer=4,
dropout_rate=0.0, drop_path_rate=0.0
),
'pentachora_shock_xs_128d': PentachoraConfig(
dim=128, depth=2, heads=8, mlp_ratio=2.0,
preserve_structure_until_layer=4,
vocab_dim=256,
dropout_rate=0.0, drop_path_rate=0.0
),
'vit_pixie_256_patch4': PentachoraConfig(
dim=256, depth=10, heads=16, mlp_ratio=1.0,
preserve_structure_until_layer=10,
vocab_dim=256, patch_size=4,
dropout_rate=0.0, drop_path_rate=0.0
),
'vit_pixie_256_patch2': PentachoraConfig(
dim=256, depth=10, heads=16, mlp_ratio=1.0,
preserve_structure_until_layer=10,
vocab_dim=256, patch_size=2,
dropout_rate=0.0, drop_path_rate=0.0
),
'pentachora_shock_xs_256d': PentachoraConfig(
dim=256, depth=2, heads=8, mlp_ratio=4.0,
preserve_structure_until_layer=4,
vocab_dim=128,
dropout_rate=0.0, drop_path_rate=0.0
),
'pentachora_shock_xs_512d': PentachoraConfig(
dim=512, depth=2, heads=8, mlp_ratio=4.0,
preserve_structure_until_layer=4,
dropout_rate=0.0, drop_path_rate=0.0
),
'pentachora_tiny': PentachoraConfig(
dim=384, depth=12, heads=6, mlp_ratio=4.0,
preserve_structure_until_layer=6,
dropout_rate=0.1, drop_path_rate=0.1
),
'pentachora_small': PentachoraConfig(
dim=512, depth=12, heads=8, mlp_ratio=4.0,
preserve_structure_until_layer=6,
dropout_rate=0.1, drop_path_rate=0.1
),
'pentachora_base': PentachoraConfig(
dim=768, depth=12, heads=12, mlp_ratio=4.0,
preserve_structure_until_layer=8,
dropout_rate=0.1, drop_path_rate=0.2
),
'pentachora_large': PentachoraConfig(
dim=1024, depth=24, heads=16, mlp_ratio=4.0,
preserve_structure_until_layer=12,
dropout_rate=0.1, drop_path_rate=0.3
),
}
def create_pentachora_vit(variant: str = 'pentachora_small',
pretrained: bool = False,
**kwargs) -> PentachoraViT:
"""Create PentachoraViT model."""
if variant not in MODEL_CONFIGS:
raise ValueError(f"Unknown variant: {variant}. Choose from {list(MODEL_CONFIGS.keys())}")
config = MODEL_CONFIGS[variant]
# Override config with kwargs
for key, value in kwargs.items():
setattr(config, key, value)
model = PentachoraViT(config)
if pretrained:
warnings.warn("Pretrained weights not available yet")
return model
# Convenience functions for each variant
def pentachora_vit_spark_tiny(pretrained: bool = False, **kwargs) -> PentachoraViT:
"""Create spark variant (smallest)."""
return create_pentachora_vit('pentachora_spark_xs', pretrained=pretrained, **kwargs)
def pentachora_shock_xs_64d(pretrained: bool = False, **kwargs) -> PentachoraViT:
"""Create shock xs 64d variant."""
return create_pentachora_vit('pentachora_shock_xs_64d', pretrained=pretrained, **kwargs)
def pentachora_vit_spark(pretrained: bool = False, **kwargs) -> PentachoraViT:
"""Create spark variant."""
return create_pentachora_vit('pentachora_spark', pretrained=pretrained, **kwargs)
def pentachora_shock_xs_32d(pretrained: bool = False, **kwargs) -> PentachoraViT:
"""Create shock xs 32d variant."""
return create_pentachora_vit('pentachora_shock_xs_32d', pretrained=pretrained, **kwargs)
def pentachora_shock_xs_256d(pretrained: bool = False, **kwargs) -> PentachoraViT:
"""Create shock xs 256d variant."""
return create_pentachora_vit('pentachora_shock_xs_256d', pretrained=pretrained, **kwargs)
def pentachora_shock_xs_512d(pretrained: bool = False, **kwargs) -> PentachoraViT:
"""Create shock xs 512d variant."""
return create_pentachora_vit('pentachora_shock_xs_512d', pretrained=pretrained, **kwargs)
def pentachora_vit_shock(pretrained: bool = False, **kwargs) -> PentachoraViT:
"""Create shock variant."""
return create_pentachora_vit('pentachora_shock', pretrained=pretrained, **kwargs)
def pentachora_vit_tiny(pretrained: bool = False, **kwargs) -> PentachoraViT:
"""Create tiny variant."""
return create_pentachora_vit('pentachora_tiny', pretrained=pretrained, **kwargs)
def pentachora_vit_small(pretrained: bool = False, **kwargs) -> PentachoraViT:
"""Create small variant."""
return create_pentachora_vit('pentachora_small', pretrained=pretrained, **kwargs)
def pentachora_vit_base(pretrained: bool = False, **kwargs) -> PentachoraViT:
"""Create base variant."""
return create_pentachora_vit('pentachora_base', pretrained=pretrained, **kwargs)
def pentachora_vit_large(pretrained: bool = False, **kwargs) -> PentachoraViT:
"""Create large variant."""
return create_pentachora_vit('pentachora_large', pretrained=pretrained, **kwargs)
# ============================================
# TRAINING UTILITIES
# ============================================
def get_parameter_groups(model: PentachoraViT,
weight_decay: float = 0.05) -> List[Dict[str, Any]]:
"""Get parameter groups for optimizer with weight decay handling."""
no_decay = ['bias', 'norm', 'LayerNorm']
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if any(nd in name for nd in no_decay):
no_decay_params.append(param)
else:
decay_params.append(param)
return [
{'params': decay_params, 'weight_decay': weight_decay},
{'params': no_decay_params, 'weight_decay': 0.0}
]
def count_parameters(model: nn.Module) -> Dict[str, int]:
"""Count model parameters."""
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
return {
'total': total,
'trainable': trainable,
'non_trainable': total - trainable
}
# ============================================
# INFERENCE UTILITIES
# ============================================
@torch.no_grad()
def extract_features(model: PentachoraViT,
images: torch.Tensor,
feature_type: str = 'global_cls') -> torch.Tensor:
"""Extract features from images using the model."""
model.eval()
features = model.forward_features(images)
return features.get(feature_type, features['global_cls'])
# ============================================
# EXAMPLE USAGE AND TESTING
# ============================================
def test_model():
"""Test model creation and forward pass."""
print("Testing Fixed PentachoraViT Model")
print("=" * 50)
# Test different variants
variants = ['pentachora_spark', 'pentachora_shock_xs_256d', 'pentachora_small']
for variant in variants:
print(f"\nTesting {variant}:")
# Create model with vocab_dim
model = create_pentachora_vit(
variant=variant,
img_size=32,
patch_size=4,
num_classes=100,
vocab_dim=64
)
# Count parameters
params = count_parameters(model)
print(f" Total parameters: {params['total']:,}")
print(f" Trainable parameters: {params['trainable']:,}")
# Test forward pass
x = torch.randn(2, 3, 32, 32)
# Time the forward pass
if torch.cuda.is_available():
model = model.cuda()
x = x.cuda()
torch.cuda.synchronize()
import time
start = time.time()
outputs = model(x)
if torch.cuda.is_available():
torch.cuda.synchronize()
end = time.time()
print(f" Output shapes:")
print(f" Logits: {outputs['logits'].shape}")
print(f" Aux logits: {outputs['aux_logits'].shape}")
print(f" Geometric alignments: {outputs['geometric_alignments'].shape}")
print(f" Forward pass time: {(end - start)*1000:.2f}ms")
# Test loss computation
loss_fn = PentachoraLoss()
targets = torch.randint(0, 100, (2,))
if torch.cuda.is_available():
targets = targets.cuda()
loss = loss_fn(outputs, targets)
print(f" Loss: {loss.item():.4f}")
print("\n" + "=" * 50)
print("All tests passed!")
if __name__ == "__main__":
# Run tests
test_model()
# Example: Create model for training
print("\nExample: Creating model with proper initialization")
model = pentachora_shock_xs_256d(
img_size=32,
num_classes=100,
vocab_dim=100,
dropout_rate=0.0,
drop_path_rate=0.0
)
# All parameters are initialized immediately
print(f"Model has {count_parameters(model)['total']:,} parameters")
print("All geometric parameters initialized at creation time")
# Move model to CUDA if available
if torch.cuda.is_available():
model = model.cuda()
print("Model moved to CUDA")
# Now torch.compile should work without issues
if hasattr(torch, 'compile'):
print("Compiling model with torch.compile...")
try:
model = torch.compile(model)
print("✓ Model compiled successfully")
except Exception as e:
print(f"Compilation warning: {e}")
print("Continuing without compilation")
print("\nModel ready for training with all parameters properly initialized!")