# -*- coding: utf-8 -*- """ Swin-Large AI vs. Non-AI Detector (with Model Selection & Attention Visualization) """ import os import math import torch import torch.nn.functional as F import torch.nn as nn import timm import numpy as np from PIL import Image, ImageDraw import gradio as gr import matplotlib.pyplot as plt from huggingface_hub import hf_hub_download # --- Configuration --------------------------------------------------------- REPO_ID = "telecomadm1145/swin-ai-detection" HF_FILENAMES = { "V2": "swin_classifier_stage1_v2_epoch_3.pth", "V4": "swin_classifier_stage1_v4.pth", } DEFAULT_CKPT = "Swin-V4 (Final)" LOCAL_CKPT_DIR = "./checkpoints" MODEL_NAME = "swin_large_patch4_window12_384" NUM_CLASSES = 2 SEED = 4421 dropout_rate = 0.1 class_names = ["Non-AI Generated", "AI Generated"] # 0, 1 device = "cuda" if torch.cuda.is_available() else "cpu" torch.manual_seed(SEED) np.random.seed(SEED) print(f"Using device: {device}") # --- Global model state ---------------------------------------------------- model = None current_ckpt_name = None attention_maps = [] # To store hooked attention maps # --------------------------------------------------------------------------- # 1. 模型结构 class SwinClassifier(nn.Module): def __init__(self, model_name, num_classes, pretrained=True): super().__init__() self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0) self.data_config = timm.data.resolve_data_config({}, model=self.backbone) self.classifier = nn.Sequential( nn.Dropout(dropout_rate), nn.Linear(self.backbone.num_features, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(dropout_rate * 0.7), nn.Linear(512, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(dropout_rate * 0.5), nn.Linear(128, num_classes) ) def forward(self, x): feats = self.backbone(x) return self.classifier(feats) # --------------------------------------------------------------------------- # 2. 动态模型加载函数 def load_model(ckpt_name: str): """ Dynamically loads the selected model checkpoint. If the model is already loaded, it does nothing. """ global model, current_ckpt_name if ckpt_name == current_ckpt_name: print(f"✅ Model '{ckpt_name}' is already loaded.") return print(f"🔄 Switching to model: '{ckpt_name}'...") hf_filename = HF_FILENAMES[ckpt_name] print("⏬ Downloading / caching checkpoint if needed…") ckpt_path = hf_hub_download( repo_id=REPO_ID, filename=hf_filename, local_dir=LOCAL_CKPT_DIR, force_download=False ) print(f"Checkpoint path: {ckpt_path}") # Instantiate and load weights model = SwinClassifier(MODEL_NAME, NUM_CLASSES, pretrained=False).to(device) state = torch.load(ckpt_path, map_location=device, weights_only=False) model.load_state_dict(state.get("model_state_dict", state), strict=True) model.eval() current_ckpt_name = ckpt_name print(f"✅ Model '{ckpt_name}' loaded successfully.") # --------------------------------------------------------------------------- # 3. torchvision / timm transform 工厂函数 def build_transform(is_training: bool, interpolation: str): """ 根据插值方式(双线性 / 三次等)构建 timm 默认变换 """ if model is None: raise RuntimeError("Model is not loaded. Please call load_model() first.") cfg = model.data_config.copy() cfg.update(dict(interpolation=interpolation)) return timm.data.create_transform(**cfg, is_training=is_training) # --------------------------------------------------------------------------- # 4. Attention Hook & Visualization def get_attention_map(module, input, output): """Hook to capture the attention map from the attention module.""" global attention_maps # The attention map is typically the second element of the output tuple # It has shape [B, num_heads, N, N] where N is num_patches attention_maps.append(output[1].cpu()) def create_attention_visualization(image_pil: Image.Image, attn_map: torch.Tensor) -> Image.Image: """Creates an overlay of the attention map on the original image.""" # Average across all heads attn_map = attn_map.mean(dim=1)[0] # Shape: [N, N] # To get the attention score for each patch, we can average the attention # it receives from all other patches. residual_attn = attn_map.sum(dim=0) # Sum over rows # Reshape to 2D grid patch_size = model.backbone.patch_embed.patch_size[0] num_patches = residual_attn.shape[0] grid_size = int(math.sqrt(num_patches)) if grid_size * grid_size != num_patches: print(f"Warning: Number of patches ({num_patches}) is not a perfect square. Visualization may be incorrect.") # Fallback for non-square patch layouts if needed, but Swin usually has square. return image_pil attn_grid = residual_attn.reshape(grid_size, grid_size).detach().numpy() # Normalize the grid attn_grid = (attn_grid - attn_grid.min()) / (attn_grid.max() - attn_grid.min()) # Use a colormap to create a heatmap cmap = plt.get_cmap('viridis') heatmap_colored = (cmap(attn_grid)[:, :, :3] * 255).astype(np.uint8) heatmap_pil = Image.fromarray(heatmap_colored) # Resize heatmap to original image size heatmap_resized = heatmap_pil.resize(image_pil.size, Image.BICUBIC) # Blend original image with the heatmap viz_image = Image.blend(image_pil, heatmap_resized, alpha=0.5) return viz_image # --------------------------------------------------------------------------- # 5. 推理 + 可选的注意力可视化 def predict_and_visualize(image_pil: Image.Image, ckpt_name: str, interpolation: str = "bicubic", show_attention: bool = True): if image_pil is None: return None, None # Ensure the correct model is loaded load_model(ckpt_name) global attention_maps attention_maps = [] # Reset before inference transform = build_transform(is_training=False, interpolation=interpolation) input_tensor = transform(image_pil).unsqueeze(0).to(device) # Register hook if visualization is requested hook_handle = None if show_attention: target_layer = model.backbone.layers[-1].blocks[-1].attn hook_handle = target_layer.register_forward_hook(get_attention_map) with torch.no_grad(): logits = model(input_tensor) # Always remove the hook after the forward pass if hook_handle: hook_handle.remove() probs = F.softmax(logits, dim=1)[0] confidences = {class_names[i]: float(probs[i]) for i in range(NUM_CLASSES)} # Generate visualization if requested and possible viz_image = None if show_attention and attention_maps: original_image = image_pil.copy().convert("RGB") viz_image = create_attention_visualization(original_image, attention_maps[0]) return confidences, viz_image # --------------------------------------------------------------------------- # 6. Gradio UI def launch_app(): # Load default model at startup load_model(DEFAULT_CKPT) with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 🖼️ AI vs. Non-AI Image Classifier") gr.Markdown("Using Swin-Large Transformer with Attention Visualization.") with gr.Row(): with gr.Column(scale=1): in_img = gr.Image(type="pil", label="Upload an Image") model_choice = gr.Dropdown( list(HF_FILENAMES.keys()), value=DEFAULT_CKPT, label="Select Model" ) interp_choice = gr.Radio( ["bilinear", "bicubic", "nearest"], value="bicubic", label="Resize Interpolation (Preprocessing)" ) viz_checkbox = gr.Checkbox(value=True, label="Show Attention Visualization") run_btn = gr.Button("🚀 Run Analysis", variant="primary") with gr.Column(scale=2): out_lbl = gr.Label(num_top_classes=2, label="Predictions") out_viz = gr.Image(type="pil", label="Attention Map Visualization", visible=True) run_btn.click( predict_and_visualize, inputs=[in_img, model_choice, interp_choice, viz_checkbox], outputs=[out_lbl, out_viz] ) gr.Examples( examples=[ #[os.path.join(os.path.dirname(__file__), "examples/ai_1.png"), DEFAULT_CKPT, "bicubic", True], #[os.path.join(os.path.dirname(__file__), "examples/real_1.jpg"), DEFAULT_CKPT, "bicubic", True], ], inputs=[in_img, model_choice, interp_choice, viz_checkbox], outputs=[out_lbl, out_viz], fn=predict_and_visualize, cache_examples=False, # Set to True if examples are static ) demo.launch() # --------------------------------------------------------------------------- if __name__ == "__main__": # Create an examples directory for Gradio if not os.path.exists("examples"): os.makedirs("examples") print("Created 'examples' directory. Please add some sample images (e.g., ai_1.png, real_1.jpg) there for the UI examples.") launch_app()