File size: 10,219 Bytes
c1e8b54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 |
# ============================================
# PentachoraViT CIFAR-100 Evaluation
# ============================================
import torch
import torch.nn.functional as F
from collections import defaultdict
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
def evaluate_pentachora_vit(model, test_loader, device='cuda'):
"""Properly evaluate PentachoraViT model."""
model.eval()
# Get class names
class_names = get_cifar100_class_names()
# Check model configuration
print(f"Model Configuration:")
print(f" Internal dim: {model.dim}")
print(f" Vocab dim: {model.vocab_dim}")
print(f" Num classes: {model.num_classes}")
# Get the class crystals
if hasattr(model, 'cls_tokens') and hasattr(model.cls_tokens, 'class_pentachora'):
crystals = model.cls_tokens.class_pentachora # [100, 5, vocab_dim]
print(f" Crystal shape: {crystals.shape}")
else:
print(" No crystals found!")
return None
# Track metrics
all_predictions = []
all_targets = []
all_confidences = []
geometric_alignments_by_class = defaultdict(list)
aux_predictions = []
with torch.no_grad():
for images, targets in tqdm(test_loader, desc="Evaluating"):
images = images.to(device)
targets = targets.to(device)
# Get model outputs dictionary
outputs = model(images)
# Main predictions from primary head
logits = outputs['logits'] # [batch, 100]
probs = F.softmax(logits, dim=1)
confidence, predicted = torch.max(probs, 1)
# Store predictions
all_predictions.extend(predicted.cpu().numpy())
all_targets.extend(targets.cpu().numpy())
all_confidences.extend(confidence.cpu().numpy())
# Auxiliary predictions
if 'aux_logits' in outputs:
aux_probs = F.softmax(outputs['aux_logits'], dim=1)
_, aux_pred = torch.max(aux_probs, 1)
aux_predictions.extend(aux_pred.cpu().numpy())
# Geometric alignments - these show how patches align with class crystals
if 'geometric_alignments' in outputs:
# Shape: [batch, num_patches, num_classes]
geo_align = outputs['geometric_alignments']
# Average over patches to get per-sample class alignments
geo_align_mean = geo_align.mean(dim=1) # [batch, num_classes]
for i, target_class in enumerate(targets):
class_idx = target_class.item()
# Store alignment score for the true class
geometric_alignments_by_class[class_idx].append(
geo_align_mean[i, class_idx].item()
)
# Convert to numpy arrays
all_predictions = np.array(all_predictions)
all_targets = np.array(all_targets)
all_confidences = np.array(all_confidences)
# Calculate per-class metrics
class_results = []
for class_idx in range(len(class_names)):
mask = all_targets == class_idx
if mask.sum() == 0:
continue
class_preds = all_predictions[mask]
correct = (class_preds == class_idx).sum()
total = mask.sum()
accuracy = 100.0 * correct / total
# Average confidence for this class
class_conf = all_confidences[mask].mean()
# Geometric alignment for this class
geo_align = np.mean(geometric_alignments_by_class[class_idx]) if geometric_alignments_by_class[class_idx] else 0
# Crystal statistics
class_crystal = crystals[class_idx].detach().cpu() # [5, vocab_dim]
vertex_variance = class_crystal.var(dim=0).mean().item()
# Crystal norm (average magnitude)
crystal_norm = class_crystal.norm(dim=-1).mean().item()
class_results.append({
'class_idx': class_idx,
'class_name': class_names[class_idx],
'accuracy': accuracy,
'correct': int(correct),
'total': int(total),
'avg_confidence': class_conf,
'geometric_alignment': geo_align,
'vertex_variance': vertex_variance,
'crystal_norm': crystal_norm
})
# Sort by accuracy
class_results.sort(key=lambda x: x['accuracy'], reverse=True)
# Overall metrics
overall_acc = 100.0 * (all_predictions == all_targets).mean()
# Auxiliary head accuracy if available
aux_acc = None
if aux_predictions:
aux_predictions = np.array(aux_predictions)
aux_acc = 100.0 * (aux_predictions == all_targets).mean()
# Print results
print(f"\n" + "="*80)
print(f"EVALUATION RESULTS")
print(f"="*80)
print(f"\nOverall Accuracy: {overall_acc:.2f}%")
if aux_acc:
print(f"Auxiliary Head Accuracy: {aux_acc:.2f}%")
# Top 10 classes
print(f"\nTop 10 Classes:")
print(f"{'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12}")
print("-"*70)
for r in class_results[:10]:
print(f"{r['class_name']:<20} {r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} "
f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f}")
# Bottom 10 classes
print(f"\nBottom 10 Classes:")
print(f"{'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12}")
print("-"*70)
for r in class_results[-10:]:
print(f"{r['class_name']:<20} {r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} "
f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f}")
# Analyze correlations
accuracies = [r['accuracy'] for r in class_results]
geo_aligns = [r['geometric_alignment'] for r in class_results]
crystal_norms = [r['crystal_norm'] for r in class_results]
vertex_vars = [r['vertex_variance'] for r in class_results]
print(f"\nCorrelations with Accuracy:")
print(f" Geometric Alignment: {np.corrcoef(accuracies, geo_aligns)[0,1]:.3f}")
print(f" Crystal Norm: {np.corrcoef(accuracies, crystal_norms)[0,1]:.3f}")
print(f" Vertex Variance: {np.corrcoef(accuracies, vertex_vars)[0,1]:.3f}")
# Visualizations
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# 1. Accuracy distribution
ax = axes[0, 0]
ax.hist(accuracies, bins=20, edgecolor='black', alpha=0.7)
ax.axvline(overall_acc, color='red', linestyle='--', label=f'Overall: {overall_acc:.1f}%')
ax.set_xlabel('Accuracy (%)')
ax.set_ylabel('Count')
ax.set_title('Per-Class Accuracy Distribution')
ax.legend()
ax.grid(True, alpha=0.3)
# 2. Accuracy vs Geometric Alignment
ax = axes[0, 1]
scatter = ax.scatter(geo_aligns, accuracies, c=crystal_norms, cmap='viridis', alpha=0.6)
ax.set_xlabel('Geometric Alignment Score')
ax.set_ylabel('Accuracy (%)')
ax.set_title('Accuracy vs Geometric Alignment\n(color = crystal norm)')
plt.colorbar(scatter, ax=ax)
ax.grid(True, alpha=0.3)
# 3. Crystal Analysis
ax = axes[1, 0]
ax.scatter(crystal_norms, accuracies, alpha=0.6)
ax.set_xlabel('Crystal Norm (avg magnitude)')
ax.set_ylabel('Accuracy (%)')
ax.set_title('Accuracy vs Crystal Norm')
ax.grid(True, alpha=0.3)
# 4. Top/Bottom comparison
ax = axes[1, 1]
top10_acc = [r['accuracy'] for r in class_results[:10]]
bottom10_acc = [r['accuracy'] for r in class_results[-10:]]
top10_geo = [r['geometric_alignment'] for r in class_results[:10]]
bottom10_geo = [r['geometric_alignment'] for r in class_results[-10:]]
x = np.arange(10)
width = 0.35
ax.bar(x - width/2, top10_acc, width, label='Top 10 Accuracy', color='green', alpha=0.7)
ax.bar(x + width/2, bottom10_acc, width, label='Bottom 10 Accuracy', color='red', alpha=0.7)
ax.set_xlabel('Rank within group')
ax.set_ylabel('Accuracy (%)')
ax.set_title('Top 10 vs Bottom 10 Classes')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# ===================================================================================
# FULL 100-CLASS DIAGNOSTIC SPECTRUM (SORTED BY CLASS IDX FOR CONSISTENCY)
# ===================================================================================
print(f"\n{'='*90}")
print("Sparky — Full Class Spectrum")
print(f"{'='*90}")
print(f"{'Idx':<5} {'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12} {'Variance':<10}")
print("-" * 90)
for r in sorted(class_results, key=lambda x: x['class_idx']):
print(f"{r['class_idx']:<5} {r['class_name']:<20} "
f"{r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} "
f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f} "
f"{r['vertex_variance']:>9.8f}")
return class_results, overall_acc
# Run evaluation
if 'model' in globals():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
_, test_loader = get_cifar100_dataloaders(batch_size=100)
results, overall = evaluate_pentachora_vit(model, test_loader, device)
# Additional crystal analysis
print("\nCrystal Geometry Analysis:")
print("-"*50)
# Get crystals
crystals = model.cls_tokens.class_pentachora.detach().cpu()
# Compute pairwise similarities between class crystals
crystals_flat = crystals.mean(dim=1) # Average over 5 vertices
crystals_norm = F.normalize(crystals_flat, dim=1)
similarities = torch.matmul(crystals_norm, crystals_norm.T)
# Find confused pairs (high similarity, both low accuracy)
print("\nMost similar classes with poor performance:")
for i in range(100):
for j in range(i+1, 100):
if results[i]['accuracy'] < 20 and results[j]['accuracy'] < 20:
sim = similarities[results[i]['class_idx'], results[j]['class_idx']].item()
if sim > 0.5:
print(f" {results[i]['class_name']:<15} ({results[i]['accuracy']:.1f}%) ↔ "
f"{results[j]['class_name']:<15} ({results[j]['accuracy']:.1f}%) : {sim:.3f}")
else:
print("No model found in memory!") |