telecomadm1145 commited on
Commit
cedb103
·
verified ·
1 Parent(s): 619e447

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -240
app.py CHANGED
@@ -1,6 +1,6 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- Swin-Large AI vs. Non-AI Detector (带多层注意力可视化)
4
  """
5
  import os
6
  import math
@@ -11,28 +11,30 @@ import timm
11
  import numpy as np
12
  from PIL import Image
13
  import gradio as gr
14
- from collections import defaultdict
15
 
16
  from huggingface_hub import hf_hub_download
 
 
 
17
 
18
  # --- Configuration ---------------------------------------------------------
19
  REPO_ID = "telecomadm1145/swin-ai-detection"
20
  HF_FILENAME = "swin_classifier_stage1_v2_epoch_3.pth"
21
  LOCAL_CKPT_DIR = "./checkpoints"
22
- MODEL_NAME = "swin_large_patch4_window12_384"
23
  NUM_CLASSES = 2
24
  SEED = 4421
25
  dropout_rate = 0.1
26
 
27
- class_names = ["Non-AI Generated", "AI Generated"]
28
 
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
  torch.manual_seed(SEED); np.random.seed(SEED)
31
  print(f"Using device: {device}")
32
 
33
  # ---------------------------------------------------------------------------
34
- # 1. 模型结构(修改以支持注意力提取)
35
- class SwinClassifier(nn.Module):
36
  def __init__(self, model_name, num_classes, pretrained=True):
37
  super().__init__()
38
  self.backbone = timm.create_model(model_name, pretrained=pretrained,
@@ -52,34 +54,31 @@ class SwinClassifier(nn.Module):
52
  nn.Linear(128, num_classes)
53
  )
54
 
55
- # 存储注意力权重
56
- self.attention_maps = defaultdict(list)
57
  self.register_attention_hooks()
58
 
59
  def register_attention_hooks(self):
60
- """注册钩子函数来捕获注意力权重"""
61
- def get_attention_hook(layer_name):
62
  def hook(module, input, output):
63
- # 对于 Swin Transformer,注意力在 attention 模块中
 
64
  if hasattr(module, 'attn'):
65
  # 获取注意力权重
66
- attn_weights = module.attn.attention_weights if hasattr(module.attn, 'attention_weights') else None
67
- if attn_weights is not None:
68
- self.attention_maps[layer_name].append(attn_weights.detach().cpu())
69
  return hook
70
 
71
- # 为每个 stage 的每个 block 注册钩子
72
  for stage_idx, stage in enumerate(self.backbone.layers):
73
  for block_idx, block in enumerate(stage.blocks):
74
- layer_name = f"stage_{stage_idx}_block_{block_idx}"
75
- block.register_forward_hook(get_attention_hook(layer_name))
76
-
77
- def clear_attention_maps(self):
78
- """清空注意力映射"""
79
- self.attention_maps.clear()
80
 
81
  def forward(self, x):
82
- self.clear_attention_maps()
 
83
  feats = self.backbone(x)
84
  return self.classifier(feats)
85
 
@@ -88,282 +87,248 @@ class SwinClassifier(nn.Module):
88
  class AttentionExtractor:
89
  def __init__(self, model):
90
  self.model = model
91
- self.hooks = []
92
  self.attention_maps = {}
93
 
94
- def register_hooks(self):
95
- """注册钩子来提取注意力权重"""
96
- def get_attention_hook(name):
97
- def hook(module, input, output):
98
- # 提取 Window Attention 的注意力权重
99
- if hasattr(module, 'qkv'):
100
- B, N, C = input[0].shape
101
- qkv = module.qkv(input[0]).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4)
102
- q, k, v = qkv.unbind(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- attn = (q @ k.transpose(-2, -1)) * module.scale
105
- attn = attn.softmax(dim=-1)
 
106
 
107
- self.attention_maps[name] = attn.detach().cpu()
108
- return hook
 
109
 
110
- # 清除之前的钩子
111
- self.clear_hooks()
112
-
113
- # 为每个 stage 的每个 block 的 attention 模块注册钩子
114
- for stage_idx, stage in enumerate(self.model.backbone.layers):
115
- for block_idx, block in enumerate(stage.blocks):
116
- if hasattr(block, 'attn'):
117
- name = f"stage_{stage_idx}_block_{block_idx}"
118
- hook = block.attn.register_forward_hook(get_attention_hook(name))
119
- self.hooks.append(hook)
120
-
121
- def clear_hooks(self):
122
- """清除所有钩子"""
123
- for hook in self.hooks:
124
- hook.remove()
125
- self.hooks = []
126
-
127
- def clear_attention_maps(self):
128
- """清空注意力映射"""
129
- self.attention_maps.clear()
130
-
131
- def create_attention_visualization(attention_weights, input_size, stage_info):
132
- """
133
- 创建注意力可视化图
134
-
135
- Args:
136
- attention_weights: [B, num_heads, N, N] 注意力权重
137
- input_size: 输入图像尺寸 (H, W)
138
- stage_info: stage 信息,用于确定分辨率
139
- """
140
- if attention_weights is None or len(attention_weights) == 0:
141
- return None
142
-
143
- # 取第一个样本和所有头的平均
144
- attn = attention_weights[0].mean(dim=0) # [N, N]
145
-
146
- # 获取 [CLS] token 对其他 token 的注意力(如果存在)
147
- # 对于 Swin,通常没有 CLS token,所以我们计算平均注意力
148
- attn_map = attn.mean(dim=0) # [N]
149
-
150
- # 确定特征图的尺寸
151
- N = attn_map.shape[0]
152
- feat_size = int(math.sqrt(N))
153
-
154
- if feat_size * feat_size != N:
155
- # 如果不是完全平方数,可能包含了额外的 token
156
- feat_size = int(math.sqrt(N))
157
- attn_map = attn_map[:feat_size*feat_size]
158
-
159
- # 重塑为 2D
160
- attn_2d = attn_map.reshape(feat_size, feat_size)
161
-
162
- # 转换为 numpy 并归一化
163
- attn_np = attn_2d.numpy()
164
- attn_np = (attn_np - attn_np.min()) / (attn_np.max() - attn_np.min() + 1e-8)
165
-
166
- # 调整到输入图像尺寸
167
- attn_img = Image.fromarray((attn_np * 255).astype(np.uint8), mode='L')
168
- attn_img = attn_img.resize(input_size, Image.Resampling.BILINEAR)
169
-
170
- return attn_img
171
 
172
  # ---------------------------------------------------------------------------
173
- # 下载和加载模型
174
  print("⏬ Download / cache checkpoint …")
175
  ckpt_path = hf_hub_download(
176
  repo_id = REPO_ID,
177
  filename = HF_FILENAME,
178
  local_dir = LOCAL_CKPT_DIR,
179
- force_download=False
180
  )
181
  print(f"Checkpoint path: {ckpt_path}")
182
 
183
- model = SwinClassifier(MODEL_NAME, NUM_CLASSES, pretrained=False).to(device)
 
 
184
  state = torch.load(ckpt_path, map_location=device, weights_only=False)
185
- model.load_state_dict(state.get("model_state_dict", state), strict=True)
186
  model.eval()
187
  print("✅ Model loaded.")
188
 
189
- # 创建注意力提取器
190
  attention_extractor = AttentionExtractor(model)
191
 
192
  # ---------------------------------------------------------------------------
193
- # 构建变换函数
194
  def build_transform(is_training: bool, interpolation: str):
 
 
 
195
  cfg = model.data_config.copy()
196
  cfg.update(dict(interpolation=interpolation))
197
  return timm.data.create_transform(**cfg, is_training=is_training)
198
 
199
  # ---------------------------------------------------------------------------
200
- # 推理函数
201
- @torch.no_grad()
202
- def infer(image_pil: Image.Image,
203
- interpolation: str = "bilinear",
204
- show_attention: bool = True,
205
- selected_layer: str = "stage_3_block_1"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
 
 
 
 
 
 
 
 
 
 
207
  if image_pil is None:
208
- return None, None, "请上传图片"
209
 
210
  transform = build_transform(is_training=False, interpolation=interpolation)
211
  input_tensor = transform(image_pil).unsqueeze(0).to(device)
212
 
213
- # 注册钩子
214
- if show_attention:
215
- attention_extractor.register_hooks()
216
- attention_extractor.clear_attention_maps()
217
-
218
- # 分类预测
219
- logits = model(input_tensor)
220
- probs = F.softmax(logits, dim=1)[0]
221
- confidences = {class_names[i]: float(probs[i]) for i in range(NUM_CLASSES)}
222
 
223
- if not show_attention:
224
- attention_extractor.clear_hooks()
225
- return confidences, None, "分类完成"
226
-
227
- # 获取选定层的注意力
228
- layer_info = f"当前显示层: {selected_layer}"
229
 
230
- if selected_layer in attention_extractor.attention_maps:
231
- attention_weights = attention_extractor.attention_maps[selected_layer]
232
-
233
- # 获取 stage 信息来确定分辨率
234
- stage_num = int(selected_layer.split('_')[1])
235
- stage_info = {'stage': stage_num}
236
-
237
- attention_img = create_attention_visualization(
238
- attention_weights, image_pil.size, stage_info
239
- )
 
 
 
 
 
 
240
 
241
- # 清理钩子
242
- attention_extractor.clear_hooks()
 
243
 
244
- if attention_img is not None:
245
- # 创建彩色热力图
246
- attention_colored = Image.new('RGB', image_pil.size)
247
- attn_array = np.array(attention_img)
248
-
249
- # 创建热力图 (红色表示高注意力)
250
- colored_array = np.zeros((*attn_array.shape, 3), dtype=np.uint8)
251
- colored_array[:, :, 0] = attn_array # 红色通道
252
- colored_array[:, :, 1] = attn_array // 2 # 绿色通道(减弱)
253
-
254
- attention_colored = Image.fromarray(colored_array)
255
-
256
- # 与原图混合
257
- blended = Image.blend(image_pil.convert('RGB'), attention_colored, alpha=0.4)
258
-
259
- return confidences, blended, f"{layer_info} - 注意力可视化完成"
260
  else:
261
- return confidences, None, f"{layer_info} - 注意力提取失败"
262
  else:
263
- available_layers = list(attention_extractor.attention_maps.keys())
264
- attention_extractor.clear_hooks()
265
- return confidences, None, f"层 {selected_layer} 不可用。可用层: {available_layers[:5]}..."
266
-
267
- # ---------------------------------------------------------------------------
268
- # 获取可用的层列表
269
- def get_available_layers():
270
- """获取所有可用的注意力层"""
271
- layers = []
272
- for stage_idx in range(4): # Swin-Large 有 4 个 stage
273
- # 每个 stage 的 block 数量
274
- stage_blocks = [2, 2, 18, 2] # Swin-Large 的配置
275
- for block_idx in range(stage_blocks[stage_idx]):
276
- layers.append(f"stage_{stage_idx}_block_{block_idx}")
277
- return layers
278
 
279
  # ---------------------------------------------------------------------------
280
- # Gradio UI
281
  def launch_app():
282
- available_layers = get_available_layers()
283
-
284
- with gr.Blocks(title="AI Image Detector with Attention") as demo:
285
  gr.Markdown("""
