File size: 5,317 Bytes
70a26de
 
 
 
 
 
 
eb707d4
a2b32c9
eb707d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70a26de
 
 
 
 
 
 
 
 
 
63106a4
 
 
 
70a26de
 
 
63106a4
 
70a26de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb707d4
70a26de
 
 
eb707d4
70a26de
 
 
 
 
 
 
eb707d4
70a26de
 
 
eb707d4
 
70a26de
eb707d4
 
70a26de
 
eb707d4
70a26de
eb707d4
 
 
70a26de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
from PIL import Image
import json
import os
from models import ResNet50

# ImageNet-1k class names from HuggingFace
# Source: https://huggingface.co/datasets/huggingface/label-files/blob/main/imagenet-1k-id2label.json
if os.path.exists('imagenet_classes.json'):
    with open('imagenet_classes.json', 'r') as f:
        IMAGENET_CLASSES = json.load(f)
else:
    # Fallback: download if not present
    import urllib.request
    print("Downloading ImageNet class labels...")
    url = "https://huggingface.co/datasets/huggingface/label-files/raw/main/imagenet-1k-id2label.json"
    with urllib.request.urlopen(url) as response:
        IMAGENET_CLASSES = json.loads(response.read().decode())
    with open('imagenet_classes.json', 'w') as f:
        json.dump(IMAGENET_CLASSES, f, indent=2)
    print("ImageNet class labels downloaded successfully!")


# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet50(num_classes=1000)

# Load trained weights
try:
    checkpoint = torch.load("best_model.pt", map_location=device)
    if 'model_state_dict' in checkpoint:
        state_dict = checkpoint['model_state_dict']
        # Remove _orig_mod. prefix if present (from torch.compile)
        state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
        model.load_state_dict(state_dict)
        print(f"Model loaded successfully! Top-1 accuracy: {checkpoint.get('top1_accuracy', 'N/A'):.2f}%")
        print(f"Top-5 accuracy: {checkpoint.get('top5_accuracy', 'N/A'):.2f}%")
    else:
        state_dict = {k.replace('_orig_mod.', ''): v for k, v in checkpoint.items()}
        model.load_state_dict(state_dict)
        print("Model loaded successfully!")
except Exception as e:
    print(f"Warning: Could not load model weights: {e}")
    print("Using randomly initialized model for demo purposes.")

model.to(device)
model.eval()

# ImageNet preprocessing
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                       std=[0.229, 0.224, 0.225])
])


def predict(image):
    """Predict the class of the input image"""
    if image is None:
        return {"Error": "No image provided"}

    try:
        # Convert to PIL Image if needed
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image.astype('uint8'), 'RGB')
        
        # Ensure RGB mode
        if image.mode != 'RGB':
            image = image.convert('RGB')

        # Preprocess image
        img_tensor = transform(image).unsqueeze(0).to(device)

        # Make prediction
        with torch.no_grad():
            outputs = model(img_tensor)
            probabilities = F.softmax(outputs, dim=1)[0]

        # Get top 5 predictions
        top5_prob, top5_idx = torch.topk(probabilities, 5)

        # Format results as a dictionary
        results = {}
        for i, (idx, prob) in enumerate(zip(top5_idx, top5_prob), 1):
            class_idx = idx.item()
            class_name = IMAGENET_CLASSES.get(str(class_idx), f"Class {class_idx}")
            results[f"{i}. {class_name}"] = f"{float(prob.item()) * 100:.2f}%"

        return results
    
    except Exception as e:
        return {"Error": str(e)}


# Create Gradio interface
title = "ResNet-50 ImageNet-1k Classifier"

description = """
Upload an image to classify it into one of **1000 ImageNet categories**.

This model is a **ResNet-50** trained on the ImageNet-1k dataset with modern optimization techniques:
- **Architecture**: ResNet-50 with Bottleneck blocks [3, 4, 6, 3]
- **Parameters**: ~25.6M trainable parameters
- **Training Optimizations**: 
  - Progressive resizing (128→160→192→224px)
  - CutMix and MixUp augmentation
  - Label smoothing (0.1)
  - Exponential Moving Average (EMA)
  - Automatic Mixed Precision (AMP)
  - PyTorch 2.0 compilation
  - FFCV high-performance data loading
- **Target Accuracy**: 78%+ (Top-1), 94%+ (Top-5)
- **Training Time**: ~90 minutes on 8x A100 GPUs

**Class labels** are from the official [HuggingFace ImageNet-1k dataset](https://huggingface.co/datasets/huggingface/label-files/blob/main/imagenet-1k-id2label.json).

The model works best with natural images containing objects, animals, or scenes from the ImageNet categories.

**Training code**: [github.com/arghyaiitb/assignment_9](https://github.com/arghyaiitb/assignment_9)
"""

# Example images for demonstration
examples = [
    "https://images.unsplash.com/photo-1543466835-00a7907e9de1?w=400",  # Golden Retriever
    "https://images.unsplash.com/photo-1514888286974-6c03e2ca1dba?w=400",  # Tabby Cat  
    "https://images.unsplash.com/photo-1511367461989-f85a21fda167?w=400",  # Granny Smith Apple
]

# Create the interface
demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload Image"),
    outputs=gr.JSON(label="Top 5 Predictions"),
    title=title,
    description=description,
    examples=examples,
    theme=gr.themes.Soft(),
    allow_flagging="never"
)

if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False
    )