argo commited on
Commit
70a26de
·
1 Parent(s): c2a28ed

Added gradio app

Browse files
Files changed (2) hide show
  1. app.py +243 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchvision import transforms
6
+ import numpy as np
7
+ from PIL import Image
8
+ import json
9
+
10
+ # ImageNet-1k class names
11
+ # We'll load these from a separate file
12
+ with open('imagenet_classes.json', 'r') as f:
13
+ IMAGENET_CLASSES = json.load(f)
14
+
15
+ # Model definition - ResNet-50 for ImageNet
16
+ class Bottleneck(nn.Module):
17
+ """Bottleneck block for ResNet-50/101/152"""
18
+ expansion = 4
19
+
20
+ def __init__(self, in_channels, out_channels, stride=1, downsample=None):
21
+ super(Bottleneck, self).__init__()
22
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
23
+ self.bn1 = nn.BatchNorm2d(out_channels)
24
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
25
+ stride=stride, padding=1, bias=False)
26
+ self.bn2 = nn.BatchNorm2d(out_channels)
27
+ self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion,
28
+ kernel_size=1, bias=False)
29
+ self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
30
+ self.downsample = downsample
31
+
32
+ def forward(self, x):
33
+ identity = x
34
+
35
+ out = self.conv1(x)
36
+ out = self.bn1(out)
37
+ out = F.relu(out)
38
+
39
+ out = self.conv2(out)
40
+ out = self.bn2(out)
41
+ out = F.relu(out)
42
+
43
+ out = self.conv3(out)
44
+ out = self.bn3(out)
45
+
46
+ if self.downsample is not None:
47
+ identity = self.downsample(x)
48
+
49
+ out += identity
50
+ out = F.relu(out)
51
+
52
+ return out
53
+
54
+
55
+ class ResNet50(nn.Module):
56
+ """ResNet-50 model for ImageNet"""
57
+
58
+ def __init__(self, num_classes=1000):
59
+ super(ResNet50, self).__init__()
60
+ self.in_channels = 64
61
+
62
+ # Initial convolution layer
63
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
64
+ self.bn1 = nn.BatchNorm2d(64)
65
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
66
+
67
+ # ResNet-50 architecture: [3, 4, 6, 3] blocks
68
+ self.layer1 = self._make_layer(64, 3, stride=1)
69
+ self.layer2 = self._make_layer(128, 4, stride=2)
70
+ self.layer3 = self._make_layer(256, 6, stride=2)
71
+ self.layer4 = self._make_layer(512, 3, stride=2)
72
+
73
+ # Final layers
74
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
75
+ self.fc = nn.Linear(512 * Bottleneck.expansion, num_classes)
76
+
77
+ # Initialize weights
78
+ self._initialize_weights()
79
+
80
+ def _make_layer(self, out_channels, blocks, stride):
81
+ """Create a residual layer with specified number of blocks"""
82
+ downsample = None
83
+ if stride != 1 or self.in_channels != out_channels * Bottleneck.expansion:
84
+ downsample = nn.Sequential(
85
+ nn.Conv2d(self.in_channels, out_channels * Bottleneck.expansion,
86
+ kernel_size=1, stride=stride, bias=False),
87
+ nn.BatchNorm2d(out_channels * Bottleneck.expansion),
88
+ )
89
+
90
+ layers = []
91
+ layers.append(Bottleneck(self.in_channels, out_channels, stride, downsample))
92
+ self.in_channels = out_channels * Bottleneck.expansion
93
+
94
+ for _ in range(1, blocks):
95
+ layers.append(Bottleneck(self.in_channels, out_channels))
96
+
97
+ return nn.Sequential(*layers)
98
+
99
+ def _initialize_weights(self):
100
+ """Initialize weights using He initialization"""
101
+ for m in self.modules():
102
+ if isinstance(m, nn.Conv2d):
103
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
104
+ elif isinstance(m, nn.BatchNorm2d):
105
+ nn.init.constant_(m.weight, 1)
106
+ nn.init.constant_(m.bias, 0)
107
+
108
+ def forward(self, x):
109
+ # Initial layers
110
+ x = self.conv1(x)
111
+ x = self.bn1(x)
112
+ x = F.relu(x)
113
+ x = self.maxpool(x)
114
+
115
+ # Residual layers
116
+ x = self.layer1(x)
117
+ x = self.layer2(x)
118
+ x = self.layer3(x)
119
+ x = self.layer4(x)
120
+
121
+ # Final layers
122
+ x = self.avgpool(x)
123
+ x = torch.flatten(x, 1)
124
+ x = self.fc(x)
125
+
126
+ return x
127
+
128
+
129
+ # Load model
130
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
131
+ model = ResNet50(num_classes=1000)
132
+
133
+ # Load trained weights
134
+ try:
135
+ checkpoint = torch.load("best_model.pt", map_location=device)
136
+ if 'model_state_dict' in checkpoint:
137
+ model.load_state_dict(checkpoint['model_state_dict'])
138
+ print(f"Model loaded successfully! Top-1 accuracy: {checkpoint.get('top1_accuracy', 'N/A'):.2f}%")
139
+ print(f"Top-5 accuracy: {checkpoint.get('top5_accuracy', 'N/A'):.2f}%")
140
+ else:
141
+ model.load_state_dict(checkpoint)
142
+ print("Model loaded successfully!")
143
+ except Exception as e:
144
+ print(f"Warning: Could not load model weights: {e}")
145
+ print("Using randomly initialized model for demo purposes.")
146
+
147
+ model.to(device)
148
+ model.eval()
149
+
150
+ # ImageNet preprocessing
151
+ transform = transforms.Compose([
152
+ transforms.Resize(256),
153
+ transforms.CenterCrop(224),
154
+ transforms.ToTensor(),
155
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
156
+ std=[0.229, 0.224, 0.225])
157
+ ])
158
+
159
+
160
+ def predict(image):
161
+ """Predict the class of the input image"""
162
+ if image is None:
163
+ return {"Error": "No image provided"}
164
+
165
+ try:
166
+ # Convert to PIL Image if needed
167
+ if isinstance(image, np.ndarray):
168
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
169
+
170
+ # Ensure RGB mode
171
+ if image.mode != 'RGB':
172
+ image = image.convert('RGB')
173
+
174
+ # Preprocess image
175
+ img_tensor = transform(image).unsqueeze(0).to(device)
176
+
177
+ # Make prediction
178
+ with torch.no_grad():
179
+ outputs = model(img_tensor)
180
+ probabilities = F.softmax(outputs, dim=1)[0]
181
+
182
+ # Get top 5 predictions
183
+ top5_prob, top5_idx = torch.topk(probabilities, 5)
184
+
185
+ # Format results as a dictionary
186
+ results = {}
187
+ for i, (idx, prob) in enumerate(zip(top5_idx, top5_prob), 1):
188
+ class_idx = idx.item()
189
+ class_name = IMAGENET_CLASSES.get(str(class_idx), f"Class {class_idx}")
190
+ results[f"{i}. {class_name}"] = f"{float(prob.item()) * 100:.2f}%"
191
+
192
+ return results
193
+
194
+ except Exception as e:
195
+ return {"Error": str(e)}
196
+
197
+
198
+ # Create Gradio interface
199
+ title = "ResNet-50 ImageNet-1k Classifier"
200
+
201
+ description = """
202
+ Upload an image to classify it into one of 1000 ImageNet categories.
203
+
204
+ This model is a **ResNet-50** trained on the ImageNet-1k dataset with modern optimization techniques:
205
+ - **Architecture**: ResNet-50 with Bottleneck blocks [3, 4, 6, 3]
206
+ - **Training Optimizations**:
207
+ - Progressive resizing (128→160→192→224px)
208
+ - CutMix and MixUp augmentation
209
+ - Label smoothing (0.1)
210
+ - Exponential Moving Average (EMA)
211
+ - Automatic Mixed Precision (AMP)
212
+ - PyTorch 2.0 compilation
213
+ - **Target Accuracy**: 78%+ (Top-1), 94%+ (Top-5)
214
+ - **Training Time**: ~90 minutes on 8x A100 GPUs
215
+
216
+ The model works best with natural images containing objects, animals, or scenes from the ImageNet categories.
217
+ """
218
+
219
+ examples = [
220
+ ["https://images.unsplash.com/photo-1543466835-00a7907e9de1?w=400", "Golden Retriever"],
221
+ ["https://images.unsplash.com/photo-1514888286974-6c03e2ca1dba?w=400", "Tabby Cat"],
222
+ ["https://images.unsplash.com/photo-1511367461989-f85a21fda167?w=400", "Granny Smith Apple"],
223
+ ]
224
+
225
+ # Create the interface
226
+ demo = gr.Interface(
227
+ fn=predict,
228
+ inputs=gr.Image(type="pil", label="Upload Image"),
229
+ outputs=gr.JSON(label="Top 5 Predictions"),
230
+ title=title,
231
+ description=description,
232
+ examples=examples,
233
+ theme=gr.themes.Soft(),
234
+ allow_flagging="never"
235
+ )
236
+
237
+ if __name__ == "__main__":
238
+ demo.launch(
239
+ server_name="0.0.0.0",
240
+ server_port=7860,
241
+ share=False
242
+ )
243
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ gradio>=5.49.1
4
+ numpy>=1.24.0
5
+ Pillow>=9.0.0
6
+ pydantic==2.10.6