|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from collections import defaultdict |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
def evaluate_pentachora_vit(model, test_loader, device='cuda'): |
|
|
"""Properly evaluate PentachoraViT model.""" |
|
|
model.eval() |
|
|
|
|
|
|
|
|
class_names = get_cifar100_class_names() |
|
|
|
|
|
|
|
|
print(f"Model Configuration:") |
|
|
print(f" Internal dim: {model.dim}") |
|
|
print(f" Vocab dim: {model.vocab_dim}") |
|
|
print(f" Num classes: {model.num_classes}") |
|
|
|
|
|
|
|
|
if hasattr(model, 'cls_tokens') and hasattr(model.cls_tokens, 'class_pentachora'): |
|
|
crystals = model.cls_tokens.class_pentachora |
|
|
print(f" Crystal shape: {crystals.shape}") |
|
|
else: |
|
|
print(" No crystals found!") |
|
|
return None |
|
|
|
|
|
|
|
|
all_predictions = [] |
|
|
all_targets = [] |
|
|
all_confidences = [] |
|
|
geometric_alignments_by_class = defaultdict(list) |
|
|
aux_predictions = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for images, targets in tqdm(test_loader, desc="Evaluating"): |
|
|
images = images.to(device) |
|
|
targets = targets.to(device) |
|
|
|
|
|
|
|
|
outputs = model(images) |
|
|
|
|
|
|
|
|
logits = outputs['logits'] |
|
|
probs = F.softmax(logits, dim=1) |
|
|
confidence, predicted = torch.max(probs, 1) |
|
|
|
|
|
|
|
|
all_predictions.extend(predicted.cpu().numpy()) |
|
|
all_targets.extend(targets.cpu().numpy()) |
|
|
all_confidences.extend(confidence.cpu().numpy()) |
|
|
|
|
|
|
|
|
if 'aux_logits' in outputs: |
|
|
aux_probs = F.softmax(outputs['aux_logits'], dim=1) |
|
|
_, aux_pred = torch.max(aux_probs, 1) |
|
|
aux_predictions.extend(aux_pred.cpu().numpy()) |
|
|
|
|
|
|
|
|
if 'geometric_alignments' in outputs: |
|
|
|
|
|
geo_align = outputs['geometric_alignments'] |
|
|
|
|
|
geo_align_mean = geo_align.mean(dim=1) |
|
|
|
|
|
for i, target_class in enumerate(targets): |
|
|
class_idx = target_class.item() |
|
|
|
|
|
geometric_alignments_by_class[class_idx].append( |
|
|
geo_align_mean[i, class_idx].item() |
|
|
) |
|
|
|
|
|
|
|
|
all_predictions = np.array(all_predictions) |
|
|
all_targets = np.array(all_targets) |
|
|
all_confidences = np.array(all_confidences) |
|
|
|
|
|
|
|
|
class_results = [] |
|
|
for class_idx in range(len(class_names)): |
|
|
mask = all_targets == class_idx |
|
|
if mask.sum() == 0: |
|
|
continue |
|
|
|
|
|
class_preds = all_predictions[mask] |
|
|
correct = (class_preds == class_idx).sum() |
|
|
total = mask.sum() |
|
|
accuracy = 100.0 * correct / total |
|
|
|
|
|
|
|
|
class_conf = all_confidences[mask].mean() |
|
|
|
|
|
|
|
|
geo_align = np.mean(geometric_alignments_by_class[class_idx]) if geometric_alignments_by_class[class_idx] else 0 |
|
|
|
|
|
|
|
|
class_crystal = crystals[class_idx].detach().cpu() |
|
|
vertex_variance = class_crystal.var(dim=0).mean().item() |
|
|
|
|
|
|
|
|
crystal_norm = class_crystal.norm(dim=-1).mean().item() |
|
|
|
|
|
class_results.append({ |
|
|
'class_idx': class_idx, |
|
|
'class_name': class_names[class_idx], |
|
|
'accuracy': accuracy, |
|
|
'correct': int(correct), |
|
|
'total': int(total), |
|
|
'avg_confidence': class_conf, |
|
|
'geometric_alignment': geo_align, |
|
|
'vertex_variance': vertex_variance, |
|
|
'crystal_norm': crystal_norm |
|
|
}) |
|
|
|
|
|
|
|
|
class_results.sort(key=lambda x: x['accuracy'], reverse=True) |
|
|
|
|
|
|
|
|
overall_acc = 100.0 * (all_predictions == all_targets).mean() |
|
|
|
|
|
|
|
|
aux_acc = None |
|
|
if aux_predictions: |
|
|
aux_predictions = np.array(aux_predictions) |
|
|
aux_acc = 100.0 * (aux_predictions == all_targets).mean() |
|
|
|
|
|
|
|
|
print(f"\n" + "="*80) |
|
|
print(f"EVALUATION RESULTS") |
|
|
print(f"="*80) |
|
|
print(f"\nOverall Accuracy: {overall_acc:.2f}%") |
|
|
if aux_acc: |
|
|
print(f"Auxiliary Head Accuracy: {aux_acc:.2f}%") |
|
|
|
|
|
|
|
|
print(f"\nTop 10 Classes:") |
|
|
print(f"{'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12}") |
|
|
print("-"*70) |
|
|
for r in class_results[:10]: |
|
|
print(f"{r['class_name']:<20} {r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} " |
|
|
f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f}") |
|
|
|
|
|
|
|
|
print(f"\nBottom 10 Classes:") |
|
|
print(f"{'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12}") |
|
|
print("-"*70) |
|
|
for r in class_results[-10:]: |
|
|
print(f"{r['class_name']:<20} {r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} " |
|
|
f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f}") |
|
|
|
|
|
|
|
|
accuracies = [r['accuracy'] for r in class_results] |
|
|
geo_aligns = [r['geometric_alignment'] for r in class_results] |
|
|
crystal_norms = [r['crystal_norm'] for r in class_results] |
|
|
vertex_vars = [r['vertex_variance'] for r in class_results] |
|
|
|
|
|
print(f"\nCorrelations with Accuracy:") |
|
|
print(f" Geometric Alignment: {np.corrcoef(accuracies, geo_aligns)[0,1]:.3f}") |
|
|
print(f" Crystal Norm: {np.corrcoef(accuracies, crystal_norms)[0,1]:.3f}") |
|
|
print(f" Vertex Variance: {np.corrcoef(accuracies, vertex_vars)[0,1]:.3f}") |
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(2, 2, figsize=(12, 10)) |
|
|
|
|
|
|
|
|
ax = axes[0, 0] |
|
|
ax.hist(accuracies, bins=20, edgecolor='black', alpha=0.7) |
|
|
ax.axvline(overall_acc, color='red', linestyle='--', label=f'Overall: {overall_acc:.1f}%') |
|
|
ax.set_xlabel('Accuracy (%)') |
|
|
ax.set_ylabel('Count') |
|
|
ax.set_title('Per-Class Accuracy Distribution') |
|
|
ax.legend() |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
ax = axes[0, 1] |
|
|
scatter = ax.scatter(geo_aligns, accuracies, c=crystal_norms, cmap='viridis', alpha=0.6) |
|
|
ax.set_xlabel('Geometric Alignment Score') |
|
|
ax.set_ylabel('Accuracy (%)') |
|
|
ax.set_title('Accuracy vs Geometric Alignment\n(color = crystal norm)') |
|
|
plt.colorbar(scatter, ax=ax) |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
ax = axes[1, 0] |
|
|
ax.scatter(crystal_norms, accuracies, alpha=0.6) |
|
|
ax.set_xlabel('Crystal Norm (avg magnitude)') |
|
|
ax.set_ylabel('Accuracy (%)') |
|
|
ax.set_title('Accuracy vs Crystal Norm') |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
ax = axes[1, 1] |
|
|
top10_acc = [r['accuracy'] for r in class_results[:10]] |
|
|
bottom10_acc = [r['accuracy'] for r in class_results[-10:]] |
|
|
top10_geo = [r['geometric_alignment'] for r in class_results[:10]] |
|
|
bottom10_geo = [r['geometric_alignment'] for r in class_results[-10:]] |
|
|
|
|
|
x = np.arange(10) |
|
|
width = 0.35 |
|
|
ax.bar(x - width/2, top10_acc, width, label='Top 10 Accuracy', color='green', alpha=0.7) |
|
|
ax.bar(x + width/2, bottom10_acc, width, label='Bottom 10 Accuracy', color='red', alpha=0.7) |
|
|
ax.set_xlabel('Rank within group') |
|
|
ax.set_ylabel('Accuracy (%)') |
|
|
ax.set_title('Top 10 vs Bottom 10 Classes') |
|
|
ax.legend() |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
|
|
|
|
|
|
|
|
|
print(f"\n{'='*90}") |
|
|
print("Sparky — Full Class Spectrum") |
|
|
print(f"{'='*90}") |
|
|
print(f"{'Idx':<5} {'Class':<20} {'Acc%':<8} {'Conf':<8} {'GeoAlign':<10} {'CrystalNorm':<12} {'Variance':<10}") |
|
|
print("-" * 90) |
|
|
|
|
|
for r in sorted(class_results, key=lambda x: x['class_idx']): |
|
|
print(f"{r['class_idx']:<5} {r['class_name']:<20} " |
|
|
f"{r['accuracy']:>6.1f} {r['avg_confidence']:>7.3f} " |
|
|
f"{r['geometric_alignment']:>9.3f} {r['crystal_norm']:>11.3f} " |
|
|
f"{r['vertex_variance']:>9.8f}") |
|
|
|
|
|
return class_results, overall_acc |
|
|
|
|
|
|
|
|
if 'model' in globals(): |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
_, test_loader = get_cifar100_dataloaders(batch_size=100) |
|
|
|
|
|
results, overall = evaluate_pentachora_vit(model, test_loader, device) |
|
|
|
|
|
|
|
|
print("\nCrystal Geometry Analysis:") |
|
|
print("-"*50) |
|
|
|
|
|
|
|
|
crystals = model.cls_tokens.class_pentachora.detach().cpu() |
|
|
|
|
|
|
|
|
crystals_flat = crystals.mean(dim=1) |
|
|
crystals_norm = F.normalize(crystals_flat, dim=1) |
|
|
similarities = torch.matmul(crystals_norm, crystals_norm.T) |
|
|
|
|
|
|
|
|
print("\nMost similar classes with poor performance:") |
|
|
for i in range(100): |
|
|
for j in range(i+1, 100): |
|
|
if results[i]['accuracy'] < 20 and results[j]['accuracy'] < 20: |
|
|
sim = similarities[results[i]['class_idx'], results[j]['class_idx']].item() |
|
|
if sim > 0.5: |
|
|
print(f" {results[i]['class_name']:<15} ({results[i]['accuracy']:.1f}%) ↔ " |
|
|
f"{results[j]['class_name']:<15} ({results[j]['accuracy']:.1f}%) : {sim:.3f}") |
|
|
|
|
|
else: |
|
|
print("No model found in memory!") |