telecomadm1145 commited on
Commit
9094a6f
·
verified ·
1 Parent(s): 8e53ad8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -39
app.py CHANGED
@@ -42,10 +42,13 @@ CKPT_META = {
42
  "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
43
  "V2.5-CAFormer": { "n_cls": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384",
44
  "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
 
45
  "EmbeddingClassifier": {
46
  "n_cls": 2,
47
  "head": "embedding_classifier",
48
- "backbone": "hf-hub:SmilingWolf/wd-swinv2-tagger-v3",
 
 
49
  "names": ["Non-AI Generated", "AI Generated"]
50
  }
51
  }
@@ -60,7 +63,6 @@ print(f"Using device: {device}")
60
  model, current_ckpt = None, None
61
  current_meta = None
62
 
63
- # --- EmbeddingClassifier Model ---
64
  class EmbeddingClassifier(nn.Module):
65
  def __init__(self, input_dim=1024, hidden_dim1=512, hidden_dim2=256, output_dim=1):
66
  super().__init__()
@@ -79,28 +81,32 @@ class EmbeddingClassifier(nn.Module):
79
  def forward(self, x):
80
  return self.net(x)
81
 
 
82
  class EmbeddingClassifierModel(nn.Module):
83
- def __init__(self, backbone_name, num_classes):
84
  super().__init__()
85
- self.backbone = timm.create_model(backbone_name, pretrained=True, num_classes=0,strict=False)
86
  self.data_config = timm.data.resolve_data_config({}, model=self.backbone)
87
- # This specific model uses a binary classifier with a single output logit
88
- # that indicates if an image is real (Non-AI). So, output_dim is fixed to 1.
89
- self.classifier = EmbeddingClassifier(input_dim=1024, hidden_dim1=512, hidden_dim2=256, output_dim=1)
90
 
91
  def forward(self, x):
92
  features = self.backbone(x)
93
- # The classifier returns a single logit. A positive value indicates "real" (Non-AI).
94
- # This corresponds to class 0 in `current_meta["names"]`.
95
- logit_class0 = self.classifier(features)
96
 
97
- # To maintain compatibility with the `predict` function which expects multi-class logits
98
- # and applies softmax, we construct a 2-class logit tensor.
99
- # We assume the logit for the other class (AI-generated, class 1) is 0.
100
- logit_class1 = torch.zeros_like(logit_class0)
 
 
 
 
 
 
 
 
101
 
102
- # The final logits are for ["Non-AI", "AI"], i.e., [logit_class0, logit_class1].
103
- return torch.cat([logit_class0, logit_class1], dim=1)
104
 
105
  # --- Original DINOv2 Classifier (Weighted Attention Pooling) ---
106
  class DINOv2Classifier_WeightedPool(nn.Module):
@@ -234,22 +240,48 @@ def load_model(ckpt_name: str):
234
  meta = CKPT_META[ckpt_name]
235
  ckpt_filename = HF_FILENAMES[ckpt_name]
236
 
237
- ckpt_file = hf_hub_download(
238
- repo_id=REPO_ID,
239
- filename=ckpt_filename,
240
- local_dir=LOCAL_CKPT_DIR, force_download=False
241
- )
242
- print(f"Checkpoint: {ckpt_file}")
243
-
244
- # Build model structure based on model_type or head
245
  head_version = meta.get("head", "v4")
246
  if head_version == "embedding_classifier":
 
 
247
  model = EmbeddingClassifierModel(
248
- backbone_name=meta["backbone"],
249
  num_classes=meta["n_cls"]
250
  ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  else:
252
- # Existing logic for other models
 
 
 
 
 
 
 
 
253
  model_type = meta.get("model_type")
254
  if model_type == "dinov2_weighted_pool":
255
  model = DINOv2Classifier_WeightedPool(
@@ -269,18 +301,12 @@ def load_model(ckpt_name: str):
269
  head_version=head_version
270
  ).to(device)
271
 
272
- # Compatible load for .pth and .safetensors
273
- if ckpt_filename.endswith(".safetensors"):
274
- state = load_file(ckpt_file, device=device)
275
- else:
276
- state = torch.load(ckpt_file, map_location=device, weights_only=False)
277
-
278
- # Load state dict
279
- if head_version == "embedding_classifier":
280
- # For EmbeddingClassifierModel, we need to load the state dict for the classifier part
281
- # Assuming the checkpoint only contains the classifier state dict
282
- model.classifier.load_state_dict(state)
283
- else:
284
  model.load_state_dict(state.get("model_state_dict", state), strict=True)
285
 
286
  model.eval()
@@ -318,7 +344,13 @@ def predict(image: Image.Image,
318
  tfm = build_transform(False, interpolation)
319
 
320
  inp = tfm(image).unsqueeze(0).to(device)
321
- probs = F.softmax(model(inp), dim=1)[0].cpu()
 
 
 
 
 
 
322
  class_names = current_meta["names"]
323
  # 保证 gr.Label 在 2 / 4 类都能正常显示
324
  return {class_names[i]: float(probs[i])
 
42
  "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
43
  "V2.5-CAFormer": { "n_cls": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384",
44
  "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
45
+ # --- MODIFIED: Added specific keys for the new loading logic ---
46
  "EmbeddingClassifier": {
47
  "n_cls": 2,
48
  "head": "embedding_classifier",
49
+ "timm_model_name": "swinv2_base_window8_256.ms_in1k",
50
+ "backbone_repo_id": "SmilingWolf/wd-swinv2-tagger-v3",
51
+ "backbone_filename": "model.safetensors",
52
  "names": ["Non-AI Generated", "AI Generated"]
53
  }
54
  }
 
63
  model, current_ckpt = None, None
64
  current_meta = None
65
 
 
66
  class EmbeddingClassifier(nn.Module):
67
  def __init__(self, input_dim=1024, hidden_dim1=512, hidden_dim2=256, output_dim=1):
68
  super().__init__()
 
81
  def forward(self, x):
82
  return self.net(x)
83
 
84
+ # MODIFIED: Changed __init__ to accept timm_model_name and use pretrained=False
85
  class EmbeddingClassifierModel(nn.Module):
86
+ def __init__(self, timm_model_name, num_classes):
87
  super().__init__()
88
+ self.backbone = timm.create_model(timm_model_name, pretrained=False, num_classes=0)
89
  self.data_config = timm.data.resolve_data_config({}, model=self.backbone)
90
+ self.classifier = EmbeddingClassifier(input_dim=self.backbone.num_features)
 
 
91
 
92
  def forward(self, x):
93
  features = self.backbone(x)
94
+ # The classifier returns a single value (probability of being Non-AI)
95
+ prob_class0 = self.classifier(features)
 
96
 
97
+ # To maintain compatibility with the `predict` function which expects multi-class outputs
98
+ # and applies softmax, we construct a 2-class output.
99
+ # prob_class1 is simply 1 - prob_class0
100
+ prob_class1 = 1 - prob_class0
101
+
102
+ # The final output is for ["Non-AI", "AI"], i.e., [prob_class0, prob_class1].
103
+ # The softmax in predict() will be applied to this, so we should return logits.
104
+ # However, since the original output is a sigmoid, we can work with probabilities
105
+ # and just return them directly. The gr.Label will normalize this.
106
+ # A simpler way is to construct logits that would result in these probabilities.
107
+ # Let's stick to the original logic's output format.
108
+ return torch.cat([prob_class0, prob_class1], dim=1)
109
 
 
 
110
 
111
  # --- Original DINOv2 Classifier (Weighted Attention Pooling) ---
112
  class DINOv2Classifier_WeightedPool(nn.Module):
 
240
  meta = CKPT_META[ckpt_name]
241
  ckpt_filename = HF_FILENAMES[ckpt_name]
242
 
243
+ # --- MODIFIED: Special handling for EmbeddingClassifier ---
 
 
 
 
 
 
 
244
  head_version = meta.get("head", "v4")
245
  if head_version == "embedding_classifier":
246
+ # 1. Create the model structure with a non-pretrained backbone
247
+ print(f"Creating backbone: {meta['timm_model_name']}")
248
  model = EmbeddingClassifierModel(
249
+ timm_model_name=meta["timm_model_name"],
250
  num_classes=meta["n_cls"]
251
  ).to(device)
252
+
253
+ # 2. Download and load backbone weights from SmilingWolf's repo
254
+ print(f"Loading backbone weights from {meta['backbone_repo_id']}...")
255
+ backbone_ckpt_file = hf_hub_download(
256
+ repo_id=meta["backbone_repo_id"],
257
+ filename=meta["backbone_filename"],
258
+ local_dir=LOCAL_CKPT_DIR, force_download=False
259
+ )
260
+ backbone_state = load_file(backbone_ckpt_file, device=device)
261
+ model.backbone.load_state_dict(backbone_state)
262
+ print("✅ Backbone weights loaded.")
263
+
264
+ # 3. Download and load classifier (head) weights from the main repo
265
+ print(f"Loading classifier head weights from {REPO_ID}...")
266
+ classifier_ckpt_file = hf_hub_download(
267
+ repo_id=REPO_ID,
268
+ filename=ckpt_filename, # This is 'swinv2_v3_v1.pth'
269
+ local_dir=LOCAL_CKPT_DIR, force_download=False
270
+ )
271
+ classifier_state = torch.load(classifier_ckpt_file, map_location=device, weights_only=False)
272
+ model.classifier.load_state_dict(classifier_state)
273
+ print("✅ Classifier head weights loaded.")
274
+
275
  else:
276
+ # --- Original logic for all other models ---
277
+ ckpt_file = hf_hub_download(
278
+ repo_id=REPO_ID,
279
+ filename=ckpt_filename,
280
+ local_dir=LOCAL_CKPT_DIR, force_download=False
281
+ )
282
+ print(f"Checkpoint: {ckpt_file}")
283
+
284
+ # Build model structure based on model_type or head
285
  model_type = meta.get("model_type")
286
  if model_type == "dinov2_weighted_pool":
287
  model = DINOv2Classifier_WeightedPool(
 
301
  head_version=head_version
302
  ).to(device)
303
 
304
+ # Compatible load for .pth and .safetensors
305
+ if ckpt_filename.endswith(".safetensors"):
306
+ state = load_file(ckpt_file, device=device)
307
+ else:
308
+ state = torch.load(ckpt_file, map_location=device, weights_only=False)
309
+
 
 
 
 
 
 
310
  model.load_state_dict(state.get("model_state_dict", state), strict=True)
311
 
312
  model.eval()
 
344
  tfm = build_transform(False, interpolation)
345
 
346
  inp = tfm(image).unsqueeze(0).to(device)
347
+ # MODIFIED: For EmbeddingClassifier, the output is already probabilities, no need for softmax.
348
+ # For others, softmax is needed.
349
+ if current_meta["head"] == "embedding_classifier":
350
+ probs = model(inp)[0].cpu()
351
+ else:
352
+ probs = F.softmax(model(inp), dim=1)[0].cpu()
353
+
354
  class_names = current_meta["names"]
355
  # 保证 gr.Label 在 2 / 4 类都能正常显示
356
  return {class_names[i]: float(probs[i])