telecomadm1145 commited on
Commit
b6adf0f
·
verified ·
1 Parent(s): 94778bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -10
app.py CHANGED
@@ -5,6 +5,7 @@
5
  • Swin-V7 / V8 / V9 : 4-class (photo / anime × AI / Non-AI)
6
  • CAFormer-V2.5 : 4-class (photo / anime × AI / Non-AI)
7
  • V3-Emb : 2-class (AI vs. Non-AI)
 
8
  -------------------------------------------------------------------
9
  """
10
  import os, torch, timm, numpy as np
@@ -24,7 +25,8 @@ HF_FILENAMES = {
24
  "V2": "swin_classifier_stage1_v2_epoch_3.pth",
25
  "V4": "swin_classifier_stage1_v4.pth",
26
  "V9": "swin_classifier_4class_fp16_v9_acc9861.pth",
27
- "V3-Emb": "swinv2_v3_v3.pth"
 
28
  }
29
  CKPT_META = {
30
  "V2": { "n_cls": 2, "head": "v4", "backbone": "swin_large_patch4_window12_384",
@@ -42,9 +44,21 @@ CKPT_META = {
42
  "backbone_repo_id": "SmilingWolf/wd-swinv2-tagger-v3",
43
  "backbone_filename": "model.safetensors",
44
  "names": ["Non-AI Generated", "AI Generated"]
 
 
 
 
 
 
 
 
 
 
 
 
45
  }
46
  }
47
- DEFAULT_CKPT = "V3-Emb"
48
  LOCAL_CKPT_DIR = "./checkpoints"
49
  SEED = 4421
50
  DROP_RATE = 0.1
@@ -55,6 +69,7 @@ print(f"Using device: {device}")
55
  model, current_ckpt = None, None
56
  current_meta = None
57
 
 
58
  class EmbeddingClassifier(nn.Module):
59
  def __init__(self, input_dim=1024, hidden_dim1=4096, hidden_dim2=256, output_dim=1):
60
  super().__init__()
@@ -67,12 +82,13 @@ class EmbeddingClassifier(nn.Module):
67
  nn.LayerNorm(hidden_dim2),
68
  nn.GELU(),
69
  nn.Dropout(0.4),
70
- nn.Linear(hidden_dim2, output_dim),
71
- nn.Sigmoid(),
72
  )
73
  def forward(self, x):
74
- return self.net(x)
75
 
 
76
  class EmbeddingClassifierModel(nn.Module):
77
  def __init__(self, timm_model_name, num_classes):
78
  super().__init__()
@@ -82,10 +98,134 @@ class EmbeddingClassifierModel(nn.Module):
82
 
83
  def forward(self, x):
84
  features = self.backbone(x)
85
- prob_class0 = self.classifier(features)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  prob_class1 = 1 - prob_class0
87
  return torch.cat([prob_class0, prob_class1], dim=1)
88
 
 
89
  class SwinClassifier(nn.Module):
90
  def __init__(self, model_name, num_classes, pretrained=True,
91
  head_version="v4"):
@@ -141,6 +281,8 @@ def load_model(ckpt_name: str):
141
  ckpt_filename = HF_FILENAMES[ckpt_name]
142
 
143
  head_version = meta.get("head", "v4")
 
 
144
  if head_version == "embedding_classifier":
145
  print(f"Creating backbone: {meta['timm_model_name']}")
146
  model = EmbeddingClassifierModel(
@@ -167,7 +309,40 @@ def load_model(ckpt_name: str):
167
  classifier_state = torch.load(classifier_ckpt_file, map_location=device, weights_only=False)
168
  model.classifier.load_state_dict(classifier_state)
169
  print("✅ Classifier head weights loaded.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  else:
172
  ckpt_file = hf_hub_download(
173
  repo_id=REPO_ID,
@@ -235,9 +410,12 @@ def predict(image: Image.Image,
235
  tfm = build_transform(False, interpolation)
236
  inp = tfm(image).unsqueeze(0).to(device)
237
 
238
- if current_meta["head"] == "embedding_classifier":
 
 
239
  probs = model(inp)[0].cpu()
240
  else:
 
241
  probs = F.softmax(model(inp), dim=1)[0].cpu()
242
 
243
  class_names = current_meta["names"]
@@ -251,13 +429,13 @@ def launch():
251
  gr.Markdown("# AI Detector")
252
  gr.Markdown(
253
  "Choose a model checkpoint on the left, upload an image, "
254
- "and click **Run** to see predictions. V3-Emb produces the best results."
255
  )
256
  with gr.Row():
257
  with gr.Column(scale=1):
258
  run_btn = gr.Button("🚀 Run", variant="primary")
259
  sel_ckpt = gr.Dropdown(
260
- list(HF_FILENAMES.keys()),
261
  value=DEFAULT_CKPT, label="Checkpoint"
262
  )
263
  sel_interp = gr.Radio(
@@ -289,4 +467,5 @@ def launch():
289
  demo.launch()
290
 
291
  if __name__ == "__main__":
292
- launch()
 
 
5
  • Swin-V7 / V8 / V9 : 4-class (photo / anime × AI / Non-AI)
6
  • CAFormer-V2.5 : 4-class (photo / anime × AI / Non-AI)
7
  • V3-Emb : 2-class (AI vs. Non-AI)
8
+ • V3-Emb-MoE (新) : 2-class (AI vs. Non-AI, MoE Head)
9
  -------------------------------------------------------------------
10
  """