286
- # 🖼️ AI vs. Non-AI Image Classifier (Swin-Large + Attention Visualization)
287
 
288
- 🔍 基于 Swin-Large 的 AI 图片检测器,使用注意力机制可视化模型关注的区域
289
 
290
- ## 功能特点:
291
- - 🎯 使用 Swin Transformer 原生注意力权重
292
- - 🎨 可视化不同层的注意力模式
293
- - 🔄 支持多种插值方法优化预处理
294
 
295
- ## 使用说明:
296
- 1. 上传图片
297
- 2. 选择要可视化的注意力层
298
- 3. 选择插值方法(推荐 bicubic)
299
- 4. 点击运行
300
 
301
- ## 注意事项:
302
- - 工具仅供研究和教育用途
303
- - 不保证 100% 准确率
304
- - 请负责任地使用
305
  """)
306
 
307
  with gr.Row():
308
- with gr.Column():
309
- in_img = gr.Image(type="pil", label="上传图片")
310
-
311
- with gr.Row():
312
- interp_choice = gr.Radio(
313
- ["bilinear", "bicubic", "nearest"],
314
- value="bicubic",
315
- label="插值方法"
316
- )
317
-
318
- attention_toggle = gr.Checkbox(
319
- value=True,
320
- label="显示注意力可视化"
321
- )
322
-
323
- layer_choice = gr.Dropdown(
324
- choices=available_layers,
325
- value="stage_3_block_1",
326
- label="选择注意力层",
327
- info="不同层关注不同级别的特征"
328
- )
329
-
330
- run_btn = gr.Button("🚀 开始检测", variant="primary")
331
-
332
- with gr.Column():
333
- out_lbl = gr.Label(
334
- num_top_classes=2,
335
- label="分类结果"
336
- )
337
- out_attention = gr.Image(
338
- type="pil",
339
- label="注意力可视化"
340
- )
341
- status_text = gr.Textbox(
342
- label="状态信息",
343
- interactive=False
344
- )
345
-
346
- # 层选择说明
347
- gr.Markdown("""
348
- ### 层选择指南:
349
- - **Stage 0-1**: 关注底层特征(边缘、纹理)
350
- - **Stage 2**: 关注中层特征(形状、局部模式)
351
- - **Stage 3**: 关注高层特征(语义、全局结构)
352
 
