Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| """ | |
| Swin-Large AI vs. Non-AI Detector (基于注意力机制可视化) | |
| """ | |
| import os | |
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| import timm | |
| import numpy as np | |
| from PIL import Image | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from torchvision import transforms | |
| import matplotlib.pyplot as plt | |
| import matplotlib.cm as cm | |
| # --- Configuration --------------------------------------------------------- | |
| REPO_ID = "telecomadm1145/swin-ai-detection" | |
| HF_FILENAME = "swin_classifier_stage1_v2_epoch_3.pth" | |
| LOCAL_CKPT_DIR = "./checkpoints" | |
| MODEL_NAME = "swin_large_patch4_window12_384" # ← 使用 large | |
| NUM_CLASSES = 2 | |
| SEED = 4421 | |
| dropout_rate = 0.1 | |
| class_names = ["Non-AI Generated", "AI Generated"] # 0, 1 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch.manual_seed(SEED); np.random.seed(SEED) | |
| print(f"Using device: {device}") | |
| # --------------------------------------------------------------------------- | |
| # 1. 修改模型结构以提取注意力 | |
| class SwinClassifierWithAttention(nn.Module): | |
| def __init__(self, model_name, num_classes, pretrained=True): | |
| super().__init__() | |
| self.backbone = timm.create_model(model_name, pretrained=pretrained, | |
| num_classes=0) | |
| self.data_config = timm.data.resolve_data_config({}, model=self.backbone) | |
| self.classifier = nn.Sequential( | |
| nn.Dropout(dropout_rate), | |
| nn.Linear(self.backbone.num_features, 512), | |
| nn.BatchNorm1d(512), | |
| nn.ReLU(), | |
| nn.Dropout(dropout_rate * 0.7), | |
| nn.Linear(512, 128), | |
| nn.BatchNorm1d(128), | |
| nn.ReLU(), | |
| nn.Dropout(dropout_rate * 0.5), | |
| nn.Linear(128, num_classes) | |
| ) | |
| # 存储注意力权重的钩子 | |
| self.attention_weights = {} | |
| self.register_attention_hooks() | |
| def register_attention_hooks(self): | |
| """注册钩子函数来提取注意力权重""" | |
| def hook_fn(name): | |
| def hook(module, input, output): | |
| # 对于Swin Transformer的窗口注意力机制 | |
| # output通常是 (B, N, C) 格式 | |
| if hasattr(module, 'attn'): | |
| # 获取注意力权重 | |
| self.attention_weights[name] = module.attn.attention_weights | |
| return hook | |
| # 为每个stage的每个block注册钩子 | |
| for stage_idx, stage in enumerate(self.backbone.layers): | |
| for block_idx, block in enumerate(stage.blocks): | |
| hook_name = f"stage_{stage_idx}_block_{block_idx}" | |
| if hasattr(block, 'attn'): | |
| block.attn.register_forward_hook(hook_fn(hook_name)) | |
| def forward(self, x): | |
| # 清空之前的注意力权重 | |
| self.attention_weights = {} | |
| feats = self.backbone(x) | |
| return self.classifier(feats) | |
| # --------------------------------------------------------------------------- | |
| # 2. 注意力提取和可视化类 | |
| class AttentionExtractor: | |
| def __init__(self, model): | |
| self.model = model | |
| self.attention_maps = {} | |
| def extract_attention_weights(self, x): | |
| """提取所有层的注意力权重""" | |
| with torch.no_grad(): | |
| _ = self.model(x) # 前向传播以触发钩子 | |
| return self.model.attention_weights.copy() | |
| def process_attention_for_visualization(self, attention_weights, input_size): | |
| """处理注意力权重用于可视化""" | |
| processed_maps = {} | |
| for layer_name, attn_weight in attention_weights.items(): | |
| if attn_weight is None: | |
| continue | |
| # attn_weight shape: [batch_size, num_heads, seq_len, seq_len] | |
| if len(attn_weight.shape) == 4: | |
| # 取平均池化所有注意力头 | |
| attn_map = attn_weight.mean(dim=1) # [batch_size, seq_len, seq_len] | |
| # 取第一个样本 | |
| attn_map = attn_map[0] # [seq_len, seq_len] | |
| # 对于自注意力,我们通常关注CLS token对其他token的注意力 | |
| # 或者计算所有token的平均注意力 | |
| if attn_map.shape[0] > 1: | |
| # 计算每个位置的平均注意力分数 | |
| avg_attention = attn_map.mean(dim=0) # [seq_len] | |
| # 将注意力分数reshape为2D特征图 | |
| seq_len = avg_attention.shape[0] | |
| grid_size = int(math.sqrt(seq_len)) | |
| if grid_size * grid_size == seq_len: | |
| attention_2d = avg_attention.reshape(grid_size, grid_size) | |
| processed_maps[layer_name] = attention_2d | |
| return processed_maps | |
| # --------------------------------------------------------------------------- | |
| # 3. 下载 / 缓存 checkpoint | |
| print("⏬ Download / cache checkpoint …") | |
| ckpt_path = hf_hub_download( | |
| repo_id = REPO_ID, | |
| filename = HF_FILENAME, | |
| local_dir = LOCAL_CKPT_DIR, | |
| force_download=False # 已存在则直接用 | |
| ) | |
| print(f"Checkpoint path: {ckpt_path}") | |
| # --------------------------------------------------------------------------- | |
| # 4. 实例化 & 加载权重 | |
| model = SwinClassifierWithAttention(MODEL_NAME, NUM_CLASSES, pretrained=False).to(device) | |
| state = torch.load(ckpt_path, map_location=device, weights_only=False) | |
| model.load_state_dict(state.get("model_state_dict", state), strict=False) # strict=False 因为添加了新的组件 | |
| model.eval() | |
| print("✅ Model loaded.") | |
| attention_extractor = AttentionExtractor(model) | |
| # --------------------------------------------------------------------------- | |
| # 5. 变换函数 | |
| def build_transform(is_training: bool, interpolation: str): | |
| """ | |
| 根据插值方式(双线性 / 三次等)构建 timm 默认变换 | |
| """ | |
| cfg = model.data_config.copy() | |
| cfg.update(dict(interpolation=interpolation)) | |
| return timm.data.create_transform(**cfg, is_training=is_training) | |
| # --------------------------------------------------------------------------- | |
| # 6. 注意力可视化函数 | |
| def visualize_attention(attention_map, original_image, normalize=True): | |
| """将注意力图可视化到原始图像上""" | |
| if normalize: | |
| # 归一化注意力图到[0,1] | |
| attention_map = attention_map.cpu().numpy() | |
| attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min() + 1e-8) | |
| else: | |
| attention_map = attention_map.cpu().numpy() | |
| # 调整注意力图大小到原始图像大小 | |
| attention_resized = Image.fromarray((attention_map * 255).astype(np.uint8)) \ | |
| .resize(original_image.size, Image.Resampling.BILINEAR) | |
| # 转换为热力图 | |
| attention_array = np.array(attention_resized) / 255.0 | |
| heatmap = cm.jet(attention_array)[:, :, :3] # 去掉alpha通道 | |
| # 叠加到原始图像 | |
| original_array = np.array(original_image) / 255.0 | |
| if len(original_array.shape) == 3: | |
| overlay = 0.6 * original_array + 0.4 * heatmap | |
| else: | |
| # 灰度图像处理 | |
| original_array = np.stack([original_array] * 3, axis=-1) | |
| overlay = 0.6 * original_array + 0.4 * heatmap | |
| overlay = np.clip(overlay, 0, 1) | |
| return Image.fromarray((overlay * 255).astype(np.uint8)) | |
| # --------------------------------------------------------------------------- | |
| # 7. 推理 + 注意力可视化 | |
| def infer_with_attention(image_pil: Image.Image, | |
| interpolation: str = "bilinear", | |
| attention_layer: str = "stage_3_block_1", | |
| stage_average: bool = False, | |
| normalize_attention: bool = True): | |
| if image_pil is None: | |
| return None, None | |
| transform = build_transform(is_training=False, interpolation=interpolation) | |
| input_tensor = transform(image_pil).unsqueeze(0).to(device) | |
| # (1) 分类预测 | |
| with torch.no_grad(): | |
| logits = model(input_tensor) | |
| probs = F.softmax(logits, dim=1)[0] | |
| confidences = {class_names[i]: float(probs[i]) for i in range(NUM_CLASSES)} | |
| # (2) 提取注意力权重 | |
| attention_weights = attention_extractor.extract_attention_weights(input_tensor) | |
| if not attention_weights: | |
| return confidences, None | |
| # (3) 处理注意力权重 | |
| processed_attention = attention_extractor.process_attention_for_visualization( | |
| attention_weights, input_tensor.shape[-2:] | |
| ) | |
| if not processed_attention: | |
| return confidences, None | |
| # (4) 选择要可视化的注意力层 | |
| if stage_average: | |
| # 计算指定stage所有block的平均注意力 | |
| stage_num = attention_layer.split('_')[1] | |
| stage_attentions = [] | |
| for layer_name, attn_map in processed_attention.items(): | |
| if f"stage_{stage_num}_" in layer_name: | |
| stage_attentions.append(attn_map) | |
| if stage_attentions: | |
| # 计算平均注意力 | |
| avg_attention = torch.stack(stage_attentions).mean(dim=0) | |
| attention_vis = visualize_attention(avg_attention, image_pil, normalize_attention) | |
| else: | |
| return confidences, None | |
| else: | |
| # 使用指定层的注意力 | |
| if attention_layer in processed_attention: | |
| attention_vis = visualize_attention( | |
| processed_attention[attention_layer], image_pil, normalize_attention | |
| ) | |
| else: | |
| # 如果指定层不存在,使用第一个可用的层 | |
| first_layer = list(processed_attention.keys())[0] | |
| attention_vis = visualize_attention( | |
| processed_attention[first_layer], image_pil, normalize_attention | |
| ) | |
| return confidences, attention_vis | |
| # --------------------------------------------------------------------------- | |
| # 8. Gradio UI | |
| def launch_app(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| # 🖼️ AI vs. Non-AI Image Classifier (Swin-Large + Attention Visualization) | |
| 🖼️ AI 鉴别器(基于 Swin-Large 视觉骨干,输出注意力热力图) | |
| 基于Swin Transformer的自注意力机制来可视化模型关注的区域。 | |
| Notice: 使用 bicubic 效果较好。请负责任地使用此工具。 | |
| 此工具仅供研究和教育用途。 | |
| """) | |
| with gr.Row(): | |
| in_img = gr.Image(type="pil", label="Upload an Image") | |
| out_attention = gr.Image(type="pil", label="Attention Heatmap") | |
| with gr.Row(): | |
| out_lbl = gr.Label(num_top_classes=2, label="Predictions") | |
| with gr.Row(): | |
| interp_choice = gr.Radio( | |
| ["bilinear", "bicubic", "nearest"], value="bicubic", | |
| label="Resize Interpolation (预处理插值)" | |
| ) | |
| with gr.Row(): | |
| attention_layer_choice = gr.Dropdown( | |
| choices=[ | |
| "stage_0_block_0", "stage_0_block_1", | |
| "stage_1_block_0", "stage_1_block_1", | |
| "stage_2_block_0", "stage_2_block_1", "stage_2_block_2", | |
| "stage_3_block_0", "stage_3_block_1" | |
| ], | |
| value="stage_3_block_1", | |
| label="选择注意力层 (Attention Layer)" | |
| ) | |
| with gr.Row(): | |
| stage_avg_toggle = gr.Checkbox( | |
| value=False, | |
| label="计算整个Stage的平均注意力 (Average Stage Attention)" | |
| ) | |
| normalize_toggle = gr.Checkbox( | |
| value=True, | |
| label="归一化注意力 (Normalize Attention)" | |
| ) | |
| run_btn = gr.Button("🚀 Run Analysis") | |
| def _run(img, inter, attn_layer, stage_avg, normalize): | |
| return infer_with_attention( | |
| img, | |
| interpolation=inter, | |
| attention_layer=attn_layer, | |
| stage_average=stage_avg, | |
| normalize_attention=normalize | |
| ) | |
| run_btn.click( | |
| _run, | |
| inputs=[in_img, interp_choice, attention_layer_choice, stage_avg_toggle, normalize_toggle], | |
| outputs=[out_lbl, out_attention] | |
| ) | |
| gr.Markdown(""" | |
| ### 说明: | |
| - **注意力层选择**: 可以选择不同的Swin Transformer层来查看注意力模式 | |
| - **Stage平均**: 勾选后会计算选中stage中所有block的平均注意力 | |
| - **归一化**: 将注意力值归一化到0-1范围内,便于可视化 | |
| - **热力图**: 红色区域表示模型更关注的区域,蓝色区域表示关注度较低的区域 | |
| """) | |
| demo.launch() | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| launch_app() |