telecomadm1145 commited on
Commit
94778bf
·
verified ·
1 Parent(s): 8cbab40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -183
app.py CHANGED
@@ -1,16 +1,11 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- Swin/CAFormer/DINOv2 AI detection
4
  -------------------------------------------------------------------
5
  • Swin-V2 / V4 : 2-class (AI vs. Non-AI)
6
  • Swin-V7 / V8 / V9 : 4-class (photo / anime × AI / Non-AI)
7
- • CAFormer-V10 : 4-class (photo / anime × AI / Non-AI)
8
- DINOv2-4class : 4-class (photo / anime × AI / Non-AI)
9
- • DINOv2-MeanPool-Contrastive : 4-class (photo / anime × AI / Non-AI)
10
- • V1-Emb : 2-class (AI vs. Non-AI)
11
- • V2-Emb : 2-class (AI vs. Non-AI)
12
  -------------------------------------------------------------------
13
- Author: telecomadm1145
14
  """
15
  import os, torch, timm, numpy as np
16
  import torch.nn as nn
@@ -18,14 +13,11 @@ import torch.nn.functional as F
18
  from PIL import Image
19
  import gradio as gr
20
  from huggingface_hub import hf_hub_download
21
- from safetensors.torch import load_file # Added for .safetensors support
22
- # Added for DINOv2 model
23
  from transformers import AutoModel
24
  from torchvision import transforms
25
 
26
- # --------------------------------------------------
27
- # 1. Model & Checkpoint Meta-data
28
- # --------------------------------------------------
29
  REPO_ID = "telecomadm1145/swin-ai-detection"
30
  HF_FILENAMES = {
31
  "V2.5-CAFormer": "caformer_b36_4class_96.safetensors",
@@ -56,7 +48,7 @@ DEFAULT_CKPT = "V3-Emb"
56
  LOCAL_CKPT_DIR = "./checkpoints"
57
  SEED = 4421
58
  DROP_RATE = 0.1
59
- DROPOUT_RATE = 0.1 # From train.py for DINOv2
60
  device = "cuda" if torch.cuda.is_available() else "cpu"
61
  torch.manual_seed(SEED); np.random.seed(SEED)
62
  print(f"Using device: {device}")
@@ -81,7 +73,6 @@ class EmbeddingClassifier(nn.Module):
81
  def forward(self, x):
82
  return self.net(x)
83
 
84
- # MODIFIED: Changed __init__ to accept timm_model_name and use pretrained=False
85
  class EmbeddingClassifierModel(nn.Module):
86
  def __init__(self, timm_model_name, num_classes):
87
  super().__init__()
@@ -91,96 +82,10 @@ class EmbeddingClassifierModel(nn.Module):
91
 
92
  def forward(self, x):
93
  features = self.backbone(x)
94
- # The classifier returns a single value (probability of being Non-AI)
95
  prob_class0 = self.classifier(features)
96
-
97
- # To maintain compatibility with the `predict` function which expects multi-class outputs
98
- # and applies softmax, we construct a 2-class output.
99
- # prob_class1 is simply 1 - prob_class0
100
  prob_class1 = 1 - prob_class0
101
-
102
- # The final output is for ["Non-AI", "AI"], i.e., [prob_class0, prob_class1].
103
- # The softmax in predict() will be applied to this, so we should return logits.
104
- # However, since the original output is a sigmoid, we can work with probabilities
105
- # and just return them directly. The gr.Label will normalize this.
106
- # A simpler way is to construct logits that would result in these probabilities.
107
- # Let's stick to the original logic's output format.
108
  return torch.cat([prob_class0, prob_class1], dim=1)
109
 
110
-
111
- # --- Original DINOv2 Classifier (Weighted Attention Pooling) ---
112
- class DINOv2Classifier_WeightedPool(nn.Module):
113
- def __init__(self, model_name, num_classes):
114
- super().__init__()
115
- self.backbone = AutoModel.from_pretrained(model_name)
116
- self.weight_self_attention = nn.MultiheadAttention(
117
- embed_dim=self.backbone.config.hidden_size,
118
- num_heads=self.backbone.config.num_attention_heads,
119
- dropout=self.backbone.config.hidden_dropout_prob,
120
- batch_first=True
121
- )
122
- self.weight_mlp = nn.Sequential(
123
- nn.Linear(self.backbone.config.hidden_size, self.backbone.config.hidden_size * 4),
124
- nn.LayerNorm(self.backbone.config.hidden_size * 4),
125
- nn.GELU(),
126
- nn.Linear(self.backbone.config.hidden_size * 4, 1)
127
- )
128
- self.classifier = nn.Sequential(
129
- nn.Dropout(DROPOUT_RATE),
130
- nn.Linear(self.backbone.config.hidden_size, self.backbone.config.hidden_size),
131
- nn.LayerNorm(self.backbone.config.hidden_size),
132
- nn.GELU(),
133
- nn.Dropout(DROPOUT_RATE),
134
- nn.Linear(self.backbone.config.hidden_size, num_classes)
135
- )
136
- nn.init.xavier_uniform_(self.weight_self_attention.in_proj_weight)
137
- nn.init.xavier_uniform_(self.weight_self_attention.out_proj.weight)
138
- nn.init.constant_(self.weight_self_attention.out_proj.bias, 0)
139
-
140
- for module in [self.weight_mlp, self.classifier]:
141
- if isinstance(module, nn.Linear):
142
- nn.init.xavier_uniform_(module.weight)
143
- nn.init.constant_(module.bias, 0)
144
-
145
- def forward(self, x):
146
- outputs = self.backbone(x)
147
- attn_output, _ = self.weight_self_attention(
148
- outputs.last_hidden_state,
149
- outputs.last_hidden_state,
150
- outputs.last_hidden_state,
151
- )
152
- raw_weights = self.weight_mlp(attn_output)
153
- raw_weights = raw_weights.squeeze(-1)
154
- pooling_weights = torch.softmax(raw_weights, dim=-1)
155
- pooled_output = torch.sum(outputs.last_hidden_state * pooling_weights.unsqueeze(-1), dim=1)
156
- return self.classifier(pooled_output)
157
-
158
- # --- New DINOv2 Classifier (Mean Pooling) ---
159
- class DINOv2Classifier_MeanPool(nn.Module):
160
- def __init__(self, model_name, num_classes):
161
- super().__init__()
162
- self.backbone = AutoModel.from_pretrained(model_name)
163
- self.classifier = nn.Sequential(
164
- nn.Dropout(DROPOUT_RATE),
165
- nn.Linear(self.backbone.config.hidden_size, self.backbone.config.hidden_size),
166
- nn.LayerNorm(self.backbone.config.hidden_size),
167
- nn.GELU(),
168
- nn.Dropout(DROPOUT_RATE),
169
- nn.Linear(self.backbone.config.hidden_size, num_classes)
170
- )
171
- for module in self.classifier:
172
- if isinstance(module, nn.Linear):
173
- nn.init.xavier_uniform_(module.weight)
174
- nn.init.constant_(module.bias, 0)
175
- def forward(self, x, return_features=False):
176
- outputs = self.backbone(x)
177
- pooled_output = outputs.last_hidden_state.mean(dim=1)
178
- if return_features:
179
- return pooled_output
180
-
181
- return self.classifier(pooled_output)
182
-
183
- # --- SwinClassifier ---
184
  class SwinClassifier(nn.Module):
185
  def __init__(self, model_name, num_classes, pretrained=True,
186
  head_version="v4"):
@@ -189,7 +94,6 @@ class SwinClassifier(nn.Module):
189
  model_name, pretrained=pretrained, num_classes=0
190
  )
191
  self.data_config = timm.data.resolve_data_config({}, model=self.backbone)
192
- # ------- 根据版本选择不同 head -------
193
  if head_version == "v7": # <-- V7, V8, V9, V10: 极简 64-hidden, GELU
194
  self.classifier = nn.Sequential(
195
  nn.Dropout(DROP_RATE),
@@ -228,11 +132,7 @@ class SwinClassifier(nn.Module):
228
  def forward(self, x):
229
  return self.classifier(self.backbone(x))
230
 
231
- # --------------------------------------------------
232
- # 4. 动态加载模型
233
- # --------------------------------------------------
234
  def load_model(ckpt_name: str):
235
- """Load model only when `ckpt_name` changes."""
236
  global model, current_ckpt, current_meta
237
  if ckpt_name == current_ckpt and model is not None:
238
  return
@@ -240,17 +140,14 @@ def load_model(ckpt_name: str):
240
  meta = CKPT_META[ckpt_name]
241
  ckpt_filename = HF_FILENAMES[ckpt_name]
242
 
243
- # --- MODIFIED: Special handling for EmbeddingClassifier ---
244
  head_version = meta.get("head", "v4")
245
  if head_version == "embedding_classifier":
246
- # 1. Create the model structure with a non-pretrained backbone
247
  print(f"Creating backbone: {meta['timm_model_name']}")
248
  model = EmbeddingClassifierModel(
249
  timm_model_name=meta["timm_model_name"],
250
  num_classes=meta["n_cls"]
251
  ).to(device)
252
 
253
- # 2. Download and load backbone weights from SmilingWolf's repo
254
  print(f"Loading backbone weights from {meta['backbone_repo_id']}...")
255
  backbone_ckpt_file = hf_hub_download(
256
  repo_id=meta["backbone_repo_id"],
@@ -261,11 +158,10 @@ def load_model(ckpt_name: str):
261
  model.backbone.load_state_dict(backbone_state,strict=False)
262
  print("✅ Backbone weights loaded.")
263
 
264
- # 3. Download and load classifier (head) weights from the main repo
265
  print(f"Loading classifier head weights from {REPO_ID}...")
266
  classifier_ckpt_file = hf_hub_download(
267
  repo_id=REPO_ID,
268
- filename=ckpt_filename, # This is 'swinv2_v3_v1.pth'
269
  local_dir=LOCAL_CKPT_DIR, force_download=False
270
  )
271
  classifier_state = torch.load(classifier_ckpt_file, map_location=device, weights_only=False)
@@ -273,7 +169,6 @@ def load_model(ckpt_name: str):
273
  print("✅ Classifier head weights loaded.")
274
 
275
  else:
276
- # --- Original logic for all other models ---
277
  ckpt_file = hf_hub_download(
278
  repo_id=REPO_ID,
279
  filename=ckpt_filename,
@@ -281,27 +176,13 @@ def load_model(ckpt_name: str):
281
  )
282
  print(f"Checkpoint: {ckpt_file}")
283
 
284
- # Build model structure based on model_type or head
285
- model_type = meta.get("model_type")
286
- if model_type == "dinov2_weighted_pool":
287
- model = DINOv2Classifier_WeightedPool(
288
- model_name=meta["backbone"],
289
- num_classes=meta["n_cls"]
290
- ).to(device)
291
- elif model_type == "dinov2_mean_pool":
292
- model = DINOv2Classifier_MeanPool(
293
- model_name=meta["backbone"],
294
- num_classes=meta["n_cls"]
295
- ).to(device)
296
- else: # Existing logic for Swin/CAFormer
297
- model = SwinClassifier(
298
- meta["backbone"],
299
- num_classes=meta["n_cls"],
300
- pretrained=False,
301
- head_version=head_version
302
- ).to(device)
303
-
304
- # Compatible load for .pth and .safetensors
305
  if ckpt_filename.endswith(".safetensors"):
306
  state = load_file(ckpt_file, device=device)
307
  else:
@@ -313,23 +194,15 @@ def load_model(ckpt_name: str):
313
  current_ckpt, current_meta = ckpt_name, meta
314
  print(f"✅ {ckpt_name} loaded (classes = {meta['n_cls']}).")
315
 
316
- # --------------------------------------------------
317
- # 5. Transform 工厂
318
- # --------------------------------------------------
319
  def build_transform(is_training: bool, interpolation: str):
320
  if model is None: raise RuntimeError("Model not loaded yet.")
321
  cfg = model.data_config.copy()
322
  cfg.update(dict(interpolation=interpolation))
323
  return timm.data.create_transform(**cfg, is_training=is_training)
324
 
325
- # ######################################################################
326
- # START: Preprocessing functions for V1-Emb model, copied from 2nd script
327
- # ######################################################################
328
  def pil_ensure_rgb(image: Image.Image) -> Image.Image:
329
- # convert to RGB/RGBA if not already (deals with palette images etc.)
330
  if image.mode not in ["RGB", "RGBA"]:
331
  image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
332
- # convert RGBA to RGB with white background
333
  if image.mode == "RGBA":
334
  canvas = Image.new("RGBA", image.size, (255, 255, 255))
335
  canvas.alpha_composite(image)
@@ -339,20 +212,11 @@ def pil_ensure_rgb(image: Image.Image) -> Image.Image:
339
 
340
  def pil_pad_square(image: Image.Image) -> Image.Image:
341
  w, h = image.size
342
- # get the largest dimension so we can pad to a square
343
  px = max(image.size)
344
- # pad to square with white background
345
  canvas = Image.new("RGB", (px, px), (255, 255, 255))
346
  canvas.paste(image, ((px - w) // 2, (px - h) // 2))
347
  return canvas
348
- # ####################################################################
349
- # END: Preprocessing functions for V1-Emb model
350
- # ####################################################################
351
 
352
-
353
- # --------------------------------------------------
354
- # 6. Inference
355
- # --------------------------------------------------
356
  @torch.no_grad()
357
  def predict(image: Image.Image,
358
  ckpt_name: str,
@@ -360,60 +224,34 @@ def predict(image: Image.Image,
360
  if image is None: return None
361
  load_model(ckpt_name)
362
 
363
- # ####################################################################
364
- # START: MODIFIED preprocessing logic
365
- # ####################################################################
366
  if "Emb" in ckpt_name:
367
- # Specific preprocessing for the V1-Emb model based on the tagger script
368
- # 1. Ensure RGB and pad to a square to prevent distortion
369
  processed_image = pil_ensure_rgb(image)
370
  processed_image = pil_pad_square(processed_image)
371
-
372
- # 2. Apply standard timm transforms (resize, tensor, normalize)
373
  tfm = build_transform(False, interpolation)
374
  inp = tfm(processed_image).unsqueeze(0).to(device)
375
-
376
- # 3. Convert from RGB to BGR as required by the original model
377
  inp = inp[:, [2, 1, 0]]
378
-
379
- elif "dinov2" in current_meta.get("model_type", ""):
380
- # DINOv2 specific transform
381
- tfm = transforms.Compose([
382
- transforms.Resize((224, 224)),
383
- transforms.ToTensor(),
384
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
385
- ])
386
- inp = tfm(image).unsqueeze(0).to(device)
387
  else:
388
- # Original transform logic for Swin/CAFormer
389
  tfm = build_transform(False, interpolation)
390
  inp = tfm(image).unsqueeze(0).to(device)
391
- # ####################################################################
392
- # END: MODIFIED preprocessing logic
393
- # ####################################################################
394
 
395
- # MODIFIED: For EmbeddingClassifier, the output is already probabilities, no need for softmax.
396
- # For others, softmax is needed.
397
  if current_meta["head"] == "embedding_classifier":
398
  probs = model(inp)[0].cpu()
399
  else:
400
  probs = F.softmax(model(inp), dim=1)[0].cpu()
401
 
402
  class_names = current_meta["names"]
403
- # 保证 gr.Label 在 2 / 4 类都能正常显示
404
  return {class_names[i]: float(probs[i])
405
  for i in range(len(class_names))}
406
 
407
- # --------------------------------------------------
408
- # 7. Gradio UI
409
- # --------------------------------------------------
410
  def launch():
411
- load_model(DEFAULT_CKPT) # 预加载
412
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
413
  gr.Markdown("# AI Detector")
414
  gr.Markdown(
415
  "Choose a model checkpoint on the left, upload an image, "
416
- "and click **Run** to see predictions. V2-Emb produces the best results."
417
  )
418
  with gr.Row():
419
  with gr.Column(scale=1):
@@ -429,17 +267,14 @@ def launch():
429
 
430
  in_img = gr.Image(type="pil", label="Upload Image")
431
  with gr.Column(scale=1):
432
- # num_top_classes 设为 4,兼容 2-class / 4-class
433
  out_lbl = gr.Label(num_top_classes=4, label="Predictions")
434
  run_btn.click(
435
  predict,
436
  inputs=[in_img, sel_ckpt, sel_interp],
437
  outputs=[out_lbl]
438
  )
439
- # optional example folder
440
  if not os.path.exists("examples"):
441
  os.makedirs("examples")
442
- print("Put some jpg/png files inside ./examples for demo examples")
443
  example_files = [os.path.join("examples", f)
444
  for f in os.listdir("examples")
445
  if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
@@ -453,6 +288,5 @@ def launch():
453
  )
454
  demo.launch()
455
 
456
- # --------------------------------------------------
457
  if __name__ == "__main__":
458
  launch()
 
1
  # -*- coding: utf-8 -*-
2
  """
 
