AbstractPhil's picture
Rename benchmark.py to legacy/benchmark.py
a924676 verified
# ============================================
# PentachoraViT CIFAR-100 Evaluation
# ============================================
import torch
import torch.nn.functional as F
from collections import defaultdict
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
def evaluate_pentachora_vit(model, test_loader, device='cuda'):
"""Properly evaluate PentachoraViT model."""
model.eval()
# Get class names
class_names = get_cifar100_class_names()
# Check model configuration
print(f"Model Configuration:")
print(f" Internal dim: {model.dim}")
print(f" Vocab dim: {model.vocab_dim}")
print(f" Num classes: {model.num_classes}")
# Get the class crystals
if hasattr(model, 'cls_tokens') and hasattr(model.cls_tokens, 'class_pentachora'):
crystals = model.cls_tokens.class_pentachora # [100, 5, vocab_dim]
print(f" Crystal shape: {crystals.shape}")
else:
print(" No crystals found!")
return None
# Track metrics
all_predictions = []
all_targets = []
all_confidences = []
geometric_alignments_by_class = defaultdict(list)
aux_predictions = []
with torch.no_grad():
for images, targets in tqdm(test_loader, desc="Evaluating"):
images = images.to(device)
targets = targets.to(device)
# Get model outputs dictionary
outputs = model(images)
# Main predictions from primary head
logits = outputs['logits'] # [batch, 100]
probs = F.softmax(logits, dim=1)
confidence, predicted = torch.max(probs, 1)
# Store predictions
all_predictions.extend(predicted.cpu().numpy())
all_targets.extend(targets.cpu().numpy())
all_confidences.extend(confidence.cpu().numpy())
# Auxiliary predictions
if 'aux_logits' in outputs:
aux_probs = F.softmax(outputs['aux_logits'], dim=1)
_, aux_pred = torch.max(aux_probs, 1)
aux_predictions.extend(aux_pred.cpu().numpy())
# Geometric alignments - these show how patches align with class crystals
if 'geometric_alignments' in outputs:
# Shape: [batch, num_patches, num_classes]
geo_align = outputs['geometric_alignments']
# Average over patches to get per-sample class alignments
geo_align_mean = geo_align.mean(dim=1) # [batch, num_classes]
for i, target_class in enumerate(targets):
class_idx = target_class.item()
# Store alignment score for the true class
geometric_alignments_by_class[class_idx].append(
geo_align_mean[i, class_idx].item()
)
# Convert to numpy arrays
all_predictions = np.array(all_predictions)
all_targets = np.array(all_targets)
all_confidences = np.array(all_confidences)
# Calculate per-class metrics
class_results = []
for class_idx in range(len(class_names)):
mask = all_targets == class_idx
if mask.sum() == 0:
continue
class_preds = all_predictions[mask]
correct = (class_preds == class_idx).sum()
total = mask.sum()
accuracy = 100.0 * correct / total
# Average confidence for this class
class_conf = all_confidences[mask].mean()
# Geometric alignment for this class
geo_align = np.mean(geometric_alignments_by_class[class_idx]) if geometric_alignments_by_class[class_idx] else 0
# Crystal statistics
class_crystal = crystals[class_idx].detach().cpu() # [5, vocab_dim]
vertex_variance = class_crystal.var(dim=0).mean().item()
# Crystal norm (average magnitude)
crystal_norm = class_crystal.norm(dim=-1).mean().item()
class_results.append({
'class_idx': class_idx,
'class_name': class_names[class_idx],
'accuracy': accuracy,
'correct': int(correct),
'total': int(total),
'avg_confidence': class_conf,
'geometric_alignment': geo_align,
'vertex_variance': vertex_variance,
'crystal_norm': crystal_norm
})
# Sort by accuracy
class_results.sort(key=lambda x: x['accuracy'], reverse=True)
# Overall metrics
overall_acc = 100.0 * (all_predictions == all_targets).mean()
# Auxiliary head accuracy if available
aux_acc = None
if aux_predictions:
aux_predictions = np.array(aux_predictions)
aux_acc = 100.0 * (aux_predictions == all_targets).mean()
# Print results
print(f"\n" + "="*80)
print(f"EVALUATION RESULTS")
print(f"="*80)
print(f"\nOverall Accuracy: {overall_acc:.2f}%")
if aux_acc:
print(f"Auxiliary Head Accuracy: {aux_acc:.2f}%")
# Top 10 classes
print(f"\nTop 10 Classes:")
print(f"{'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12}")
print("-"*70)
for r in class_results[:10]:
print(f"{r['class_name']:<20} {r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} "
f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f}")
# Bottom 10 classes
print(f"\nBottom 10 Classes:")
print(f"{'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12}")
print("-"*70)
for r in class_results[-10:]:
print(f"{r['class_name']:<20} {r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} "
f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f}")
# Analyze correlations
accuracies = [r['accuracy'] for r in class_results]
geo_aligns = [r['geometric_alignment'] for r in class_results]
crystal_norms = [r['crystal_norm'] for r in class_results]
vertex_vars = [r['vertex_variance'] for r in class_results]
print(f"\nCorrelations with Accuracy:")
print(f" Geometric Alignment: {np.corrcoef(accuracies, geo_aligns)[0,1]:.3f}")
print(f" Crystal Norm: {np.corrcoef(accuracies, crystal_norms)[0,1]:.3f}")
print(f" Vertex Variance: {np.corrcoef(accuracies, vertex_vars)[0,1]:.3f}")
# Visualizations
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# 1. Accuracy distribution
ax = axes[0, 0]
ax.hist(accuracies, bins=20, edgecolor='black', alpha=0.7)
ax.axvline(overall_acc, color='red', linestyle='--', label=f'Overall: {overall_acc:.1f}%')
ax.set_xlabel('Accuracy (%)')
ax.set_ylabel('Count')
ax.set_title('Per-Class Accuracy Distribution')
ax.legend()
ax.grid(True, alpha=0.3)
# 2. Accuracy vs Geometric Alignment
ax = axes[0, 1]
scatter = ax.scatter(geo_aligns, accuracies, c=crystal_norms, cmap='viridis', alpha=0.6)
ax.set_xlabel('Geometric Alignment Score')
ax.set_ylabel('Accuracy (%)')
ax.set_title('Accuracy vs Geometric Alignment\n(color = crystal norm)')
plt.colorbar(scatter, ax=ax)
ax.grid(True, alpha=0.3)
# 3. Crystal Analysis
ax = axes[1, 0]
ax.scatter(crystal_norms, accuracies, alpha=0.6)
ax.set_xlabel('Crystal Norm (avg magnitude)')
ax.set_ylabel('Accuracy (%)')
ax.set_title('Accuracy vs Crystal Norm')
ax.grid(True, alpha=0.3)
# 4. Top/Bottom comparison
ax = axes[1, 1]
top10_acc = [r['accuracy'] for r in class_results[:10]]
bottom10_acc = [r['accuracy'] for r in class_results[-10:]]
top10_geo = [r['geometric_alignment'] for r in class_results[:10]]
bottom10_geo = [r['geometric_alignment'] for r in class_results[-10:]]
x = np.arange(10)
width = 0.35
ax.bar(x - width/2, top10_acc, width, label='Top 10 Accuracy', color='green', alpha=0.7)
ax.bar(x + width/2, bottom10_acc, width, label='Bottom 10 Accuracy', color='red', alpha=0.7)
ax.set_xlabel('Rank within group')
ax.set_ylabel('Accuracy (%)')
ax.set_title('Top 10 vs Bottom 10 Classes')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# ===================================================================================
# FULL 100-CLASS DIAGNOSTIC SPECTRUM (SORTED BY CLASS IDX FOR CONSISTENCY)
# ===================================================================================
print(f"\n{'='*90}")
print("Sparky — Full Class Spectrum")
print(f"{'='*90}")
print(f"{'Idx':<5} {'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12} {'Variance':<10}")
print("-" * 90)
for r in sorted(class_results, key=lambda x: x['class_idx']):
print(f"{r['class_idx']:<5} {r['class_name']:<20} "
f"{r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} "
f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f} "
f"{r['vertex_variance']:>9.8f}")
return class_results, overall_acc
# Run evaluation
if 'model' in globals():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
_, test_loader = get_cifar100_dataloaders(batch_size=100)
results, overall = evaluate_pentachora_vit(model, test_loader, device)
# Additional crystal analysis
print("\nCrystal Geometry Analysis:")
print("-"*50)
# Get crystals
crystals = model.cls_tokens.class_pentachora.detach().cpu()
# Compute pairwise similarities between class crystals
crystals_flat = crystals.mean(dim=1) # Average over 5 vertices
crystals_norm = F.normalize(crystals_flat, dim=1)
similarities = torch.matmul(crystals_norm, crystals_norm.T)
# Find confused pairs (high similarity, both low accuracy)
print("\nMost similar classes with poor performance:")
for i in range(100):
for j in range(i+1, 100):
if results[i]['accuracy'] < 20 and results[j]['accuracy'] < 20:
sim = similarities[results[i]['class_idx'], results[j]['class_idx']].item()
if sim > 0.5:
print(f" {results[i]['class_name']:<15} ({results[i]['accuracy']:.1f}%) ↔ "
f"{results[j]['class_name']:<15} ({results[j]['accuracy']:.1f}%) : {sim:.3f}")
else:
print("No model found in memory!")