umjunsik1323 commited on
Commit
efa26bc
·
verified ·
1 Parent(s): a791b0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -151
app.py CHANGED
@@ -1,152 +1,150 @@
1
- import gradio as gr
2
- import torch
3
- from torchvision.datasets import CIFAR100
4
- from PIL import Image
5
- import random
6
- import numpy as np
7
- from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
8
-
9
- Image.warnings.simplefilter('ignore', Image.DecompressionBombWarning)
10
-
11
- try:
12
- sr_processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x4-64")
13
- sr_model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x4-64")
14
- sr_model.eval()
15
- except Exception as e:
16
- sr_model = None
17
-
18
- try:
19
- classifier_model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar100_resnet56", pretrained=True)
20
- classifier_model.eval()
21
- except Exception as e:
22
- classifier_model = None
23
-
24
- cifar100_dataset = CIFAR100(root="./cifar100_data", train=False, download=True)
25
- cifar100_labels = cifar100_dataset.classes
26
-
27
- def upscale_image(low_res_pil_image):
28
- if sr_model is None or low_res_pil_image is None:
29
- return low_res_pil_image.resize((400, 400), Image.Resampling.NEAREST)
30
-
31
- with torch.no_grad():
32
- inputs = sr_processor(low_res_pil_image, return_tensors="pt")
33
- outputs = sr_model(**inputs)
34
-
35
- output_tensor = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1)
36
- output_numpy = np.moveaxis(output_tensor.numpy(), 0, -1)
37
- output_image = (output_numpy * 255.0).round().astype(np.uint8)
38
-
39
- return Image.fromarray(output_image)
40
-
41
- def predict_ai(low_res_pil_image):
42
- try:
43
- from torchvision import transforms
44
- preprocess_for_classifier = transforms.Compose([
45
- transforms.ToTensor(),
46
- transforms.Normalize(
47
- mean=[0.5071, 0.4867, 0.4408],
48
- std=[0.2675, 0.2565, 0.2761]
49
- ),
50
- ])
51
- img_t = preprocess_for_classifier(low_res_pil_image.convert("RGB"))
52
- batch_t = torch.unsqueeze(img_t, 0)
53
-
54
- with torch.no_grad():
55
- out = classifier_model(batch_t)
56
-
57
- _, index = torch.max(out, 1)
58
- return cifar100_labels[index[0]]
59
- except Exception as e:
60
- return "Error"
61
-
62
- def generate_category_markdown():
63
- md = "|||||\n|:---|:---|:---|:---|\n"
64
- for i in range(0, 100, 4):
65
- row = cifar100_labels[i:i+4]
66
- md += "| " + " | ".join(row) + " |\n"
67
- return md
68
-
69
- # --- 4. 게임 로직 ---
70
- def battle(user_guess, state):
71
- user_score = state["user_score"]
72
- ai_score = state["ai_score"]
73
- current_image_idx = state["current_image_idx"]
74
- played_indices = state["played_indices"]
75
-
76
- low_res_image, label_idx = cifar100_dataset[current_image_idx]
77
- current_label = cifar100_labels[label_idx]
78
-
79
- ai_guess = predict_ai(low_res_image)
80
-
81
- if user_guess.lower().strip() == current_label.lower():
82
- user_score += 1
83
- if ai_guess.lower() == current_label.lower():
84
- ai_score += 1
85
-
86
- if len(played_indices) >= len(cifar100_dataset):
87
- message = f"AI's Guess: '{ai_guess}'\nCorrect Answer: '{current_label}'\n\nAll images have been played! Game Over."
88
- next_high_res_image = None
89
- else:
90
- while True:
91
- next_image_idx = random.randint(0, len(cifar100_dataset) - 1)
92
- if next_image_idx not in played_indices:
93
- break
94
-
95
- next_low_res_image, _ = cifar100_dataset[next_image_idx]
96
- next_high_res_image = upscale_image(next_low_res_image)
97
- message = f"AI's Guess: '{ai_guess}'\nCorrect Answer: '{current_label}'"
98
- state["current_image_idx"] = next_image_idx
99
- played_indices.add(next_image_idx)
100
-
101
- new_state = {
102
- "user_score": user_score,
103
- "ai_score": ai_score,
104
- "current_image_idx": state["current_image_idx"],
105
- "played_indices": played_indices
106
- }
107
-
108
- return user_score, ai_score, message, "", next_high_res_image, new_state
109
-
110
- def start_game():
111
- if not classifier_model or not sr_model:
112
- return 0, 0, "A required AI model failed to load. Please restart.", "", None, {}
113
-
114
- first_idx = random.randint(0, len(cifar100_dataset) - 1)
115
- first_low_res_image, _ = cifar100_dataset[first_idx]
116
- first_high_res_image = upscale_image(first_low_res_image)
117
-
118
- initial_state = {
119
- "user_score": 0,
120
- "ai_score": 0,
121
- "current_image_idx": first_idx,
122
- "played_indices": {first_idx}
123
- }
124
- return 0, 0, "Game Start! What is this high-resolution image?", "", first_high_res_image, initial_state
125
-
126
- # --- 5. Gradio 인터페이스 ---
127
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky")) as demo:
128
- state = gr.State()
129
-
130
- gr.Markdown("<h1>Human vs. AI: Super-Resolution Battle</h1>")
131
- gr.Markdown("A Swin2SR AI has upscaled a 32x32 image for you. Can you guess it before the other AI, which only sees the original low-res image?")
132
-
133
- with gr.Row():
134
- user_score_display = gr.Number(label="Your Score", value=0, interactive=False)
135
- ai_score_display = gr.Number(label="AI Score", value=0, interactive=False)
136
-
137
- with gr.Row(equal_height=False):
138
- with gr.Column(scale=2):
139
- image_display = gr.Image(label="Guess this upscaled image!", type="pil", height=400, width=400, interactive=False)
140
- result_display = gr.Textbox(label="Round Result", interactive=False, lines=3)
141
- with gr.Column(scale=1):
142
- user_input = gr.Textbox(label="What is this image?", placeholder="e.g., apple, bicycle, cloud...")
143
- submit_button = gr.Button("Submit Guess", variant="primary")
144
- with gr.Accordion("View All 100 Categories", open=False):
145
- gr.Markdown(generate_category_markdown())
146
-
147
- submit_button.click(fn=battle, inputs=[user_input, state], outputs=[user_score_display, ai_score_display, result_display, user_input, image_display, state])
148
- user_input.submit(fn=battle, inputs=[user_input, state], outputs=[user_score_display, ai_score_display, result_display, user_input, image_display, state])
149
- demo.load(fn=start_game, inputs=None, outputs=[user_score_display, ai_score_display, result_display, user_input, image_display, state])
150
-
151
- if __name__ == "__main__":
152
  demo.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision.datasets import CIFAR100
