arghyaiitb's picture
fixed the model code
a2b32c9
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
from PIL import Image
import json
import os
from models import ResNet50
# ImageNet-1k class names from HuggingFace
# Source: https://huggingface.co/datasets/huggingface/label-files/blob/main/imagenet-1k-id2label.json
if os.path.exists('imagenet_classes.json'):
with open('imagenet_classes.json', 'r') as f:
IMAGENET_CLASSES = json.load(f)
else:
# Fallback: download if not present
import urllib.request
print("Downloading ImageNet class labels...")
url = "https://huggingface.co/datasets/huggingface/label-files/raw/main/imagenet-1k-id2label.json"
with urllib.request.urlopen(url) as response:
IMAGENET_CLASSES = json.loads(response.read().decode())
with open('imagenet_classes.json', 'w') as f:
json.dump(IMAGENET_CLASSES, f, indent=2)
print("ImageNet class labels downloaded successfully!")
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet50(num_classes=1000)
# Load trained weights
try:
checkpoint = torch.load("best_model.pt", map_location=device)
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
# Remove _orig_mod. prefix if present (from torch.compile)
state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
print(f"Model loaded successfully! Top-1 accuracy: {checkpoint.get('top1_accuracy', 'N/A'):.2f}%")
print(f"Top-5 accuracy: {checkpoint.get('top5_accuracy', 'N/A'):.2f}%")
else:
state_dict = {k.replace('_orig_mod.', ''): v for k, v in checkpoint.items()}
model.load_state_dict(state_dict)
print("Model loaded successfully!")
except Exception as e:
print(f"Warning: Could not load model weights: {e}")
print("Using randomly initialized model for demo purposes.")
model.to(device)
model.eval()
# ImageNet preprocessing
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def predict(image):
"""Predict the class of the input image"""
if image is None:
return {"Error": "No image provided"}
try:
# Convert to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'), 'RGB')
# Ensure RGB mode
if image.mode != 'RGB':
image = image.convert('RGB')
# Preprocess image
img_tensor = transform(image).unsqueeze(0).to(device)
# Make prediction
with torch.no_grad():
outputs = model(img_tensor)
probabilities = F.softmax(outputs, dim=1)[0]
# Get top 5 predictions
top5_prob, top5_idx = torch.topk(probabilities, 5)
# Format results as a dictionary
results = {}
for i, (idx, prob) in enumerate(zip(top5_idx, top5_prob), 1):
class_idx = idx.item()
class_name = IMAGENET_CLASSES.get(str(class_idx), f"Class {class_idx}")
results[f"{i}. {class_name}"] = f"{float(prob.item()) * 100:.2f}%"
return results
except Exception as e:
return {"Error": str(e)}
# Create Gradio interface
title = "ResNet-50 ImageNet-1k Classifier"
description = """
Upload an image to classify it into one of **1000 ImageNet categories**.
This model is a **ResNet-50** trained on the ImageNet-1k dataset with modern optimization techniques:
- **Architecture**: ResNet-50 with Bottleneck blocks [3, 4, 6, 3]
- **Parameters**: ~25.6M trainable parameters
- **Training Optimizations**:
- Progressive resizing (128→160→192→224px)
- CutMix and MixUp augmentation
- Label smoothing (0.1)
- Exponential Moving Average (EMA)
- Automatic Mixed Precision (AMP)
- PyTorch 2.0 compilation
- FFCV high-performance data loading
- **Target Accuracy**: 78%+ (Top-1), 94%+ (Top-5)
- **Training Time**: ~90 minutes on 8x A100 GPUs
**Class labels** are from the official [HuggingFace ImageNet-1k dataset](https://huggingface.co/datasets/huggingface/label-files/blob/main/imagenet-1k-id2label.json).
The model works best with natural images containing objects, animals, or scenes from the ImageNet categories.
**Training code**: [github.com/arghyaiitb/assignment_9](https://github.com/arghyaiitb/assignment_9)
"""
# Example images for demonstration
examples = [
"https://images.unsplash.com/photo-1543466835-00a7907e9de1?w=400", # Golden Retriever
"https://images.unsplash.com/photo-1514888286974-6c03e2ca1dba?w=400", # Tabby Cat
"https://images.unsplash.com/photo-1511367461989-f85a21fda167?w=400", # Granny Smith Apple
]
# Create the interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.JSON(label="Top 5 Predictions"),
title=title,
description=description,
examples=examples,
theme=gr.themes.Soft(),
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)