AbstractPhil commited on
Commit
d7a1a55
·
verified ·
1 Parent(s): 7009caa

Update penta_vit_model_v1.py

Browse files
Files changed (1) hide show
  1. penta_vit_model_v1.py +72 -74
penta_vit_model_v1.py CHANGED
@@ -1,9 +1,7 @@
1
  """
2
  PentachoraViT: Vision Transformer with Pentachoron Geometric Structure
3
  Enhanced with Geometric Attention for improved head cohesion and generalization
4
-
5
- Author: AbstractPhil
6
-
7
  """
8
 
9
  import torch
@@ -44,7 +42,7 @@ class PentachoraConfig:
44
  return (self.img_size // self.patch_size) ** 2
45
 
46
  # ============================================
47
- # GEOMETRIC ATTENTION COMPONENTS (OPTIMIZED)
48
  # ============================================
49
 
50
  def perfect_4simplex(device):
@@ -70,42 +68,42 @@ class GeometricConfig:
70
  fuse_alpha: float = 0.7
71
  phases: Tuple[float, ...] = (0.0, math.pi/2, math.pi, 3*math.pi/2)
72
  jitter: float = 0.02
73
- shift: float = 0.25
74
  rotate_cycle: int = 11
75
  use_phase_variance: bool = False
76
  geometry_type: str = "pentachoron"
77
 
78
  class GeometricNavigator(nn.Module):
79
- """Maps inputs to geometric regions in 4D space - OPTIMIZED with vectorized operations."""
80
 
81
- def __init__(self, input_dim: int, num_regions: int, config: GeometricConfig, num_heads: int = 1):
82
  super().__init__()
83
  self.input_dim = input_dim
84
  self.num_regions = num_regions
85
  self.config = config
86
  self.num_heads = num_heads
87
 
 
 
 
 
88
  # Create separate parameters for each head if num_heads > 1
89
  if num_heads > 1:
90
- self.to_nav = nn.Parameter(torch.randn(num_heads, input_dim, 4) * 0.02)
91
- self.vertex_w = nn.Parameter(torch.zeros(num_heads, num_regions, 5))
92
  else:
93
  self.to_nav = nn.Linear(input_dim, 4, bias=False)
94
- self.vertex_w = nn.Parameter(torch.zeros(num_regions, 5))
95
 
96
  # Pre-compute phase tensors for vectorization
97
- self.register_buffer('phase_cos', torch.cos(torch.tensor(config.phases, dtype=torch.float32)))
98
- self.register_buffer('phase_sin', torch.sin(torch.tensor(config.phases, dtype=torch.float32)))
99
 
100
- # Initialize geometry after module is created
101
- self.register_parameter('D', None)
102
- self.register_parameter('S', None)
103
-
104
- def _lazy_init_geometry(self, device):
105
- """Initialize geometry on first forward pass."""
106
- if self.D is not None:
107
- return
108
 
 
 
109
  base = perfect_4simplex(device)
110
 
111
  if self.num_heads > 1:
@@ -143,8 +141,6 @@ class GeometricNavigator(nn.Module):
143
 
144
  def navigate(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
145
  """Navigate inputs through geometric space - OPTIMIZED with vectorized phase computation."""
146
- self._lazy_init_geometry(x.device)
147
-
148
  if self.num_heads > 1:
149
  # Batched navigation for multiple heads
150
  BT, H, head_dim = x.shape
@@ -159,8 +155,8 @@ class GeometricNavigator(nn.Module):
159
  s_disp = -softmin_over_last(d_disp, self.config.softmin_tau)
160
 
161
  # OPTIMIZED: Vectorized phase computation (no loop)
162
- cos_phases = self.phase_cos.to(x.device).view(-1, 1, 1, 1, 1)
163
- sin_phases = self.phase_sin.to(x.device).view(-1, 1, 1, 1, 1)
164
 
165
  # Compute all phase variants at once [phases, H, regions, 5, 4]
166
  Vt_all = cos_phases * self.D.unsqueeze(0) + sin_phases * self.S.unsqueeze(0)
@@ -193,8 +189,8 @@ class GeometricNavigator(nn.Module):
193
  w = F.softmax(self.vertex_w, dim=1)
194
 
195
  # OPTIMIZED: Vectorized phase computation for single head
196
- cos_phases = self.phase_cos.to(x.device).view(-1, 1, 1, 1)
197
- sin_phases = self.phase_sin.to(x.device).view(-1, 1, 1, 1)
198
 
199
  Vt_all = cos_phases * self.D.unsqueeze(0) + sin_phases * self.S.unsqueeze(0)
200
  w_expanded = w.unsqueeze(0).unsqueeze(-1)
@@ -217,10 +213,10 @@ class GeometricNavigator(nn.Module):
217
  return {'scores': scores, 'diagnostics': diagnostics}
218
 
219
  class GeometricAttention(nn.Module):
220
- """Multi-head geometric attention with Q-K alignment - OPTIMIZED with batched processing."""
221
 
222
  def __init__(self, dim: int, num_heads: int = 8, num_regions: Optional[int] = None,
223
- config: Optional[GeometricConfig] = None, dropout: float = 0.0):
224
  super().__init__()
225
  self.dim = dim
226
  self.num_heads = num_heads
@@ -234,9 +230,9 @@ class GeometricAttention(nn.Module):
234
  self.config = config
235
  self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
236
 
237
- # Create batched navigators
238
- self.q_navigator = GeometricNavigator(self.head_dim, num_regions, config, num_heads=num_heads)
239
- self.k_navigator = GeometricNavigator(self.head_dim, num_regions, config, num_heads=num_heads)
240
 
241
  self.out_proj = nn.Linear(dim, dim)
242
  self.dropout = nn.Dropout(dropout)
@@ -342,10 +338,13 @@ class HierarchicalPentachoronCLS(nn.Module):
342
 
343
  def forward(self, batch_size: int, class_indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
344
  """Generate CLS tokens for batch."""
 
 
 
345
  if class_indices is not None and class_indices.shape[0] == batch_size:
346
- vertex_cls_vocab = self.class_pentachora[class_indices]
347
  else:
348
- vertex_cls_vocab = self.class_pentachora.mean(dim=0, keepdim=True)
349
  vertex_cls_vocab = vertex_cls_vocab.expand(batch_size, -1, -1)
350
 
351
  # Project from vocabulary dimension to model dimension
@@ -362,7 +361,8 @@ class HierarchicalPentachoronCLS(nn.Module):
362
 
363
  def get_class_prototypes(self) -> torch.Tensor:
364
  """Get class prototypes in model dimension."""
365
- pentachora_model = self.vocab_to_model(self.class_pentachora)
 
366
  weights = F.softmax(self.vertex_weights, dim=0)
367
  prototypes = torch.einsum('cvd,v->cd', pentachora_model, weights)
368
  return prototypes
@@ -453,7 +453,7 @@ class PentachoronViTBlock(nn.Module):
453
  """ViT block with geometric attention for structured layers."""
454
  def __init__(self, dim: int, heads: int = 8, mlp_ratio: float = 4.0,
455
  use_mesh: bool = True, dropout: float = 0., attn_dropout: float = 0.,
456
- drop_path: float = 0.):
457
  super().__init__()
458
  self.norm1 = nn.LayerNorm(dim)
459
 
@@ -464,7 +464,8 @@ class PentachoronViTBlock(nn.Module):
464
  num_heads=heads,
465
  num_regions=min(dim // heads, 16),
466
  config=GeometricConfig(),
467
- dropout=attn_dropout
 
468
  )
469
  else:
470
  # Standard multi-head attention for later layers
@@ -578,7 +579,8 @@ class PentachoraViT(nn.Module):
578
  use_mesh=(cfg.use_mesh_attention and i < cfg.preserve_structure_until_layer),
579
  dropout=cfg.dropout_rate,
580
  attn_dropout=cfg.dropout_rate,
581
- drop_path=dpr[i]
 
582
  )
583
  for i in range(cfg.depth)
584
  ])
