telecomadm1145 commited on
Commit
f844b83
·
verified ·
1 Parent(s): 11503dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +174 -74
app.py CHANGED
@@ -1,6 +1,6 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- Swin-Large AI vs. Non-AI Detector (带多层 Grad-CAM 可视化)
4
  """
5
  import os
6
  import math
@@ -9,29 +9,37 @@ import torch.nn.functional as F
9
  import torch.nn as nn
10
  import timm
11
  import numpy as np
12
- from PIL import Image
13
  import gradio as gr
 
14
 
15
- from huggingface_hub import hf_hub_download # ← 新增
16
- from pytorch_grad_cam import GradCAM
17
- from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
18
- from pytorch_grad_cam.utils.image import show_cam_on_image
19
 
20
  # --- Configuration ---------------------------------------------------------
21
- REPO_ID = "telecomadm1145/swin-ai-detection"
22
- HF_FILENAME = "swin_classifier_stage1_v2_epoch_3.pth"
23
- LOCAL_CKPT_DIR = "./checkpoints"
24
- MODEL_NAME = "swin_large_patch4_window12_384" # ← 使用 large
25
- NUM_CLASSES = 2
26
- SEED = 4421
27
- dropout_rate = 0.1
28
-
29
- class_names = ["Non-AI Generated", "AI Generated"] # 0, 1
 
 
 
 
30
 
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
32
- torch.manual_seed(SEED); np.random.seed(SEED)
 
33
  print(f"Using device: {device}")
34
 
 
 
 
 
 
35
  # ---------------------------------------------------------------------------
36
  # 1. 模型结构
37
  class SwinClassifier(nn.Module):
@@ -59,97 +67,189 @@ class SwinClassifier(nn.Module):
59
  return self.classifier(feats)
60
 
61
  # ---------------------------------------------------------------------------
62
- # 2. 下载 / 缓存 checkpoint
63
- print("⏬ Download / cache checkpoint …")
64
- ckpt_path = hf_hub_download(
65
- repo_id = REPO_ID,
66
- filename = HF_FILENAME,
67
- local_dir = LOCAL_CKPT_DIR,
68
- force_download=False # 已存在则直接用
69
- )
70
- print(f"Checkpoint path: {ckpt_path}")
71
-
72
- # ---------------------------------------------------------------------------
73
- # 3. 实例化 & 加载权重
74
- model = SwinClassifier(MODEL_NAME, NUM_CLASSES, pretrained=False).to(device)
75
- state = torch.load(ckpt_path, map_location=device,weights_only=False)
76
- model.load_state_dict(state.get("model_state_dict", state), strict=True)
77
- model.eval()
78
- print("✅ Model loaded.")
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  # ---------------------------------------------------------------------------
81
- # 4. torchvision / timm transform 工厂函数
82
  def build_transform(is_training: bool, interpolation: str):
83
  """
84
  根据插值方式(双线性 / 三次等)构建 timm 默认变换