353
- 推荐从 `stage_3_block_1` 开始尝试,然后对比不同层的关注点。
354
- """)
355
 
356
- def _run(img, inter, attention_flag, layer):
357
- return infer(img, interpolation=inter,
358
- show_attention=attention_flag,
359
- selected_layer=layer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
  run_btn.click(
362
  _run,
363
- inputs=[in_img, interp_choice, attention_toggle, layer_choice],
364
- outputs=[out_lbl, out_attention, status_text]
365
  )
366
 
 
 
 
 
 
 
 
 
367
  demo.launch()
368
 
369
  # ---------------------------------------------------------------------------
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ Swin-Large AI vs. Non-AI Detector (基于注意力机制可视化)
4
  """
5
  import os
6
  import math
 
11
  import numpy as np
12
  from PIL import Image
13
  import gradio as gr
 
14
 
15
  from huggingface_hub import hf_hub_download
16
+ from torchvision import transforms
17
+ import matplotlib.pyplot as plt
18
+ import matplotlib.cm as cm
19
 
20
  # --- Configuration ---------------------------------------------------------
21
  REPO_ID = "telecomadm1145/swin-ai-detection"
22
  HF_FILENAME = "swin_classifier_stage1_v2_epoch_3.pth"
23
  LOCAL_CKPT_DIR = "./checkpoints"
