telecomadm1145 commited on
Commit
0358510
·
verified ·
1 Parent(s): 6211642

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -16
app.py CHANGED
@@ -1,10 +1,11 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- Swin/CAFormer AI detection
4
  -------------------------------------------------------------------
5
  • Swin-V2 / V4 : 2-class (AI vs. Non-AI)
6
  • Swin-V7 / V8 / V9 : 4-class (photo / anime × AI / Non-AI)
7
  • CAFormer-V10 : 4-class (photo / anime × AI / Non-AI)
 
8
  -------------------------------------------------------------------
9
  Author: telecomadm1145
10
  """
@@ -16,6 +17,9 @@ from PIL import Image
16
  import gradio as gr
17
  from huggingface_hub import hf_hub_download
18
  from safetensors.torch import load_file # Added for .safetensors support
 
 
 
19
 
20
  # --------------------------------------------------
21
  # 1. Model & Checkpoint Meta-data
@@ -29,6 +33,9 @@ HF_FILENAMES = {
29
  "V8": "swin_classifier_4class_fp16_v8_epoch7_acc9740.pth",
30
  "V9": "swin_classifier_4class_fp16_v9_acc9861.pth",
31
  "V1-CAFormer": "caformer_b36_4class.safetensors",
 
 
 
32
  }
33
 
34
  CKPT_META = {
@@ -46,12 +53,24 @@ CKPT_META = {
46
  "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
47
  "V1-CAFormer": { "n_cls": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384",
48
  "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
 
 
 
 
 
 
 
 
 
 
 
49
  }
50
 
51
  DEFAULT_CKPT = "V1-CAFormer"
52
  LOCAL_CKPT_DIR = "./checkpoints"
53
  SEED = 4421
54
  DROP_RATE = 0.1
 
55
 
56
  device = "cuda" if torch.cuda.is_available() else "cpu"
57
  torch.manual_seed(SEED); np.random.seed(SEED)
@@ -60,6 +79,55 @@ print(f"Using device: {device}")
60
  model, current_ckpt = None, None
61
  current_meta = None
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  # Renamed to ImageClassifier for clarity, but keeping original name to avoid breaking changes if subclassed elsewhere.
64
  class SwinClassifier(nn.Module):
65
  def __init__(self, model_name, num_classes, pretrained=True,
@@ -126,20 +194,36 @@ def load_model(ckpt_name: str):
126
  print(f"\n🔄 Switching to {ckpt_name} ...")
127
  meta = CKPT_META[ckpt_name]
128
  ckpt_filename = HF_FILENAMES[ckpt_name]
129
- ckpt_file = hf_hub_download(
130
- repo_id=REPO_ID,
131
- filename=ckpt_filename,
132
- local_dir=LOCAL_CKPT_DIR, force_download=False
133
- )
 
 
 
 
 
 
 
 
 
134
  print(f"Checkpoint: {ckpt_file}")
135
 
136
- # Build model structure
137
- model = SwinClassifier(
138
- meta["backbone"], # Use backbone from meta
139
- num_classes = meta["n_cls"],
140
- pretrained = False,
141
- head_version = meta["head"]
142
- ).to(device)
 
 
 
 
 
 
 
143
 
144
  # Compatible load for .pth and .safetensors
145
  if ckpt_filename.endswith(".safetensors"):
@@ -173,7 +257,19 @@ def predict(image: Image.Image,
173
  if image is None: return None
174
 
175
  load_model(ckpt_name)
176
- tfm = build_transform(False, interpolation)
 
 
 
 
 
 
 
 
 
 
 
 
177
  inp = tfm(image).unsqueeze(0).to(device)
178
 
179
  probs = F.softmax(model(inp), dim=1)[0].cpu()
@@ -193,7 +289,7 @@ def launch():
193
  gr.Markdown("# AI Detector")
194
  gr.Markdown(
195
  "Choose a model checkpoint on the left, upload an image, "
196
- "and click **Run** to see predictions. Checkpoint V7+ outputs 4 classes."
197
  )
198
 
199
  with gr.Row():
@@ -206,7 +302,7 @@ def launch():
206
  )
207
  sel_interp = gr.Radio(
208
  ["bilinear", "bicubic", "nearest"],
209
- value="bicubic", label="Resize Interpolation"
210
  )
211
 
212
  in_img = gr.Image(type="pil", label="Upload Image")
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ Swin/CAFormer/DINOv2 AI detection
4
  -------------------------------------------------------------------
5
  • Swin-V2 / V4 : 2-class (AI vs. Non-AI)
6
  • Swin-V7 / V8 / V9 : 4-class (photo / anime × AI / Non-AI)
7
  • CAFormer-V10 : 4-class (photo / anime × AI / Non-AI)
8
+ • DINOv2-4class : 4-class (photo / anime × AI / Non-AI)
9
  -------------------------------------------------------------------
10
  Author: telecomadm1145
11
  """
 
