HumanvsAI / app.py
umjunsik1323's picture
Update app.py
efa26bc verified
import gradio as gr
import torch
from torchvision.datasets import CIFAR100
from PIL import Image
import random
import numpy as np
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
Image.warnings.simplefilter('ignore', Image.DecompressionBombWarning)
try:
sr_processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x4-64")
sr_model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x4-64")
sr_model.eval()
except Exception as e:
sr_model = None
try:
classifier_model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar100_resnet56", pretrained=True)
classifier_model.eval()
except Exception as e:
classifier_model = None
cifar100_dataset = CIFAR100(root="./cifar100_data", train=False, download=True)
cifar100_labels = cifar100_dataset.classes
def upscale_image(low_res_pil_image):
if sr_model is None or low_res_pil_image is None:
return low_res_pil_image.resize((400, 400), Image.Resampling.NEAREST)
with torch.no_grad():
inputs = sr_processor(low_res_pil_image, return_tensors="pt")
outputs = sr_model(**inputs)
output_tensor = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1)
output_numpy = np.moveaxis(output_tensor.numpy(), 0, -1)
output_image = (output_numpy * 255.0).round().astype(np.uint8)
return Image.fromarray(output_image)
def predict_ai(low_res_pil_image):
try:
from torchvision import transforms
preprocess_for_classifier = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.5071, 0.4867, 0.4408],
std=[0.2675, 0.2565, 0.2761]
),
])
img_t = preprocess_for_classifier(low_res_pil_image.convert("RGB"))
batch_t = torch.unsqueeze(img_t, 0)
with torch.no_grad():
out = classifier_model(batch_t)
_, index = torch.max(out, 1)
return cifar100_labels[index[0]]
except Exception as e:
return "Error"
def generate_category_markdown():
md = "|||||\n|:---|:---|:---|:---|\n"
for i in range(0, 100, 4):
row = cifar100_labels[i:i+4]
md += "| " + " | ".join(row) + " |\n"
return md
def battle(user_guess, state):
user_score = state["user_score"]
ai_score = state["ai_score"]
current_image_idx = state["current_image_idx"]
played_indices = state["played_indices"]
low_res_image, label_idx = cifar100_dataset[current_image_idx]
current_label = cifar100_labels[label_idx]
ai_guess = predict_ai(low_res_image)
if user_guess.lower().strip() == current_label.lower():
user_score += 1
if ai_guess.lower() == current_label.lower():
ai_score += 1
if len(played_indices) >= len(cifar100_dataset):
message = f"AI's Guess: '{ai_guess}'\nCorrect Answer: '{current_label}'\n\nAll images have been played! Game Over."
next_high_res_image = None
else:
while True:
next_image_idx = random.randint(0, len(cifar100_dataset) - 1)
if next_image_idx not in played_indices:
break
next_low_res_image, _ = cifar100_dataset[next_image_idx]
next_high_res_image = upscale_image(next_low_res_image)
message = f"AI's Guess: '{ai_guess}'\nCorrect Answer: '{current_label}'"
state["current_image_idx"] = next_image_idx
played_indices.add(next_image_idx)
new_state = {
"user_score": user_score,
"ai_score": ai_score,
"current_image_idx": state["current_image_idx"],
"played_indices": played_indices
}
return user_score, ai_score, message, "", next_high_res_image, new_state
def start_game():
if not classifier_model or not sr_model:
return 0, 0, "A required AI model failed to load. Please restart.", "", None, {}
first_idx = random.randint(0, len(cifar100_dataset) - 1)
first_low_res_image, _ = cifar100_dataset[first_idx]
first_high_res_image = upscale_image(first_low_res_image)
initial_state = {
"user_score": 0,
"ai_score": 0,
"current_image_idx": first_idx,
"played_indices": {first_idx}
}
return 0, 0, "Game Start! What is this high-resolution image?", "", first_high_res_image, initial_state
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky")) as demo:
state = gr.State()
gr.Markdown("<h1>Human vs. AI: Super-Resolution Battle</h1>")
gr.Markdown("A Swin2SR AI has upscaled a 32x32 image for you. Can you guess it before the other AI, which only sees the original low-res image?")
with gr.Row():
user_score_display = gr.Number(label="Your Score", value=0, interactive=False)
ai_score_display = gr.Number(label="AI Score", value=0, interactive=False)
with gr.Row(equal_height=False):
with gr.Column(scale=2):
image_display = gr.Image(label="Guess this upscaled image!", type="pil", height=400, width=400, interactive=False)
result_display = gr.Textbox(label="Round Result", interactive=False, lines=3)
with gr.Column(scale=1):
user_input = gr.Textbox(label="What is this image?", placeholder="e.g., apple, bicycle, cloud...")
submit_button = gr.Button("Submit Guess", variant="primary")
with gr.Accordion("View All 100 Categories", open=False):
gr.Markdown(generate_category_markdown())
submit_button.click(fn=battle, inputs=[user_input, state], outputs=[user_score_display, ai_score_display, result_display, user_input, image_display, state])
user_input.submit(fn=battle, inputs=[user_input, state], outputs=[user_score_display, ai_score_display, result_display, user_input, image_display, state])
demo.load(fn=start_game, inputs=None, outputs=[user_score_display, ai_score_display, result_display, user_input, image_display, state])
if __name__ == "__main__":
demo.launch()