24
+ MODEL_NAME = "swin_large_patch4_window12_384" # ← 使用 large
25
  NUM_CLASSES = 2
26
  SEED = 4421
27
  dropout_rate = 0.1
28
 
29
+ class_names = ["Non-AI Generated", "AI Generated"] # 0, 1
30
 
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
32
  torch.manual_seed(SEED); np.random.seed(SEED)
33
  print(f"Using device: {device}")
34
 
35
  # ---------------------------------------------------------------------------
36
+ # 1. 修改模型结构以提取注意力
37
+ class SwinClassifierWithAttention(nn.Module):
38
  def __init__(self, model_name, num_classes, pretrained=True):
39
  super().__init__()
40
  self.backbone = timm.create_model(model_name, pretrained=pretrained,
 
54
  nn.Linear(128, num_classes)
55
  )
56
 
57
+ # 存储注意力权重的钩子
58
+ self.attention_weights = {}
59
  self.register_attention_hooks()
60
 
61
  def register_attention_hooks(self):
62
+ """注册钩子函数来提取注意力权重"""
63
+ def hook_fn(name):
64
  def hook(module, input, output):
65
+ # 对于Swin Transformer的窗口注意力机制
66
+ # output通常是 (B, N, C) 格式
67
  if hasattr(module, 'attn'):