85
  """
 
 
86
  cfg = model.data_config.copy()
87
  cfg.update(dict(interpolation=interpolation))
88
  return timm.data.create_transform(**cfg, is_training=is_training)
89
 
90
  # ---------------------------------------------------------------------------
91
- # 5. Grad-CAM 辅助
92
- def reshape_transform_swin(tokens):
93
- """
94
- [B, N, C] 的 token 序列还原成 2D Feature map
95
- 适用于各 stage:224→56→28→14→7
96
- """
97
- B, N, C = tokens.size()
98
- H = W = int(math.sqrt(N))
99
- assert H * W == N, "Token 数量不是平方数!"
100
- return tokens.permute(0, 2, 1).view(B, C, H, W)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- # 选取四个 stage 最后一个 block norm2
103
- target_layers = [layer.blocks[-1].norm2 for layer in model.backbone.layers]
104
- print("Target layers for Grad-CAM:", len(target_layers))
105
 
106
  # ---------------------------------------------------------------------------
107
- # 6. 推理 + (可选)Grad-CAM
108
- @torch.no_grad()
109
- def infer(image_pil: Image.Image,
110
- interpolation: str = "bilinear",
111
- show_cam: bool = True):
112
  if image_pil is None:
113
  return None, None
114
 
 
 
 
 
 
 
115
  transform = build_transform(is_training=False, interpolation=interpolation)
116
  input_tensor = transform(image_pil).unsqueeze(0).to(device)
117
 
118
- logits = model(input_tensor)
119
- probs = F.softmax(logits, dim=1)[0]
 
 
 
 
 
 
 
 
 
 
 
 
120
  confidences = {class_names[i]: float(probs[i]) for i in range(NUM_CLASSES)}
121
 
122
- return confidences
 
 
 
 
 
 
123
 
124
  # ---------------------------------------------------------------------------
125
- # 7. Gradio UI
126
  def launch_app():
127
- with gr.Blocks() as demo:
128
- gr.Markdown("# 🖼️ AI vs. Non-AI Image Classifier (with Swin-Large)")
 
 
 
 
129
 
130
- run_btn = gr.Button("🚀 Run")
131
-
132
- with gr.Row():
133
- interp_choice = gr.Radio(
134
- ["bilinear", "bicubic", "nearest"], value="bicubic",
135
- label="Resize Interpolation (预处理插值)"
136
- )
137
-
138
  with gr.Row():
139
- in_img = gr.Image(type="pil", label="Upload an Image")
140
- out_lbl = gr.Label(num_top_classes=2, label="Predictions")
141
-
142
- def _run(img, inter):
143
- return infer(img, interpolation=inter)
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  run_btn.click(
146
- _run,
147
- inputs=[in_img, interp_choice],
148
- outputs=[out_lbl]
 
 
 
 
 
 
 
 
 
 
 
149
  )
150
 
151
  demo.launch()
152
 
153
  # ---------------------------------------------------------------------------
154
  if __name__ == "__main__":
 
 
 
 
 
155
  launch_app()
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ Swin-Large AI vs. Non-AI Detector (with Model Selection & Attention Visualization)
4
  """
5
  import os
6
  import math
 
9
  import torch.nn as nn
10
  import timm
11
  import numpy as np
12
+ from PIL import Image, ImageDraw
13
  import gradio as gr
14
+ import matplotlib.pyplot as plt
15
 
16
+ from huggingface_hub import hf_hub_download
 
 
 
17
 
18
  # --- Configuration ---------------------------------------------------------
19
+ REPO_ID = "telecomadm1145/swin-ai-detection"
20
+ HF_FILENAMES = {
21
+ "V2": "swin_classifier_stage1_v2_epoch_3.pth",
22
+ "V4": "swin_classifier_stage1_v4.pth",
23
+ }
24
+ DEFAULT_CKPT = "Swin-V4 (Final)"
25
+ LOCAL_CKPT_DIR = "./checkpoints"
26
+ MODEL_NAME = "swin_large_patch4_window12_384"
27
+ NUM_CLASSES = 2
28
+ SEED = 4421
29
+ dropout_rate = 0.1
30
+
31
+ class_names = ["Non-AI Generated", "AI Generated"] # 0, 1
32
 
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ torch.manual_seed(SEED)
35
+ np.random.seed(SEED)
36
  print(f"Using device: {device}")
37
 
38
+ # --- Global model state ----------------------------------------------------
39
+ model = None
40
+ current_ckpt_name = None
41
+ attention_maps = [] # To store hooked attention maps
42
+
43
  # ---------------------------------------------------------------------------
44
  # 1. 模型结构
45
  class SwinClassifier(nn.Module):
 
67
  return self.classifier(feats)
68
 
69
  # ---------------------------------------------------------------------------
