AIDetectV2 / app.py
telecomadm1145's picture
Update app.py
f844b83 verified
raw
history blame
9.65 kB
# -*- coding: utf-8 -*-
"""
Swin-Large AI vs. Non-AI Detector (with Model Selection & Attention Visualization)
"""
import os
import math
import torch
import torch.nn.functional as F
import torch.nn as nn
import timm
import numpy as np
from PIL import Image, ImageDraw
import gradio as gr
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download
# --- Configuration ---------------------------------------------------------
REPO_ID = "telecomadm1145/swin-ai-detection"
HF_FILENAMES = {
"V2": "swin_classifier_stage1_v2_epoch_3.pth",
"V4": "swin_classifier_stage1_v4.pth",
}
DEFAULT_CKPT = "Swin-V4 (Final)"
LOCAL_CKPT_DIR = "./checkpoints"
MODEL_NAME = "swin_large_patch4_window12_384"
NUM_CLASSES = 2
SEED = 4421
dropout_rate = 0.1
class_names = ["Non-AI Generated", "AI Generated"] # 0, 1
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(SEED)
np.random.seed(SEED)
print(f"Using device: {device}")
# --- Global model state ----------------------------------------------------
model = None
current_ckpt_name = None
attention_maps = [] # To store hooked attention maps
# ---------------------------------------------------------------------------
# 1. 模型结构
class SwinClassifier(nn.Module):
def __init__(self, model_name, num_classes, pretrained=True):
super().__init__()
self.backbone = timm.create_model(model_name, pretrained=pretrained,
num_classes=0)
self.data_config = timm.data.resolve_data_config({}, model=self.backbone)
self.classifier = nn.Sequential(
nn.Dropout(dropout_rate),
nn.Linear(self.backbone.num_features, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(dropout_rate * 0.7),
nn.Linear(512, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(dropout_rate * 0.5),
nn.Linear(128, num_classes)
)
def forward(self, x):
feats = self.backbone(x)
return self.classifier(feats)
# ---------------------------------------------------------------------------
# 2. 动态模型加载函数
def load_model(ckpt_name: str):
"""
Dynamically loads the selected model checkpoint.
If the model is already loaded, it does nothing.
"""
global model, current_ckpt_name
if ckpt_name == current_ckpt_name:
print(f"✅ Model '{ckpt_name}' is already loaded.")
return
print(f"🔄 Switching to model: '{ckpt_name}'...")
hf_filename = HF_FILENAMES[ckpt_name]
print("⏬ Downloading / caching checkpoint if needed…")
ckpt_path = hf_hub_download(
repo_id=REPO_ID,
filename=hf_filename,
local_dir=LOCAL_CKPT_DIR,
force_download=False
)
print(f"Checkpoint path: {ckpt_path}")
# Instantiate and load weights
model = SwinClassifier(MODEL_NAME, NUM_CLASSES, pretrained=False).to(device)
state = torch.load(ckpt_path, map_location=device, weights_only=False)
model.load_state_dict(state.get("model_state_dict", state), strict=True)
model.eval()
current_ckpt_name = ckpt_name
print(f"✅ Model '{ckpt_name}' loaded successfully.")
# ---------------------------------------------------------------------------
# 3. torchvision / timm transform 工厂函数
def build_transform(is_training: bool, interpolation: str):
"""
根据插值方式(双线性 / 三次等)构建 timm 默认变换
"""
if model is None:
raise RuntimeError("Model is not loaded. Please call load_model() first.")
cfg = model.data_config.copy()
cfg.update(dict(interpolation=interpolation))
return timm.data.create_transform(**cfg, is_training=is_training)
# ---------------------------------------------------------------------------
# 4. Attention Hook & Visualization
def get_attention_map(module, input, output):
"""Hook to capture the attention map from the attention module."""
global attention_maps
# The attention map is typically the second element of the output tuple
# It has shape [B, num_heads, N, N] where N is num_patches
attention_maps.append(output[1].cpu())
def create_attention_visualization(image_pil: Image.Image, attn_map: torch.Tensor) -> Image.Image:
"""Creates an overlay of the attention map on the original image."""
# Average across all heads
attn_map = attn_map.mean(dim=1)[0] # Shape: [N, N]
# To get the attention score for each patch, we can average the attention
# it receives from all other patches.
residual_attn = attn_map.sum(dim=0) # Sum over rows
# Reshape to 2D grid
patch_size = model.backbone.patch_embed.patch_size[0]
num_patches = residual_attn.shape[0]
grid_size = int(math.sqrt(num_patches))
if grid_size * grid_size != num_patches:
print(f"Warning: Number of patches ({num_patches}) is not a perfect square. Visualization may be incorrect.")
# Fallback for non-square patch layouts if needed, but Swin usually has square.
return image_pil
attn_grid = residual_attn.reshape(grid_size, grid_size).detach().numpy()
# Normalize the grid
attn_grid = (attn_grid - attn_grid.min()) / (attn_grid.max() - attn_grid.min())
# Use a colormap to create a heatmap
cmap = plt.get_cmap('viridis')
heatmap_colored = (cmap(attn_grid)[:, :, :3] * 255).astype(np.uint8)
heatmap_pil = Image.fromarray(heatmap_colored)
# Resize heatmap to original image size
heatmap_resized = heatmap_pil.resize(image_pil.size, Image.BICUBIC)
# Blend original image with the heatmap
viz_image = Image.blend(image_pil, heatmap_resized, alpha=0.5)
return viz_image
# ---------------------------------------------------------------------------
# 5. 推理 + 可选的注意力可视化
def predict_and_visualize(image_pil: Image.Image,
ckpt_name: str,
interpolation: str = "bicubic",
show_attention: bool = True):
if image_pil is None:
return None, None
# Ensure the correct model is loaded
load_model(ckpt_name)
global attention_maps
attention_maps = [] # Reset before inference
transform = build_transform(is_training=False, interpolation=interpolation)
input_tensor = transform(image_pil).unsqueeze(0).to(device)
# Register hook if visualization is requested
hook_handle = None
if show_attention:
target_layer = model.backbone.layers[-1].blocks[-1].attn
hook_handle = target_layer.register_forward_hook(get_attention_map)
with torch.no_grad():
logits = model(input_tensor)
# Always remove the hook after the forward pass
if hook_handle:
hook_handle.remove()
probs = F.softmax(logits, dim=1)[0]
confidences = {class_names[i]: float(probs[i]) for i in range(NUM_CLASSES)}
# Generate visualization if requested and possible
viz_image = None
if show_attention and attention_maps:
original_image = image_pil.copy().convert("RGB")
viz_image = create_attention_visualization(original_image, attention_maps[0])
return confidences, viz_image
# ---------------------------------------------------------------------------
# 6. Gradio UI
def launch_app():
# Load default model at startup
load_model(DEFAULT_CKPT)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🖼️ AI vs. Non-AI Image Classifier")
gr.Markdown("Using Swin-Large Transformer with Attention Visualization.")
with gr.Row():
with gr.Column(scale=1):
in_img = gr.Image(type="pil", label="Upload an Image")
model_choice = gr.Dropdown(
list(HF_FILENAMES.keys()), value=DEFAULT_CKPT, label="Select Model"
)
interp_choice = gr.Radio(
["bilinear", "bicubic", "nearest"], value="bicubic",
label="Resize Interpolation (Preprocessing)"
)
viz_checkbox = gr.Checkbox(value=True, label="Show Attention Visualization")
run_btn = gr.Button("🚀 Run Analysis", variant="primary")
with gr.Column(scale=2):
out_lbl = gr.Label(num_top_classes=2, label="Predictions")
out_viz = gr.Image(type="pil", label="Attention Map Visualization", visible=True)
run_btn.click(
predict_and_visualize,
inputs=[in_img, model_choice, interp_choice, viz_checkbox],
outputs=[out_lbl, out_viz]
)
gr.Examples(
examples=[
#[os.path.join(os.path.dirname(__file__), "examples/ai_1.png"), DEFAULT_CKPT, "bicubic", True],
#[os.path.join(os.path.dirname(__file__), "examples/real_1.jpg"), DEFAULT_CKPT, "bicubic", True],
],
inputs=[in_img, model_choice, interp_choice, viz_checkbox],
outputs=[out_lbl, out_viz],
fn=predict_and_visualize,
cache_examples=False, # Set to True if examples are static
)
demo.launch()
# ---------------------------------------------------------------------------
if __name__ == "__main__":
# Create an examples directory for Gradio
if not os.path.exists("examples"):
os.makedirs("examples")
print("Created 'examples' directory. Please add some sample images (e.g., ai_1.png, real_1.jpg) there for the UI examples.")
launch_app()