Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| import numpy as np | |
| from PIL import Image | |
| import json | |
| import os | |
| from models import ResNet50 | |
| # ImageNet-1k class names from HuggingFace | |
| # Source: https://huggingface.co/datasets/huggingface/label-files/blob/main/imagenet-1k-id2label.json | |
| if os.path.exists('imagenet_classes.json'): | |
| with open('imagenet_classes.json', 'r') as f: | |
| IMAGENET_CLASSES = json.load(f) | |
| else: | |
| # Fallback: download if not present | |
| import urllib.request | |
| print("Downloading ImageNet class labels...") | |
| url = "https://huggingface.co/datasets/huggingface/label-files/raw/main/imagenet-1k-id2label.json" | |
| with urllib.request.urlopen(url) as response: | |
| IMAGENET_CLASSES = json.loads(response.read().decode()) | |
| with open('imagenet_classes.json', 'w') as f: | |
| json.dump(IMAGENET_CLASSES, f, indent=2) | |
| print("ImageNet class labels downloaded successfully!") | |
| # Load model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = ResNet50(num_classes=1000) | |
| # Load trained weights | |
| try: | |
| checkpoint = torch.load("best_model.pt", map_location=device) | |
| if 'model_state_dict' in checkpoint: | |
| state_dict = checkpoint['model_state_dict'] | |
| # Remove _orig_mod. prefix if present (from torch.compile) | |
| state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()} | |
| model.load_state_dict(state_dict) | |
| print(f"Model loaded successfully! Top-1 accuracy: {checkpoint.get('top1_accuracy', 'N/A'):.2f}%") | |
| print(f"Top-5 accuracy: {checkpoint.get('top5_accuracy', 'N/A'):.2f}%") | |
| else: | |
| state_dict = {k.replace('_orig_mod.', ''): v for k, v in checkpoint.items()} | |
| model.load_state_dict(state_dict) | |
| print("Model loaded successfully!") | |
| except Exception as e: | |
| print(f"Warning: Could not load model weights: {e}") | |
| print("Using randomly initialized model for demo purposes.") | |
| model.to(device) | |
| model.eval() | |
| # ImageNet preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def predict(image): | |
| """Predict the class of the input image""" | |
| if image is None: | |
| return {"Error": "No image provided"} | |
| try: | |
| # Convert to PIL Image if needed | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image.astype('uint8'), 'RGB') | |
| # Ensure RGB mode | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Preprocess image | |
| img_tensor = transform(image).unsqueeze(0).to(device) | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = model(img_tensor) | |
| probabilities = F.softmax(outputs, dim=1)[0] | |
| # Get top 5 predictions | |
| top5_prob, top5_idx = torch.topk(probabilities, 5) | |
| # Format results as a dictionary | |
| results = {} | |
| for i, (idx, prob) in enumerate(zip(top5_idx, top5_prob), 1): | |
| class_idx = idx.item() | |
| class_name = IMAGENET_CLASSES.get(str(class_idx), f"Class {class_idx}") | |
| results[f"{i}. {class_name}"] = f"{float(prob.item()) * 100:.2f}%" | |
| return results | |
| except Exception as e: | |
| return {"Error": str(e)} | |
| # Create Gradio interface | |
| title = "ResNet-50 ImageNet-1k Classifier" | |
| description = """ | |
| Upload an image to classify it into one of **1000 ImageNet categories**. | |
| This model is a **ResNet-50** trained on the ImageNet-1k dataset with modern optimization techniques: | |
| - **Architecture**: ResNet-50 with Bottleneck blocks [3, 4, 6, 3] | |
| - **Parameters**: ~25.6M trainable parameters | |
| - **Training Optimizations**: | |
| - Progressive resizing (128→160→192→224px) | |
| - CutMix and MixUp augmentation | |
| - Label smoothing (0.1) | |
| - Exponential Moving Average (EMA) | |
| - Automatic Mixed Precision (AMP) | |
| - PyTorch 2.0 compilation | |
| - FFCV high-performance data loading | |
| - **Target Accuracy**: 78%+ (Top-1), 94%+ (Top-5) | |
| - **Training Time**: ~90 minutes on 8x A100 GPUs | |
| **Class labels** are from the official [HuggingFace ImageNet-1k dataset](https://huggingface.co/datasets/huggingface/label-files/blob/main/imagenet-1k-id2label.json). | |
| The model works best with natural images containing objects, animals, or scenes from the ImageNet categories. | |
| **Training code**: [github.com/arghyaiitb/assignment_9](https://github.com/arghyaiitb/assignment_9) | |
| """ | |
| # Example images for demonstration | |
| examples = [ | |
| "https://images.unsplash.com/photo-1543466835-00a7907e9de1?w=400", # Golden Retriever | |
| "https://images.unsplash.com/photo-1514888286974-6c03e2ca1dba?w=400", # Tabby Cat | |
| "https://images.unsplash.com/photo-1511367461989-f85a21fda167?w=400", # Granny Smith Apple | |
| ] | |
| # Create the interface | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload Image"), | |
| outputs=gr.JSON(label="Top 5 Predictions"), | |
| title=title, | |
| description=description, | |
| examples=examples, | |
| theme=gr.themes.Soft(), | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) | |