Spaces:
Sleeping
Sleeping
File size: 4,503 Bytes
57e3814 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
"""
Test script to verify ResNet-50 model architecture and predictions
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import requests
from io import BytesIO
import json
# Import model from app.py
import sys
sys.path.append('.')
print("=" * 60)
print("ResNet-50 ImageNet Model Test")
print("=" * 60)
# Load model architecture from app.py
from app import ResNet50, transform, IMAGENET_CLASSES
# Create model
model = ResNet50(num_classes=1000)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# Test 1: Model architecture
print("\n✓ Test 1: Model Architecture")
print(f" - Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f" - Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
print(f" - Device: {device}")
# Expected parameters for ResNet-50: ~25.6M
total_params = sum(p.numel() for p in model.parameters())
expected_params = 25_557_032 # Standard ResNet-50
if abs(total_params - expected_params) < 100000: # Allow some variance
print(f" ✓ Parameter count matches expected ResNet-50 (~25.6M)")
else:
print(f" ⚠ Parameter count differs from standard ResNet-50")
print(f" Expected: ~{expected_params:,}, Got: {total_params:,}")
# Test 2: Forward pass with dummy input
print("\n✓ Test 2: Forward Pass")
dummy_input = torch.randn(1, 3, 224, 224).to(device)
with torch.no_grad():
output = model(dummy_input)
print(f" - Input shape: {dummy_input.shape}")
print(f" - Output shape: {output.shape}")
print(f" - Expected output shape: torch.Size([1, 1000])")
assert output.shape == torch.Size([1, 1000]), "Output shape mismatch!"
print(f" ✓ Output shape is correct")
# Test 3: Test with a sample image (if available online)
print("\n✓ Test 3: Sample Prediction")
try:
# Download a sample image (Golden Retriever)
url = "https://images.unsplash.com/photo-1633722715463-d30f4f325e24?w=400"
response = requests.get(url, timeout=10)
img = Image.open(BytesIO(response.content)).convert('RGB')
# Preprocess
img_tensor = transform(img).unsqueeze(0).to(device)
# Predict
with torch.no_grad():
outputs = model(img_tensor)
probabilities = F.softmax(outputs, dim=1)[0]
# Get top 5
top5_prob, top5_idx = torch.topk(probabilities, 5)
print(f" - Image size: {img.size}")
print(f" - Top 5 predictions:")
for i, (idx, prob) in enumerate(zip(top5_idx, top5_prob), 1):
class_name = IMAGENET_CLASSES.get(str(idx.item()), f"Class {idx.item()}")
print(f" {i}. {class_name}: {prob.item()*100:.2f}%")
print(f"\n ⚠ Note: Model is randomly initialized, so predictions are random")
print(f" Load a trained checkpoint to get meaningful predictions")
except Exception as e:
print(f" ⚠ Could not test with sample image: {e}")
print(f" This is normal if you don't have internet connection")
# Test 4: Check ImageNet classes
print("\n✓ Test 4: ImageNet Classes")
print(f" - Total classes defined: {len(IMAGENET_CLASSES)}")
print(f" - Expected: 1000 classes")
if len(IMAGENET_CLASSES) >= 1000:
print(f" ✓ Class definitions available")
print(f" - Sample classes:")
for idx in [0, 285, 341, 388, 980]:
if str(idx) in IMAGENET_CLASSES:
print(f" {idx}: {IMAGENET_CLASSES[str(idx)]}")
else:
print(f" ⚠ Warning: Only {len(IMAGENET_CLASSES)} classes defined")
# Test 5: Model layers
print("\n✓ Test 5: Model Structure")
print(f" - Layer 1 blocks: {len(model.layer1)}")
print(f" - Layer 2 blocks: {len(model.layer2)}")
print(f" - Layer 3 blocks: {len(model.layer3)}")
print(f" - Layer 4 blocks: {len(model.layer4)}")
expected_blocks = [3, 4, 6, 3]
actual_blocks = [len(model.layer1), len(model.layer2), len(model.layer3), len(model.layer4)]
if actual_blocks == expected_blocks:
print(f" ✓ ResNet-50 architecture confirmed: {actual_blocks}")
else:
print(f" ⚠ Warning: Expected {expected_blocks}, got {actual_blocks}")
print("\n" + "=" * 60)
print("✓ All tests completed!")
print("=" * 60)
print("\nNext steps:")
print("1. Train your model using the training code from:")
print(" https://github.com/arghyaiitb/assignment_9")
print("2. Save the best checkpoint as 'best_model.pt'")
print("3. Place the checkpoint in this directory")
print("4. Run: python app.py")
print("5. Upload to Hugging Face Spaces")
print("=" * 60)
|