Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| """ | |
| Swin-Large AI vs. Non-AI Detector (with Model Selection & Attention Visualization) | |
| """ | |
| import os | |
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| import timm | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| from huggingface_hub import hf_hub_download | |
| # --- Configuration --------------------------------------------------------- | |
| REPO_ID = "telecomadm1145/swin-ai-detection" | |
| HF_FILENAMES = { | |
| "V2": "swin_classifier_stage1_v2_epoch_3.pth", | |
| "V4": "swin_classifier_stage1_v4.pth", | |
| } | |
| DEFAULT_CKPT = "Swin-V4 (Final)" | |
| LOCAL_CKPT_DIR = "./checkpoints" | |
| MODEL_NAME = "swin_large_patch4_window12_384" | |
| NUM_CLASSES = 2 | |
| SEED = 4421 | |
| dropout_rate = 0.1 | |
| class_names = ["Non-AI Generated", "AI Generated"] # 0, 1 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch.manual_seed(SEED) | |
| np.random.seed(SEED) | |
| print(f"Using device: {device}") | |
| # --- Global model state ---------------------------------------------------- | |
| model = None | |
| current_ckpt_name = None | |
| attention_maps = [] # To store hooked attention maps | |
| # --------------------------------------------------------------------------- | |
| # 1. 模型结构 | |
| class SwinClassifier(nn.Module): | |
| def __init__(self, model_name, num_classes, pretrained=True): | |
| super().__init__() | |
| self.backbone = timm.create_model(model_name, pretrained=pretrained, | |
| num_classes=0) | |
| self.data_config = timm.data.resolve_data_config({}, model=self.backbone) | |
| self.classifier = nn.Sequential( | |
| nn.Dropout(dropout_rate), | |
| nn.Linear(self.backbone.num_features, 512), | |
| nn.BatchNorm1d(512), | |
| nn.ReLU(), | |
| nn.Dropout(dropout_rate * 0.7), | |
| nn.Linear(512, 128), | |
| nn.BatchNorm1d(128), | |
| nn.ReLU(), | |
| nn.Dropout(dropout_rate * 0.5), | |
| nn.Linear(128, num_classes) | |
| ) | |
| def forward(self, x): | |
| feats = self.backbone(x) | |
| return self.classifier(feats) | |
| # --------------------------------------------------------------------------- | |
| # 2. 动态模型加载函数 | |
| def load_model(ckpt_name: str): | |
| """ | |
| Dynamically loads the selected model checkpoint. | |
| If the model is already loaded, it does nothing. | |
| """ | |
| global model, current_ckpt_name | |
| if ckpt_name == current_ckpt_name: | |
| print(f"✅ Model '{ckpt_name}' is already loaded.") | |
| return | |
| print(f"🔄 Switching to model: '{ckpt_name}'...") | |
| hf_filename = HF_FILENAMES[ckpt_name] | |
| print("⏬ Downloading / caching checkpoint if needed…") | |
| ckpt_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=hf_filename, | |
| local_dir=LOCAL_CKPT_DIR, | |
| force_download=False | |
| ) | |
| print(f"Checkpoint path: {ckpt_path}") | |
| # Instantiate and load weights | |
| model = SwinClassifier(MODEL_NAME, NUM_CLASSES, pretrained=False).to(device) | |
| state = torch.load(ckpt_path, map_location=device, weights_only=False) | |
| model.load_state_dict(state.get("model_state_dict", state), strict=True) | |
| model.eval() | |
| current_ckpt_name = ckpt_name | |
| print(f"✅ Model '{ckpt_name}' loaded successfully.") | |
| # --------------------------------------------------------------------------- | |
| # 3. torchvision / timm transform 工厂函数 | |
| def build_transform(is_training: bool, interpolation: str): | |
| """ | |
| 根据插值方式(双线性 / 三次等)构建 timm 默认变换 | |
| """ | |
| if model is None: | |
| raise RuntimeError("Model is not loaded. Please call load_model() first.") | |
| cfg = model.data_config.copy() | |
| cfg.update(dict(interpolation=interpolation)) | |
| return timm.data.create_transform(**cfg, is_training=is_training) | |
| # --------------------------------------------------------------------------- | |
| # 4. Attention Hook & Visualization | |
| def get_attention_map(module, input, output): | |
| """Hook to capture the attention map from the attention module.""" | |
| global attention_maps | |
| # The attention map is typically the second element of the output tuple | |
| # It has shape [B, num_heads, N, N] where N is num_patches | |
| attention_maps.append(output[1].cpu()) | |
| def create_attention_visualization(image_pil: Image.Image, attn_map: torch.Tensor) -> Image.Image: | |
| """Creates an overlay of the attention map on the original image.""" | |
| # Average across all heads | |
| attn_map = attn_map.mean(dim=1)[0] # Shape: [N, N] | |
| # To get the attention score for each patch, we can average the attention | |
| # it receives from all other patches. | |
| residual_attn = attn_map.sum(dim=0) # Sum over rows | |
| # Reshape to 2D grid | |
| patch_size = model.backbone.patch_embed.patch_size[0] | |
| num_patches = residual_attn.shape[0] | |
| grid_size = int(math.sqrt(num_patches)) | |
| if grid_size * grid_size != num_patches: | |
| print(f"Warning: Number of patches ({num_patches}) is not a perfect square. Visualization may be incorrect.") | |
| # Fallback for non-square patch layouts if needed, but Swin usually has square. | |
| return image_pil | |
| attn_grid = residual_attn.reshape(grid_size, grid_size).detach().numpy() | |
| # Normalize the grid | |
| attn_grid = (attn_grid - attn_grid.min()) / (attn_grid.max() - attn_grid.min()) | |
| # Use a colormap to create a heatmap | |
| cmap = plt.get_cmap('viridis') | |
| heatmap_colored = (cmap(attn_grid)[:, :, :3] * 255).astype(np.uint8) | |
| heatmap_pil = Image.fromarray(heatmap_colored) | |
| # Resize heatmap to original image size | |
| heatmap_resized = heatmap_pil.resize(image_pil.size, Image.BICUBIC) | |
| # Blend original image with the heatmap | |
| viz_image = Image.blend(image_pil, heatmap_resized, alpha=0.5) | |
| return viz_image | |
| # --------------------------------------------------------------------------- | |
| # 5. 推理 + 可选的注意力可视化 | |
| def predict_and_visualize(image_pil: Image.Image, | |
| ckpt_name: str, | |
| interpolation: str = "bicubic", | |
| show_attention: bool = True): | |
| if image_pil is None: | |
| return None, None | |
| # Ensure the correct model is loaded | |
| load_model(ckpt_name) | |
| global attention_maps | |
| attention_maps = [] # Reset before inference | |
| transform = build_transform(is_training=False, interpolation=interpolation) | |
| input_tensor = transform(image_pil).unsqueeze(0).to(device) | |
| # Register hook if visualization is requested | |
| hook_handle = None | |
| if show_attention: | |
| target_layer = model.backbone.layers[-1].blocks[-1].attn | |
| hook_handle = target_layer.register_forward_hook(get_attention_map) | |
| with torch.no_grad(): | |
| logits = model(input_tensor) | |
| # Always remove the hook after the forward pass | |
| if hook_handle: | |
| hook_handle.remove() | |
| probs = F.softmax(logits, dim=1)[0] | |
| confidences = {class_names[i]: float(probs[i]) for i in range(NUM_CLASSES)} | |
| # Generate visualization if requested and possible | |
| viz_image = None | |
| if show_attention and attention_maps: | |
| original_image = image_pil.copy().convert("RGB") | |
| viz_image = create_attention_visualization(original_image, attention_maps[0]) | |
| return confidences, viz_image | |
| # --------------------------------------------------------------------------- | |
| # 6. Gradio UI | |
| def launch_app(): | |
| # Load default model at startup | |
| load_model(DEFAULT_CKPT) | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🖼️ AI vs. Non-AI Image Classifier") | |
| gr.Markdown("Using Swin-Large Transformer with Attention Visualization.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| in_img = gr.Image(type="pil", label="Upload an Image") | |
| model_choice = gr.Dropdown( | |
| list(HF_FILENAMES.keys()), value=DEFAULT_CKPT, label="Select Model" | |
| ) | |
| interp_choice = gr.Radio( | |
| ["bilinear", "bicubic", "nearest"], value="bicubic", | |
| label="Resize Interpolation (Preprocessing)" | |
| ) | |
| viz_checkbox = gr.Checkbox(value=True, label="Show Attention Visualization") | |
| run_btn = gr.Button("🚀 Run Analysis", variant="primary") | |
| with gr.Column(scale=2): | |
| out_lbl = gr.Label(num_top_classes=2, label="Predictions") | |
| out_viz = gr.Image(type="pil", label="Attention Map Visualization", visible=True) | |
| run_btn.click( | |
| predict_and_visualize, | |
| inputs=[in_img, model_choice, interp_choice, viz_checkbox], | |
| outputs=[out_lbl, out_viz] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| #[os.path.join(os.path.dirname(__file__), "examples/ai_1.png"), DEFAULT_CKPT, "bicubic", True], | |
| #[os.path.join(os.path.dirname(__file__), "examples/real_1.jpg"), DEFAULT_CKPT, "bicubic", True], | |
| ], | |
| inputs=[in_img, model_choice, interp_choice, viz_checkbox], | |
| outputs=[out_lbl, out_viz], | |
| fn=predict_and_visualize, | |
| cache_examples=False, # Set to True if examples are static | |
| ) | |
| demo.launch() | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| # Create an examples directory for Gradio | |
| if not os.path.exists("examples"): | |
| os.makedirs("examples") | |
| print("Created 'examples' directory. Please add some sample images (e.g., ai_1.png, real_1.jpg) there for the UI examples.") | |
| launch_app() |