File size: 3,081 Bytes
e8590af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3197cb6
e8590af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from transformers import ViTForImageClassification, ViTConfig
import random
import numpy as np
import transformers
from skimage.metrics import structural_similarity as ssim
import requests
import os


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    transformers.set_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seed(42)

device = "cpu"
config = ViTConfig.from_pretrained("google/vit-base-patch16-224")
config.num_labels = 2  # Binary classification

# Download the model file
model_url = "https://huggingface.co/spuun/yummy-paws/resolve/main/best_model.pth"
model_path = "best_model.pth"

if not os.path.exists(model_path):
    response = requests.get(model_url)
    with open(model_path, "wb") as f:
        f.write(response.content)

# Load the trained model
model = ViTForImageClassification.from_pretrained(
    model_path, config=config, ignore_mismatched_sizes=True, weights_only=False
)
model.classifier = nn.Linear(model.config.hidden_size, 2)
model.to(device)

# Download the reference image
reference_image_url = (
    "https://huggingface.co/spuun/yummy-paws/resolve/main/images%20(15).jpeg"
)
reference_image_path = "reference_image.jpeg"

if not os.path.exists(reference_image_path):
    response = requests.get(reference_image_url)
    with open(reference_image_path, "wb") as f:
        f.write(response.content)

# Load the reference image for SSIM comparison
reference_image = Image.open(reference_image_path)


def calculate_ssim(img1, img2):
    img1_array = np.array(img1)
    img2_array = np.array(img2)
    ssim_value = ssim(img1_array, img2_array, channel_axis=2)
    return ssim_value


def predict_and_compare(image):
    image = image.resize(reference_image.size)
    ssim_value = calculate_ssim(image, reference_image)

    transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )

    image_tensor = transform(image).unsqueeze(0).to(device)

    model.eval()
    with torch.no_grad():
        output = model(image_tensor).logits
        probabilities = torch.softmax(output, dim=1)[0]
        predicted_class_index = torch.argmax(probabilities).item()

        class_names = ["False", "True"]  # Assuming 0 index is False, 1 is True
        predicted_class = class_names[predicted_class_index]
        probability = probabilities[predicted_class_index].item()

    return f"Predicted: {predicted_class}\nProbability: {probability:.4f}\nSSIM with reference: {ssim_value:.4f}"


iface = gr.Interface(
    fn=predict_and_compare,
    inputs=gr.Image(type="pil"),
    outputs="text",
    title="Image Classification and Comparison",
    description="Upload an image to classify it and compare with a reference image.",
)

iface.launch()