# -*- coding: utf-8 -*- """ Swin-Large AI vs. Non-AI Detector (基于注意力机制可视化) """ import os import math import torch import torch.nn.functional as F import torch.nn as nn import timm import numpy as np from PIL import Image import gradio as gr from huggingface_hub import hf_hub_download from torchvision import transforms import matplotlib.pyplot as plt import matplotlib.cm as cm # --- Configuration --------------------------------------------------------- REPO_ID = "telecomadm1145/swin-ai-detection" HF_FILENAME = "swin_classifier_stage1_v2_epoch_3.pth" LOCAL_CKPT_DIR = "./checkpoints" MODEL_NAME = "swin_large_patch4_window12_384" # ← 使用 large NUM_CLASSES = 2 SEED = 4421 dropout_rate = 0.1 class_names = ["Non-AI Generated", "AI Generated"] # 0, 1 device = "cuda" if torch.cuda.is_available() else "cpu" torch.manual_seed(SEED); np.random.seed(SEED) print(f"Using device: {device}") # --------------------------------------------------------------------------- # 1. 修改模型结构以提取注意力 class SwinClassifierWithAttention(nn.Module): def __init__(self, model_name, num_classes, pretrained=True): super().__init__() self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0) self.data_config = timm.data.resolve_data_config({}, model=self.backbone) self.classifier = nn.Sequential( nn.Dropout(dropout_rate), nn.Linear(self.backbone.num_features, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(dropout_rate * 0.7), nn.Linear(512, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(dropout_rate * 0.5), nn.Linear(128, num_classes) ) # 存储注意力权重的钩子 self.attention_weights = {} self.register_attention_hooks() def register_attention_hooks(self): """注册钩子函数来提取注意力权重""" def hook_fn(name): def hook(module, input, output): # 对于Swin Transformer的窗口注意力机制 # output通常是 (B, N, C) 格式 if hasattr(module, 'attn'): # 获取注意力权重 self.attention_weights[name] = module.attn.attention_weights return hook # 为每个stage的每个block注册钩子 for stage_idx, stage in enumerate(self.backbone.layers): for block_idx, block in enumerate(stage.blocks): hook_name = f"stage_{stage_idx}_block_{block_idx}" if hasattr(block, 'attn'): block.attn.register_forward_hook(hook_fn(hook_name)) def forward(self, x): # 清空之前的注意力权重 self.attention_weights = {} feats = self.backbone(x) return self.classifier(feats) # --------------------------------------------------------------------------- # 2. 注意力提取和可视化类 class AttentionExtractor: def __init__(self, model): self.model = model self.attention_maps = {} def extract_attention_weights(self, x): """提取所有层的注意力权重""" with torch.no_grad(): _ = self.model(x) # 前向传播以触发钩子 return self.model.attention_weights.copy() def process_attention_for_visualization(self, attention_weights, input_size): """处理注意力权重用于可视化""" processed_maps = {} for layer_name, attn_weight in attention_weights.items(): if attn_weight is None: continue # attn_weight shape: [batch_size, num_heads, seq_len, seq_len] if len(attn_weight.shape) == 4: # 取平均池化所有注意力头 attn_map = attn_weight.mean(dim=1) # [batch_size, seq_len, seq_len] # 取第一个样本 attn_map = attn_map[0] # [seq_len, seq_len] # 对于自注意力,我们通常关注CLS token对其他token的注意力 # 或者计算所有token的平均注意力 if attn_map.shape[0] > 1: # 计算每个位置的平均注意力分数 avg_attention = attn_map.mean(dim=0) # [seq_len] # 将注意力分数reshape为2D特征图 seq_len = avg_attention.shape[0] grid_size = int(math.sqrt(seq_len)) if grid_size * grid_size == seq_len: attention_2d = avg_attention.reshape(grid_size, grid_size) processed_maps[layer_name] = attention_2d return processed_maps # --------------------------------------------------------------------------- # 3. 下载 / 缓存 checkpoint print("⏬ Download / cache checkpoint …") ckpt_path = hf_hub_download( repo_id = REPO_ID, filename = HF_FILENAME, local_dir = LOCAL_CKPT_DIR, force_download=False # 已存在则直接用 ) print(f"Checkpoint path: {ckpt_path}") # --------------------------------------------------------------------------- # 4. 实例化 & 加载权重 model = SwinClassifierWithAttention(MODEL_NAME, NUM_CLASSES, pretrained=False).to(device) state = torch.load(ckpt_path, map_location=device, weights_only=False) model.load_state_dict(state.get("model_state_dict", state), strict=False) # strict=False 因为添加了新的组件 model.eval() print("✅ Model loaded.") attention_extractor = AttentionExtractor(model) # --------------------------------------------------------------------------- # 5. 变换函数 def build_transform(is_training: bool, interpolation: str): """ 根据插值方式(双线性 / 三次等)构建 timm 默认变换 """ cfg = model.data_config.copy() cfg.update(dict(interpolation=interpolation)) return timm.data.create_transform(**cfg, is_training=is_training) # --------------------------------------------------------------------------- # 6. 注意力可视化函数 def visualize_attention(attention_map, original_image, normalize=True): """将注意力图可视化到原始图像上""" if normalize: # 归一化注意力图到[0,1] attention_map = attention_map.cpu().numpy() attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min() + 1e-8) else: attention_map = attention_map.cpu().numpy() # 调整注意力图大小到原始图像大小 attention_resized = Image.fromarray((attention_map * 255).astype(np.uint8)) \ .resize(original_image.size, Image.Resampling.BILINEAR) # 转换为热力图 attention_array = np.array(attention_resized) / 255.0 heatmap = cm.jet(attention_array)[:, :, :3] # 去掉alpha通道 # 叠加到原始图像 original_array = np.array(original_image) / 255.0 if len(original_array.shape) == 3: overlay = 0.6 * original_array + 0.4 * heatmap else: # 灰度图像处理 original_array = np.stack([original_array] * 3, axis=-1) overlay = 0.6 * original_array + 0.4 * heatmap overlay = np.clip(overlay, 0, 1) return Image.fromarray((overlay * 255).astype(np.uint8)) # --------------------------------------------------------------------------- # 7. 推理 + 注意力可视化 def infer_with_attention(image_pil: Image.Image, interpolation: str = "bilinear", attention_layer: str = "stage_3_block_1", stage_average: bool = False, normalize_attention: bool = True): if image_pil is None: return None, None transform = build_transform(is_training=False, interpolation=interpolation) input_tensor = transform(image_pil).unsqueeze(0).to(device) # (1) 分类预测 with torch.no_grad(): logits = model(input_tensor) probs = F.softmax(logits, dim=1)[0] confidences = {class_names[i]: float(probs[i]) for i in range(NUM_CLASSES)} # (2) 提取注意力权重 attention_weights = attention_extractor.extract_attention_weights(input_tensor) if not attention_weights: return confidences, None # (3) 处理注意力权重 processed_attention = attention_extractor.process_attention_for_visualization( attention_weights, input_tensor.shape[-2:] ) if not processed_attention: return confidences, None # (4) 选择要可视化的注意力层 if stage_average: # 计算指定stage所有block的平均注意力 stage_num = attention_layer.split('_')[1] stage_attentions = [] for layer_name, attn_map in processed_attention.items(): if f"stage_{stage_num}_" in layer_name: stage_attentions.append(attn_map) if stage_attentions: # 计算平均注意力 avg_attention = torch.stack(stage_attentions).mean(dim=0) attention_vis = visualize_attention(avg_attention, image_pil, normalize_attention) else: return confidences, None else: # 使用指定层的注意力 if attention_layer in processed_attention: attention_vis = visualize_attention( processed_attention[attention_layer], image_pil, normalize_attention ) else: # 如果指定层不存在,使用第一个可用的层 first_layer = list(processed_attention.keys())[0] attention_vis = visualize_attention( processed_attention[first_layer], image_pil, normalize_attention ) return confidences, attention_vis # --------------------------------------------------------------------------- # 8. Gradio UI def launch_app(): with gr.Blocks() as demo: gr.Markdown(""" # 🖼️ AI vs. Non-AI Image Classifier (Swin-Large + Attention Visualization) 🖼️ AI 鉴别器(基于 Swin-Large 视觉骨干,输出注意力热力图) 基于Swin Transformer的自注意力机制来可视化模型关注的区域。 Notice: 使用 bicubic 效果较好。请负责任地使用此工具。 此工具仅供研究和教育用途。 """) with gr.Row(): in_img = gr.Image(type="pil", label="Upload an Image") out_attention = gr.Image(type="pil", label="Attention Heatmap") with gr.Row(): out_lbl = gr.Label(num_top_classes=2, label="Predictions") with gr.Row(): interp_choice = gr.Radio( ["bilinear", "bicubic", "nearest"], value="bicubic", label="Resize Interpolation (预处理插值)" ) with gr.Row(): attention_layer_choice = gr.Dropdown( choices=[ "stage_0_block_0", "stage_0_block_1", "stage_1_block_0", "stage_1_block_1", "stage_2_block_0", "stage_2_block_1", "stage_2_block_2", "stage_3_block_0", "stage_3_block_1" ], value="stage_3_block_1", label="选择注意力层 (Attention Layer)" ) with gr.Row(): stage_avg_toggle = gr.Checkbox( value=False, label="计算整个Stage的平均注意力 (Average Stage Attention)" ) normalize_toggle = gr.Checkbox( value=True, label="归一化注意力 (Normalize Attention)" ) run_btn = gr.Button("🚀 Run Analysis") def _run(img, inter, attn_layer, stage_avg, normalize): return infer_with_attention( img, interpolation=inter, attention_layer=attn_layer, stage_average=stage_avg, normalize_attention=normalize ) run_btn.click( _run, inputs=[in_img, interp_choice, attention_layer_choice, stage_avg_toggle, normalize_toggle], outputs=[out_lbl, out_attention] ) gr.Markdown(""" ### 说明: - **注意力层选择**: 可以选择不同的Swin Transformer层来查看注意力模式 - **Stage平均**: 勾选后会计算选中stage中所有block的平均注意力 - **归一化**: 将注意力值归一化到0-1范围内,便于可视化 - **热力图**: 红色区域表示模型更关注的区域,蓝色区域表示关注度较低的区域 """) demo.launch() # --------------------------------------------------------------------------- if __name__ == "__main__": launch_app()