Spaces:
Sleeping
Sleeping
File size: 5,317 Bytes
70a26de eb707d4 a2b32c9 eb707d4 70a26de 63106a4 70a26de 63106a4 70a26de eb707d4 70a26de eb707d4 70a26de eb707d4 70a26de eb707d4 70a26de eb707d4 70a26de eb707d4 70a26de eb707d4 70a26de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
from PIL import Image
import json
import os
from models import ResNet50
# ImageNet-1k class names from HuggingFace
# Source: https://huggingface.co/datasets/huggingface/label-files/blob/main/imagenet-1k-id2label.json
if os.path.exists('imagenet_classes.json'):
with open('imagenet_classes.json', 'r') as f:
IMAGENET_CLASSES = json.load(f)
else:
# Fallback: download if not present
import urllib.request
print("Downloading ImageNet class labels...")
url = "https://huggingface.co/datasets/huggingface/label-files/raw/main/imagenet-1k-id2label.json"
with urllib.request.urlopen(url) as response:
IMAGENET_CLASSES = json.loads(response.read().decode())
with open('imagenet_classes.json', 'w') as f:
json.dump(IMAGENET_CLASSES, f, indent=2)
print("ImageNet class labels downloaded successfully!")
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet50(num_classes=1000)
# Load trained weights
try:
checkpoint = torch.load("best_model.pt", map_location=device)
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
# Remove _orig_mod. prefix if present (from torch.compile)
state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
print(f"Model loaded successfully! Top-1 accuracy: {checkpoint.get('top1_accuracy', 'N/A'):.2f}%")
print(f"Top-5 accuracy: {checkpoint.get('top5_accuracy', 'N/A'):.2f}%")
else:
state_dict = {k.replace('_orig_mod.', ''): v for k, v in checkpoint.items()}
model.load_state_dict(state_dict)
print("Model loaded successfully!")
except Exception as e:
print(f"Warning: Could not load model weights: {e}")
print("Using randomly initialized model for demo purposes.")
model.to(device)
model.eval()
# ImageNet preprocessing
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def predict(image):
"""Predict the class of the input image"""
if image is None:
return {"Error": "No image provided"}
try:
# Convert to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'), 'RGB')
# Ensure RGB mode
if image.mode != 'RGB':
image = image.convert('RGB')
# Preprocess image
img_tensor = transform(image).unsqueeze(0).to(device)
# Make prediction
with torch.no_grad():
outputs = model(img_tensor)
probabilities = F.softmax(outputs, dim=1)[0]
# Get top 5 predictions
top5_prob, top5_idx = torch.topk(probabilities, 5)
# Format results as a dictionary
results = {}
for i, (idx, prob) in enumerate(zip(top5_idx, top5_prob), 1):
class_idx = idx.item()
class_name = IMAGENET_CLASSES.get(str(class_idx), f"Class {class_idx}")
results[f"{i}. {class_name}"] = f"{float(prob.item()) * 100:.2f}%"
return results
except Exception as e:
return {"Error": str(e)}
# Create Gradio interface
title = "ResNet-50 ImageNet-1k Classifier"
description = """
Upload an image to classify it into one of **1000 ImageNet categories**.
This model is a **ResNet-50** trained on the ImageNet-1k dataset with modern optimization techniques:
- **Architecture**: ResNet-50 with Bottleneck blocks [3, 4, 6, 3]
- **Parameters**: ~25.6M trainable parameters
- **Training Optimizations**:
- Progressive resizing (128→160→192→224px)
- CutMix and MixUp augmentation
- Label smoothing (0.1)
- Exponential Moving Average (EMA)
- Automatic Mixed Precision (AMP)
- PyTorch 2.0 compilation
- FFCV high-performance data loading
- **Target Accuracy**: 78%+ (Top-1), 94%+ (Top-5)
- **Training Time**: ~90 minutes on 8x A100 GPUs
**Class labels** are from the official [HuggingFace ImageNet-1k dataset](https://huggingface.co/datasets/huggingface/label-files/blob/main/imagenet-1k-id2label.json).
The model works best with natural images containing objects, animals, or scenes from the ImageNet categories.
**Training code**: [github.com/arghyaiitb/assignment_9](https://github.com/arghyaiitb/assignment_9)
"""
# Example images for demonstration
examples = [
"https://images.unsplash.com/photo-1543466835-00a7907e9de1?w=400", # Golden Retriever
"https://images.unsplash.com/photo-1514888286974-6c03e2ca1dba?w=400", # Tabby Cat
"https://images.unsplash.com/photo-1511367461989-f85a21fda167?w=400", # Granny Smith Apple
]
# Create the interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.JSON(label="Top 5 Predictions"),
title=title,
description=description,
examples=examples,
theme=gr.themes.Soft(),
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
|