@@ -750,10 +752,10 @@ class PentachoraViT(nn.Module):
750
  vertex_flat = features['vertex_cls'].reshape(B, -1)
751
  aux_logits = self.head_aux(vertex_flat)
752
 
753
- # Geometric alignment scores - use class_pentachora directly
754
  geometric_alignments = self.geometric_proj(
755
  features['patches'],
756
- self.cls_tokens.class_pentachora # Back to original
757
  )
758
 
759
  return {
@@ -822,18 +824,32 @@ MODEL_CONFIGS = {
822
  dropout_rate=0.0, drop_path_rate=0.0
823
  ),
824
  'pentachora_shock_xs_64d': PentachoraConfig(
825
- dim=64, depth=2, heads=8, mlp_ratio=4.0,
826
  preserve_structure_until_layer=4,
827
  dropout_rate=0.0, drop_path_rate=0.0
828
  ),
829
  'pentachora_shock_xs_128d': PentachoraConfig(
830
- dim=128, depth=2, heads=8, mlp_ratio=4.0,
831
  preserve_structure_until_layer=4,
 
 
 
 
 
 
 
 
 
 
 
 
 
832
  dropout_rate=0.0, drop_path_rate=0.0
833
  ),
834
  'pentachora_shock_xs_256d': PentachoraConfig(
835
  dim=256, depth=2, heads=8, mlp_ratio=4.0,
836
- preserve_structure_until_layer=4,
 
837
  dropout_rate=0.0, drop_path_rate=0.0
838
  ),
839
  'pentachora_shock_xs_512d': PentachoraConfig(
@@ -983,7 +999,7 @@ def extract_features(model: PentachoraViT,
983
 
984
  def test_model():
985
  """Test model creation and forward pass."""
986
- print("Testing Optimized PentachoraViT Model")
987
  print("=" * 50)
988
 
989
  # Test different variants
@@ -1041,10 +1057,10 @@ def test_model():
1041
 
1042
  if __name__ == "__main__":
1043
  # Run tests
1044
- #test_model()
1045
 
1046
- # Example: Create model for A100 training
1047
- print("\nExample: Creating optimized model for A100 training")
1048
  model = pentachora_shock_xs_256d(
1049
  img_size=32,
1050
  num_classes=100,
@@ -1053,41 +1069,23 @@ if __name__ == "__main__":
1053
  drop_path_rate=0.0
1054
  )
1055
 
1056
- # Move model to CUDA first if available
 
 
 
 
1057
  if torch.cuda.is_available():
1058
  model = model.cuda()
1059
  print("Model moved to CUDA")
1060
 
1061
- # Now try torch.compile (PyTorch 2.0+)
1062
- # Model reformatted to allow eager compiling, speeds along training substantially.
1063
  if hasattr(torch, 'compile'):
1064
  print("Compiling model with torch.compile...")
1065
  try:
1066
- model = torch.compile(model, backend="eager")
1067
- print("Model compiled successfully")
1068
  except Exception as e:
1069
  print(f"Compilation warning: {e}")
1070
- print("Continuing without compilation - vectorized ops will still provide speedup")
1071
-
1072
- # Get parameter groups for optimizer
1073
- param_groups = get_parameter_groups(model, weight_decay=0.05)
1074
- print(f"Number of parameter groups: {len(param_groups)}")
1075
-
1076
- # Example batch - FULL PRECISION
1077
- images = torch.randn(4, 3, 32, 32)
1078
- targets = torch.randint(0, 100, (4,))
1079
-
1080
- if torch.cuda.is_available():
1081
- images = images.cuda()
1082
- targets = targets.cuda()
1083
-
1084
- # Forward pass in FULL PRECISION (no autocast)
1085
- outputs = model(images)
1086
-
1087
- # Compute loss
1088
- criterion = PentachoraLoss(aux_weight=0.3, geo_weight=0.1)
1089
- loss = criterion(outputs, targets)
1090
 
1091
- print(f"Training loss (full precision): {loss.item():.4f}")
1092
- print("\nModel ready for full precision A100 training!")
1093
- print("Eager initialization ensures all parameters are created upfront")
 
1
  """
2
  PentachoraViT: Vision Transformer with Pentachoron Geometric Structure
3
  Enhanced with Geometric Attention for improved head cohesion and generalization
4
+ FIXED: All parameters initialized at module creation time (no lazy init)
 
 
5
  """
6
 
7
  import torch
 
42
  return (self.img_size // self.patch_size) ** 2
43
 
44
  # ============================================
45
+ # GEOMETRIC ATTENTION COMPONENTS (FIXED INIT)
46
  # ============================================
47
 
48
  def perfect_4simplex(device):
 
68
  fuse_alpha: float = 0.7
69
  phases: Tuple[float, ...] = (0.0, math.pi/2, math.pi, 3*math.pi/2)
70
  jitter: float = 0.02
71
+ shift: float = 0.71
72
  rotate_cycle: int = 11
73
  use_phase_variance: bool = False
74
  geometry_type: str = "pentachoron"
75
 
76
  class GeometricNavigator(nn.Module):
77
+ """Maps inputs to geometric regions in 4D space - FIXED with immediate initialization."""
78
 
79
+ def __init__(self, input_dim: int, num_regions: int, config: GeometricConfig, num_heads: int = 1, device=None):
80
  super().__init__()
81
  self.input_dim = input_dim
82
  self.num_regions = num_regions
83
  self.config = config
84
  self.num_heads = num_heads
85
 
86
+ # Use CPU by default if device not specified
87
+ if device is None:
88
+ device = torch.device('cpu')
89
+
90
  # Create separate parameters for each head if num_heads > 1
91
  if num_heads > 1:
92
+ self.to_nav = nn.Parameter(torch.randn(num_heads, input_dim, 4, device=device) * 0.02)
93
+ self.vertex_w = nn.Parameter(torch.zeros(num_heads, num_regions, 5, device=device))
94
  else:
95
  self.to_nav = nn.Linear(input_dim, 4, bias=False)
96
+ self.vertex_w = nn.Parameter(torch.zeros(num_regions, 5, device=device))
97
 
98
  # Pre-compute phase tensors for vectorization
99
+ self.register_buffer('phase_cos', torch.cos(torch.tensor(config.phases, dtype=torch.float32, device=device)))
100
+ self.register_buffer('phase_sin', torch.sin(torch.tensor(config.phases, dtype=torch.float32, device=device)))
101
 
102
+ # Initialize geometry immediately at creation time
103
+ self._init_geometry(device)
 
 
 
 
 
 
104
 
105
+ def _init_geometry(self, device):
106
+ """Initialize geometry at module creation time."""
107
  base = perfect_4simplex(device)
108
 
109
  if self.num_heads > 1:
 
141
 
142
  def navigate(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
143
  """Navigate inputs through geometric space - OPTIMIZED with vectorized phase computation."""
 
 
144
  if self.num_heads > 1:
145
  # Batched navigation for multiple heads
146
  BT, H, head_dim = x.shape
 
155
  s_disp = -softmin_over_last(d_disp, self.config.softmin_tau)
156
 
157
  # OPTIMIZED: Vectorized phase computation (no loop)
158
+ cos_phases = self.phase_cos.view(-1, 1, 1, 1, 1)
159
+ sin_phases = self.phase_sin.view(-1, 1, 1, 1, 1)
160
 
161
  # Compute all phase variants at once [phases, H, regions, 5, 4]
162
  Vt_all = cos_phases * self.D.unsqueeze(0) + sin_phases * self.S.unsqueeze(0)
 
189
  w = F.softmax(self.vertex_w, dim=1)
190
 
191
  # OPTIMIZED: Vectorized phase computation for single head
192
+ cos_phases = self.phase_cos.view(-1, 1, 1, 1)
193
+ sin_phases = self.phase_sin.view(-1, 1, 1, 1)
194
 
195
  Vt_all = cos_phases * self.D.unsqueeze(0) + sin_phases * self.S.unsqueeze(0)
196
  w_expanded = w.unsqueeze(0).unsqueeze(-1)
 
213
  return {'scores': scores, 'diagnostics': diagnostics}
214
 
215
  class GeometricAttention(nn.Module):
216
+ """Multi-head geometric attention with Q-K alignment - FIXED with proper device handling."""
217
 
218
  def __init__(self, dim: int, num_heads: int = 8, num_regions: Optional[int] = None,
219
+ config: Optional[GeometricConfig] = None, dropout: float = 0.0, device=None):
220
  super().__init__()
221
  self.dim = dim
222
  self.num_heads = num_heads
 
230
  self.config = config
231
  self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
232
 
233
+ # Create batched navigators with device
234
+ self.q_navigator = GeometricNavigator(self.head_dim, num_regions, config, num_heads=num_heads, device=device)
235
+ self.k_navigator = GeometricNavigator(self.head_dim, num_regions, config, num_heads=num_heads, device=device)
236
 
237
  self.out_proj = nn.Linear(dim, dim)
238
  self.dropout = nn.Dropout(dropout)
 
338
 
339
  def forward(self, batch_size: int, class_indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
340
  """Generate CLS tokens for batch."""
341
+ # Get class-specific pentachora
342
+ class_pentachora = self.class_pentachora # This is now a computed property
343
+
344
  if class_indices is not None and class_indices.shape[0] == batch_size:
345
+ vertex_cls_vocab = class_pentachora[class_indices]
346
  else:
347
+ vertex_cls_vocab = class_pentachora.mean(dim=0, keepdim=True)
348
  vertex_cls_vocab = vertex_cls_vocab.expand(batch_size, -1, -1)
349
 
350
  # Project from vocabulary dimension to model dimension
 
361
 
362
  def get_class_prototypes(self) -> torch.Tensor:
363
  """Get class prototypes in model dimension."""
364
+ class_pentachora = self.class_pentachora # Get computed pentachora
365
+ pentachora_model = self.vocab_to_model(class_pentachora)
366
  weights = F.softmax(self.vertex_weights, dim=0)
367
  prototypes = torch.einsum('cvd,v->cd', pentachora_model, weights)
368
  return prototypes
 
453
  """ViT block with geometric attention for structured layers."""
454
  def __init__(self, dim: int, heads: int = 8, mlp_ratio: float = 4.0,
455
  use_mesh: bool = True, dropout: float = 0., attn_dropout: float = 0.,
456
+ drop_path: float = 0., device=None):
457
  super().__init__()
458
  self.norm1 = nn.LayerNorm(dim)
459
 
 
464
  num_heads=heads,
465
  num_regions=min(dim // heads, 16),
466
  config=GeometricConfig(),
467
+ dropout=attn_dropout,
468
+ device=device
469
  )
470
  else:
471
  # Standard multi-head attention for later layers
 
579
  use_mesh=(cfg.use_mesh_attention and i < cfg.preserve_structure_until_layer),
580
  dropout=cfg.dropout_rate,
581
  attn_dropout=cfg.dropout_rate,
582
+ drop_path=dpr[i],
583
+ device=torch.device('cpu') # Initialize on CPU, will be moved later
584
  )
585
  for i in range(cfg.depth)
586
  ])
 
752
  vertex_flat = features['vertex_cls'].reshape(B, -1)
753
  aux_logits = self.head_aux(vertex_flat)
754
 
755
+ # Geometric alignment scores
756
  geometric_alignments = self.geometric_proj(
757
  features['patches'],
758
+ self.cls_tokens.class_pentachora
759
  )
760
 
761
  return {
 
824
  dropout_rate=0.0, drop_path_rate=0.0
825
  ),
826
  'pentachora_shock_xs_64d': PentachoraConfig(
827
+ dim=64, depth=2, heads=8, mlp_ratio=1.0,
828
  preserve_structure_until_layer=4,
829
  dropout_rate=0.0, drop_path_rate=0.0
830
  ),
831
  'pentachora_shock_xs_128d': PentachoraConfig(
832
+ dim=128, depth=2, heads=8, mlp_ratio=2.0,
833
  preserve_structure_until_layer=4,
834
+ vocab_dim=256,
835
+ dropout_rate=0.0, drop_path_rate=0.0
836
+ ),
837
+ 'vit_pixie_256_patch4': PentachoraConfig(
838
+ dim=256, depth=10, heads=16, mlp_ratio=1.0,
839
+ preserve_structure_until_layer=10,
840
+ vocab_dim=256, patch_size=4,
841
+ dropout_rate=0.0, drop_path_rate=0.0
842
+ ),
843
+ 'vit_pixie_256_patch2': PentachoraConfig(
844
+ dim=256, depth=10, heads=16, mlp_ratio=1.0,
845
+ preserve_structure_until_layer=10,
846
+ vocab_dim=256, patch_size=2,
847
  dropout_rate=0.0, drop_path_rate=0.0
848
  ),
849
  'pentachora_shock_xs_256d': PentachoraConfig(
850
  dim=256, depth=2, heads=8, mlp_ratio=4.0,
851
+ preserve_structure_until_layer=4,
852
+ vocab_dim=128,
853
  dropout_rate=0.0, drop_path_rate=0.0
854
  ),
855
  'pentachora_shock_xs_512d': PentachoraConfig(
 
999
 
1000
  def test_model():
1001
  """Test model creation and forward pass."""
1002
+ print("Testing Fixed PentachoraViT Model")
1003
  print("=" * 50)
1004
 
1005
  # Test different variants
 
1057
 
1058
  if __name__ == "__main__":
1059
  # Run tests
1060
+ test_model()
1061
 
1062
+ # Example: Create model for training
1063
+ print("\nExample: Creating model with proper initialization")
1064
  model = pentachora_shock_xs_256d(
1065
  img_size=32,
1066
  num_classes=100,
 
1069
  drop_path_rate=0.0
1070
  )
1071
 
1072
+ # All parameters are initialized immediately
1073
+ print(f"Model has {count_parameters(model)['total']:,} parameters")
1074
+ print("All geometric parameters initialized at creation time")
1075
+
1076
+ # Move model to CUDA if available
1077
  if torch.cuda.is_available():
1078
  model = model.cuda()
1079
  print("Model moved to CUDA")
1080
 
1081
+ # Now torch.compile should work without issues
 
1082
  if hasattr(torch, 'compile'):
1083
  print("Compiling model with torch.compile...")
1084
  try:
1085
+ model = torch.compile(model)
1086
+ print("Model compiled successfully")
1087
  except Exception as e:
1088
  print(f"Compilation warning: {e}")
1089
+ print("Continuing without compilation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1090
 
1091
+ print("\nModel ready for training with all parameters properly initialized!")