File size: 6,067 Bytes
efa26bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a791b0d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import gradio as gr
import torch
from torchvision.datasets import CIFAR100
from PIL import Image
import random
import numpy as np
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution

Image.warnings.simplefilter('ignore', Image.DecompressionBombWarning)

try:
    sr_processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x4-64")
    sr_model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x4-64")
    sr_model.eval()
except Exception as e:
    sr_model = None

try:
    classifier_model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar100_resnet56", pretrained=True)
    classifier_model.eval()
except Exception as e:
    classifier_model = None

cifar100_dataset = CIFAR100(root="./cifar100_data", train=False, download=True)
cifar100_labels = cifar100_dataset.classes

def upscale_image(low_res_pil_image):
    if sr_model is None or low_res_pil_image is None:
        return low_res_pil_image.resize((400, 400), Image.Resampling.NEAREST)

    with torch.no_grad():
        inputs = sr_processor(low_res_pil_image, return_tensors="pt")
        outputs = sr_model(**inputs)
        
    output_tensor = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1)
    output_numpy = np.moveaxis(output_tensor.numpy(), 0, -1)
    output_image = (output_numpy * 255.0).round().astype(np.uint8)
    
    return Image.fromarray(output_image)

def predict_ai(low_res_pil_image):
    try:
        from torchvision import transforms
        preprocess_for_classifier = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.5071, 0.4867, 0.4408], 
                std=[0.2675, 0.2565, 0.2761]
            ),
        ])
        img_t = preprocess_for_classifier(low_res_pil_image.convert("RGB"))
        batch_t = torch.unsqueeze(img_t, 0)
        
        with torch.no_grad():
            out = classifier_model(batch_t)
        
        _, index = torch.max(out, 1)
        return cifar100_labels[index[0]]
    except Exception as e:
        return "Error"
        
def generate_category_markdown():
    md = "|||||\n|:---|:---|:---|:---|\n"
    for i in range(0, 100, 4):
        row = cifar100_labels[i:i+4]
        md += "| " + " | ".join(row) + " |\n"
    return md

def battle(user_guess, state):
    user_score = state["user_score"]
    ai_score = state["ai_score"]
    current_image_idx = state["current_image_idx"]
    played_indices = state["played_indices"]

    low_res_image, label_idx = cifar100_dataset[current_image_idx]
    current_label = cifar100_labels[label_idx]

    ai_guess = predict_ai(low_res_image)

    if user_guess.lower().strip() == current_label.lower():
        user_score += 1
    if ai_guess.lower() == current_label.lower():
        ai_score += 1
    
    if len(played_indices) >= len(cifar100_dataset):
        message = f"AI's Guess: '{ai_guess}'\nCorrect Answer: '{current_label}'\n\nAll images have been played! Game Over."
        next_high_res_image = None
    else:
        while True:
            next_image_idx = random.randint(0, len(cifar100_dataset) - 1)
            if next_image_idx not in played_indices:
                break
        
        next_low_res_image, _ = cifar100_dataset[next_image_idx]
        next_high_res_image = upscale_image(next_low_res_image)
        message = f"AI's Guess: '{ai_guess}'\nCorrect Answer: '{current_label}'"
        state["current_image_idx"] = next_image_idx
        played_indices.add(next_image_idx)

    new_state = {
        "user_score": user_score,
        "ai_score": ai_score,
        "current_image_idx": state["current_image_idx"],
        "played_indices": played_indices
    }
    
    return user_score, ai_score, message, "", next_high_res_image, new_state

def start_game():
    if not classifier_model or not sr_model:
        return 0, 0, "A required AI model failed to load. Please restart.", "", None, {}

    first_idx = random.randint(0, len(cifar100_dataset) - 1)
    first_low_res_image, _ = cifar100_dataset[first_idx]
    first_high_res_image = upscale_image(first_low_res_image)
    
    initial_state = {
        "user_score": 0,
        "ai_score": 0,
        "current_image_idx": first_idx,
        "played_indices": {first_idx}
    }
    return 0, 0, "Game Start! What is this high-resolution image?", "", first_high_res_image, initial_state

with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky")) as demo:
    state = gr.State()

    gr.Markdown("<h1>Human vs. AI: Super-Resolution Battle</h1>")
    gr.Markdown("A Swin2SR AI has upscaled a 32x32 image for you. Can you guess it before the other AI, which only sees the original low-res image?")

    with gr.Row():
        user_score_display = gr.Number(label="Your Score", value=0, interactive=False)
        ai_score_display = gr.Number(label="AI Score", value=0, interactive=False)
    
    with gr.Row(equal_height=False):
        with gr.Column(scale=2):
            image_display = gr.Image(label="Guess this upscaled image!", type="pil", height=400, width=400, interactive=False)
            result_display = gr.Textbox(label="Round Result", interactive=False, lines=3)
        with gr.Column(scale=1):
            user_input = gr.Textbox(label="What is this image?", placeholder="e.g., apple, bicycle, cloud...")
            submit_button = gr.Button("Submit Guess", variant="primary")
            with gr.Accordion("View All 100 Categories", open=False):
                gr.Markdown(generate_category_markdown())

    submit_button.click(fn=battle, inputs=[user_input, state], outputs=[user_score_display, ai_score_display, result_display, user_input, image_display, state])
    user_input.submit(fn=battle, inputs=[user_input, state], outputs=[user_score_display, ai_score_display, result_display, user_input, image_display, state])
    demo.load(fn=start_game, inputs=None, outputs=[user_score_display, ai_score_display, result_display, user_input, image_display, state])

if __name__ == "__main__":
    demo.launch()