File size: 13,045 Bytes
8b95acd
 
cedb103
8b95acd
3e0966c
 
 
 
9fc4060
3e0966c
 
 
 
 
619e447
cedb103
 
 
3e0966c
8b95acd
 
c101d2c
8b95acd
cedb103
8b95acd
 
 
3e0966c
cedb103
3e0966c
 
8b95acd
3e0966c
 
8b95acd
cedb103
 
3e0966c
 
8b95acd
 
3e0966c
 
 
77d9851
3e0966c
 
 
77d9851
3e0966c
 
 
77d9851
3e0966c
 
619e447
cedb103
 
619e447
 
 
cedb103
 
619e447
cedb103
 
619e447
 
cedb103
619e447
 
cedb103
619e447
 
cedb103
 
 
3e0966c
 
cedb103
 
8b95acd
 
 
 
619e447
 
 
 
 
 
cedb103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
619e447
cedb103
 
 
619e447
cedb103
 
 
619e447
cedb103
619e447
 
cedb103
8b95acd
 
 
 
 
cedb103
8b95acd
 
 
cedb103
 
 
f8b2050
cedb103
3e0966c
8b95acd
3e0966c
619e447
 
8b95acd
cedb103
8b95acd
cedb103
 
 
8b95acd
 
 
3e0966c
8b95acd
cedb103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
619e447
cedb103
 
 
 
 
 
 
 
 
 
8b95acd
cedb103
3e0966c
8b95acd
 
3e0966c
cedb103
 
 
 
 
3e0966c
cedb103
 
619e447
cedb103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
619e447
cedb103
 
 
619e447
cedb103
 
 
 
619e447
cedb103
619e447
cedb103
 
 
 
 
 
 
 
 
 
 
 
 
3e0966c
619e447
cedb103
619e447
cedb103
619e447
cedb103
8b95acd
cedb103
3e0966c
cedb103
3e0966c
cedb103
8b95acd
cedb103
619e447
8b95acd
 
cedb103
 
8b95acd
cedb103
 
8b95acd
cedb103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b95acd
 
 
cedb103
 
8b95acd
 
cedb103
 
 
 
 
 
 
 
8b95acd
3e0966c
8b95acd
3e0966c
619e447
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
# -*- coding: utf-8 -*-
"""
Swin-Large   AI vs. Non-AI Detector   (基于注意力机制可视化)
"""
import os
import math
import torch
import torch.nn.functional as F
import torch.nn as nn
import timm
import numpy as np
from PIL import Image
import gradio as gr

from huggingface_hub import hf_hub_download
from torchvision import transforms
import matplotlib.pyplot as plt
import matplotlib.cm as cm

# --- Configuration ---------------------------------------------------------
REPO_ID          = "telecomadm1145/swin-ai-detection"
HF_FILENAME      = "swin_classifier_stage1_v2_epoch_3.pth"
LOCAL_CKPT_DIR   = "./checkpoints"
MODEL_NAME       = "swin_large_patch4_window12_384"     # ← 使用 large
NUM_CLASSES      = 2
SEED             = 4421
dropout_rate     = 0.1

class_names = ["Non-AI Generated", "AI Generated"]     # 0, 1

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(SEED);  np.random.seed(SEED)
print(f"Using device: {device}")

# ---------------------------------------------------------------------------
# 1. 修改模型结构以提取注意力
class SwinClassifierWithAttention(nn.Module):
    def __init__(self, model_name, num_classes, pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=pretrained,
                                          num_classes=0)
        self.data_config = timm.data.resolve_data_config({}, model=self.backbone)

        self.classifier = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(self.backbone.num_features, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(dropout_rate * 0.7),
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(dropout_rate * 0.5),
            nn.Linear(128, num_classes)
        )
        
        # 存储注意力权重的钩子
        self.attention_weights = {}
        self.register_attention_hooks()

    def register_attention_hooks(self):
        """注册钩子函数来提取注意力权重"""
        def hook_fn(name):
            def hook(module, input, output):
                # 对于Swin Transformer的窗口注意力机制
                # output通常是 (B, N, C) 格式
                if hasattr(module, 'attn'):
                    # 获取注意力权重
                    self.attention_weights[name] = module.attn.attention_weights
            return hook
        
        # 为每个stage的每个block注册钩子
        for stage_idx, stage in enumerate(self.backbone.layers):
            for block_idx, block in enumerate(stage.blocks):
                hook_name = f"stage_{stage_idx}_block_{block_idx}"
                if hasattr(block, 'attn'):
                    block.attn.register_forward_hook(hook_fn(hook_name))

    def forward(self, x):
        # 清空之前的注意力权重
        self.attention_weights = {}
        feats = self.backbone(x)
        return self.classifier(feats)

