telecomadm1145 commited on
Commit
f014c13
·
verified ·
1 Parent(s): af382d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -130
app.py CHANGED
@@ -1,217 +1,235 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- Swin-Large AI vs. Non-AI Detector (with Model Selection & Attention Visualization) - V5 Update
 
 
 
 
 
4
  """
5
- import os
6
- import math
7
- import torch
8
- import torch.nn.functional as F
9
  import torch.nn as nn
10
- import timm
11
- import numpy as np
12
  from PIL import Image
13
  import gradio as gr
14
- import matplotlib.pyplot as plt
15
-
16
  from huggingface_hub import hf_hub_download
17
 
18
- # --- Configuration ---------------------------------------------------------
19
- REPO_ID = "telecomadm1145/swin-ai-detection"
 
 
20
  HF_FILENAMES = {
21
- "V2": "swin_classifier_stage1_v2_epoch_3.pth",
22
- "V4": "swin_classifier_stage1_v4.pth",
23
- "V5(underfitting)": "swin_classifier_stage1_v5_fp16.pth",
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  }
25
- DEFAULT_CKPT = "V4" # Set V5 as the new default
 
 
26
  LOCAL_CKPT_DIR = "./checkpoints"
27
- MODEL_NAME = "swin_large_patch4_window12_384"
28
- NUM_CLASSES = 2
29
  SEED = 4421
30
- dropout_rate = 0.1
31
-
32
- class_names = ["Non-AI Generated", "AI Generated"] # 0, 1
33
 
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
35
- torch.manual_seed(SEED)
36
- np.random.seed(SEED)
37
  print(f"Using device: {device}")
38
 
39
- # --- Global model state ----------------------------------------------------
40
- model = None
41
- current_ckpt_name = None
42
- attention_maps = [] # To store hooked attention maps
 
43
 
44
- # ---------------------------------------------------------------------------
45
- # 1. 模型结构 (Model Structure)
 
46
  class SwinClassifier(nn.Module):
47
- """
48
- Swin Transformer based classifier.
49
- The MLP head can be configured for different model versions (V2/V4 vs. V5).
50
- """
51
- def __init__(self, model_name, num_classes, pretrained=True, classifier_version='v4'):
52
  super().__init__()
53
- self.backbone = timm.create_model(model_name, pretrained=pretrained,
54
- num_classes=0) # Get features only
 
55
  self.data_config = timm.data.resolve_data_config({}, model=self.backbone)
56
 
57
- # Select classifier head based on version
58
- if classifier_version == 'v5':
59
- print("Using V5 classifier head with GELU activation.")
 
 
 
 
 
 
 
 
 
60
  self.classifier = nn.Sequential(
61
- nn.Dropout(dropout_rate),
62
  nn.Linear(self.backbone.num_features, 512),
63
  nn.BatchNorm1d(512),
64
- nn.GELU(), # V5 uses GELU
65
- nn.Dropout(dropout_rate * 0.7),
66
  nn.Linear(512, 128),
67
  nn.BatchNorm1d(128),
68
- nn.GELU(), # V5 uses GELU
69
- nn.Dropout(dropout_rate * 0.5),
70
- nn.Linear(128, num_classes)
71
  )
72
- else:
73
- print("Using V2/V4 classifier head with ReLU activation.")
74
  self.classifier = nn.Sequential(
75
- nn.Dropout(dropout_rate),
76
  nn.Linear(self.backbone.num_features, 512),
77
  nn.BatchNorm1d(512),
78
  nn.ReLU(),
79
- nn.Dropout(dropout_rate * 0.7),
80
  nn.Linear(512, 128),
81
  nn.BatchNorm1d(128),
82
  nn.ReLU(),
83
- nn.Dropout(dropout_rate * 0.5),
84
- nn.Linear(128, num_classes)
85
  )
86
 
87
  def forward(self, x):
88
- feats = self.backbone(x)
89
- return self.classifier(feats)
90
 
91
- # ---------------------------------------------------------------------------
92
- # 2. 动态模型加载函数 (Dynamic Model Loading Function)
 
93
  def load_model(ckpt_name: str):
94
- """
95
- Dynamically loads the selected model checkpoint.
96
- If the model is already loaded, it does nothing.
97
- It selects the correct classifier head based on the checkpoint name.
98
- """
99
- global model, current_ckpt_name
100
- if ckpt_name == current_ckpt_name and model is not None:
101
- #print(f"✅ Model '{ckpt_name}' is already loaded.")
102
- return
103
 
104
- print(f"🔄 Switching to model: '{ckpt_name}'...")
105
- hf_filename = HF_FILENAMES[ckpt_name]
106
 
107
- print(" Downloading / caching checkpoint if needed…")
108
- ckpt_path = hf_hub_download(
 
109
  repo_id=REPO_ID,
110
- filename=hf_filename,
111
- local_dir=LOCAL_CKPT_DIR,
112
- force_download=False
113
  )
114
- print(f"Checkpoint path: {ckpt_path}")
115
-
116
- # Determine which classifier version to use based on the checkpoint name
117
- classifier_version = 'v5' if 'V5' in ckpt_name else 'v4'
118
 
119
- # Instantiate and load weights
120
  model = SwinClassifier(
121
  MODEL_NAME,
122
- NUM_CLASSES,
123
- pretrained=False,
124
- classifier_version=classifier_version
125
  ).to(device)
126
 
127
- state = torch.load(ckpt_path, map_location=device, weights_only=False)
 
128
  model.load_state_dict(state.get("model_state_dict", state), strict=True)
129
  model.eval()
130
- current_ckpt_name = ckpt_name
131
- print(f"✅ Model '{ckpt_name}' loaded successfully.")
132
 
133
- # ---------------------------------------------------------------------------
134
- # 3. torchvision / timm transform 工厂函数 (Transform Factory Function)
 
 
 
 
135
  def build_transform(is_training: bool, interpolation: str):
136
- if model is None:
137
- raise RuntimeError("Model is not loaded. Please call load_model() first.")
138
  cfg = model.data_config.copy()
139
  cfg.update(dict(interpolation=interpolation))
140
  return timm.data.create_transform(**cfg, is_training=is_training)
141
 
142
- # ---------------------------------------------------------------------------
143
- # 5. 推理 (Inference)
144
- def predict_and_visualize(image_pil: Image.Image,
145
- ckpt_name: str,
146
- interpolation: str = "bicubic"):
147
- if image_pil is None:
148
- return None, None
149
 
150
- load_model(ckpt_name)
151
-
152
- transform = build_transform(is_training=False, interpolation=interpolation)
153
- input_tensor = transform(image_pil).unsqueeze(0).to(device)
154
 
155
- with torch.no_grad():
156
- logits = model(input_tensor)
 
157
 
158
- probs = F.softmax(logits, dim=1)[0]
159
- confidences = {class_names[i]: float(probs[i]) for i in range(NUM_CLASSES)}
160
 
161
- return confidences
 
 
162
 
163
- # ---------------------------------------------------------------------------
164
- # 6. Gradio UI
165
- def launch_app():
166
- # Pre-load the default model on startup
167
- load_model(DEFAULT_CKPT)
168
 
169
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
170
- gr.Markdown("# 🖼️ AI vs. Non-AI Image Classifier")
171
- gr.Markdown("Using Swin-Large Transformer with multiple model checkpoints.")
 
 
 
172
 
173
  with gr.Row():
174
  with gr.Column(scale=1):
 
175
  run_btn = gr.Button("🚀 Run", variant="primary")
176
 
177
- model_choice = gr.Dropdown(
178
- list(HF_FILENAMES.keys()), value=DEFAULT_CKPT, label="Select Model"
 
179
  )
180
- interp_choice = gr.Radio(
181
- ["bilinear", "bicubic", "nearest"], value="bicubic",
182
- label="Resize Interpolation (Preprocessing)"
183
  )
184
 
185
- in_img = gr.Image(type="pil", label="Upload an Image")
186
-
187
- with gr.Column(scale=2):
188
- out_lbl = gr.Label(num_top_classes=2, label="Predictions")
189
 
190
  run_btn.click(
191
- predict_and_visualize,
192
- inputs=[in_img, model_choice, interp_choice],
193
  outputs=[out_lbl]
194
  )
195
 
196
- # Create a dummy examples directory if it doesn't exist
197
- example_dir = "examples"
198
- if not os.path.exists(example_dir):
199
- os.makedirs(example_dir)
200
- print(f"Created '{example_dir}' directory. Please add sample images there for UI examples.")
201
 
202
- # Check for example files before creating the component
203
- example_files = [os.path.join(example_dir, f) for f in os.listdir(example_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
 
204
  if example_files:
205
  gr.Examples(
206
  examples=[[f, DEFAULT_CKPT, "bicubic"] for f in example_files],
207
- inputs=[in_img, model_choice, interp_choice],
208
  outputs=[out_lbl],
209
- fn=predict_and_visualize,
210
  cache_examples=False,
211
  )
212
 
213
  demo.launch()
214
 
215
- # ---------------------------------------------------------------------------
216
  if __name__ == "__main__":
217
- launch_app()
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ Swin-Large AI / Non-AI ‑- now with V7 (4-class) support
4
+ -------------------------------------------------------------------
5
+ • V2 / V4 / V5(underfitting) : 2-class (photo-style AI vs. Non-AI)
6
+ • NEW V7 : 4-class (photo / anime × AI / Non-AI)
7
+ -------------------------------------------------------------------
8
+ Author : you 😊
9
  """