17
  import gradio as gr
18
  from huggingface_hub import hf_hub_download
19
  from safetensors.torch import load_file # Added for .safetensors support
20
+ # Added for DINOv2 model
21
+ from transformers import AutoModel
22
+ from torchvision import transforms
23
 
24
  # --------------------------------------------------
25
  # 1. Model & Checkpoint Meta-data
 
33
  "V8": "swin_classifier_4class_fp16_v8_epoch7_acc9740.pth",
34
  "V9": "swin_classifier_4class_fp16_v9_acc9861.pth",
35
  "V1-CAFormer": "caformer_b36_4class.safetensors",
36
+ "V2-CAFormer": "caformer_b36_4class_95.safetensors",
37
+ "V2.5-CAFormer": "caformer_b36_4class_96.safetensors",
38
+ "DINOv2-4class": "dinov2_4class.safetensors", # Added DINOv2 checkpoint
39
  }
40
 
41
  CKPT_META = {
 
53
  "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
54
  "V1-CAFormer": { "n_cls": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384",
55
  "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
56
+ "V2-CAFormer": { "n_cls": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384",
57
+ "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
58
+ "V2.5-CAFormer": { "n_cls": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384",
59
+ "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]},
60
+ # Added DINOv2 metadata
61
+ "DINOv2-4class": {
62
+ "model_type": "dinov2",
63
+ "backbone": 'facebook/dinov2-base',
64
+ "n_cls": 4,
65
+ "names": ["non_ai", "ai", "ani_non_ai", "ani_ai"]
66
+ },
67
  }
68
 
69
  DEFAULT_CKPT = "V1-CAFormer"
70
  LOCAL_CKPT_DIR = "./checkpoints"
71
  SEED = 4421
72
  DROP_RATE = 0.1
73
+ DROPOUT_RATE = 0.1 # From train.py for DINOv2
74
 
75
  device = "cuda" if torch.cuda.is_available() else "cpu"
76
  torch.manual_seed(SEED); np.random.seed(SEED)
 
79
  model, current_ckpt = None, None
80
  current_meta = None
81
 
82
+ # --- Start of code from train.py ---
83
+ class DINOv2Classifier(nn.Module):
84
+ def __init__(self, model_name, num_classes):
85
+ super().__init__()
86
+ self.backbone = AutoModel.from_pretrained(model_name)
87
+ self.weight_self_attention = nn.MultiheadAttention(
88
+ embed_dim=self.backbone.config.hidden_size,
89
+ num_heads=self.backbone.config.num_attention_heads,
90
+ dropout=self.backbone.config.hidden_dropout_prob,
91
+ batch_first=True
92
+ )
93
+ self.weight_mlp = nn.Sequential(
94
+ nn.Linear(self.backbone.config.hidden_size, self.backbone.config.hidden_size * 4),
95
+ nn.LayerNorm(self.backbone.config.hidden_size * 4),
96
+ nn.GELU(),
97
+ nn.Linear(self.backbone.config.hidden_size * 4, 1)
98
+ )
99
+ self.classifier = nn.Sequential(
100
+ nn.Dropout(DROPOUT_RATE),
101
+ nn.Linear(self.backbone.config.hidden_size, self.backbone.config.hidden_size),
102
+ nn.LayerNorm(self.backbone.config.hidden_size),
103
+ nn.GELU(),
104
+ nn.Dropout(DROPOUT_RATE),
105
+ nn.Linear(self.backbone.config.hidden_size, num_classes)
106
+ )
107
+ nn.init.xavier_uniform_(self.weight_self_attention.in_proj_weight)
108
+ nn.init.xavier_uniform_(self.weight_self_attention.out_proj.weight)
109
+ nn.init.constant_(self.weight_self_attention.out_proj.bias, 0)
110
+
111
+ for module in [self.weight_mlp, self.classifier]:
112
+ if isinstance(module, nn.Linear):
113
+ nn.init.xavier_uniform_(module.weight)
114
+ nn.init.constant_(module.bias, 0)
115
+
116
+ def forward(self, x):
117
+ outputs = self.backbone(x)
118
+ attn_output, _ = self.weight_self_attention(
119
+ outputs.last_hidden_state,
120
+ outputs.last_hidden_state,
121
+ outputs.last_hidden_state,
122
+ )
123
+ raw_weights = self.weight_mlp(attn_output)
124
+ raw_weights = raw_weights.squeeze(-1)
125
+ pooling_weights = torch.softmax(raw_weights, dim=-1)
126
+ pooled_output = torch.sum(outputs.last_hidden_state * pooling_weights.unsqueeze(-1), dim=1)
127
+ return self.classifier(pooled_output)
128
+ # --- End of code from train.py ---
129
+
130
+
131
  # Renamed to ImageClassifier for clarity, but keeping original name to avoid breaking changes if subclassed elsewhere.
132
  class SwinClassifier(nn.Module):
133
  def __init__(self, model_name, num_classes, pretrained=True,
 
194
  print(f"\n🔄 Switching to {ckpt_name} ...")
195
  meta = CKPT_META[ckpt_name]
196
  ckpt_filename = HF_FILENAMES[ckpt_name]
197
+
198
+ # Check if the checkpoint is DINOv2 and handle its local path
199
+ if meta.get("model_type") == "dinov2":
200
+ # Assume DINOv2 model is local, as generated by train.py
201
+ ckpt_file = ckpt_filename
202
+ if not os.path.exists(ckpt_file):
203
+ raise FileNotFoundError(f"DINOv2 checkpoint not found at {ckpt_file}. Please run train.py first.")
204
+ else:
205
+ # Download other models from HF Hub
206
+ ckpt_file = hf_hub_download(
207
+ repo_id=REPO_ID,
208
+ filename=ckpt_filename,
209
+ local_dir=LOCAL_CKPT_DIR, force_download=False
210
+ )
211
  print(f"Checkpoint: {ckpt_file}")
212
 
213
+ # Build model structure based on model_type
214
+ if meta.get("model_type") == "dinov2":
215
+ model = DINOv2Classifier(
216
+ model_name=meta["backbone"],
217
+ num_classes=meta["n_cls"]
218
+ ).to(device)
219
+ else: # Existing logic for Swin/CAFormer
220
+ model = SwinClassifier(
221
+ meta["backbone"],
222
+ num_classes=meta["n_cls"],
223
+ pretrained=False,
224
+ head_version=meta.get("head", "v4")
225
+ ).to(device)
226
+
227
 
228
  # Compatible load for .pth and .safetensors
229
  if ckpt_filename.endswith(".safetensors"):
 
257
  if image is None: return None
258
 
259
  load_model(ckpt_name)
260
+
261
+ # Select transform based on the current model type
262
+ if current_meta.get("model_type") == "dinov2":
263
+ # DINOv2 specific transform from train.py
264
+ tfm = transforms.Compose([
265
+ transforms.Resize((224, 224)),
266
+ transforms.ToTensor(),
267
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
268
+ ])
269
+ else:
270
+ # Original transform logic for Swin/CAFormer
271
+ tfm = build_transform(False, interpolation)
272
+
273
  inp = tfm(image).unsqueeze(0).to(device)
274
 
275
  probs = F.softmax(model(inp), dim=1)[0].cpu()
 
289
  gr.Markdown("# AI Detector")
290
  gr.Markdown(
291
  "Choose a model checkpoint on the left, upload an image, "
292
+ "and click **Run** to see predictions. Checkpoint V7+ and DINOv2 outputs 4 classes."
293
  )
294
 
295
  with gr.Row():
 
302
  )
303
  sel_interp = gr.Radio(
304
  ["bilinear", "bicubic", "nearest"],
305
+ value="bicubic", label="Resize Interpolation (for Swin/CAFormer)"
306
  )
307
 
308
  in_img = gr.Image(type="pil", label="Upload Image")