# ---------------------------------------------------------------------------
# 2. 注意力提取和可视化类
class AttentionExtractor:
    def __init__(self, model):
        self.model = model
        self.attention_maps = {}
        
    def extract_attention_weights(self, x):
        """提取所有层的注意力权重"""
        with torch.no_grad():
            _ = self.model(x)  # 前向传播以触发钩子
            return self.model.attention_weights.copy()
    
    def process_attention_for_visualization(self, attention_weights, input_size):
        """处理注意力权重用于可视化"""
        processed_maps = {}
        
        for layer_name, attn_weight in attention_weights.items():
            if attn_weight is None:
                continue
                
            # attn_weight shape: [batch_size, num_heads, seq_len, seq_len]
            if len(attn_weight.shape) == 4:
                # 取平均池化所有注意力头
                attn_map = attn_weight.mean(dim=1)  # [batch_size, seq_len, seq_len]
                
                # 取第一个样本
                attn_map = attn_map[0]  # [seq_len, seq_len]
                
                # 对于自注意力,我们通常关注CLS token对其他token的注意力
                # 或者计算所有token的平均注意力
                if attn_map.shape[0] > 1:
                    # 计算每个位置的平均注意力分数
                    avg_attention = attn_map.mean(dim=0)  # [seq_len]
                    
                    # 将注意力分数reshape为2D特征图
                    seq_len = avg_attention.shape[0]
                    grid_size = int(math.sqrt(seq_len))
                    
                    if grid_size * grid_size == seq_len:
                        attention_2d = avg_attention.reshape(grid_size, grid_size)
                        processed_maps[layer_name] = attention_2d
        
        return processed_maps

# ---------------------------------------------------------------------------
# 3. 下载 / 缓存 checkpoint
print("⏬ Download / cache checkpoint …")
ckpt_path = hf_hub_download(
    repo_id      = REPO_ID,
    filename     = HF_FILENAME,
    local_dir    = LOCAL_CKPT_DIR,
    force_download=False     # 已存在则直接用
)
print(f"Checkpoint path: {ckpt_path}")

# ---------------------------------------------------------------------------
# 4. 实例化 & 加载权重
model = SwinClassifierWithAttention(MODEL_NAME, NUM_CLASSES, pretrained=False).to(device)
state = torch.load(ckpt_path, map_location=device, weights_only=False)
model.load_state_dict(state.get("model_state_dict", state), strict=False)  # strict=False 因为添加了新的组件
model.eval()
print("✅  Model loaded.")

attention_extractor = AttentionExtractor(model)

# ---------------------------------------------------------------------------
# 5. 变换函数
def build_transform(is_training: bool, interpolation: str):
    """
    根据插值方式(双线性 / 三次等)构建 timm 默认变换
    """
    cfg = model.data_config.copy()
    cfg.update(dict(interpolation=interpolation))
    return timm.data.create_transform(**cfg, is_training=is_training)

# ---------------------------------------------------------------------------
# 6. 注意力可视化函数
def visualize_attention(attention_map, original_image, normalize=True):
    """将注意力图可视化到原始图像上"""
    if normalize:
        # 归一化注意力图到[0,1]
        attention_map = attention_map.cpu().numpy()
        attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min() + 1e-8)
    else:
        attention_map = attention_map.cpu().numpy()
    
    # 调整注意力图大小到原始图像大小
    attention_resized = Image.fromarray((attention_map * 255).astype(np.uint8)) \
                            .resize(original_image.size, Image.Resampling.BILINEAR)
    
    # 转换为热力图
    attention_array = np.array(attention_resized) / 255.0
    heatmap = cm.jet(attention_array)[:, :, :3]  # 去掉alpha通道
    
    # 叠加到原始图像
    original_array = np.array(original_image) / 255.0
    if len(original_array.shape) == 3:
        overlay = 0.6 * original_array + 0.4 * heatmap
    else:
        # 灰度图像处理
        original_array = np.stack([original_array] * 3, axis=-1)
        overlay = 0.6 * original_array + 0.4 * heatmap
    
    overlay = np.clip(overlay, 0, 1)
    return Image.fromarray((overlay * 255).astype(np.uint8))

