telecomadm1145 commited on
Commit
d66a824
·
verified ·
1 Parent(s): 5c5ed86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -8
app.py CHANGED
@@ -7,7 +7,7 @@ Swin/CAFormer/DINOv2 AI detection
7
  • CAFormer-V10 : 4-class (photo / anime × AI / Non-AI)
8
  • DINOv2-4class : 4-class (photo / anime × AI / Non-AI)
9
  • DINOv2-MeanPool-Contrastive : 4-class (photo / anime × AI / Non-AI)
10
- EmbeddingClassifier : 2-class (AI vs. Non-AI)
11
  -------------------------------------------------------------------
12
  Author: telecomadm1145
13
  """
@@ -31,7 +31,7 @@ HF_FILENAMES = {
31
  "V2-Swin": "swin_classifier_stage1_v2_epoch_3.pth",
32
  "V4-Swin": "swin_classifier_stage1_v4.pth",
33
  "V9-Swin": "swin_classifier_4class_fp16_v9_acc9861.pth",
34
- "EmbeddingClassifier": "swinv2_v3_v1.pth"
35
  }
36
  CKPT_META = {
37
  "V2": { "n_cls": 2, "head": "v4", "backbone": "swin_large_patch4_window12_384",
@@ -43,10 +43,10 @@ CKPT_META = {
43
  "V2.5-CAFormer": { "n_cls": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384",
44
  "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
45
  # --- MODIFIED: Added specific keys for the new loading logic ---
46
- "EmbeddingClassifier": {
47
  "n_cls": 2,
48
  "head": "embedding_classifier",
49
- "timm_model_name": "swinv2_base_window8_256.ms_in1k",
50
  "backbone_repo_id": "SmilingWolf/wd-swinv2-tagger-v3",
51
  "backbone_filename": "model.safetensors",
52
  "names": ["Non-AI Generated", "AI Generated"]
@@ -322,6 +322,34 @@ def build_transform(is_training: bool, interpolation: str):
322
  cfg.update(dict(interpolation=interpolation))
323
  return timm.data.create_transform(**cfg, is_training=is_training)
324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  # --------------------------------------------------
326
  # 6. Inference
327
  # --------------------------------------------------
@@ -331,19 +359,39 @@ def predict(image: Image.Image,
331
  interpolation: str = "bicubic"):
332
  if image is None: return None
333
  load_model(ckpt_name)
334
- # Select transform based on the current model type
335
- if "dinov2" in current_meta.get("model_type", ""):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  # DINOv2 specific transform
337
  tfm = transforms.Compose([
338
  transforms.Resize((224, 224)),
339
  transforms.ToTensor(),
340
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
341
  ])
 
342
  else:
343
- # Original transform logic for Swin/CAFormer/EmbeddingClassifier
344
  tfm = build_transform(False, interpolation)
 
 
 
 
345
 
346
- inp = tfm(image).unsqueeze(0).to(device)
347
  # MODIFIED: For EmbeddingClassifier, the output is already probabilities, no need for softmax.
348
  # For others, softmax is needed.
349
  if current_meta["head"] == "embedding_classifier":
 
7
  • CAFormer-V10 : 4-class (photo / anime × AI / Non-AI)
8
  • DINOv2-4class : 4-class (photo / anime × AI / Non-AI)
9
  • DINOv2-MeanPool-Contrastive : 4-class (photo / anime × AI / Non-AI)
10
+ V1-Emb : 2-class (AI vs. Non-AI)
11
  -------------------------------------------------------------------
12
  Author: telecomadm1145
13
  """
 
31
  "V2-Swin": "swin_classifier_stage1_v2_epoch_3.pth",
32
  "V4-Swin": "swin_classifier_stage1_v4.pth",
33
  "V9-Swin": "swin_classifier_4class_fp16_v9_acc9861.pth",
34
+ "V1-Emb": "swinv2_v3_v1.pth"
35
  }