11
  import os, torch, timm, numpy as np
 
25
  "V2": "swin_classifier_stage1_v2_epoch_3.pth",
26
  "V4": "swin_classifier_stage1_v4.pth",
27
  "V9": "swin_classifier_4class_fp16_v9_acc9861.pth",
28
+ "V3-Emb": "swinv2_v3_v3.pth",
29
+ "V3-Emb-MoE": "smoe_emb.pth" # <-- 新增 MoE 模型文件
30
  }
31
  CKPT_META = {
32
  "V2": { "n_cls": 2, "head": "v4", "backbone": "swin_large_patch4_window12_384",
 
44
  "backbone_repo_id": "SmilingWolf/wd-swinv2-tagger-v3",
45
  "backbone_filename": "model.safetensors",
46
  "names": ["Non-AI Generated", "AI Generated"]
47
+ },
48
+ # <-- 新增 MoE 模型元数据 -->
49
+ "V3-Emb-MoE": {
50
+ "n_cls": 2,
51
+ "head": "moe_embedding_classifier", # 新的 head 类型
52
+ "timm_model_name": "hf_hub:SmilingWolf/wd-swinv2-tagger-v3",
53
+ "backbone_repo_id": "SmilingWolf/wd-swinv2-tagger-v3",
54
+ "backbone_filename": "model.safetensors",
55
+ "names": ["Non-AI Generated", "AI Generated"],
56
+ "num_experts": 16, # <-- MoE 特定参数
57
+ "moe_hidden_dim": 1024, # <-- MoE 特定参数
58
+ "top_k": 2 # 假设 top_k=2,与训练脚本一致
59
  }
60
  }
61
+ DEFAULT_CKPT = "V3-Emb-MoE" # <-- 默认为新的 MoE 模型
62
  LOCAL_CKPT_DIR = "./checkpoints"
63
  SEED = 4421
64
  DROP_RATE = 0.1
 
69
  model, current_ckpt = None, None
70
  current_meta = None
71
 
72
+ # --- 标准分类头 (V3-Emb) ---
73
  class EmbeddingClassifier(nn.Module):
74
  def __init__(self, input_dim=1024, hidden_dim1=4096, hidden_dim2=256, output_dim=1):
75
  super().__init__()
 
82
  nn.LayerNorm(hidden_dim2),
83
  nn.GELU(),
84
  nn.Dropout(0.4),
85
+ nn.Linear(hidden_dim2, output_dim)
86
+ # <-- 修改: 移除了 nn.Sigmoid(),包装器将处理激活
87
  )
88
  def forward(self, x):
89
+ return self.net(x) # 输出 logits
90
 
91
+ # --- 标准分类头包装器 (V3-Emb) ---
92
  class EmbeddingClassifierModel(nn.Module):
93
  def __init__(self, timm_model_name, num_classes):
94
  super().__init__()
 
98
 
99
  def forward(self, x):
100
  features = self.backbone(x)