10
+
11
+ import os, torch, timm, math, numpy as np
 
 
12
  import torch.nn as nn
13
+ import torch.nn.functional as F
 
14
  from PIL import Image
15
  import gradio as gr
 
 
16
  from huggingface_hub import hf_hub_download
17
 
18
+ # --------------------------------------------------
19
+ # 1. Model & Checkpoint Meta-data
20
+ # --------------------------------------------------
21
+ REPO_ID = "telecomadm1145/swin-ai-detection" # 同一个 repo 存两种 ckpt 也 OK
22
  HF_FILENAMES = {
23
+ "V2": "swin_classifier_stage1_v2_epoch_3.pth",
24
+ "V4": "swin_classifier_stage1_v4.pth",
25
+ "V5(underfitting)": "swin_classifier_stage1_v5_fp16.pth",
26
+ "V7": "swin_classifier_4class_fp16_v7.pth" # <-- NEW
27
+ }
28
+
29
+ CKPT_META = {
30
+ "V2": { "n_cls": 2, "head": "v4",
31
+ "names": ["Non-AI Generated", "AI Generated"]},
32
+ "V4": { "n_cls": 2, "head": "v4",
33
+ "names": ["Non-AI Generated", "AI Generated"]},
34
+ "V5(underfitting)": { "n_cls": 2, "head": "v5",
35
+ "names": ["Non-AI Generated", "AI Generated"]},
36
+ # ---------- NEW ----------
37
+ "V7": { "n_cls": 4, "head": "v7",
38
+ "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
39
  }
40
+
41
+ DEFAULT_CKPT = "V4" # 默认仍然先加载较小的 2-类模型
42
+ MODEL_NAME = "swin_large_patch4_window12_384"
43
  LOCAL_CKPT_DIR = "./checkpoints"
 
 
44
  SEED = 4421
45
+ DROP_RATE = 0.1
 
 
46
 
47
  device = "cuda" if torch.cuda.is_available() else "cpu"
48
+ torch.manual_seed(SEED); np.random.seed(SEED)
 
49
  print(f"Using device: {device}")
50
 
51
+ # --------------------------------------------------
52
+ # 2. Global State
53
+ # --------------------------------------------------
54
+ model, current_ckpt = None, None
55
+ current_meta = None # 记录当前模型的 meta(类别数 / 名称)
56
 
57
+ # --------------------------------------------------
58
+ # 3. SwinClassifier 添加 v7 专属 MLP
59
+ # --------------------------------------------------
60
  class SwinClassifier(nn.Module):
61
+ def __init__(self, model_name, num_classes, pretrained=True,
62
+ head_version="v4"):
 
 
 
63
  super().__init__()
64
+ self.backbone = timm.create_model(
65
+ model_name, pretrained=pretrained, num_classes=0
66
+ )
67
  self.data_config = timm.data.resolve_data_config({}, model=self.backbone)
68
 
69
+ # ------- 根据版本选择不同 head -------
70
+ if head_version == "v7": # <-- V7: 极简 64-hidden, GELU
71
+ self.classifier = nn.Sequential(
72
+ nn.Dropout(DROP_RATE),
73
+ nn.Linear(self.backbone.num_features, 64),
74
+ nn.BatchNorm1d(64),
75
+ nn.GELU(),
76
+ nn.Dropout(DROP_RATE * 0.8),
77
+ nn.Linear(64, num_classes),
78
+ )
79
+
80
+ elif head_version == "v5": # V5: 512-128, GELU
81
  self.classifier = nn.Sequential(
82
+ nn.Dropout(DROP_RATE),
83
  nn.Linear(self.backbone.num_features, 512),
84
  nn.BatchNorm1d(512),
85
+ nn.GELU(),
86
+ nn.Dropout(DROP_RATE * 0.7),
87
  nn.Linear(512, 128),
88
  nn.BatchNorm1d(128),
89
+ nn.GELU(),
90
+ nn.Dropout(DROP_RATE * 0.5),
91
+ nn.Linear(128, num_classes),
92
  )
93
+
94
+ else: # V2 / V4: 512-128, ReLU
95
  self.classifier = nn.Sequential(
96
+ nn.Dropout(DROP_RATE),
97
  nn.Linear(self.backbone.num_features, 512),
98
  nn.BatchNorm1d(512),
99
  nn.ReLU(),
100
+ nn.Dropout(DROP_RATE * 0.7),
101
  nn.Linear(512, 128),
102
  nn.BatchNorm1d(128),
103
  nn.ReLU(),
104
+ nn.Dropout(DROP_RATE * 0.5),
105
+ nn.Linear(128, num_classes),
106
  )
107
 
108
  def forward(self, x):
109
+ return self.classifier(self.backbone(x))
110
+
111
 
112
+ # --------------------------------------------------
113
+ # 4. 动态加载模型
114
+ # --------------------------------------------------
115
  def load_model(ckpt_name: str):
116
+ """Load model only when `ckpt_name` changes."""
117
+ global model, current_ckpt, current_meta
 
 
 
 
 
 
 
118
 
119
+ if ckpt_name == current_ckpt and model is not None:
120
+ return
121
 
122
+ print(f"\n🔄 Switching to {ckpt_name} ...")
123
+ meta = CKPT_META[ckpt_name]
124
+ ckpt_file = hf_hub_download(
125
  repo_id=REPO_ID,
126
+ filename=HF_FILENAMES[ckpt_name],
127
+ local_dir=LOCAL_CKPT_DIR, force_download=False
 
128
  )
129
+ print(f"Checkpoint: {ckpt_file}")
 
 
 
130
 
131
+ # Build model structure
132
  model = SwinClassifier(
133
  MODEL_NAME,
134
+ num_classes = meta["n_cls"],
135
+ pretrained = False,
136
+ head_version = meta["head"]
137
  ).to(device)
138
 
139
+ # compatible load
140
+ state = torch.load(ckpt_file, map_location=device, weights_only=False)
141
  model.load_state_dict(state.get("model_state_dict", state), strict=True)
142
  model.eval()
 
 
143
 
144
+ current_ckpt, current_meta = ckpt_name, meta
145
+ print(f"✅ {ckpt_name} loaded (classes = {meta['n_cls']}).")
146
+
147
+ # --------------------------------------------------
148
+ # 5. Transform 工厂
149
+ # --------------------------------------------------
150
  def build_transform(is_training: bool, interpolation: str):
151
+ if model is None: raise RuntimeError("Model not loaded yet.")
 
152
  cfg = model.data_config.copy()
153
  cfg.update(dict(interpolation=interpolation))
154
  return timm.data.create_transform(**cfg, is_training=is_training)
155
 
156
+ # --------------------------------------------------
157
+ # 6. Inference
158
+ # --------------------------------------------------
159
+ @torch.no_grad()
160
+ def predict(image: Image.Image,
161
+ ckpt_name: str,
162
+ interpolation: str = "bicubic"):
163
 
164
+ if image is None: return None
 
 
 
165
 
166
+ load_model(ckpt_name)
167
+ tfm = build_transform(False, interpolation)
168
+ inp = tfm(image).unsqueeze(0).to(device)
169
 
170
+ probs = F.softmax(model(inp), dim=1)[0].cpu()
171
+ class_names = current_meta["names"]
172
 
173
+ # 保证 gr.Label 在 2 / 4 类都能正常显示
174
+ return {class_names[i]: float(probs[i])
175
+ for i in range(len(class_names))}
176
 
177
+ # --------------------------------------------------
178
+ # 7. Gradio UI
179
+ # --------------------------------------------------
180
+ def launch():
181
+ load_model(DEFAULT_CKPT) # 预加载
182
 
183
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
184
+ gr.Markdown("# 🖼️ Swin-Large — AI / Non-AI Detector (V2-V7)")
185
+ gr.Markdown(
186
+ "Choose a model checkpoint on the left, upload an image, "
187
+ "and click **Run** to see predictions. V7 outputs 4 classes."
188
+ )
189
 
190
  with gr.Row():
191
  with gr.Column(scale=1):
192
+ in_img = gr.Image(type="pil", label="Upload Image")
193
  run_btn = gr.Button("🚀 Run", variant="primary")
194
 
195
+ sel_ckpt = gr.Dropdown(
196
+ list(HF_FILENAMES.keys()),
197
+ value=DEFAULT_CKPT, label="Checkpoint"
198
  )
199
+ sel_interp = gr.Radio(
200
+ ["bilinear", "bicubic", "nearest"],
201
+ value="bicubic", label="Resize Interpolation"
202
  )
203
 
204
+ with gr.Column(scale=1):
205
+ # num_top_classes 设为 4,兼容 2-class / 4-class
206
+ out_lbl = gr.Label(num_top_classes=4, label="Predictions")
 
207
 
208
  run_btn.click(
209
+ predict,
210
+ inputs=[in_img, sel_ckpt, sel_interp],
211
  outputs=[out_lbl]
212
  )
213
 
214
+ # optional example folder
215
+ if not os.path.exists("examples"):
216
+ os.makedirs("examples")
217
+ print("Put some jpg/png files inside ./examples for demo examples")
 
218
 
219
+ example_files = [os.path.join("examples", f)
220
+ for f in os.listdir("examples")
221
+ if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
222
  if example_files:
223
  gr.Examples(
224
  examples=[[f, DEFAULT_CKPT, "bicubic"] for f in example_files],
225
+ inputs=[in_img, sel_ckpt, sel_interp],
226
  outputs=[out_lbl],
227
+ fn=predict,
228
  cache_examples=False,
229
  )
230
 
231
  demo.launch()
232
 
233
+ # --------------------------------------------------
234
  if __name__ == "__main__":
235
+ launch()