Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
"""
|
| 3 |
-
Swin-Large AI vs. Non-AI Detector (
|
| 4 |
"""
|
| 5 |
import os
|
| 6 |
import math
|
|
@@ -11,28 +11,30 @@ import timm
|
|
| 11 |
import numpy as np
|
| 12 |
from PIL import Image
|
| 13 |
import gradio as gr
|
| 14 |
-
from collections import defaultdict
|
| 15 |
|
| 16 |
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# --- Configuration ---------------------------------------------------------
|
| 19 |
REPO_ID = "telecomadm1145/swin-ai-detection"
|
| 20 |
HF_FILENAME = "swin_classifier_stage1_v2_epoch_3.pth"
|
| 21 |
LOCAL_CKPT_DIR = "./checkpoints"
|
| 22 |
-
MODEL_NAME = "swin_large_patch4_window12_384"
|
| 23 |
NUM_CLASSES = 2
|
| 24 |
SEED = 4421
|
| 25 |
dropout_rate = 0.1
|
| 26 |
|
| 27 |
-
class_names = ["Non-AI Generated", "AI Generated"]
|
| 28 |
|
| 29 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 30 |
torch.manual_seed(SEED); np.random.seed(SEED)
|
| 31 |
print(f"Using device: {device}")
|
| 32 |
|
| 33 |
# ---------------------------------------------------------------------------
|
| 34 |
-
# 1.
|
| 35 |
-
class
|
| 36 |
def __init__(self, model_name, num_classes, pretrained=True):
|
| 37 |
super().__init__()
|
| 38 |
self.backbone = timm.create_model(model_name, pretrained=pretrained,
|
|
@@ -52,34 +54,31 @@ class SwinClassifier(nn.Module):
|
|
| 52 |
nn.Linear(128, num_classes)
|
| 53 |
)
|
| 54 |
|
| 55 |
-
#
|
| 56 |
-
self.
|
| 57 |
self.register_attention_hooks()
|
| 58 |
|
| 59 |
def register_attention_hooks(self):
|
| 60 |
-
"""
|
| 61 |
-
def
|
| 62 |
def hook(module, input, output):
|
| 63 |
-
# 对于
|
|
|
|
| 64 |
if hasattr(module, 'attn'):
|
| 65 |
# 获取注意力权重
|
| 66 |
-
|
| 67 |
-
if attn_weights is not None:
|
| 68 |
-
self.attention_maps[layer_name].append(attn_weights.detach().cpu())
|
| 69 |
return hook
|
| 70 |
|
| 71 |
-
# 为每个
|
| 72 |
for stage_idx, stage in enumerate(self.backbone.layers):
|
| 73 |
for block_idx, block in enumerate(stage.blocks):
|
| 74 |
-
|
| 75 |
-
block
|
| 76 |
-
|
| 77 |
-
def clear_attention_maps(self):
|
| 78 |
-
"""清空注意力映射"""
|
| 79 |
-
self.attention_maps.clear()
|
| 80 |
|
| 81 |
def forward(self, x):
|
| 82 |
-
|
|
|
|
| 83 |
feats = self.backbone(x)
|
| 84 |
return self.classifier(feats)
|
| 85 |
|
|
@@ -88,282 +87,248 @@ class SwinClassifier(nn.Module):
|
|
| 88 |
class AttentionExtractor:
|
| 89 |
def __init__(self, model):
|
| 90 |
self.model = model
|
| 91 |
-
self.hooks = []
|
| 92 |
self.attention_maps = {}
|
| 93 |
|
| 94 |
-
def
|
| 95 |
-
"""
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
|
|
|
| 106 |
|
| 107 |
-
|
| 108 |
-
|
|
|
|
| 109 |
|
| 110 |
-
|
| 111 |
-
self.clear_hooks()
|
| 112 |
-
|
| 113 |
-
# 为每个 stage 的每个 block 的 attention 模块注册钩子
|
| 114 |
-
for stage_idx, stage in enumerate(self.model.backbone.layers):
|
| 115 |
-
for block_idx, block in enumerate(stage.blocks):
|
| 116 |
-
if hasattr(block, 'attn'):
|
| 117 |
-
name = f"stage_{stage_idx}_block_{block_idx}"
|
| 118 |
-
hook = block.attn.register_forward_hook(get_attention_hook(name))
|
| 119 |
-
self.hooks.append(hook)
|
| 120 |
-
|
| 121 |
-
def clear_hooks(self):
|
| 122 |
-
"""清除所有钩子"""
|
| 123 |
-
for hook in self.hooks:
|
| 124 |
-
hook.remove()
|
| 125 |
-
self.hooks = []
|
| 126 |
-
|
| 127 |
-
def clear_attention_maps(self):
|
| 128 |
-
"""清空注意力映射"""
|
| 129 |
-
self.attention_maps.clear()
|
| 130 |
-
|
| 131 |
-
def create_attention_visualization(attention_weights, input_size, stage_info):
|
| 132 |
-
"""
|
| 133 |
-
创建注意力可视化图
|
| 134 |
-
|
| 135 |
-
Args:
|
| 136 |
-
attention_weights: [B, num_heads, N, N] 注意力权重
|
| 137 |
-
input_size: 输入图像尺寸 (H, W)
|
| 138 |
-
stage_info: stage 信息,用于确定分辨率
|
| 139 |
-
"""
|
| 140 |
-
if attention_weights is None or len(attention_weights) == 0:
|
| 141 |
-
return None
|
| 142 |
-
|
| 143 |
-
# 取第一个样本和所有头的平均
|
| 144 |
-
attn = attention_weights[0].mean(dim=0) # [N, N]
|
| 145 |
-
|
| 146 |
-
# 获取 [CLS] token 对其他 token 的注意力(如果存在)
|
| 147 |
-
# 对于 Swin,通常没有 CLS token,所以我们计算平均注意力
|
| 148 |
-
attn_map = attn.mean(dim=0) # [N]
|
| 149 |
-
|
| 150 |
-
# 确定特征图的尺寸
|
| 151 |
-
N = attn_map.shape[0]
|
| 152 |
-
feat_size = int(math.sqrt(N))
|
| 153 |
-
|
| 154 |
-
if feat_size * feat_size != N:
|
| 155 |
-
# 如果不是完全平方数,可能包含了额外的 token
|
| 156 |
-
feat_size = int(math.sqrt(N))
|
| 157 |
-
attn_map = attn_map[:feat_size*feat_size]
|
| 158 |
-
|
| 159 |
-
# 重塑为 2D
|
| 160 |
-
attn_2d = attn_map.reshape(feat_size, feat_size)
|
| 161 |
-
|
| 162 |
-
# 转换为 numpy 并归一化
|
| 163 |
-
attn_np = attn_2d.numpy()
|
| 164 |
-
attn_np = (attn_np - attn_np.min()) / (attn_np.max() - attn_np.min() + 1e-8)
|
| 165 |
-
|
| 166 |
-
# 调整到输入图像尺寸
|
| 167 |
-
attn_img = Image.fromarray((attn_np * 255).astype(np.uint8), mode='L')
|
| 168 |
-
attn_img = attn_img.resize(input_size, Image.Resampling.BILINEAR)
|
| 169 |
-
|
| 170 |
-
return attn_img
|
| 171 |
|
| 172 |
# ---------------------------------------------------------------------------
|
| 173 |
-
#
|
| 174 |
print("⏬ Download / cache checkpoint …")
|
| 175 |
ckpt_path = hf_hub_download(
|
| 176 |
repo_id = REPO_ID,
|
| 177 |
filename = HF_FILENAME,
|
| 178 |
local_dir = LOCAL_CKPT_DIR,
|
| 179 |
-
force_download=False
|
| 180 |
)
|
| 181 |
print(f"Checkpoint path: {ckpt_path}")
|
| 182 |
|
| 183 |
-
|
|
|
|
|
|
|
| 184 |
state = torch.load(ckpt_path, map_location=device, weights_only=False)
|
| 185 |
-
model.load_state_dict(state.get("model_state_dict", state), strict=
|
| 186 |
model.eval()
|
| 187 |
print("✅ Model loaded.")
|
| 188 |
|
| 189 |
-
# 创建注意力提取器
|
| 190 |
attention_extractor = AttentionExtractor(model)
|
| 191 |
|
| 192 |
# ---------------------------------------------------------------------------
|
| 193 |
-
#
|
| 194 |
def build_transform(is_training: bool, interpolation: str):
|
|
|
|
|
|
|
|
|
|
| 195 |
cfg = model.data_config.copy()
|
| 196 |
cfg.update(dict(interpolation=interpolation))
|
| 197 |
return timm.data.create_transform(**cfg, is_training=is_training)
|
| 198 |
|
| 199 |
# ---------------------------------------------------------------------------
|
| 200 |
-
#
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
if image_pil is None:
|
| 208 |
-
return None, None
|
| 209 |
|
| 210 |
transform = build_transform(is_training=False, interpolation=interpolation)
|
| 211 |
input_tensor = transform(image_pil).unsqueeze(0).to(device)
|
| 212 |
|
| 213 |
-
#
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
# 分类预测
|
| 219 |
-
logits = model(input_tensor)
|
| 220 |
-
probs = F.softmax(logits, dim=1)[0]
|
| 221 |
-
confidences = {class_names[i]: float(probs[i]) for i in range(NUM_CLASSES)}
|
| 222 |
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
return confidences, None, "分类完成"
|
| 226 |
-
|
| 227 |
-
# 获取选定层的注意力
|
| 228 |
-
layer_info = f"当前显示层: {selected_layer}"
|
| 229 |
|
| 230 |
-
if
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
-
|
| 242 |
-
|
|
|
|
| 243 |
|
| 244 |
-
if
|
| 245 |
-
#
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
# 创建热力图 (红色表示高注意力)
|
| 250 |
-
colored_array = np.zeros((*attn_array.shape, 3), dtype=np.uint8)
|
| 251 |
-
colored_array[:, :, 0] = attn_array # 红色通道
|
| 252 |
-
colored_array[:, :, 1] = attn_array // 2 # 绿色通道(减弱)
|
| 253 |
-
|
| 254 |
-
attention_colored = Image.fromarray(colored_array)
|
| 255 |
-
|
| 256 |
-
# 与原图混合
|
| 257 |
-
blended = Image.blend(image_pil.convert('RGB'), attention_colored, alpha=0.4)
|
| 258 |
-
|
| 259 |
-
return confidences, blended, f"{layer_info} - 注意力可视化完成"
|
| 260 |
else:
|
| 261 |
-
return confidences, None
|
| 262 |
else:
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
layers.append(f"stage_{stage_idx}_block_{block_idx}")
|
| 277 |
-
return layers
|
| 278 |
|
| 279 |
# ---------------------------------------------------------------------------
|
| 280 |
-
# Gradio UI
|
| 281 |
def launch_app():
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
with gr.Blocks(title="AI Image Detector with Attention") as demo:
|
| 285 |
gr.Markdown("""
|
| 286 |
-
# 🖼️ AI vs. Non-AI Image Classifier
|
| 287 |
|
| 288 |
-
|
| 289 |
|
| 290 |
-
|
| 291 |
-
- 🎯 使用 Swin Transformer 原生注意力权重
|
| 292 |
-
- 🎨 可视化不同层的注意力模式
|
| 293 |
-
- 🔄 支持多种插值方法优化预处理
|
| 294 |
|
| 295 |
-
|
| 296 |
-
1. 上传图片
|
| 297 |
-
2. 选择要可视化的注意力层
|
| 298 |
-
3. 选择插值方法(推荐 bicubic)
|
| 299 |
-
4. 点击运行
|
| 300 |
|
| 301 |
-
|
| 302 |
-
- 工具仅供研究和教育用途
|
| 303 |
-
- 不保证 100% 准确率
|
| 304 |
-
- 请负责任地使用
|
| 305 |
""")
|
| 306 |
|
| 307 |
with gr.Row():
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
with gr.Row():
|
| 312 |
-
interp_choice = gr.Radio(
|
| 313 |
-
["bilinear", "bicubic", "nearest"],
|
| 314 |
-
value="bicubic",
|
| 315 |
-
label="插值方法"
|
| 316 |
-
)
|
| 317 |
-
|
| 318 |
-
attention_toggle = gr.Checkbox(
|
| 319 |
-
value=True,
|
| 320 |
-
label="显示注意力可视化"
|
| 321 |
-
)
|
| 322 |
-
|
| 323 |
-
layer_choice = gr.Dropdown(
|
| 324 |
-
choices=available_layers,
|
| 325 |
-
value="stage_3_block_1",
|
| 326 |
-
label="选择注意力层",
|
| 327 |
-
info="不同层关注不同级别的特征"
|
| 328 |
-
)
|
| 329 |
-
|
| 330 |
-
run_btn = gr.Button("🚀 开始检测", variant="primary")
|
| 331 |
-
|
| 332 |
-
with gr.Column():
|
| 333 |
-
out_lbl = gr.Label(
|
| 334 |
-
num_top_classes=2,
|
| 335 |
-
label="分类结果"
|
| 336 |
-
)
|
| 337 |
-
out_attention = gr.Image(
|
| 338 |
-
type="pil",
|
| 339 |
-
label="注意力可视化"
|
| 340 |
-
)
|
| 341 |
-
status_text = gr.Textbox(
|
| 342 |
-
label="状态信息",
|
| 343 |
-
interactive=False
|
| 344 |
-
)
|
| 345 |
-
|
| 346 |
-
# 层选择说明
|
| 347 |
-
gr.Markdown("""
|
| 348 |
-
### 层选择指南:
|
| 349 |
-
- **Stage 0-1**: 关注底层特征(边缘、纹理)
|
| 350 |
-
- **Stage 2**: 关注中层特征(形状、局部模式)
|
| 351 |
-
- **Stage 3**: 关注高层特征(语义、全局结构)
|
| 352 |
|
| 353 |
-
|
| 354 |
-
""
|
| 355 |
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
|
| 361 |
run_btn.click(
|
| 362 |
_run,
|
| 363 |
-
inputs=[in_img, interp_choice,
|
| 364 |
-
outputs=[out_lbl, out_attention
|
| 365 |
)
|
| 366 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
demo.launch()
|
| 368 |
|
| 369 |
# ---------------------------------------------------------------------------
|
|
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
"""
|
| 3 |
+
Swin-Large AI vs. Non-AI Detector (基于注意力机制可视化)
|
| 4 |
"""
|
| 5 |
import os
|
| 6 |
import math
|
|
|
|
| 11 |
import numpy as np
|
| 12 |
from PIL import Image
|
| 13 |
import gradio as gr
|
|
|
|
| 14 |
|
| 15 |
from huggingface_hub import hf_hub_download
|
| 16 |
+
from torchvision import transforms
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
import matplotlib.cm as cm
|
| 19 |
|
| 20 |
# --- Configuration ---------------------------------------------------------
|
| 21 |
REPO_ID = "telecomadm1145/swin-ai-detection"
|
| 22 |
HF_FILENAME = "swin_classifier_stage1_v2_epoch_3.pth"
|
| 23 |
LOCAL_CKPT_DIR = "./checkpoints"
|
| 24 |
+
MODEL_NAME = "swin_large_patch4_window12_384" # ← 使用 large
|
| 25 |
NUM_CLASSES = 2
|
| 26 |
SEED = 4421
|
| 27 |
dropout_rate = 0.1
|
| 28 |
|
| 29 |
+
class_names = ["Non-AI Generated", "AI Generated"] # 0, 1
|
| 30 |
|
| 31 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 32 |
torch.manual_seed(SEED); np.random.seed(SEED)
|
| 33 |
print(f"Using device: {device}")
|
| 34 |
|
| 35 |
# ---------------------------------------------------------------------------
|
| 36 |
+
# 1. 修改模型结构以提取注意力
|
| 37 |
+
class SwinClassifierWithAttention(nn.Module):
|
| 38 |
def __init__(self, model_name, num_classes, pretrained=True):
|
| 39 |
super().__init__()
|
| 40 |
self.backbone = timm.create_model(model_name, pretrained=pretrained,
|
|
|
|
| 54 |
nn.Linear(128, num_classes)
|
| 55 |
)
|
| 56 |
|
| 57 |
+
# 存储注意力权重的钩子
|
| 58 |
+
self.attention_weights = {}
|
| 59 |
self.register_attention_hooks()
|
| 60 |
|
| 61 |
def register_attention_hooks(self):
|
| 62 |
+
"""注册钩子函数来提取注意力权重"""
|
| 63 |
+
def hook_fn(name):
|
| 64 |
def hook(module, input, output):
|
| 65 |
+
# 对于Swin Transformer的窗口注意力机制
|
| 66 |
+
# output通常是 (B, N, C) 格式
|
| 67 |
if hasattr(module, 'attn'):
|
| 68 |
# 获取注意力权重
|
| 69 |
+
self.attention_weights[name] = module.attn.attention_weights
|
|
|
|
|
|
|
| 70 |
return hook
|
| 71 |
|
| 72 |
+
# 为每个stage的每个block注册钩子
|
| 73 |
for stage_idx, stage in enumerate(self.backbone.layers):
|
| 74 |
for block_idx, block in enumerate(stage.blocks):
|
| 75 |
+
hook_name = f"stage_{stage_idx}_block_{block_idx}"
|
| 76 |
+
if hasattr(block, 'attn'):
|
| 77 |
+
block.attn.register_forward_hook(hook_fn(hook_name))
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
def forward(self, x):
|
| 80 |
+
# 清空之前的注意力权重
|
| 81 |
+
self.attention_weights = {}
|
| 82 |
feats = self.backbone(x)
|
| 83 |
return self.classifier(feats)
|
| 84 |
|
|
|
|
| 87 |
class AttentionExtractor:
|
| 88 |
def __init__(self, model):
|
| 89 |
self.model = model
|
|
|
|
| 90 |
self.attention_maps = {}
|
| 91 |
|
| 92 |
+
def extract_attention_weights(self, x):
|
| 93 |
+
"""提取所有层的注意力权重"""
|
| 94 |
+
with torch.no_grad():
|
| 95 |
+
_ = self.model(x) # 前向传播以触发钩子
|
| 96 |
+
return self.model.attention_weights.copy()
|
| 97 |
+
|
| 98 |
+
def process_attention_for_visualization(self, attention_weights, input_size):
|
| 99 |
+
"""处理注意力权重用于可视化"""
|
| 100 |
+
processed_maps = {}
|
| 101 |
+
|
| 102 |
+
for layer_name, attn_weight in attention_weights.items():
|
| 103 |
+
if attn_weight is None:
|
| 104 |
+
continue
|
| 105 |
+
|
| 106 |
+
# attn_weight shape: [batch_size, num_heads, seq_len, seq_len]
|
| 107 |
+
if len(attn_weight.shape) == 4:
|
| 108 |
+
# 取平均池化所有注意力头
|
| 109 |
+
attn_map = attn_weight.mean(dim=1) # [batch_size, seq_len, seq_len]
|
| 110 |
+
|
| 111 |
+
# 取第一个样本
|
| 112 |
+
attn_map = attn_map[0] # [seq_len, seq_len]
|
| 113 |
+
|
| 114 |
+
# 对于自注意力,我们通常关注CLS token对其他token的注意力
|
| 115 |
+
# 或者计算所有token的平均注意力
|
| 116 |
+
if attn_map.shape[0] > 1:
|
| 117 |
+
# 计算每个位置的平均注意力分数
|
| 118 |
+
avg_attention = attn_map.mean(dim=0) # [seq_len]
|
| 119 |
|
| 120 |
+
# 将注意力分数reshape为2D特征图
|
| 121 |
+
seq_len = avg_attention.shape[0]
|
| 122 |
+
grid_size = int(math.sqrt(seq_len))
|
| 123 |
|
| 124 |
+
if grid_size * grid_size == seq_len:
|
| 125 |
+
attention_2d = avg_attention.reshape(grid_size, grid_size)
|
| 126 |
+
processed_maps[layer_name] = attention_2d
|
| 127 |
|
| 128 |
+
return processed_maps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
# ---------------------------------------------------------------------------
|
| 131 |
+
# 3. 下载 / 缓存 checkpoint
|
| 132 |
print("⏬ Download / cache checkpoint …")
|
| 133 |
ckpt_path = hf_hub_download(
|
| 134 |
repo_id = REPO_ID,
|
| 135 |
filename = HF_FILENAME,
|
| 136 |
local_dir = LOCAL_CKPT_DIR,
|
| 137 |
+
force_download=False # 已存在则直接用
|
| 138 |
)
|
| 139 |
print(f"Checkpoint path: {ckpt_path}")
|
| 140 |
|
| 141 |
+
# ---------------------------------------------------------------------------
|
| 142 |
+
# 4. 实例化 & 加载权重
|
| 143 |
+
model = SwinClassifierWithAttention(MODEL_NAME, NUM_CLASSES, pretrained=False).to(device)
|
| 144 |
state = torch.load(ckpt_path, map_location=device, weights_only=False)
|
| 145 |
+
model.load_state_dict(state.get("model_state_dict", state), strict=False) # strict=False 因为添加了新的组件
|
| 146 |
model.eval()
|
| 147 |
print("✅ Model loaded.")
|
| 148 |
|
|
|
|
| 149 |
attention_extractor = AttentionExtractor(model)
|
| 150 |
|
| 151 |
# ---------------------------------------------------------------------------
|
| 152 |
+
# 5. 变换函数
|
| 153 |
def build_transform(is_training: bool, interpolation: str):
|
| 154 |
+
"""
|
| 155 |
+
根据插值方式(双线性 / 三次等)构建 timm 默认变换
|
| 156 |
+
"""
|
| 157 |
cfg = model.data_config.copy()
|
| 158 |
cfg.update(dict(interpolation=interpolation))
|
| 159 |
return timm.data.create_transform(**cfg, is_training=is_training)
|
| 160 |
|
| 161 |
# ---------------------------------------------------------------------------
|
| 162 |
+
# 6. 注意力可视化函数
|
| 163 |
+
def visualize_attention(attention_map, original_image, normalize=True):
|
| 164 |
+
"""将注意力图可视化到原始图像上"""
|
| 165 |
+
if normalize:
|
| 166 |
+
# 归一化注意力图到[0,1]
|
| 167 |
+
attention_map = attention_map.cpu().numpy()
|
| 168 |
+
attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min() + 1e-8)
|
| 169 |
+
else:
|
| 170 |
+
attention_map = attention_map.cpu().numpy()
|
| 171 |
+
|
| 172 |
+
# 调整注意力图大小到原始图像大小
|
| 173 |
+
attention_resized = Image.fromarray((attention_map * 255).astype(np.uint8)) \
|
| 174 |
+
.resize(original_image.size, Image.Resampling.BILINEAR)
|
| 175 |
+
|
| 176 |
+
# 转换为热力图
|
| 177 |
+
attention_array = np.array(attention_resized) / 255.0
|
| 178 |
+
heatmap = cm.jet(attention_array)[:, :, :3] # 去掉alpha通道
|
| 179 |
+
|
| 180 |
+
# 叠加到原始图像
|
| 181 |
+
original_array = np.array(original_image) / 255.0
|
| 182 |
+
if len(original_array.shape) == 3:
|
| 183 |
+
overlay = 0.6 * original_array + 0.4 * heatmap
|
| 184 |
+
else:
|
| 185 |
+
# 灰度图像处理
|
| 186 |
+
original_array = np.stack([original_array] * 3, axis=-1)
|
| 187 |
+
overlay = 0.6 * original_array + 0.4 * heatmap
|
| 188 |
|
| 189 |
+
overlay = np.clip(overlay, 0, 1)
|
| 190 |
+
return Image.fromarray((overlay * 255).astype(np.uint8))
|
| 191 |
+
|
| 192 |
+
# ---------------------------------------------------------------------------
|
| 193 |
+
# 7. 推理 + 注意力可视化
|
| 194 |
+
def infer_with_attention(image_pil: Image.Image,
|
| 195 |
+
interpolation: str = "bilinear",
|
| 196 |
+
attention_layer: str = "stage_3_block_1",
|
| 197 |
+
stage_average: bool = False,
|
| 198 |
+
normalize_attention: bool = True):
|
| 199 |
if image_pil is None:
|
| 200 |
+
return None, None
|
| 201 |
|
| 202 |
transform = build_transform(is_training=False, interpolation=interpolation)
|
| 203 |
input_tensor = transform(image_pil).unsqueeze(0).to(device)
|
| 204 |
|
| 205 |
+
# (1) 分类预测
|
| 206 |
+
with torch.no_grad():
|
| 207 |
+
logits = model(input_tensor)
|
| 208 |
+
probs = F.softmax(logits, dim=1)[0]
|
| 209 |
+
confidences = {class_names[i]: float(probs[i]) for i in range(NUM_CLASSES)}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
+
# (2) 提取注意力权重
|
| 212 |
+
attention_weights = attention_extractor.extract_attention_weights(input_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
+
if not attention_weights:
|
| 215 |
+
return confidences, None
|
| 216 |
+
|
| 217 |
+
# (3) 处理注意力权重
|
| 218 |
+
processed_attention = attention_extractor.process_attention_for_visualization(
|
| 219 |
+
attention_weights, input_tensor.shape[-2:]
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
if not processed_attention:
|
| 223 |
+
return confidences, None
|
| 224 |
+
|
| 225 |
+
# (4) 选择要可视化的注意力层
|
| 226 |
+
if stage_average:
|
| 227 |
+
# 计算指定stage所有block的平均注意力
|
| 228 |
+
stage_num = attention_layer.split('_')[1]
|
| 229 |
+
stage_attentions = []
|
| 230 |
|
| 231 |
+
for layer_name, attn_map in processed_attention.items():
|
| 232 |
+
if f"stage_{stage_num}_" in layer_name:
|
| 233 |
+
stage_attentions.append(attn_map)
|
| 234 |
|
| 235 |
+
if stage_attentions:
|
| 236 |
+
# 计算平均注意力
|
| 237 |
+
avg_attention = torch.stack(stage_attentions).mean(dim=0)
|
| 238 |
+
attention_vis = visualize_attention(avg_attention, image_pil, normalize_attention)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
else:
|
| 240 |
+
return confidences, None
|
| 241 |
else:
|
| 242 |
+
# 使用指定层的注意力
|
| 243 |
+
if attention_layer in processed_attention:
|
| 244 |
+
attention_vis = visualize_attention(
|
| 245 |
+
processed_attention[attention_layer], image_pil, normalize_attention
|
| 246 |
+
)
|
| 247 |
+
else:
|
| 248 |
+
# 如果指定层不存在,使用第一个可用的层
|
| 249 |
+
first_layer = list(processed_attention.keys())[0]
|
| 250 |
+
attention_vis = visualize_attention(
|
| 251 |
+
processed_attention[first_layer], image_pil, normalize_attention
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
return confidences, attention_vis
|
|
|
|
|
|
|
| 255 |
|
| 256 |
# ---------------------------------------------------------------------------
|
| 257 |
+
# 8. Gradio UI
|
| 258 |
def launch_app():
|
| 259 |
+
with gr.Blocks() as demo:
|
|
|
|
|
|
|
| 260 |
gr.Markdown("""
|
| 261 |
+
# 🖼️ AI vs. Non-AI Image Classifier (Swin-Large + Attention Visualization)
|
| 262 |
|
| 263 |
+
🖼️ AI 鉴别器(基于 Swin-Large 视觉骨干,输出注意力热力图)
|
| 264 |
|
| 265 |
+
基于Swin Transformer的自注意力机制来可视化模型关注的区域。
|
|
|
|
|
|
|
|
|
|
| 266 |
|
| 267 |
+
Notice: 使用 bicubic 效果较好。请负责任地使用此工具。
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
+
此工具仅供研究和教育用途。
|
|
|
|
|
|
|
|
|
|
| 270 |
""")
|
| 271 |
|
| 272 |
with gr.Row():
|
| 273 |
+
in_img = gr.Image(type="pil", label="Upload an Image")
|
| 274 |
+
out_attention = gr.Image(type="pil", label="Attention Heatmap")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
+
with gr.Row():
|
| 277 |
+
out_lbl = gr.Label(num_top_classes=2, label="Predictions")
|
| 278 |
|
| 279 |
+
with gr.Row():
|
| 280 |
+
interp_choice = gr.Radio(
|
| 281 |
+
["bilinear", "bicubic", "nearest"], value="bicubic",
|
| 282 |
+
label="Resize Interpolation (预处理插值)"
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
with gr.Row():
|
| 286 |
+
attention_layer_choice = gr.Dropdown(
|
| 287 |
+
choices=[
|
| 288 |
+
"stage_0_block_0", "stage_0_block_1",
|
| 289 |
+
"stage_1_block_0", "stage_1_block_1",
|
| 290 |
+
"stage_2_block_0", "stage_2_block_1", "stage_2_block_2",
|
| 291 |
+
"stage_3_block_0", "stage_3_block_1"
|
| 292 |
+
],
|
| 293 |
+
value="stage_3_block_1",
|
| 294 |
+
label="选择注意力层 (Attention Layer)"
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
with gr.Row():
|
| 298 |
+
stage_avg_toggle = gr.Checkbox(
|
| 299 |
+
value=False,
|
| 300 |
+
label="计算整个Stage的平均注意力 (Average Stage Attention)"
|
| 301 |
+
)
|
| 302 |
+
normalize_toggle = gr.Checkbox(
|
| 303 |
+
value=True,
|
| 304 |
+
label="归一化注意力 (Normalize Attention)"
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
run_btn = gr.Button("🚀 Run Analysis")
|
| 308 |
+
|
| 309 |
+
def _run(img, inter, attn_layer, stage_avg, normalize):
|
| 310 |
+
return infer_with_attention(
|
| 311 |
+
img,
|
| 312 |
+
interpolation=inter,
|
| 313 |
+
attention_layer=attn_layer,
|
| 314 |
+
stage_average=stage_avg,
|
| 315 |
+
normalize_attention=normalize
|
| 316 |
+
)
|
| 317 |
|
| 318 |
run_btn.click(
|
| 319 |
_run,
|
| 320 |
+
inputs=[in_img, interp_choice, attention_layer_choice, stage_avg_toggle, normalize_toggle],
|
| 321 |
+
outputs=[out_lbl, out_attention]
|
| 322 |
)
|
| 323 |
|
| 324 |
+
gr.Markdown("""
|
| 325 |
+
### 说明:
|
| 326 |
+
- **注意力层选择**: 可以选择不同的Swin Transformer层来查看注意力模式
|
| 327 |
+
- **Stage平均**: 勾选后会计算选中stage中所有block的平均注意力
|
| 328 |
+
- **归一化**: 将注意力值归一化到0-1范围内,便于可视化
|
| 329 |
+
- **热力图**: 红色区域表示模型更关注的区域,蓝色区域表示关注度较低的区域
|
| 330 |
+
""")
|
| 331 |
+
|
| 332 |
demo.launch()
|
| 333 |
|
| 334 |
# ---------------------------------------------------------------------------
|