101
+ logits = self.classifier(features) # 获取 logits
102
+ # <-- 修改: 在此处应用 sigmoid 将 logits 转为 prob ---
103
+ prob_class0 = torch.sigmoid(logits)
104
+ prob_class1 = 1 - prob_class0
105
+ return torch.cat([prob_class0, prob_class1], dim=1)
106
+
107
+ # --- 新增: MoE 模型定义 (V3-Emb-MoE) ---
108
+ class Expert(nn.Module):
109
+ def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.4):
110
+ super().__init__()
111
+ self.net = nn.Sequential(
112
+ nn.Linear(input_dim, hidden_dim),
113
+ nn.GELU(),
114
+ nn.Dropout(dropout),
115
+ nn.Linear(hidden_dim, output_dim)
116
+ )
117
+ def forward(self, x):
118
+ return self.net(x)
119
+
120
+ class SparseMoE(nn.Module):
121
+ def __init__(self, input_dim, num_experts, top_k, expert_hidden_dim, load_balancing_alpha=1e-2):
122
+ super().__init__()
123
+ self.input_dim = input_dim
124
+ self.num_experts = num_experts
125
+ self.top_k = top_k
126
+ self.load_balancing_alpha = load_balancing_alpha
127
+ self.gate = nn.Linear(input_dim, num_experts)
128
+ self.experts = nn.ModuleList([
129
+ Expert(input_dim, expert_hidden_dim, input_dim) for _ in range(num_experts)
130
+ ])
131
+
132
+ def forward(self, x):
133
+ batch_size, _ = x.shape
134
+ gate_logits = self.gate(x)
135
+ gate_probs = torch.softmax(gate_logits, dim=-1)
136
+ top_k_weights, top_k_indices = torch.topk(gate_probs, self.top_k, dim=-1)
137
+ top_k_weights = top_k_weights / torch.sum(top_k_weights, dim=-1, keepdim=True)
138
+
139
+ # 辅助损失 (仅在训练时重要)
140
+ tokens_per_expert_onehot = nn.functional.one_hot(top_k_indices, self.num_experts).sum(dim=1).float()
141
+ f_i = tokens_per_expert_onehot.mean(dim=0)
142
+ P_i = gate_probs.mean(dim=0)
143
+ aux_loss = self.load_balancing_alpha * self.num_experts * torch.mean(f_i * P_i)
144
+
145
+ expanded_x = x.unsqueeze(1).expand(-1, self.top_k, -1)
146
+ flat_x = expanded_x.flatten(0, 1)
147
+ flat_top_k_indices = top_k_indices.flatten()
148
+ flat_output = torch.zeros_like(flat_x)
149
+
150
+ for i in range(self.num_experts):
151
+ mask = (flat_top_k_indices == i)
152
+ if mask.any():
153
+ expert_inputs = flat_x[mask]
154
+ expert_outputs = self.experts[i](expert_inputs)
155
+ flat_output[mask] = expert_outputs
156
+
157
+ expert_outputs_grouped = flat_output.view(batch_size, self.top_k, self.input_dim)
158
+ weighted_outputs = top_k_weights.unsqueeze(-1) * expert_outputs_grouped
159
+ final_output = torch.sum(weighted_outputs, dim=1)
160
+
161
+ return final_output, aux_loss
162
+
163
+ class MoEClassifier(nn.Module):
164
+ def __init__(self, input_dim=1024, output_dim=1, num_experts=8, top_k=2,
165
+ moe_hidden_dim=2048, head_hidden_dim=256, load_balancing_alpha=1e-2):
166
+ super().__init__()
167
+ self.input_dim = input_dim
168
+ self.num_experts = num_experts
169
+ self.top_k = top_k
170
+ self.moe_hidden_dim = moe_hidden_dim
171
+ self.head_hidden_dim = head_hidden_dim
172
+ self.load_balancing_alpha = load_balancing_alpha
173
+ self.pre_moe_net = nn.Sequential(
174
+ nn.Linear(input_dim, input_dim),
175
+ nn.LayerNorm(input_dim),
176
+ nn.GELU()
177
+ )
178
+ self.moe_layer = SparseMoE(
179
+ input_dim=input_dim,
180
+ num_experts=num_experts,
181
+ top_k=top_k,
182
+ expert_hidden_dim=moe_hidden_dim,
183
+ load_balancing_alpha=load_balancing_alpha
184
+ )
185
+ self.moe_ln = nn.LayerNorm(input_dim)
186
+ self.moe_dropout = nn.Dropout(0.4)
187
+ self.head = nn.Sequential(
188
+ nn.Linear(input_dim, head_hidden_dim),
189
+ nn.LayerNorm(head_hidden_dim),
190
+ nn.GELU(),
191
+ nn.Dropout(0.4),
192
+ nn.Linear(head_hidden_dim, output_dim) # 输出 logits
193
+ )
194
+
195
+ def forward(self, x):
196
+ pre_moe_out = self.pre_moe_net(x)
197
+ moe_input = pre_moe_out
198
+ moe_output, aux_loss = self.moe_layer(moe_input)
199
+ moe_output = self.moe_dropout(moe_output)
200
+ post_moe = self.moe_ln(moe_output + moe_input)
201
+ logits = self.head(post_moe)
202
+ return logits, aux_loss
203
+
204
+ # --- 新增: MoE 分类头包装器 (V3-Emb-MoE) ---
205
+ class MoEEmbeddingClassifierModel(nn.Module):
206
+ def __init__(self, timm_model_name, num_classes, num_experts, moe_hidden_dim, top_k=2):
207
+ super().__init__()
208
+ self.backbone = timm.create_model(timm_model_name, pretrained=False, num_classes=0)
209
+ self.data_config = timm.data.resolve_data_config({}, model=self.backbone)
210
+ # 使用 MoEClassifier 作为分类头
211
+ self.classifier = MoEClassifier(
212
+ input_dim=self.backbone.num_features,
213
+ output_dim=1, # 2-class (AI vs Non-AI)
214
+ num_experts=num_experts,
215
+ top_k=top_k,
216
+ moe_hidden_dim=moe_hidden_dim,
217
+ head_hidden_dim=256 # 保持与 V3-Emb 的 head_hidden_dim 一致
218
+ )
219
+
220
+ def forward(self, x):
221
+ features = self.backbone(x)
222
+ logits, aux_loss = self.classifier(features) # MoE 返回 (logits, aux_loss)
223
+ # 推理时我们只关心 logits
224
+ prob_class0 = torch.sigmoid(logits)
225
  prob_class1 = 1 - prob_class0
