Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						d6fe440
	
1
								Parent(s):
							
							07be122
								
Upload 3 files
Browse files- app.py +70 -0
- checpoint_epoch_4.pt +3 -0
- dog_1.jpg +0 -0
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,70 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            import numpy as np # linear algebra
         | 
| 3 | 
            +
            import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import torchvision
         | 
| 7 | 
            +
            import torch.nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
         | 
| 8 | 
            +
            import torchvision.datasets as datasets # Has standard datasets we can import in a nice way
         | 
| 9 | 
            +
            import torchvision.transforms as transforms # Transformations we can perform on our dataset
         | 
| 10 | 
            +
            import torch.nn.functional as F # All functions that don't have any parameters
         | 
| 11 | 
            +
            from torch.utils.data import DataLoader, Dataset # Gives easier dataset managment and creates mini batches
         | 
| 12 | 
            +
            from torchvision.datasets import ImageFolder
         | 
| 13 | 
            +
            import torch.optim as optim # For all Optimization algorithms, SGD, Adam, etc.
         | 
| 14 | 
            +
            from PIL import Image
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # use gpu or cpu
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from tqdm import tqdm
         | 
| 19 | 
            +
            from torchvision import models
         | 
| 20 | 
            +
            # load pretrain model and modify...
         | 
| 21 | 
            +
            model = models.resnet50(pretrained=True)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            # If you want to do finetuning then set requires_grad = False
         | 
| 24 | 
            +
            # Remove these two lines if you want to train entire model,
         | 
| 25 | 
            +
            # and only want to load the pretrain weights.
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            for param in model.parameters():
         | 
| 28 | 
            +
                param.requires_grad = False
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            num_ftrs = model.fc.in_features
         | 
| 31 | 
            +
            model.fc = nn.Linear(num_ftrs, 2)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            model.to(device)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            # Loss and optimizer
         | 
| 36 | 
            +
            criterion = nn.CrossEntropyLoss()
         | 
| 37 | 
            +
            optimizer = optim.Adam(model.parameters(), lr=0.01)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            checkpoint = torch.load("D:\cats_dogs\cats_dogs\checpoint_epoch_4.pt",
         | 
| 40 | 
            +
                                   map_location=torch.device('cpu'))
         | 
| 41 | 
            +
            model.load_state_dict(checkpoint["model_state_dict"])
         | 
| 42 | 
            +
            optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            def image_classifier(inp):
         | 
| 46 | 
            +
                model.eval()
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                data_transforms = transforms.Compose([
         | 
| 49 | 
            +
                    transforms.ToTensor(),
         | 
| 50 | 
            +
                    transforms.Resize((224, 224)),
         | 
| 51 | 
            +
                    transforms.Normalize([0.5] * 3, [0.5] * 3), ])
         | 
| 52 | 
            +
                img = data_transforms(inp).unsqueeze(dim=0)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                img = img.to(device)
         | 
| 55 | 
            +
                pred = model(img)
         | 
| 56 | 
            +
                _, preds = torch.max(pred, 1)
         | 
| 57 | 
            +
                print(f"class : {preds}")
         | 
| 58 | 
            +
                cur_name = ""
         | 
| 59 | 
            +
                if preds[0] == 1:
         | 
| 60 | 
            +
                    print(f"predicted ----> Dog")
         | 
| 61 | 
            +
                    cur_name = "DOG"
         | 
| 62 | 
            +
                else:
         | 
| 63 | 
            +
                    print(f"predicted ----> Cat")
         | 
| 64 | 
            +
                    cur_name = "CAT"
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                return cur_name
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            demo = gr.Interface(fn=image_classifier, inputs="image", outputs="text")
         | 
| 70 | 
            +
            demo.launch()
         | 
    	
        checpoint_epoch_4.pt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:f4041e5a95287674e8b88731d4521bdb52dd593f54931f84f7b75fcf7f59a6c6
         | 
| 3 | 
            +
            size 94407571
         | 
    	
        dog_1.jpg
    ADDED
    
    |   |