36
  CKPT_META = {
37
  "V2": { "n_cls": 2, "head": "v4", "backbone": "swin_large_patch4_window12_384",
 
43
  "V2.5-CAFormer": { "n_cls": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384",
44
  "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
45
  # --- MODIFIED: Added specific keys for the new loading logic ---
46
+ "V1-Emb": {
47
  "n_cls": 2,
48
  "head": "embedding_classifier",
49
+ "timm_model_name": "hf_hub:SmilingWolf/wd-swinv2-tagger-v3",
50
  "backbone_repo_id": "SmilingWolf/wd-swinv2-tagger-v3",
51
  "backbone_filename": "model.safetensors",
52
  "names": ["Non-AI Generated", "AI Generated"]
 
322
  cfg.update(dict(interpolation=interpolation))
323
  return timm.data.create_transform(**cfg, is_training=is_training)
324
 
325
+ # ######################################################################
326
+ # START: Preprocessing functions for V1-Emb model, copied from 2nd script
327
+ # ######################################################################
328
+ def pil_ensure_rgb(image: Image.Image) -> Image.Image:
329
+ # convert to RGB/RGBA if not already (deals with palette images etc.)
330
+ if image.mode not in ["RGB", "RGBA"]:
331
+ image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
332
+ # convert RGBA to RGB with white background
333
+ if image.mode == "RGBA":
334
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
335
+ canvas.alpha_composite(image)
336
+ image = canvas.convert("RGB")
337
+ return image
338
+
339
+
340
+ def pil_pad_square(image: Image.Image) -> Image.Image:
341
+ w, h = image.size
342
+ # get the largest dimension so we can pad to a square
343
+ px = max(image.size)
344
+ # pad to square with white background
345
+ canvas = Image.new("RGB", (px, px), (255, 255, 255))
346
+ canvas.paste(image, ((px - w) // 2, (px - h) // 2))
347
+ return canvas
348
+ # ####################################################################
349
+ # END: Preprocessing functions for V1-Emb model
350
+ # ####################################################################
351
+
352
+
353
  # --------------------------------------------------
354
  # 6. Inference
355
  # --------------------------------------------------
 
359
  interpolation: str = "bicubic"):
360
  if image is None: return None
361
  load_model(ckpt_name)
362
+
363
+ # ####################################################################
364
+ # START: MODIFIED preprocessing logic
365
+ # ####################################################################
366
+ if ckpt_name == "V1-Emb":
367
+ # Specific preprocessing for the V1-Emb model based on the tagger script
368
+ # 1. Ensure RGB and pad to a square to prevent distortion
369
+ processed_image = pil_ensure_rgb(image)
370
+ processed_image = pil_pad_square(processed_image)
371
+
372
+ # 2. Apply standard timm transforms (resize, tensor, normalize)
373
+ tfm = build_transform(False, interpolation)
374
+ inp = tfm(processed_image).unsqueeze(0).to(device)
375
+
376
+ # 3. Convert from RGB to BGR as required by the original model
377
+ inp = inp[:, [2, 1, 0]]
378
+
379
+ elif "dinov2" in current_meta.get("model_type", ""):
380
  # DINOv2 specific transform
381
  tfm = transforms.Compose([
382
  transforms.Resize((224, 224)),
383
  transforms.ToTensor(),
384
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
385
  ])
386
+ inp = tfm(image).unsqueeze(0).to(device)
387
  else:
388
+ # Original transform logic for Swin/CAFormer
389
  tfm = build_transform(False, interpolation)
390
+ inp = tfm(image).unsqueeze(0).to(device)
391
+ # ####################################################################
392
+ # END: MODIFIED preprocessing logic
393
+ # ####################################################################
394
 
 
395
  # MODIFIED: For EmbeddingClassifier, the output is already probabilities, no need for softmax.
396
  # For others, softmax is needed.
397
  if current_meta["head"] == "embedding_classifier":