68
  # 获取注意力权重
69
+ self.attention_weights[name] = module.attn.attention_weights
 
 
70
  return hook
71
 
72
+ # 为每个stage的每个block注册钩子
73
  for stage_idx, stage in enumerate(self.backbone.layers):
74
  for block_idx, block in enumerate(stage.blocks):
75
+ hook_name = f"stage_{stage_idx}_block_{block_idx}"
76
+ if hasattr(block, 'attn'):
77
+ block.attn.register_forward_hook(hook_fn(hook_name))
 
 
 
78
 
79
  def forward(self, x):
80
+ # 清空之前的注意力权重
81
+ self.attention_weights = {}
82
  feats = self.backbone(x)
83
  return self.classifier(feats)
84
 
 
87
  class AttentionExtractor:
88
  def __init__(self, model):
89
  self.model = model
 
90
  self.attention_maps = {}
91
 
92
+ def extract_attention_weights(self, x):
93
+ """提取所有层的注意力权重"""
94
+ with torch.no_grad():
95
+ _ = self.model(x) # 前向传播以触发钩子
96
+ return self.model.attention_weights.copy()
97
+
98
+ def process_attention_for_visualization(self, attention_weights, input_size):
99
+ """处理注意力权重用于可视化"""
100
+ processed_maps = {}
101
+
102
+ for layer_name, attn_weight in attention_weights.items():
103
+ if attn_weight is None:
104
+ continue
105
+
106
+ # attn_weight shape: [batch_size, num_heads, seq_len, seq_len]
107
+ if len(attn_weight.shape) == 4:
108
+ # 取平均池化所有注意力头
109
+ attn_map = attn_weight.mean(dim=1) # [batch_size, seq_len, seq_len]
110
+
111
+ # 取第一个样本
112
+ attn_map = attn_map[0] # [seq_len, seq_len]
113
+
114
+ # 对于自注意力,我们通常关注CLS token对其他token的注意力
115
+ # 或者计算所有token的平均注意力
116
+ if attn_map.shape[0] > 1:
117
+ # 计算每个位置的平均注意力分数
118
+ avg_attention = attn_map.mean(dim=0) # [seq_len]
119
 
120
+ # 将注意力分数reshape为2D特征图
121
+ seq_len = avg_attention.shape[0]
122
+ grid_size = int(math.sqrt(seq_len))
123
 
124
+ if grid_size * grid_size == seq_len:
125
+ attention_2d = avg_attention.reshape(grid_size, grid_size)
126
+ processed_maps[layer_name] = attention_2d
127
 
128
+ return processed_maps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  # ---------------------------------------------------------------------------
131
+ # 3. 下载 / 缓存 checkpoint
132
  print("⏬ Download / cache checkpoint …")
133
  ckpt_path = hf_hub_download(
134
  repo_id = REPO_ID,
135
  filename = HF_FILENAME,
136
  local_dir = LOCAL_CKPT_DIR,
137
+ force_download=False # 已存在则直接用
138
  )
139
  print(f"Checkpoint path: {ckpt_path}")
140
 
141
+ # ---------------------------------------------------------------------------
142
+ # 4. 实例化 & 加载权重
143
+ model = SwinClassifierWithAttention(MODEL_NAME, NUM_CLASSES, pretrained=False).to(device)
144
  state = torch.load(ckpt_path, map_location=device, weights_only=False)
145
+ model.load_state_dict(state.get("model_state_dict", state), strict=False) # strict=False 因为添加了新的组件
146
  model.eval()
147
  print("✅ Model loaded.")
148
 
 
149
  attention_extractor = AttentionExtractor(model)
150
 
