Spaces:
Running
Running
File size: 9,650 Bytes
8b95acd f844b83 8b95acd 3e0966c 9fc4060 3e0966c f844b83 3e0966c f844b83 3e0966c f844b83 3e0966c 8b95acd f844b83 3e0966c f844b83 3e0966c f844b83 8b95acd 3acde1c 3e0966c 8b95acd 3e0966c b789a77 3e0966c b789a77 3e0966c b789a77 3e0966c 8b95acd f844b83 3e0966c 8b95acd f844b83 8b95acd cedb103 f844b83 8b95acd 3e0966c 8b95acd f844b83 3acde1c f844b83 cedb103 f844b83 8b95acd cedb103 3e0966c f844b83 8b95acd 11503dd f844b83 3acde1c 3e0966c f844b83 8b95acd 3acde1c f844b83 3acde1c f844b83 8b95acd 62bdf13 f844b83 8b95acd f844b83 8b95acd 3e0966c 8b95acd 3e0966c f844b83 619e447 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
# -*- coding: utf-8 -*-
"""
Swin-Large AI vs. Non-AI Detector (with Model Selection & Attention Visualization)
"""
import os
import math
import torch
import torch.nn.functional as F
import torch.nn as nn
import timm
import numpy as np
from PIL import Image, ImageDraw
import gradio as gr
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download
# --- Configuration ---------------------------------------------------------
REPO_ID = "telecomadm1145/swin-ai-detection"
HF_FILENAMES = {
"V2": "swin_classifier_stage1_v2_epoch_3.pth",
"V4": "swin_classifier_stage1_v4.pth",
}
DEFAULT_CKPT = "Swin-V4 (Final)"
LOCAL_CKPT_DIR = "./checkpoints"
MODEL_NAME = "swin_large_patch4_window12_384"
NUM_CLASSES = 2
SEED = 4421
dropout_rate = 0.1
class_names = ["Non-AI Generated", "AI Generated"] # 0, 1
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(SEED)
np.random.seed(SEED)
print(f"Using device: {device}")
# --- Global model state ----------------------------------------------------
model = None
current_ckpt_name = None
attention_maps = [] # To store hooked attention maps
# ---------------------------------------------------------------------------
# 1. 模型结构
class SwinClassifier(nn.Module):
def __init__(self, model_name, num_classes, pretrained=True):
super().__init__()
self.backbone = timm.create_model(model_name, pretrained=pretrained,
num_classes=0)
self.data_config = timm.data.resolve_data_config({}, model=self.backbone)
self.classifier = nn.Sequential(
nn.Dropout(dropout_rate),
nn.Linear(self.backbone.num_features, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(dropout_rate * 0.7),
nn.Linear(512, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(dropout_rate * 0.5),
nn.Linear(128, num_classes)
)
def forward(self, x):
feats = self.backbone(x)
return self.classifier(feats)
# ---------------------------------------------------------------------------
# 2. 动态模型加载函数
def load_model(ckpt_name: str):
"""
Dynamically loads the selected model checkpoint.
If the model is already loaded, it does nothing.
"""
global model, current_ckpt_name
if ckpt_name == current_ckpt_name:
print(f"✅ Model '{ckpt_name}' is already loaded.")
return
print(f"🔄 Switching to model: '{ckpt_name}'...")
hf_filename = HF_FILENAMES[ckpt_name]
print("⏬ Downloading / caching checkpoint if needed…")
ckpt_path = hf_hub_download(
repo_id=REPO_ID,
filename=hf_filename,
local_dir=LOCAL_CKPT_DIR,
force_download=False
)
print(f"Checkpoint path: {ckpt_path}")
# Instantiate and load weights
model = SwinClassifier(MODEL_NAME, NUM_CLASSES, pretrained=False).to(device)
state = torch.load(ckpt_path, map_location=device, weights_only=False)
model.load_state_dict(state.get("model_state_dict", state), strict=True)
model.eval()
current_ckpt_name = ckpt_name
print(f"✅ Model '{ckpt_name}' loaded successfully.")
# ---------------------------------------------------------------------------
# 3. torchvision / timm transform 工厂函数
def build_transform(is_training: bool, interpolation: str):
"""
根据插值方式(双线性 / 三次等)构建 timm 默认变换
"""
if model is None:
raise RuntimeError("Model is not loaded. Please call load_model() first.")
cfg = model.data_config.copy()
cfg.update(dict(interpolation=interpolation))
return timm.data.create_transform(**cfg, is_training=is_training)
# ---------------------------------------------------------------------------
# 4. Attention Hook & Visualization
def get_attention_map(module, input, output):
"""Hook to capture the attention map from the attention module."""
global attention_maps
# The attention map is typically the second element of the output tuple
# It has shape [B, num_heads, N, N] where N is num_patches
attention_maps.append(output[1].cpu())
def create_attention_visualization(image_pil: Image.Image, attn_map: torch.Tensor) -> Image.Image:
"""Creates an overlay of the attention map on the original image."""
# Average across all heads
attn_map = attn_map.mean(dim=1)[0] # Shape: [N, N]
# To get the attention score for each patch, we can average the attention
# it receives from all other patches.
residual_attn = attn_map.sum(dim=0) # Sum over rows
# Reshape to 2D grid
patch_size = model.backbone.patch_embed.patch_size[0]
num_patches = residual_attn.shape[0]
grid_size = int(math.sqrt(num_patches))
if grid_size * grid_size != num_patches:
print(f"Warning: Number of patches ({num_patches}) is not a perfect square. Visualization may be incorrect.")
# Fallback for non-square patch layouts if needed, but Swin usually has square.
return image_pil
attn_grid = residual_attn.reshape(grid_size, grid_size).detach().numpy()
# Normalize the grid
attn_grid = (attn_grid - attn_grid.min()) / (attn_grid.max() - attn_grid.min())
# Use a colormap to create a heatmap
cmap = plt.get_cmap('viridis')
heatmap_colored = (cmap(attn_grid)[:, :, :3] * 255).astype(np.uint8)
heatmap_pil = Image.fromarray(heatmap_colored)
# Resize heatmap to original image size
heatmap_resized = heatmap_pil.resize(image_pil.size, Image.BICUBIC)
# Blend original image with the heatmap
viz_image = Image.blend(image_pil, heatmap_resized, alpha=0.5)
return viz_image
# ---------------------------------------------------------------------------
# 5. 推理 + 可选的注意力可视化
def predict_and_visualize(image_pil: Image.Image,
ckpt_name: str,
interpolation: str = "bicubic",
show_attention: bool = True):
if image_pil is None:
return None, None
# Ensure the correct model is loaded
load_model(ckpt_name)
global attention_maps
attention_maps = [] # Reset before inference
transform = build_transform(is_training=False, interpolation=interpolation)
input_tensor = transform(image_pil).unsqueeze(0).to(device)
# Register hook if visualization is requested
hook_handle = None
if show_attention:
target_layer = model.backbone.layers[-1].blocks[-1].attn
hook_handle = target_layer.register_forward_hook(get_attention_map)
with torch.no_grad():
logits = model(input_tensor)
# Always remove the hook after the forward pass
if hook_handle:
hook_handle.remove()
probs = F.softmax(logits, dim=1)[0]
confidences = {class_names[i]: float(probs[i]) for i in range(NUM_CLASSES)}
# Generate visualization if requested and possible
viz_image = None
if show_attention and attention_maps:
original_image = image_pil.copy().convert("RGB")
viz_image = create_attention_visualization(original_image, attention_maps[0])
return confidences, viz_image
# ---------------------------------------------------------------------------
# 6. Gradio UI
def launch_app():
# Load default model at startup
load_model(DEFAULT_CKPT)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🖼️ AI vs. Non-AI Image Classifier")
gr.Markdown("Using Swin-Large Transformer with Attention Visualization.")
with gr.Row():
with gr.Column(scale=1):
in_img = gr.Image(type="pil", label="Upload an Image")
model_choice = gr.Dropdown(
list(HF_FILENAMES.keys()), value=DEFAULT_CKPT, label="Select Model"
)
interp_choice = gr.Radio(
["bilinear", "bicubic", "nearest"], value="bicubic",
label="Resize Interpolation (Preprocessing)"
)
viz_checkbox = gr.Checkbox(value=True, label="Show Attention Visualization")
run_btn = gr.Button("🚀 Run Analysis", variant="primary")
with gr.Column(scale=2):
out_lbl = gr.Label(num_top_classes=2, label="Predictions")
out_viz = gr.Image(type="pil", label="Attention Map Visualization", visible=True)
run_btn.click(
predict_and_visualize,
inputs=[in_img, model_choice, interp_choice, viz_checkbox],
outputs=[out_lbl, out_viz]
)
gr.Examples(
examples=[
#[os.path.join(os.path.dirname(__file__), "examples/ai_1.png"), DEFAULT_CKPT, "bicubic", True],
#[os.path.join(os.path.dirname(__file__), "examples/real_1.jpg"), DEFAULT_CKPT, "bicubic", True],
],
inputs=[in_img, model_choice, interp_choice, viz_checkbox],
outputs=[out_lbl, out_viz],
fn=predict_and_visualize,
cache_examples=False, # Set to True if examples are static
)
demo.launch()
# ---------------------------------------------------------------------------
if __name__ == "__main__":
# Create an examples directory for Gradio
if not os.path.exists("examples"):
os.makedirs("examples")
print("Created 'examples' directory. Please add some sample images (e.g., ai_1.png, real_1.jpg) there for the UI examples.")
launch_app() |