4
+ from PIL import Image
5
+ import random
6
+ import numpy as np
7
+ from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
8
+
9
+ Image.warnings.simplefilter('ignore', Image.DecompressionBombWarning)
10
+
11
+ try:
12
+ sr_processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x4-64")
13
+ sr_model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x4-64")
14
+ sr_model.eval()
15
+ except Exception as e:
16
+ sr_model = None
17
+
18
+ try:
19
+ classifier_model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar100_resnet56", pretrained=True)
20
+ classifier_model.eval()
21
+ except Exception as e:
22
+ classifier_model = None
23
+
24
+ cifar100_dataset = CIFAR100(root="./cifar100_data", train=False, download=True)
25
+ cifar100_labels = cifar100_dataset.classes
26
+
27
+ def upscale_image(low_res_pil_image):
28
+ if sr_model is None or low_res_pil_image is None:
29
+ return low_res_pil_image.resize((400, 400), Image.Resampling.NEAREST)
30
+
31
+ with torch.no_grad():
32
+ inputs = sr_processor(low_res_pil_image, return_tensors="pt")
33
+ outputs = sr_model(**inputs)
34
+
35
+ output_tensor = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1)
36
+ output_numpy = np.moveaxis(output_tensor.numpy(), 0, -1)
37
+ output_image = (output_numpy * 255.0).round().astype(np.uint8)
38
+
39
+ return Image.fromarray(output_image)
40
+
41
+ def predict_ai(low_res_pil_image):
42
+ try:
43
+ from torchvision import transforms
44
+ preprocess_for_classifier = transforms.Compose([
45
+ transforms.ToTensor(),
46
+ transforms.Normalize(
47
+ mean=[0.5071, 0.4867, 0.4408],
48
+ std=[0.2675, 0.2565, 0.2761]
49
+ ),
50
+ ])
51
+ img_t = preprocess_for_classifier(low_res_pil_image.convert("RGB"))
52
+ batch_t = torch.unsqueeze(img_t, 0)
53
+
54
+ with torch.no_grad():
55
+ out = classifier_model(batch_t)
56
+
57
+ _, index = torch.max(out, 1)
58
+ return cifar100_labels[index[0]]
59
+ except Exception as e:
60
+ return "Error"
61
+
62
+ def generate_category_markdown():
63
+ md = "|||||\n|:---|:---|:---|:---|\n"
64
+ for i in range(0, 100, 4):
65
+ row = cifar100_labels[i:i+4]
66
+ md += "| " + " | ".join(row) + " |\n"
67
+ return md
68
+
69
+ def battle(user_guess, state):
70
+ user_score = state["user_score"]
71
+ ai_score = state["ai_score"]
72
+ current_image_idx = state["current_image_idx"]
73
+ played_indices = state["played_indices"]
74
+
75
+ low_res_image, label_idx = cifar100_dataset[current_image_idx]
76
+ current_label = cifar100_labels[label_idx]
77
+
78
+ ai_guess = predict_ai(low_res_image)
79
+
80
+ if user_guess.lower().strip() == current_label.lower():
81
+ user_score += 1
82
+ if ai_guess.lower() == current_label.lower():
83
+ ai_score += 1
84
+
85
+ if len(played_indices) >= len(cifar100_dataset):
86
+ message = f"AI's Guess: '{ai_guess}'\nCorrect Answer: '{current_label}'\n\nAll images have been played! Game Over."
87
+ next_high_res_image = None
88
+ else:
89
+ while True:
90
+ next_image_idx = random.randint(0, len(cifar100_dataset) - 1)
91
+ if next_image_idx not in played_indices:
92
+ break
93
+
94
+ next_low_res_image, _ = cifar100_dataset[next_image_idx]
95
+ next_high_res_image = upscale_image(next_low_res_image)
96
+ message = f"AI's Guess: '{ai_guess}'\nCorrect Answer: '{current_label}'"
97
+ state["current_image_idx"] = next_image_idx
98
+ played_indices.add(next_image_idx)
99
+
100
+ new_state = {
101
+ "user_score": user_score,
102
+ "ai_score": ai_score,
103
+ "current_image_idx": state["current_image_idx"],
104
+ "played_indices": played_indices
105
+ }
106
+
107
+ return user_score, ai_score, message, "", next_high_res_image, new_state
108
+
109
+ def start_game():
110
+ if not classifier_model or not sr_model:
111
+ return 0, 0, "A required AI model failed to load. Please restart.", "", None, {}
112
+
113
+ first_idx = random.randint(0, len(cifar100_dataset) - 1)
114
+ first_low_res_image, _ = cifar100_dataset[first_idx]
115
+ first_high_res_image = upscale_image(first_low_res_image)
116
+
117
+ initial_state = {
118
+ "user_score": 0,
119
+ "ai_score": 0,
120
+ "current_image_idx": first_idx,
121
+ "played_indices": {first_idx}
122
+ }
123
+ return 0, 0, "Game Start! What is this high-resolution image?", "", first_high_res_image, initial_state
124
+
125
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky")) as demo:
126
+ state = gr.State()
127
+
128
+ gr.Markdown("<h1>Human vs. AI: Super-Resolution Battle</h1>")
129
+ gr.Markdown("A Swin2SR AI has upscaled a 32x32 image for you. Can you guess it before the other AI, which only sees the original low-res image?")
130
+
131
+ with gr.Row():
132
+ user_score_display = gr.Number(label="Your Score", value=0, interactive=False)
133
+ ai_score_display = gr.Number(label="AI Score", value=0, interactive=False)
134
+
135
+ with gr.Row(equal_height=False):
136
+ with gr.Column(scale=2):
137
+ image_display = gr.Image(label="Guess this upscaled image!", type="pil", height=400, width=400, interactive=False)
138
+ result_display = gr.Textbox(label="Round Result", interactive=False, lines=3)
139
+ with gr.Column(scale=1):
140
+ user_input = gr.Textbox(label="What is this image?", placeholder="e.g., apple, bicycle, cloud...")
141
+ submit_button = gr.Button("Submit Guess", variant="primary")
142
+ with gr.Accordion("View All 100 Categories", open=False):
143
+ gr.Markdown(generate_category_markdown())
144
+
145
+ submit_button.click(fn=battle, inputs=[user_input, state], outputs=[user_score_display, ai_score_display, result_display, user_input, image_display, state])
146
+ user_input.submit(fn=battle, inputs=[user_input, state], outputs=[user_score_display, ai_score_display, result_display, user_input, image_display, state])
147
+ demo.load(fn=start_game, inputs=None, outputs=[user_score_display, ai_score_display, result_display, user_input, image_display, state])
148
+
149
+ if __name__ == "__main__":
 
 
150
  demo.launch()