70
+ # 2. 动态模型加载函数
71
+ def load_model(ckpt_name: str):
72
+ """
73
+ Dynamically loads the selected model checkpoint.
74
+ If the model is already loaded, it does nothing.
75
+ """
76
+ global model, current_ckpt_name
77
+ if ckpt_name == current_ckpt_name:
78
+ print(f" Model '{ckpt_name}' is already loaded.")
79
+ return
80
+
81
+ print(f"🔄 Switching to model: '{ckpt_name}'...")
82
+ hf_filename = HF_FILENAMES[ckpt_name]
83
+
84
+ print(" Downloading / caching checkpoint if needed…")
85
+ ckpt_path = hf_hub_download(
86
+ repo_id=REPO_ID,
87
+ filename=hf_filename,
88
+ local_dir=LOCAL_CKPT_DIR,
89
+ force_download=False
90
+ )
91
+ print(f"Checkpoint path: {ckpt_path}")
92
+
93
+ # Instantiate and load weights
94
+ model = SwinClassifier(MODEL_NAME, NUM_CLASSES, pretrained=False).to(device)
95
+ state = torch.load(ckpt_path, map_location=device, weights_only=False)
96
+ model.load_state_dict(state.get("model_state_dict", state), strict=True)
97
+ model.eval()
98
+ current_ckpt_name = ckpt_name
99
+ print(f"✅ Model '{ckpt_name}' loaded successfully.")
100
 
101
  # ---------------------------------------------------------------------------
102
+ # 3. torchvision / timm transform 工厂函数
103
  def build_transform(is_training: bool, interpolation: str):
104
  """
105
  根据插值方式(双线性 / 三次等)构建 timm 默认变换
106
  """
107
+ if model is None:
108
+ raise RuntimeError("Model is not loaded. Please call load_model() first.")
109
  cfg = model.data_config.copy()
110
  cfg.update(dict(interpolation=interpolation))
111
  return timm.data.create_transform(**cfg, is_training=is_training)
112
 
113
  # ---------------------------------------------------------------------------
114
+ # 4. Attention Hook & Visualization
115
+ def get_attention_map(module, input, output):
116
+ """Hook to capture the attention map from the attention module."""
117
+ global attention_maps
118
+ # The attention map is typically the second element of the output tuple
119
+ # It has shape [B, num_heads, N, N] where N is num_patches
120
+ attention_maps.append(output[1].cpu())
121
+
122
+ def create_attention_visualization(image_pil: Image.Image, attn_map: torch.Tensor) -> Image.Image:
123
+ """Creates an overlay of the attention map on the original image."""
124
+ # Average across all heads
125
+ attn_map = attn_map.mean(dim=1)[0] # Shape: [N, N]
126
+
127
+ # To get the attention score for each patch, we can average the attention
128
+ # it receives from all other patches.
129
+ residual_attn = attn_map.sum(dim=0) # Sum over rows
130
+
131
+ # Reshape to 2D grid
132
+ patch_size = model.backbone.patch_embed.patch_size[0]
133
+ num_patches = residual_attn.shape[0]
134
+ grid_size = int(math.sqrt(num_patches))
135
+
136
+ if grid_size * grid_size != num_patches:
137
+ print(f"Warning: Number of patches ({num_patches}) is not a perfect square. Visualization may be incorrect.")
138
+ # Fallback for non-square patch layouts if needed, but Swin usually has square.
139
+ return image_pil
140
+
141
+ attn_grid = residual_attn.reshape(grid_size, grid_size).detach().numpy()
142
+
143
+ # Normalize the grid
144
+ attn_grid = (attn_grid - attn_grid.min()) / (attn_grid.max() - attn_grid.min())
145
+
146
+ # Use a colormap to create a heatmap
147
+ cmap = plt.get_cmap('viridis')
148
+ heatmap_colored = (cmap(attn_grid)[:, :, :3] * 255).astype(np.uint8)
149
+ heatmap_pil = Image.fromarray(heatmap_colored)
150
+
151
+ # Resize heatmap to original image size
152
+ heatmap_resized = heatmap_pil.resize(image_pil.size, Image.BICUBIC)
153
 
154
+ # Blend original image with the heatmap
155
+ viz_image = Image.blend(image_pil, heatmap_resized, alpha=0.5)
156
+ return viz_image
157
 
158
  # ---------------------------------------------------------------------------
159
+ # 5. 推理 + 可选的注意力可视化
160
+ def predict_and_visualize(image_pil: Image.Image,
161
+ ckpt_name: str,
162
+ interpolation: str = "bicubic",
163
+ show_attention: bool = True):
164
  if image_pil is None:
165
  return None, None
166
 
167
+ # Ensure the correct model is loaded
168
+ load_model(ckpt_name)
169
+
170
+ global attention_maps
171
+ attention_maps = [] # Reset before inference
172
+
173
  transform = build_transform(is_training=False, interpolation=interpolation)
174
  input_tensor = transform(image_pil).unsqueeze(0).to(device)
175
 
176
+ # Register hook if visualization is requested
177
+ hook_handle = None
178
+ if show_attention:
179
+ target_layer = model.backbone.layers[-1].blocks[-1].attn
180
+ hook_handle = target_layer.register_forward_hook(get_attention_map)
181
+
182
+ with torch.no_grad():
183
+ logits = model(input_tensor)
184
+
185
+ # Always remove the hook after the forward pass
186
+ if hook_handle:
187
+ hook_handle.remove()
188
+
189
+ probs = F.softmax(logits, dim=1)[0]
190
  confidences = {class_names[i]: float(probs[i]) for i in range(NUM_CLASSES)}
191
 
192
+ # Generate visualization if requested and possible
193
+ viz_image = None
194
+ if show_attention and attention_maps:
195
+ original_image = image_pil.copy().convert("RGB")
196
+ viz_image = create_attention_visualization(original_image, attention_maps[0])
197
+
198
+ return confidences, viz_image
199
 
200
  # ---------------------------------------------------------------------------
201
+ # 6. Gradio UI
202
  def launch_app():
203
+ # Load default model at startup
204
+ load_model(DEFAULT_CKPT)
205
+
206
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
207
+ gr.Markdown("# 🖼️ AI vs. Non-AI Image Classifier")
208
+ gr.Markdown("Using Swin-Large Transformer with Attention Visualization.")
209
 
 
 
 
 
 
 
 
 
210
  with gr.Row():
211
+ with gr.Column(scale=1):
212
+ in_img = gr.Image(type="pil", label="Upload an Image")
213
+
214
+ model_choice = gr.Dropdown(
215
+ list(HF_FILENAMES.keys()), value=DEFAULT_CKPT, label="Select Model"
216
+ )
217
+ interp_choice = gr.Radio(
218
+ ["bilinear", "bicubic", "nearest"], value="bicubic",
219
+ label="Resize Interpolation (Preprocessing)"
220
+ )
221
+ viz_checkbox = gr.Checkbox(value=True, label="Show Attention Visualization")
222
+
223
+ run_btn = gr.Button("🚀 Run Analysis", variant="primary")
224
+
225
+ with gr.Column(scale=2):
226
+ out_lbl = gr.Label(num_top_classes=2, label="Predictions")
227
+ out_viz = gr.Image(type="pil", label="Attention Map Visualization", visible=True)
228
 
229
  run_btn.click(
230
+ predict_and_visualize,
231
+ inputs=[in_img, model_choice, interp_choice, viz_checkbox],
232
+ outputs=[out_lbl, out_viz]
233
+ )
234
+
235
+ gr.Examples(
236
+ examples=[
237
+ #[os.path.join(os.path.dirname(__file__), "examples/ai_1.png"), DEFAULT_CKPT, "bicubic", True],
238
+ #[os.path.join(os.path.dirname(__file__), "examples/real_1.jpg"), DEFAULT_CKPT, "bicubic", True],
239
+ ],
240
+ inputs=[in_img, model_choice, interp_choice, viz_checkbox],
241
+ outputs=[out_lbl, out_viz],
242
+ fn=predict_and_visualize,
243
+ cache_examples=False, # Set to True if examples are static
244
  )
245
 
246
  demo.launch()
247
 
248
  # ---------------------------------------------------------------------------
249
  if __name__ == "__main__":
250
+ # Create an examples directory for Gradio
251
+ if not os.path.exists("examples"):
252
+ os.makedirs("examples")
253
+ print("Created 'examples' directory. Please add some sample images (e.g., ai_1.png, real_1.jpg) there for the UI examples.")
254
+
255
  launch_app()