Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
"""
|
| 3 |
-
Swin-Large
|
| 4 |
"""
|
| 5 |
import os
|
| 6 |
import math
|
|
@@ -9,29 +9,37 @@ import torch.nn.functional as F
|
|
| 9 |
import torch.nn as nn
|
| 10 |
import timm
|
| 11 |
import numpy as np
|
| 12 |
-
from PIL import Image
|
| 13 |
import gradio as gr
|
|
|
|
| 14 |
|
| 15 |
-
from huggingface_hub import hf_hub_download
|
| 16 |
-
from pytorch_grad_cam import GradCAM
|
| 17 |
-
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
| 18 |
-
from pytorch_grad_cam.utils.image import show_cam_on_image
|
| 19 |
|
| 20 |
# --- Configuration ---------------------------------------------------------
|
| 21 |
-
REPO_ID
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 32 |
-
torch.manual_seed(SEED)
|
|
|
|
| 33 |
print(f"Using device: {device}")
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
# ---------------------------------------------------------------------------
|
| 36 |
# 1. 模型结构
|
| 37 |
class SwinClassifier(nn.Module):
|
|
@@ -59,97 +67,189 @@ class SwinClassifier(nn.Module):
|
|
| 59 |
return self.classifier(feats)
|
| 60 |
|
| 61 |
# ---------------------------------------------------------------------------
|
| 62 |
-
# 2.
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
print(f"
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
# ---------------------------------------------------------------------------
|
| 81 |
-
#
|
| 82 |
def build_transform(is_training: bool, interpolation: str):
|
| 83 |
"""
|
| 84 |
根据插值方式(双线性 / 三次等)构建 timm 默认变换
|
| 85 |
"""
|
|
|
|
|
|
|
| 86 |
cfg = model.data_config.copy()
|
| 87 |
cfg.update(dict(interpolation=interpolation))
|
| 88 |
return timm.data.create_transform(**cfg, is_training=is_training)
|
| 89 |
|
| 90 |
# ---------------------------------------------------------------------------
|
| 91 |
-
#
|
| 92 |
-
def
|
| 93 |
-
"""
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
-
#
|
| 103 |
-
|
| 104 |
-
|
| 105 |
|
| 106 |
# ---------------------------------------------------------------------------
|
| 107 |
-
#
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
if image_pil is None:
|
| 113 |
return None, None
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
transform = build_transform(is_training=False, interpolation=interpolation)
|
| 116 |
input_tensor = transform(image_pil).unsqueeze(0).to(device)
|
| 117 |
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
confidences = {class_names[i]: float(probs[i]) for i in range(NUM_CLASSES)}
|
| 121 |
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
# ---------------------------------------------------------------------------
|
| 125 |
-
#
|
| 126 |
def launch_app():
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
-
run_btn = gr.Button("🚀 Run")
|
| 131 |
-
|
| 132 |
-
with gr.Row():
|
| 133 |
-
interp_choice = gr.Radio(
|
| 134 |
-
["bilinear", "bicubic", "nearest"], value="bicubic",
|
| 135 |
-
label="Resize Interpolation (预处理插值)"
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
with gr.Row():
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
run_btn.click(
|
| 146 |
-
|
| 147 |
-
inputs=[in_img, interp_choice],
|
| 148 |
-
outputs=[out_lbl]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
)
|
| 150 |
|
| 151 |
demo.launch()
|
| 152 |
|
| 153 |
# ---------------------------------------------------------------------------
|
| 154 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
launch_app()
|
|
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
"""
|
| 3 |
+
Swin-Large AI vs. Non-AI Detector (with Model Selection & Attention Visualization)
|
| 4 |
"""
|
| 5 |
import os
|
| 6 |
import math
|
|
|
|
| 9 |
import torch.nn as nn
|
| 10 |
import timm
|
| 11 |
import numpy as np
|
| 12 |
+
from PIL import Image, ImageDraw
|
| 13 |
import gradio as gr
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
|
| 16 |
+
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# --- Configuration ---------------------------------------------------------
|
| 19 |
+
REPO_ID = "telecomadm1145/swin-ai-detection"
|
| 20 |
+
HF_FILENAMES = {
|
| 21 |
+
"V2": "swin_classifier_stage1_v2_epoch_3.pth",
|
| 22 |
+
"V4": "swin_classifier_stage1_v4.pth",
|
| 23 |
+
}
|
| 24 |
+
DEFAULT_CKPT = "Swin-V4 (Final)"
|
| 25 |
+
LOCAL_CKPT_DIR = "./checkpoints"
|
| 26 |
+
MODEL_NAME = "swin_large_patch4_window12_384"
|
| 27 |
+
NUM_CLASSES = 2
|
| 28 |
+
SEED = 4421
|
| 29 |
+
dropout_rate = 0.1
|
| 30 |
+
|
| 31 |
+
class_names = ["Non-AI Generated", "AI Generated"] # 0, 1
|
| 32 |
|
| 33 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 34 |
+
torch.manual_seed(SEED)
|
| 35 |
+
np.random.seed(SEED)
|
| 36 |
print(f"Using device: {device}")
|
| 37 |
|
| 38 |
+
# --- Global model state ----------------------------------------------------
|
| 39 |
+
model = None
|
| 40 |
+
current_ckpt_name = None
|
| 41 |
+
attention_maps = [] # To store hooked attention maps
|
| 42 |
+
|
| 43 |
# ---------------------------------------------------------------------------
|
| 44 |
# 1. 模型结构
|
| 45 |
class SwinClassifier(nn.Module):
|
|
|
|
| 67 |
return self.classifier(feats)
|
| 68 |
|
| 69 |
# ---------------------------------------------------------------------------
|
| 70 |
+
# 2. 动态模型加载函数
|
| 71 |
+
def load_model(ckpt_name: str):
|
| 72 |
+
"""
|
| 73 |
+
Dynamically loads the selected model checkpoint.
|
| 74 |
+
If the model is already loaded, it does nothing.
|
| 75 |
+
"""
|
| 76 |
+
global model, current_ckpt_name
|
| 77 |
+
if ckpt_name == current_ckpt_name:
|
| 78 |
+
print(f"✅ Model '{ckpt_name}' is already loaded.")
|
| 79 |
+
return
|
| 80 |
+
|
| 81 |
+
print(f"🔄 Switching to model: '{ckpt_name}'...")
|
| 82 |
+
hf_filename = HF_FILENAMES[ckpt_name]
|
| 83 |
+
|
| 84 |
+
print("⏬ Downloading / caching checkpoint if needed…")
|
| 85 |
+
ckpt_path = hf_hub_download(
|
| 86 |
+
repo_id=REPO_ID,
|
| 87 |
+
filename=hf_filename,
|
| 88 |
+
local_dir=LOCAL_CKPT_DIR,
|
| 89 |
+
force_download=False
|
| 90 |
+
)
|
| 91 |
+
print(f"Checkpoint path: {ckpt_path}")
|
| 92 |
+
|
| 93 |
+
# Instantiate and load weights
|
| 94 |
+
model = SwinClassifier(MODEL_NAME, NUM_CLASSES, pretrained=False).to(device)
|
| 95 |
+
state = torch.load(ckpt_path, map_location=device, weights_only=False)
|
| 96 |
+
model.load_state_dict(state.get("model_state_dict", state), strict=True)
|
| 97 |
+
model.eval()
|
| 98 |
+
current_ckpt_name = ckpt_name
|
| 99 |
+
print(f"✅ Model '{ckpt_name}' loaded successfully.")
|
| 100 |
|
| 101 |
# ---------------------------------------------------------------------------
|
| 102 |
+
# 3. torchvision / timm transform 工厂函数
|
| 103 |
def build_transform(is_training: bool, interpolation: str):
|
| 104 |
"""
|
| 105 |
根据插值方式(双线性 / 三次等)构建 timm 默认变换
|
| 106 |
"""
|
| 107 |
+
if model is None:
|
| 108 |
+
raise RuntimeError("Model is not loaded. Please call load_model() first.")
|
| 109 |
cfg = model.data_config.copy()
|
| 110 |
cfg.update(dict(interpolation=interpolation))
|
| 111 |
return timm.data.create_transform(**cfg, is_training=is_training)
|
| 112 |
|
| 113 |
# ---------------------------------------------------------------------------
|
| 114 |
+
# 4. Attention Hook & Visualization
|
| 115 |
+
def get_attention_map(module, input, output):
|
| 116 |
+
"""Hook to capture the attention map from the attention module."""
|
| 117 |
+
global attention_maps
|
| 118 |
+
# The attention map is typically the second element of the output tuple
|
| 119 |
+
# It has shape [B, num_heads, N, N] where N is num_patches
|
| 120 |
+
attention_maps.append(output[1].cpu())
|
| 121 |
+
|
| 122 |
+
def create_attention_visualization(image_pil: Image.Image, attn_map: torch.Tensor) -> Image.Image:
|
| 123 |
+
"""Creates an overlay of the attention map on the original image."""
|
| 124 |
+
# Average across all heads
|
| 125 |
+
attn_map = attn_map.mean(dim=1)[0] # Shape: [N, N]
|
| 126 |
+
|
| 127 |
+
# To get the attention score for each patch, we can average the attention
|
| 128 |
+
# it receives from all other patches.
|
| 129 |
+
residual_attn = attn_map.sum(dim=0) # Sum over rows
|
| 130 |
+
|
| 131 |
+
# Reshape to 2D grid
|
| 132 |
+
patch_size = model.backbone.patch_embed.patch_size[0]
|
| 133 |
+
num_patches = residual_attn.shape[0]
|
| 134 |
+
grid_size = int(math.sqrt(num_patches))
|
| 135 |
+
|
| 136 |
+
if grid_size * grid_size != num_patches:
|
| 137 |
+
print(f"Warning: Number of patches ({num_patches}) is not a perfect square. Visualization may be incorrect.")
|
| 138 |
+
# Fallback for non-square patch layouts if needed, but Swin usually has square.
|
| 139 |
+
return image_pil
|
| 140 |
+
|
| 141 |
+
attn_grid = residual_attn.reshape(grid_size, grid_size).detach().numpy()
|
| 142 |
+
|
| 143 |
+
# Normalize the grid
|
| 144 |
+
attn_grid = (attn_grid - attn_grid.min()) / (attn_grid.max() - attn_grid.min())
|
| 145 |
+
|
| 146 |
+
# Use a colormap to create a heatmap
|
| 147 |
+
cmap = plt.get_cmap('viridis')
|
| 148 |
+
heatmap_colored = (cmap(attn_grid)[:, :, :3] * 255).astype(np.uint8)
|
| 149 |
+
heatmap_pil = Image.fromarray(heatmap_colored)
|
| 150 |
+
|
| 151 |
+
# Resize heatmap to original image size
|
| 152 |
+
heatmap_resized = heatmap_pil.resize(image_pil.size, Image.BICUBIC)
|
| 153 |
|
| 154 |
+
# Blend original image with the heatmap
|
| 155 |
+
viz_image = Image.blend(image_pil, heatmap_resized, alpha=0.5)
|
| 156 |
+
return viz_image
|
| 157 |
|
| 158 |
# ---------------------------------------------------------------------------
|
| 159 |
+
# 5. 推理 + 可选的注意力可视化
|
| 160 |
+
def predict_and_visualize(image_pil: Image.Image,
|
| 161 |
+
ckpt_name: str,
|
| 162 |
+
interpolation: str = "bicubic",
|
| 163 |
+
show_attention: bool = True):
|
| 164 |
if image_pil is None:
|
| 165 |
return None, None
|
| 166 |
|
| 167 |
+
# Ensure the correct model is loaded
|
| 168 |
+
load_model(ckpt_name)
|
| 169 |
+
|
| 170 |
+
global attention_maps
|
| 171 |
+
attention_maps = [] # Reset before inference
|
| 172 |
+
|
| 173 |
transform = build_transform(is_training=False, interpolation=interpolation)
|
| 174 |
input_tensor = transform(image_pil).unsqueeze(0).to(device)
|
| 175 |
|
| 176 |
+
# Register hook if visualization is requested
|
| 177 |
+
hook_handle = None
|
| 178 |
+
if show_attention:
|
| 179 |
+
target_layer = model.backbone.layers[-1].blocks[-1].attn
|
| 180 |
+
hook_handle = target_layer.register_forward_hook(get_attention_map)
|
| 181 |
+
|
| 182 |
+
with torch.no_grad():
|
| 183 |
+
logits = model(input_tensor)
|
| 184 |
+
|
| 185 |
+
# Always remove the hook after the forward pass
|
| 186 |
+
if hook_handle:
|
| 187 |
+
hook_handle.remove()
|
| 188 |
+
|
| 189 |
+
probs = F.softmax(logits, dim=1)[0]
|
| 190 |
confidences = {class_names[i]: float(probs[i]) for i in range(NUM_CLASSES)}
|
| 191 |
|
| 192 |
+
# Generate visualization if requested and possible
|
| 193 |
+
viz_image = None
|
| 194 |
+
if show_attention and attention_maps:
|
| 195 |
+
original_image = image_pil.copy().convert("RGB")
|
| 196 |
+
viz_image = create_attention_visualization(original_image, attention_maps[0])
|
| 197 |
+
|
| 198 |
+
return confidences, viz_image
|
| 199 |
|
| 200 |
# ---------------------------------------------------------------------------
|
| 201 |
+
# 6. Gradio UI
|
| 202 |
def launch_app():
|
| 203 |
+
# Load default model at startup
|
| 204 |
+
load_model(DEFAULT_CKPT)
|
| 205 |
+
|
| 206 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 207 |
+
gr.Markdown("# 🖼️ AI vs. Non-AI Image Classifier")
|
| 208 |
+
gr.Markdown("Using Swin-Large Transformer with Attention Visualization.")
|
| 209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
with gr.Row():
|
| 211 |
+
with gr.Column(scale=1):
|
| 212 |
+
in_img = gr.Image(type="pil", label="Upload an Image")
|
| 213 |
+
|
| 214 |
+
model_choice = gr.Dropdown(
|
| 215 |
+
list(HF_FILENAMES.keys()), value=DEFAULT_CKPT, label="Select Model"
|
| 216 |
+
)
|
| 217 |
+
interp_choice = gr.Radio(
|
| 218 |
+
["bilinear", "bicubic", "nearest"], value="bicubic",
|
| 219 |
+
label="Resize Interpolation (Preprocessing)"
|
| 220 |
+
)
|
| 221 |
+
viz_checkbox = gr.Checkbox(value=True, label="Show Attention Visualization")
|
| 222 |
+
|
| 223 |
+
run_btn = gr.Button("🚀 Run Analysis", variant="primary")
|
| 224 |
+
|
| 225 |
+
with gr.Column(scale=2):
|
| 226 |
+
out_lbl = gr.Label(num_top_classes=2, label="Predictions")
|
| 227 |
+
out_viz = gr.Image(type="pil", label="Attention Map Visualization", visible=True)
|
| 228 |
|
| 229 |
run_btn.click(
|
| 230 |
+
predict_and_visualize,
|
| 231 |
+
inputs=[in_img, model_choice, interp_choice, viz_checkbox],
|
| 232 |
+
outputs=[out_lbl, out_viz]
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
gr.Examples(
|
| 236 |
+
examples=[
|
| 237 |
+
#[os.path.join(os.path.dirname(__file__), "examples/ai_1.png"), DEFAULT_CKPT, "bicubic", True],
|
| 238 |
+
#[os.path.join(os.path.dirname(__file__), "examples/real_1.jpg"), DEFAULT_CKPT, "bicubic", True],
|
| 239 |
+
],
|
| 240 |
+
inputs=[in_img, model_choice, interp_choice, viz_checkbox],
|
| 241 |
+
outputs=[out_lbl, out_viz],
|
| 242 |
+
fn=predict_and_visualize,
|
| 243 |
+
cache_examples=False, # Set to True if examples are static
|
| 244 |
)
|
| 245 |
|
| 246 |
demo.launch()
|
| 247 |
|
| 248 |
# ---------------------------------------------------------------------------
|
| 249 |
if __name__ == "__main__":
|
| 250 |
+
# Create an examples directory for Gradio
|
| 251 |
+
if not os.path.exists("examples"):
|
| 252 |
+
os.makedirs("examples")
|
| 253 |
+
print("Created 'examples' directory. Please add some sample images (e.g., ai_1.png, real_1.jpg) there for the UI examples.")
|
| 254 |
+
|
| 255 |
launch_app()
|