# ---------------------------------------------------------------------------
# 7. 推理 + 注意力可视化
def infer_with_attention(image_pil: Image.Image,
                        interpolation: str = "bilinear",
                        attention_layer: str = "stage_3_block_1",
                        stage_average: bool = False,
                        normalize_attention: bool = True):
    if image_pil is None:
        return None, None

    transform = build_transform(is_training=False, interpolation=interpolation)
    input_tensor = transform(image_pil).unsqueeze(0).to(device)

    # (1) 分类预测
    with torch.no_grad():
        logits = model(input_tensor)
        probs  = F.softmax(logits, dim=1)[0]
        confidences = {class_names[i]: float(probs[i]) for i in range(NUM_CLASSES)}

    # (2) 提取注意力权重
    attention_weights = attention_extractor.extract_attention_weights(input_tensor)
    
    if not attention_weights:
        return confidences, None
    
    # (3) 处理注意力权重
    processed_attention = attention_extractor.process_attention_for_visualization(
        attention_weights, input_tensor.shape[-2:]
    )
    
    if not processed_attention:
        return confidences, None
    
    # (4) 选择要可视化的注意力层
    if stage_average:
        # 计算指定stage所有block的平均注意力
        stage_num = attention_layer.split('_')[1]
        stage_attentions = []
        
        for layer_name, attn_map in processed_attention.items():
            if f"stage_{stage_num}_" in layer_name:
                stage_attentions.append(attn_map)
        
        if stage_attentions:
            # 计算平均注意力
            avg_attention = torch.stack(stage_attentions).mean(dim=0)
            attention_vis = visualize_attention(avg_attention, image_pil, normalize_attention)
        else:
            return confidences, None
    else:
        # 使用指定层的注意力
        if attention_layer in processed_attention:
            attention_vis = visualize_attention(
                processed_attention[attention_layer], image_pil, normalize_attention
            )
        else:
            # 如果指定层不存在,使用第一个可用的层
            first_layer = list(processed_attention.keys())[0]
            attention_vis = visualize_attention(
                processed_attention[first_layer], image_pil, normalize_attention
            )
    
    return confidences, attention_vis

# ---------------------------------------------------------------------------
# 8. Gradio UI
def launch_app():
    with gr.Blocks() as demo:
        gr.Markdown("""
# 🖼️ AI vs. Non-AI Image Classifier  (Swin-Large + Attention Visualization)

🖼️ AI 鉴别器(基于 Swin-Large 视觉骨干,输出注意力热力图)

基于Swin Transformer的自注意力机制来可视化模型关注的区域。

Notice: 使用 bicubic 效果较好。请负责任地使用此工具。

此工具仅供研究和教育用途。
""")

        with gr.Row():
            in_img = gr.Image(type="pil", label="Upload an Image")
            out_attention = gr.Image(type="pil", label="Attention Heatmap")

        with gr.Row():
            out_lbl = gr.Label(num_top_classes=2, label="Predictions")

        with gr.Row():
            interp_choice = gr.Radio(
                ["bilinear", "bicubic", "nearest"], value="bicubic",
                label="Resize Interpolation (预处理插值)"
            )
            
        with gr.Row():
            attention_layer_choice = gr.Dropdown(
                choices=[
                    "stage_0_block_0", "stage_0_block_1",
                    "stage_1_block_0", "stage_1_block_1", 
                    "stage_2_block_0", "stage_2_block_1", "stage_2_block_2",
                    "stage_3_block_0", "stage_3_block_1"
                ],
                value="stage_3_block_1",
                label="选择注意力层 (Attention Layer)"
            )
            
        with gr.Row():
            stage_avg_toggle = gr.Checkbox(
                value=False, 
                label="计算整个Stage的平均注意力 (Average Stage Attention)"
            )
            normalize_toggle = gr.Checkbox(
                value=True, 
                label="归一化注意力 (Normalize Attention)"
            )

        run_btn = gr.Button("🚀 Run Analysis")

        def _run(img, inter, attn_layer, stage_avg, normalize):
            return infer_with_attention(
                img, 
                interpolation=inter, 
                attention_layer=attn_layer,
                stage_average=stage_avg,
                normalize_attention=normalize
            )

        run_btn.click(
            _run,
            inputs=[in_img, interp_choice, attention_layer_choice, stage_avg_toggle, normalize_toggle],
            outputs=[out_lbl, out_attention]
        )

        gr.Markdown("""
### 说明:
- **注意力层选择**: 可以选择不同的Swin Transformer层来查看注意力模式
- **Stage平均**: 勾选后会计算选中stage中所有block的平均注意力
- **归一化**: 将注意力值归一化到0-1范围内,便于可视化
- **热力图**: 红色区域表示模型更关注的区域,蓝色区域表示关注度较低的区域
        """)

    demo.launch()

# ---------------------------------------------------------------------------
if __name__ == "__main__":
    launch_app()