AIDetectV2 / app.py
telecomadm1145's picture
Update app.py
cedb103 verified
raw
history blame
13 kB
# -*- coding: utf-8 -*-
"""
Swin-Large AI vs. Non-AI Detector (基于注意力机制可视化)
"""
import os
import math
import torch
import torch.nn.functional as F
import torch.nn as nn
import timm
import numpy as np
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
from torchvision import transforms
import matplotlib.pyplot as plt
import matplotlib.cm as cm
# --- Configuration ---------------------------------------------------------
REPO_ID = "telecomadm1145/swin-ai-detection"
HF_FILENAME = "swin_classifier_stage1_v2_epoch_3.pth"
LOCAL_CKPT_DIR = "./checkpoints"
MODEL_NAME = "swin_large_patch4_window12_384" # ← 使用 large
NUM_CLASSES = 2
SEED = 4421
dropout_rate = 0.1
class_names = ["Non-AI Generated", "AI Generated"] # 0, 1
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(SEED); np.random.seed(SEED)
print(f"Using device: {device}")
# ---------------------------------------------------------------------------
# 1. 修改模型结构以提取注意力
class SwinClassifierWithAttention(nn.Module):
def __init__(self, model_name, num_classes, pretrained=True):
super().__init__()
self.backbone = timm.create_model(model_name, pretrained=pretrained,
num_classes=0)
self.data_config = timm.data.resolve_data_config({}, model=self.backbone)
self.classifier = nn.Sequential(
nn.Dropout(dropout_rate),
nn.Linear(self.backbone.num_features, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(dropout_rate * 0.7),
nn.Linear(512, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(dropout_rate * 0.5),
nn.Linear(128, num_classes)
)
# 存储注意力权重的钩子
self.attention_weights = {}
self.register_attention_hooks()
def register_attention_hooks(self):
"""注册钩子函数来提取注意力权重"""
def hook_fn(name):
def hook(module, input, output):
# 对于Swin Transformer的窗口注意力机制
# output通常是 (B, N, C) 格式
if hasattr(module, 'attn'):
# 获取注意力权重
self.attention_weights[name] = module.attn.attention_weights
return hook
# 为每个stage的每个block注册钩子
for stage_idx, stage in enumerate(self.backbone.layers):
for block_idx, block in enumerate(stage.blocks):
hook_name = f"stage_{stage_idx}_block_{block_idx}"
if hasattr(block, 'attn'):
block.attn.register_forward_hook(hook_fn(hook_name))
def forward(self, x):
# 清空之前的注意力权重
self.attention_weights = {}
feats = self.backbone(x)
return self.classifier(feats)
# ---------------------------------------------------------------------------
# 2. 注意力提取和可视化类
class AttentionExtractor:
def __init__(self, model):
self.model = model
self.attention_maps = {}
def extract_attention_weights(self, x):
"""提取所有层的注意力权重"""
with torch.no_grad():
_ = self.model(x) # 前向传播以触发钩子
return self.model.attention_weights.copy()
def process_attention_for_visualization(self, attention_weights, input_size):
"""处理注意力权重用于可视化"""
processed_maps = {}
for layer_name, attn_weight in attention_weights.items():
if attn_weight is None:
continue
# attn_weight shape: [batch_size, num_heads, seq_len, seq_len]
if len(attn_weight.shape) == 4:
# 取平均池化所有注意力头
attn_map = attn_weight.mean(dim=1) # [batch_size, seq_len, seq_len]
# 取第一个样本
attn_map = attn_map[0] # [seq_len, seq_len]
# 对于自注意力,我们通常关注CLS token对其他token的注意力
# 或者计算所有token的平均注意力
if attn_map.shape[0] > 1:
# 计算每个位置的平均注意力分数
avg_attention = attn_map.mean(dim=0) # [seq_len]
# 将注意力分数reshape为2D特征图
seq_len = avg_attention.shape[0]
grid_size = int(math.sqrt(seq_len))
if grid_size * grid_size == seq_len:
attention_2d = avg_attention.reshape(grid_size, grid_size)
processed_maps[layer_name] = attention_2d
return processed_maps
# ---------------------------------------------------------------------------
# 3. 下载 / 缓存 checkpoint
print("⏬ Download / cache checkpoint …")
ckpt_path = hf_hub_download(
repo_id = REPO_ID,
filename = HF_FILENAME,
local_dir = LOCAL_CKPT_DIR,
force_download=False # 已存在则直接用
)
print(f"Checkpoint path: {ckpt_path}")
# ---------------------------------------------------------------------------
# 4. 实例化 & 加载权重
model = SwinClassifierWithAttention(MODEL_NAME, NUM_CLASSES, pretrained=False).to(device)
state = torch.load(ckpt_path, map_location=device, weights_only=False)
model.load_state_dict(state.get("model_state_dict", state), strict=False) # strict=False 因为添加了新的组件
model.eval()
print("✅ Model loaded.")
attention_extractor = AttentionExtractor(model)
# ---------------------------------------------------------------------------
# 5. 变换函数
def build_transform(is_training: bool, interpolation: str):
"""
根据插值方式(双线性 / 三次等)构建 timm 默认变换
"""
cfg = model.data_config.copy()
cfg.update(dict(interpolation=interpolation))
return timm.data.create_transform(**cfg, is_training=is_training)
# ---------------------------------------------------------------------------
# 6. 注意力可视化函数
def visualize_attention(attention_map, original_image, normalize=True):
"""将注意力图可视化到原始图像上"""
if normalize:
# 归一化注意力图到[0,1]
attention_map = attention_map.cpu().numpy()
attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min() + 1e-8)
else:
attention_map = attention_map.cpu().numpy()
# 调整注意力图大小到原始图像大小
attention_resized = Image.fromarray((attention_map * 255).astype(np.uint8)) \
.resize(original_image.size, Image.Resampling.BILINEAR)
# 转换为热力图
attention_array = np.array(attention_resized) / 255.0
heatmap = cm.jet(attention_array)[:, :, :3] # 去掉alpha通道
# 叠加到原始图像
original_array = np.array(original_image) / 255.0
if len(original_array.shape) == 3:
overlay = 0.6 * original_array + 0.4 * heatmap
else:
# 灰度图像处理
original_array = np.stack([original_array] * 3, axis=-1)
overlay = 0.6 * original_array + 0.4 * heatmap
overlay = np.clip(overlay, 0, 1)
return Image.fromarray((overlay * 255).astype(np.uint8))
# ---------------------------------------------------------------------------
# 7. 推理 + 注意力可视化
def infer_with_attention(image_pil: Image.Image,
interpolation: str = "bilinear",
attention_layer: str = "stage_3_block_1",
stage_average: bool = False,
normalize_attention: bool = True):
if image_pil is None:
return None, None
transform = build_transform(is_training=False, interpolation=interpolation)
input_tensor = transform(image_pil).unsqueeze(0).to(device)
# (1) 分类预测
with torch.no_grad():
logits = model(input_tensor)
probs = F.softmax(logits, dim=1)[0]
confidences = {class_names[i]: float(probs[i]) for i in range(NUM_CLASSES)}
# (2) 提取注意力权重
attention_weights = attention_extractor.extract_attention_weights(input_tensor)
if not attention_weights:
return confidences, None
# (3) 处理注意力权重
processed_attention = attention_extractor.process_attention_for_visualization(
attention_weights, input_tensor.shape[-2:]
)
if not processed_attention:
return confidences, None
# (4) 选择要可视化的注意力层
if stage_average:
# 计算指定stage所有block的平均注意力
stage_num = attention_layer.split('_')[1]
stage_attentions = []
for layer_name, attn_map in processed_attention.items():
if f"stage_{stage_num}_" in layer_name:
stage_attentions.append(attn_map)
if stage_attentions:
# 计算平均注意力
avg_attention = torch.stack(stage_attentions).mean(dim=0)
attention_vis = visualize_attention(avg_attention, image_pil, normalize_attention)
else:
return confidences, None
else:
# 使用指定层的注意力
if attention_layer in processed_attention:
attention_vis = visualize_attention(
processed_attention[attention_layer], image_pil, normalize_attention
)
else:
# 如果指定层不存在,使用第一个可用的层
first_layer = list(processed_attention.keys())[0]
attention_vis = visualize_attention(
processed_attention[first_layer], image_pil, normalize_attention
)
return confidences, attention_vis
# ---------------------------------------------------------------------------
# 8. Gradio UI
def launch_app():
with gr.Blocks() as demo:
gr.Markdown("""
# 🖼️ AI vs. Non-AI Image Classifier (Swin-Large + Attention Visualization)
🖼️ AI 鉴别器(基于 Swin-Large 视觉骨干,输出注意力热力图)
基于Swin Transformer的自注意力机制来可视化模型关注的区域。
Notice: 使用 bicubic 效果较好。请负责任地使用此工具。
此工具仅供研究和教育用途。
""")
with gr.Row():
in_img = gr.Image(type="pil", label="Upload an Image")
out_attention = gr.Image(type="pil", label="Attention Heatmap")
with gr.Row():
out_lbl = gr.Label(num_top_classes=2, label="Predictions")
with gr.Row():
interp_choice = gr.Radio(
["bilinear", "bicubic", "nearest"], value="bicubic",
label="Resize Interpolation (预处理插值)"
)
with gr.Row():
attention_layer_choice = gr.Dropdown(
choices=[
"stage_0_block_0", "stage_0_block_1",
"stage_1_block_0", "stage_1_block_1",
"stage_2_block_0", "stage_2_block_1", "stage_2_block_2",
"stage_3_block_0", "stage_3_block_1"
],
value="stage_3_block_1",
label="选择注意力层 (Attention Layer)"
)
with gr.Row():
stage_avg_toggle = gr.Checkbox(
value=False,
label="计算整个Stage的平均注意力 (Average Stage Attention)"
)
normalize_toggle = gr.Checkbox(
value=True,
label="归一化注意力 (Normalize Attention)"
)
run_btn = gr.Button("🚀 Run Analysis")
def _run(img, inter, attn_layer, stage_avg, normalize):
return infer_with_attention(
img,
interpolation=inter,
attention_layer=attn_layer,
stage_average=stage_avg,
normalize_attention=normalize
)
run_btn.click(
_run,
inputs=[in_img, interp_choice, attention_layer_choice, stage_avg_toggle, normalize_toggle],
outputs=[out_lbl, out_attention]
)
gr.Markdown("""
### 说明:
- **注意力层选择**: 可以选择不同的Swin Transformer层来查看注意力模式
- **Stage平均**: 勾选后会计算选中stage中所有block的平均注意力
- **归一化**: 将注意力值归一化到0-1范围内,便于可视化
- **热力图**: 红色区域表示模型更关注的区域,蓝色区域表示关注度较低的区域
""")
demo.launch()
# ---------------------------------------------------------------------------
if __name__ == "__main__":
launch_app()