Spaces:
Runtime error
Runtime error
File size: 6,067 Bytes
efa26bc a791b0d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import gradio as gr
import torch
from torchvision.datasets import CIFAR100
from PIL import Image
import random
import numpy as np
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
Image.warnings.simplefilter('ignore', Image.DecompressionBombWarning)
try:
sr_processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x4-64")
sr_model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x4-64")
sr_model.eval()
except Exception as e:
sr_model = None
try:
classifier_model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar100_resnet56", pretrained=True)
classifier_model.eval()
except Exception as e:
classifier_model = None
cifar100_dataset = CIFAR100(root="./cifar100_data", train=False, download=True)
cifar100_labels = cifar100_dataset.classes
def upscale_image(low_res_pil_image):
if sr_model is None or low_res_pil_image is None:
return low_res_pil_image.resize((400, 400), Image.Resampling.NEAREST)
with torch.no_grad():
inputs = sr_processor(low_res_pil_image, return_tensors="pt")
outputs = sr_model(**inputs)
output_tensor = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1)
output_numpy = np.moveaxis(output_tensor.numpy(), 0, -1)
output_image = (output_numpy * 255.0).round().astype(np.uint8)
return Image.fromarray(output_image)
def predict_ai(low_res_pil_image):
try:
from torchvision import transforms
preprocess_for_classifier = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.5071, 0.4867, 0.4408],
std=[0.2675, 0.2565, 0.2761]
),
])
img_t = preprocess_for_classifier(low_res_pil_image.convert("RGB"))
batch_t = torch.unsqueeze(img_t, 0)
with torch.no_grad():
out = classifier_model(batch_t)
_, index = torch.max(out, 1)
return cifar100_labels[index[0]]
except Exception as e:
return "Error"
def generate_category_markdown():
md = "|||||\n|:---|:---|:---|:---|\n"
for i in range(0, 100, 4):
row = cifar100_labels[i:i+4]
md += "| " + " | ".join(row) + " |\n"
return md
def battle(user_guess, state):
user_score = state["user_score"]
ai_score = state["ai_score"]
current_image_idx = state["current_image_idx"]
played_indices = state["played_indices"]
low_res_image, label_idx = cifar100_dataset[current_image_idx]
current_label = cifar100_labels[label_idx]
ai_guess = predict_ai(low_res_image)
if user_guess.lower().strip() == current_label.lower():
user_score += 1
if ai_guess.lower() == current_label.lower():
ai_score += 1
if len(played_indices) >= len(cifar100_dataset):
message = f"AI's Guess: '{ai_guess}'\nCorrect Answer: '{current_label}'\n\nAll images have been played! Game Over."
next_high_res_image = None
else:
while True:
next_image_idx = random.randint(0, len(cifar100_dataset) - 1)
if next_image_idx not in played_indices:
break
next_low_res_image, _ = cifar100_dataset[next_image_idx]
next_high_res_image = upscale_image(next_low_res_image)
message = f"AI's Guess: '{ai_guess}'\nCorrect Answer: '{current_label}'"
state["current_image_idx"] = next_image_idx
played_indices.add(next_image_idx)
new_state = {
"user_score": user_score,
"ai_score": ai_score,
"current_image_idx": state["current_image_idx"],
"played_indices": played_indices
}
return user_score, ai_score, message, "", next_high_res_image, new_state
def start_game():
if not classifier_model or not sr_model:
return 0, 0, "A required AI model failed to load. Please restart.", "", None, {}
first_idx = random.randint(0, len(cifar100_dataset) - 1)
first_low_res_image, _ = cifar100_dataset[first_idx]
first_high_res_image = upscale_image(first_low_res_image)
initial_state = {
"user_score": 0,
"ai_score": 0,
"current_image_idx": first_idx,
"played_indices": {first_idx}
}
return 0, 0, "Game Start! What is this high-resolution image?", "", first_high_res_image, initial_state
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky")) as demo:
state = gr.State()
gr.Markdown("<h1>Human vs. AI: Super-Resolution Battle</h1>")
gr.Markdown("A Swin2SR AI has upscaled a 32x32 image for you. Can you guess it before the other AI, which only sees the original low-res image?")
with gr.Row():
user_score_display = gr.Number(label="Your Score", value=0, interactive=False)
ai_score_display = gr.Number(label="AI Score", value=0, interactive=False)
with gr.Row(equal_height=False):
with gr.Column(scale=2):
image_display = gr.Image(label="Guess this upscaled image!", type="pil", height=400, width=400, interactive=False)
result_display = gr.Textbox(label="Round Result", interactive=False, lines=3)
with gr.Column(scale=1):
user_input = gr.Textbox(label="What is this image?", placeholder="e.g., apple, bicycle, cloud...")
submit_button = gr.Button("Submit Guess", variant="primary")
with gr.Accordion("View All 100 Categories", open=False):
gr.Markdown(generate_category_markdown())
submit_button.click(fn=battle, inputs=[user_input, state], outputs=[user_score_display, ai_score_display, result_display, user_input, image_display, state])
user_input.submit(fn=battle, inputs=[user_input, state], outputs=[user_score_display, ai_score_display, result_display, user_input, image_display, state])
demo.load(fn=start_game, inputs=None, outputs=[user_score_display, ai_score_display, result_display, user_input, image_display, state])
if __name__ == "__main__":
demo.launch() |