226
  return torch.cat([prob_class0, prob_class1], dim=1)
227
 
228
+
229
  class SwinClassifier(nn.Module):
230
  def __init__(self, model_name, num_classes, pretrained=True,
231
  head_version="v4"):
 
281
  ckpt_filename = HF_FILENAMES[ckpt_name]
282
 
283
  head_version = meta.get("head", "v4")
284
+
285
+ # --- 修改: 扩展加载逻辑 ---
286
  if head_version == "embedding_classifier":
287
  print(f"Creating backbone: {meta['timm_model_name']}")
288
  model = EmbeddingClassifierModel(
 
309
  classifier_state = torch.load(classifier_ckpt_file, map_location=device, weights_only=False)
310
  model.classifier.load_state_dict(classifier_state)
311
  print("✅ Classifier head weights loaded.")
312
+
313
+ # --- 新增: MoE 加载逻辑 ---
314
+ elif head_version == "moe_embedding_classifier":
315
+ print(f"Creating MoE model with backbone: {meta['timm_model_name']}")
316
+ model = MoEEmbeddingClassifierModel(
317
+ timm_model_name=meta["timm_model_name"],
318
+ num_classes=meta["n_cls"],
319
+ num_experts=meta["num_experts"],
320
+ moe_hidden_dim=meta["moe_hidden_dim"],
321
+ top_k=meta.get("top_k", 2) # 从 meta 或 默认值
322
+ ).to(device)
323
+
324
+ print(f"Loading backbone weights from {meta['backbone_repo_id']}...")
325
+ backbone_ckpt_file = hf_hub_download(
326
+ repo_id=meta["backbone_repo_id"],
327
+ filename=meta["backbone_filename"],
328
+ local_dir=LOCAL_CKPT_DIR, force_download=False
329
+ )
330
+ backbone_state = load_file(backbone_ckpt_file, device=device)
331
+ model.backbone.load_state_dict(backbone_state,strict=False)
332
+ print("✅ Backbone weights loaded.")
333
 
334
+ print(f"Loading MoE classifier head weights from {REPO_ID}...")
335
+ classifier_ckpt_file = hf_hub_download(
336
+ repo_id=REPO_ID,
337
+ filename=ckpt_filename,
338
+ local_dir=LOCAL_CKPT_DIR, force_download=False
339
+ )
340
+ # 假设 MoE 头部保存的也是 state_dict
341
+ classifier_state = torch.load(classifier_ckpt_file, map_location=device, weights_only=False)
342
+ model.classifier.load_state_dict(classifier_state)
343
+ print("✅ MoE Classifier head weights loaded.")
344
+
345
+ # --- 原始 Swin 加载逻辑 ---
346
  else:
347
  ckpt_file = hf_hub_download(
348
  repo_id=REPO_ID,
 
410
  tfm = build_transform(False, interpolation)
411
  inp = tfm(image).unsqueeze(0).to(device)
412
 
413
+ # --- 修改: 扩展 logits/prob 处理 ---
414
+ # V3-Emb 和 V3-Emb-MoE 包装器都已在其 forward 中转换为 2 类概率
415
+ if current_meta["head"] in ["embedding_classifier", "moe_embedding_classifier"]:
416
  probs = model(inp)[0].cpu()
417
  else:
418
+ # 其他模型 (V2, V4, V9, CAFormer) 输出 logits,需要 softmax
419
  probs = F.softmax(model(inp), dim=1)[0].cpu()
420
 
421
  class_names = current_meta["names"]
 
429
  gr.Markdown("# AI Detector")
430
  gr.Markdown(
431
  "Choose a model checkpoint on the left, upload an image, "
432
+ "and click **Run** to see predictions. V3-Emb-MoE produces the best results."
433
  )
434
  with gr.Row():
435
  with gr.Column(scale=1):
436
  run_btn = gr.Button("🚀 Run", variant="primary")
437
  sel_ckpt = gr.Dropdown(
438
+ list(HF_FILENAMES.keys()), # 自动包含 "V3-Emb-MoE"
439
  value=DEFAULT_CKPT, label="Checkpoint"
440
  )
441
  sel_interp = gr.Radio(
 
467
  demo.launch()
468
 
469
  if __name__ == "__main__":
470
+ launch()
471
+