151
  # ---------------------------------------------------------------------------
152
+ # 5. 变换函数
153
  def build_transform(is_training: bool, interpolation: str):
154
+ """
155
+ 根据插值方式(双线性 / 三次等)构建 timm 默认变换
156
+ """
157
  cfg = model.data_config.copy()
158
  cfg.update(dict(interpolation=interpolation))
159
  return timm.data.create_transform(**cfg, is_training=is_training)
160
 
161
  # ---------------------------------------------------------------------------
162
+ # 6. 注意力可视化函数
163
+ def visualize_attention(attention_map, original_image, normalize=True):
164
+ """将注意力图可视化到原始图像上"""
165
+ if normalize:
166
+ # 归一化注意力图到[0,1]
167
+ attention_map = attention_map.cpu().numpy()
168
+ attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min() + 1e-8)
169
+ else:
170
+ attention_map = attention_map.cpu().numpy()
171
+
172
+ # 调整注意力图大小到原始图像大小
173
+ attention_resized = Image.fromarray((attention_map * 255).astype(np.uint8)) \
174
+ .resize(original_image.size, Image.Resampling.BILINEAR)
175
+
176
+ # 转换为热力图
177
+ attention_array = np.array(attention_resized) / 255.0
178
+ heatmap = cm.jet(attention_array)[:, :, :3] # 去掉alpha通道
179
+
180
+ # 叠加到原始图像
181
+ original_array = np.array(original_image) / 255.0
182
+ if len(original_array.shape) == 3:
183
+ overlay = 0.6 * original_array + 0.4 * heatmap
184
+ else:
185
+ # 灰度图像处理
186
+ original_array = np.stack([original_array] * 3, axis=-1)
187
+ overlay = 0.6 * original_array + 0.4 * heatmap
188
 
189
+ overlay = np.clip(overlay, 0, 1)
190
+ return Image.fromarray((overlay * 255).astype(np.uint8))
191
+
192
+ # ---------------------------------------------------------------------------
193
+ # 7. 推理 + 注意力可视化
194
+ def infer_with_attention(image_pil: Image.Image,
195
+ interpolation: str = "bilinear",
196
+ attention_layer: str = "stage_3_block_1",
197
+ stage_average: bool = False,
198
+ normalize_attention: bool = True):
199
  if image_pil is None:
200
+ return None, None
201
 
202
  transform = build_transform(is_training=False, interpolation=interpolation)
203
  input_tensor = transform(image_pil).unsqueeze(0).to(device)
204
 
205
+ # (1) 分类预测
206
+ with torch.no_grad():
207
+ logits = model(input_tensor)
208
+ probs = F.softmax(logits, dim=1)[0]
209
+ confidences = {class_names[i]: float(probs[i]) for i in range(NUM_CLASSES)}
 
 
 
 
210
 
211
+ # (2) 提取注意力权重
212
+ attention_weights = attention_extractor.extract_attention_weights(input_tensor)
 
 
 
 
213
 
214
+ if not attention_weights:
215
+ return confidences, None
216
+
217
+ # (3) 处理注意力权重
218
+ processed_attention = attention_extractor.process_attention_for_visualization(
219
+ attention_weights, input_tensor.shape[-2:]
220
+ )
221
+
222
+ if not processed_attention:
223
+ return confidences, None
224
+
225
+ # (4) 选择要可视化的注意力层
226
+ if stage_average:
227
+ # 计算指定stage所有block的平均注意力
228
+ stage_num = attention_layer.split('_')[1]
229
+ stage_attentions = []
230
 
231
+ for layer_name, attn_map in processed_attention.items():
232
+ if f"stage_{stage_num}_" in layer_name:
233
+ stage_attentions.append(attn_map)
234
 
235
+ if stage_attentions:
236
+ # 计算平均注意力
237
+ avg_attention = torch.stack(stage_attentions).mean(dim=0)
238
+ attention_vis = visualize_attention(avg_attention, image_pil, normalize_attention)
 
 
 
 
 
 
 
 
 
 
 
 
239
  else:
