resnet50-imagenet-1k / test_model.py
argo
Added json
57e3814
"""
Test script to verify ResNet-50 model architecture and predictions
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import requests
from io import BytesIO
import json
# Import model from app.py
import sys
sys.path.append('.')
print("=" * 60)
print("ResNet-50 ImageNet Model Test")
print("=" * 60)
# Load model architecture from app.py
from app import ResNet50, transform, IMAGENET_CLASSES
# Create model
model = ResNet50(num_classes=1000)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# Test 1: Model architecture
print("\nβœ“ Test 1: Model Architecture")
print(f" - Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f" - Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
print(f" - Device: {device}")
# Expected parameters for ResNet-50: ~25.6M
total_params = sum(p.numel() for p in model.parameters())
expected_params = 25_557_032 # Standard ResNet-50
if abs(total_params - expected_params) < 100000: # Allow some variance
print(f" βœ“ Parameter count matches expected ResNet-50 (~25.6M)")
else:
print(f" ⚠ Parameter count differs from standard ResNet-50")
print(f" Expected: ~{expected_params:,}, Got: {total_params:,}")
# Test 2: Forward pass with dummy input
print("\nβœ“ Test 2: Forward Pass")
dummy_input = torch.randn(1, 3, 224, 224).to(device)
with torch.no_grad():
output = model(dummy_input)
print(f" - Input shape: {dummy_input.shape}")
print(f" - Output shape: {output.shape}")
print(f" - Expected output shape: torch.Size([1, 1000])")
assert output.shape == torch.Size([1, 1000]), "Output shape mismatch!"
print(f" βœ“ Output shape is correct")
# Test 3: Test with a sample image (if available online)
print("\nβœ“ Test 3: Sample Prediction")
try:
# Download a sample image (Golden Retriever)
url = "https://images.unsplash.com/photo-1633722715463-d30f4f325e24?w=400"
response = requests.get(url, timeout=10)
img = Image.open(BytesIO(response.content)).convert('RGB')
# Preprocess
img_tensor = transform(img).unsqueeze(0).to(device)
# Predict
with torch.no_grad():
outputs = model(img_tensor)
probabilities = F.softmax(outputs, dim=1)[0]
# Get top 5
top5_prob, top5_idx = torch.topk(probabilities, 5)
print(f" - Image size: {img.size}")
print(f" - Top 5 predictions:")
for i, (idx, prob) in enumerate(zip(top5_idx, top5_prob), 1):
class_name = IMAGENET_CLASSES.get(str(idx.item()), f"Class {idx.item()}")
print(f" {i}. {class_name}: {prob.item()*100:.2f}%")
print(f"\n ⚠ Note: Model is randomly initialized, so predictions are random")
print(f" Load a trained checkpoint to get meaningful predictions")
except Exception as e:
print(f" ⚠ Could not test with sample image: {e}")
print(f" This is normal if you don't have internet connection")
# Test 4: Check ImageNet classes
print("\nβœ“ Test 4: ImageNet Classes")
print(f" - Total classes defined: {len(IMAGENET_CLASSES)}")
print(f" - Expected: 1000 classes")
if len(IMAGENET_CLASSES) >= 1000:
print(f" βœ“ Class definitions available")
print(f" - Sample classes:")
for idx in [0, 285, 341, 388, 980]:
if str(idx) in IMAGENET_CLASSES:
print(f" {idx}: {IMAGENET_CLASSES[str(idx)]}")
else:
print(f" ⚠ Warning: Only {len(IMAGENET_CLASSES)} classes defined")
# Test 5: Model layers
print("\nβœ“ Test 5: Model Structure")
print(f" - Layer 1 blocks: {len(model.layer1)}")
print(f" - Layer 2 blocks: {len(model.layer2)}")
print(f" - Layer 3 blocks: {len(model.layer3)}")
print(f" - Layer 4 blocks: {len(model.layer4)}")
expected_blocks = [3, 4, 6, 3]
actual_blocks = [len(model.layer1), len(model.layer2), len(model.layer3), len(model.layer4)]
if actual_blocks == expected_blocks:
print(f" βœ“ ResNet-50 architecture confirmed: {actual_blocks}")
else:
print(f" ⚠ Warning: Expected {expected_blocks}, got {actual_blocks}")
print("\n" + "=" * 60)
print("βœ“ All tests completed!")
print("=" * 60)
print("\nNext steps:")
print("1. Train your model using the training code from:")
print(" https://github.com/arghyaiitb/assignment_9")
print("2. Save the best checkpoint as 'best_model.pt'")
print("3. Place the checkpoint in this directory")
print("4. Run: python app.py")
print("5. Upload to Hugging Face Spaces")
print("=" * 60)