3
  -------------------------------------------------------------------
4
  • Swin-V2 / V4 : 2-class (AI vs. Non-AI)
5
  • Swin-V7 / V8 / V9 : 4-class (photo / anime × AI / Non-AI)
6
+ • CAFormer-V2.5 : 4-class (photo / anime × AI / Non-AI)
7
+ V3-Emb : 2-class (AI vs. Non-AI)
 
 
 
8
  -------------------------------------------------------------------
 
9
  """
10
  import os, torch, timm, numpy as np
11
  import torch.nn as nn
 
13
  from PIL import Image
14
  import gradio as gr
15
  from huggingface_hub import hf_hub_download
16
+ from safetensors.torch import load_file
 
17
  from transformers import AutoModel
18
  from torchvision import transforms
19
 
20
+
 
 
21
  REPO_ID = "telecomadm1145/swin-ai-detection"
22
  HF_FILENAMES = {
23
  "V2.5-CAFormer": "caformer_b36_4class_96.safetensors",
 
48
  LOCAL_CKPT_DIR = "./checkpoints"
49
  SEED = 4421
50
  DROP_RATE = 0.1
51
+ DROPOUT_RATE = 0.1
52
  device = "cuda" if torch.cuda.is_available() else "cpu"
53
  torch.manual_seed(SEED); np.random.seed(SEED)
54
  print(f"Using device: {device}")
 
73
  def forward(self, x):
74
  return self.net(x)
75
 
 
76
  class EmbeddingClassifierModel(nn.Module):
77
  def __init__(self, timm_model_name, num_classes):
78
  super().__init__()
 
82
 
83
  def forward(self, x):
84
  features = self.backbone(x)
 
85
  prob_class0 = self.classifier(features)
 
 
 
 
86
  prob_class1 = 1 - prob_class0
 
 
 
 
 
 
 
87
  return torch.cat([prob_class0, prob_class1], dim=1)
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  class SwinClassifier(nn.Module):
90
  def __init__(self, model_name, num_classes, pretrained=True,
91
  head_version="v4"):
 
94
  model_name, pretrained=pretrained, num_classes=0
95
  )
96
  self.data_config = timm.data.resolve_data_config({}, model=self.backbone)
 
97
  if head_version == "v7": # <-- V7, V8, V9, V10: 极简 64-hidden, GELU
98
  self.classifier = nn.Sequential(
99
  nn.Dropout(DROP_RATE),
 
132
  def forward(self, x):
133
  return self.classifier(self.backbone(x))
134
 
 
 
 
135
  def load_model(ckpt_name: str):
 
136
  global model, current_ckpt, current_meta
137
  if ckpt_name == current_ckpt and model is not None:
138
  return
 
140
  meta = CKPT_META[ckpt_name]
141
  ckpt_filename = HF_FILENAMES[ckpt_name]
142
 
 
143
  head_version = meta.get("head", "v4")
144
  if head_version == "embedding_classifier":
 
145
  print(f"Creating backbone: {meta['timm_model_name']}")
146
  model = EmbeddingClassifierModel(
147
  timm_model_name=meta["timm_model_name"],
148
  num_classes=meta["n_cls"]
149
  ).to(device)
150
 
 
151
  print(f"Loading backbone weights from {meta['backbone_repo_id']}...")
152
  backbone_ckpt_file = hf_hub_download(
153
  repo_id=meta["backbone_repo_id"],
 
158
  model.backbone.load_state_dict(backbone_state,strict=False)
159
  print("✅ Backbone weights loaded.")
160
 
 
161
  print(f"Loading classifier head weights from {REPO_ID}...")
162
  classifier_ckpt_file = hf_hub_download(
163
  repo_id=REPO_ID,
164
+ filename=ckpt_filename,
165
  local_dir=LOCAL_CKPT_DIR, force_download=False
166
  )
167
  classifier_state = torch.load(classifier_ckpt_file, map_location=device, weights_only=False)
 
169
  print("✅ Classifier head weights loaded.")
170
 
171
  else:
 
172
  ckpt_file = hf_hub_download(
173
  repo_id=REPO_ID,
174
  filename=ckpt_filename,
 
176
  )
177
  print(f"Checkpoint: {ckpt_file}")
178
 
179
+ model = SwinClassifier(
180
+ meta["backbone"],
181
+ num_classes=meta["n_cls"],
182
+ pretrained=False,
183
+ head_version=head_version
184
+ ).to(device)
185
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  if ckpt_filename.endswith(".safetensors"):
187
  state = load_file(ckpt_file, device=device)
188
  else:
 
194
  current_ckpt, current_meta = ckpt_name, meta
195
  print(f"✅ {ckpt_name} loaded (classes = {meta['n_cls']}).")
196
 
 
 
 
197
  def build_transform(is_training: bool, interpolation: str):
198
  if model is None: raise RuntimeError("Model not loaded yet.")
199
  cfg = model.data_config.copy()
200
  cfg.update(dict(interpolation=interpolation))
201
  return timm.data.create_transform(**cfg, is_training=is_training)
202
 
 
 
 
203
  def pil_ensure_rgb(image: Image.Image) -> Image.Image:
 
204
  if image.mode not in ["RGB", "RGBA"]:
205
  image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
 
206
  if image.mode == "RGBA":
207
  canvas = Image.new("RGBA", image.size, (255, 255, 255))
208
  canvas.alpha_composite(image)
 
212
 
213
  def pil_pad_square(image: Image.Image) -> Image.Image:
214
  w, h = image.size
 
215
  px = max(image.size)
 
216
  canvas = Image.new("RGB", (px, px), (255, 255, 255))
217
  canvas.paste(image, ((px - w) // 2, (px - h) // 2))
218
  return canvas
 
 
 
219
 
 
 
 
 
220
  @torch.no_grad()
221
  def predict(image: Image.Image,
222
  ckpt_name: str,
 
224
  if image is None: return None
225
  load_model(ckpt_name)
226
 
 
 
 
227
  if "Emb" in ckpt_name:
 
 
228
  processed_image = pil_ensure_rgb(image)
229
  processed_image = pil_pad_square(processed_image)
 
 
230
  tfm = build_transform(False, interpolation)
231
  inp = tfm(processed_image).unsqueeze(0).to(device)
 
 
232
  inp = inp[:, [2, 1, 0]]
233
+
 
 
 
 
 
 
 
 
234
  else:
 
235
  tfm = build_transform(False, interpolation)
236
  inp = tfm(image).unsqueeze(0).to(device)
 
 
 
237
 
 
 
238
  if current_meta["head"] == "embedding_classifier":
239
  probs = model(inp)[0].cpu()
240
  else:
241
  probs = F.softmax(model(inp), dim=1)[0].cpu()
242
 
243
  class_names = current_meta["names"]
244
+
245
  return {class_names[i]: float(probs[i])
246
  for i in range(len(class_names))}
247
 
 
 
 
248
  def launch():
249
+ load_model(DEFAULT_CKPT)
250
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
251
  gr.Markdown("# AI Detector")
252
  gr.Markdown(
253
  "Choose a model checkpoint on the left, upload an image, "
254
+ "and click **Run** to see predictions. V3-Emb produces the best results."
255
  )
256
  with gr.Row():
257
  with gr.Column(scale=1):
 
267
 
268
  in_img = gr.Image(type="pil", label="Upload Image")
269
  with gr.Column(scale=1):
 
270
  out_lbl = gr.Label(num_top_classes=4, label="Predictions")
271
  run_btn.click(
272
  predict,
273
  inputs=[in_img, sel_ckpt, sel_interp],
274
  outputs=[out_lbl]
275
  )
 
276
  if not os.path.exists("examples"):
277
  os.makedirs("examples")
 
278
  example_files = [os.path.join("examples", f)
279
  for f in os.listdir("examples")
280
  if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
 
288
  )
289
  demo.launch()
290
 
 
291
  if __name__ == "__main__":
292
  launch()