240
+ return confidences, None
241
  else:
242
+ # 使用指定层的注意力
243
+ if attention_layer in processed_attention:
244
+ attention_vis = visualize_attention(
245
+ processed_attention[attention_layer], image_pil, normalize_attention
246
+ )
247
+ else:
248
+ # 如果指定层不存在,使用第一个可用的层
249
+ first_layer = list(processed_attention.keys())[0]
250
+ attention_vis = visualize_attention(
251
+ processed_attention[first_layer], image_pil, normalize_attention
252
+ )
253
+
254
+ return confidences, attention_vis
 
 
255
 
256
  # ---------------------------------------------------------------------------
257
+ # 8. Gradio UI
258
  def launch_app():
259
+ with gr.Blocks() as demo:
 
 
260
  gr.Markdown("""
261
+ # 🖼️ AI vs. Non-AI Image Classifier (Swin-Large + Attention Visualization)
262
 
263
+ 🖼️ AI 鉴别器(基于 Swin-Large 视觉骨干,输出注意力热力图)
264
 
265
+ 基于Swin Transformer的自注意力机制来可视化模型关注的区域。
 
 
 
266
 
267
+ Notice: 使用 bicubic 效果较好。请负责任地使用此工具。
 
 
 
 
268
 
269
+ 此工具仅供研究和教育用途。
 
 
 
270
  """)
271
 
272
  with gr.Row():
273
+ in_img = gr.Image(type="pil", label="Upload an Image")
274
+ out_attention = gr.Image(type="pil", label="Attention Heatmap")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
+ with gr.Row():
277
+ out_lbl = gr.Label(num_top_classes=2, label="Predictions")
278
 
279
+ with gr.Row():
280
+ interp_choice = gr.Radio(
281
+ ["bilinear", "bicubic", "nearest"], value="bicubic",
282
+ label="Resize Interpolation (预处理插值)"
283
+ )
284
+
285
+ with gr.Row():
286
+ attention_layer_choice = gr.Dropdown(
287
+ choices=[
288
+ "stage_0_block_0", "stage_0_block_1",
289
+ "stage_1_block_0", "stage_1_block_1",
290
+ "stage_2_block_0", "stage_2_block_1", "stage_2_block_2",
291
+ "stage_3_block_0", "stage_3_block_1"
292
+ ],
293
+ value="stage_3_block_1",
294
+ label="选择注意力层 (Attention Layer)"
295
+ )
296
+
297
+ with gr.Row():
298
+ stage_avg_toggle = gr.Checkbox(
299
+ value=False,
300
+ label="计算整个Stage的平均注意力 (Average Stage Attention)"
301
+ )
302
+ normalize_toggle = gr.Checkbox(
303
+ value=True,
304
+ label="归一化注意力 (Normalize Attention)"
305
+ )
306
+
307
+ run_btn = gr.Button("🚀 Run Analysis")
308
+
309
+ def _run(img, inter, attn_layer, stage_avg, normalize):
310
+ return infer_with_attention(
311
+ img,
312
+ interpolation=inter,
313
+ attention_layer=attn_layer,
314
+ stage_average=stage_avg,
315
+ normalize_attention=normalize
316
+ )
317
 
318
  run_btn.click(
319
  _run,
320
+ inputs=[in_img, interp_choice, attention_layer_choice, stage_avg_toggle, normalize_toggle],
321
+ outputs=[out_lbl, out_attention]
322
  )
323
 
324
+ gr.Markdown("""
325
+ ### 说明:
326
+ - **注意力层选择**: 可以选择不同的Swin Transformer层来查看注意力模式
327
+ - **Stage平均**: 勾选后会计算选中stage中所有block的平均注意力
328
+ - **归一化**: 将注意力值归一化到0-1范围内,便于可视化
329
+ - **热力图**: 红色区域表示模型更关注的区域,蓝色区域表示关注度较低的区域
330
+ """)
331
+
332
  demo.launch()
333
 
334
  # ---------------------------------------------------------------------------