danielhshi8224 commited on
Commit
e322980
Β·
1 Parent(s): c5564c5

standalone cls

Browse files
Files changed (1) hide show
  1. app.py +256 -229
app.py CHANGED
@@ -1,34 +1,261 @@
1
- #Main Gradio app ith image classification and object detection tabs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
  import torch
4
  import torch.nn.functional as F
5
  from transformers import AutoImageProcessor, AutoModelForImageClassification
6
  from PIL import Image
7
- import os
8
- import csv
9
- import tempfile
10
- from pathlib import Path
11
- from ultralytics import YOLO
12
- # ultralytics YOLO import (for object detection)
13
- try:
14
- from ultralytics import YOLO
15
- except Exception:
16
- YOLO = None
17
 
18
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
19
- MODEL_ID = "dshi01/convnext-tiny-224-7clss"
 
20
 
21
- print(f"Loading model from: {MODEL_ID}")
22
- processor = AutoImageProcessor.from_pretrained("facebook/convnext-tiny-224")
23
  model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
24
  model.eval()
25
 
26
- # (Optional) use model's own labels if present
27
  ID2LABEL = [
28
  model.config.id2label.get(str(i), model.config.id2label.get(i, f"Label_{i}"))
29
  for i in range(model.config.num_labels)
30
  ]
 
31
  def classify_image(image):
 
32
  if not isinstance(image, Image.Image):
33
  image = Image.fromarray(image).convert("RGB")
34
 
@@ -39,23 +266,19 @@ def classify_image(image):
39
 
40
  return {ID2LABEL[i]: float(p) for i, p in enumerate(probs)}
41
 
42
- # ---------- NEW: batch classify up to 10 images ----------
43
  MAX_BATCH = 10
44
 
45
  def classify_images_batch(files):
46
  """
47
- files: list of gradio UploadedFile (paths) or None
48
- Returns:
49
- - gallery: list of (image, caption)
50
- - table: list of rows for Dataframe
51
  """
52
  if not files:
53
  return [], [], None
54
 
55
- # Keep at most 10
56
  files = files[:MAX_BATCH]
57
 
58
- # Load as PIL
59
  pil_images, names = [], []
60
  for f in files:
61
  path = getattr(f, "name", None) or getattr(f, "path", None) or f
@@ -64,19 +287,16 @@ def classify_images_batch(files):
64
  pil_images.append(img)
65
  names.append(os.path.basename(path))
66
  except Exception:
67
- # Skip unreadable file
68
  continue
69
 
70
  if not pil_images:
71
  return [], [], None
72
 
73
- # Batch preprocess + forward
74
  inputs = processor(images=pil_images, return_tensors="pt")
75
  with torch.no_grad():
76
  logits = model(**inputs).logits
77
  probs = F.softmax(logits, dim=1)
78
 
79
- # Build outputs
80
  gallery = []
81
  table_rows = [] # [filename, top1_label, top1_conf, top3_labels, top3_confs]
82
 
@@ -85,7 +305,6 @@ def classify_images_batch(files):
85
  top_idxs = sorted(range(len(p)), key=lambda i: p[i], reverse=True)[:3]
86
  top1 = top_idxs[0]
87
  caption = f"{ID2LABEL[top1]} ({p[top1]:.2%})"
88
-
89
  gallery.append((img, f"{fname}\n{caption}"))
90
 
91
  top3_labels = [ID2LABEL[i] for i in top_idxs]
@@ -101,191 +320,17 @@ def classify_images_batch(files):
101
  # Create CSV for download
102
  csv_path = None
103
  try:
104
- # Write CSV into a temp file inside project dir so Gradio can serve it
105
- tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", prefix="predictions_", dir=BASE_DIR, mode="w", newline='', encoding='utf-8')
 
 
106
  writer = csv.writer(tmp)
107
- # headers
108
  writer.writerow(["filename", "top1_label", "top1_conf", "top3_labels", "top3_confs"])
109
  for row in table_rows:
110
  writer.writerow(row)
111
- tmp.flush()
112
- tmp.close()
113
  csv_path = tmp.name
114
  except Exception:
115
- # If CSV can't be created, return None for the file but keep other outputs
116
- csv_path = None
117
-
118
- return gallery, table_rows, csv_path
119
-
120
-
121
- # ---------- NEW: YOLO object detection for multi-image upload ----------
122
- YOLO_WEIGHTS = os.path.join(BASE_DIR, "yolo11_best.pt")
123
- _yolo_model = None
124
- def _load_yolo():
125
- global _yolo_model
126
- if _yolo_model is not None:
127
- return _yolo_model
128
- if YOLO is None:
129
- raise RuntimeError("ultralytics package not installed. Please install 'ultralytics'.")
130
- if not os.path.exists(YOLO_WEIGHTS):
131
- # Try current directory too
132
- alt = Path.cwd() / "yolo11_best.pt"
133
- if alt.exists():
134
- model_path = str(alt)
135
- else:
136
- raise FileNotFoundError(f"YOLO weights not found at {YOLO_WEIGHTS}. Place yolo11_best.pt in project root.")
137
- else:
138
- model_path = YOLO_WEIGHTS
139
-
140
- _yolo_model = YOLO(model_path)
141
- return _yolo_model
142
-
143
-
144
- def detect_objects_batch(files, iou=0.25, conf=0.25):
145
- """
146
- Run YOLO detection on multiple images.
147
- Returns: gallery of annotated images, dataframe rows, csv file path
148
- """
149
- if YOLO is None:
150
- return [], [], None
151
-
152
- if not files:
153
- return [], [], None
154
-
155
- # Load model
156
- try:
157
- ymodel = _load_yolo()
158
- except Exception as e:
159
- print("YOLO load error:", e)
160
- return [], [], None
161
-
162
- annotated_paths = []
163
- table_rows = []
164
- gallery = []
165
-
166
- for f in files[:MAX_BATCH]:
167
- path = getattr(f, "name", None) or getattr(f, "path", None) or f
168
- try:
169
- # Run predict; returns a Results object list
170
- results = ymodel.predict(source=path, conf=conf, iou=iou, imgsz=640, verbose=False)
171
- except Exception as e:
172
- print(f"Detection failed for {path}:", e)
173
- continue
174
-
175
- # results is list-like; take first
176
- res = results[0]
177
-
178
- # Prepare annotation image using res.plot() so boxes+confidences are drawn
179
- ann_path = None
180
- try:
181
- ann_img = res.plot() # returns numpy array with annotations
182
- from PIL import Image as PILImage
183
- ann_pil = PILImage.fromarray(ann_img)
184
- out_dir = tempfile.mkdtemp(prefix="yolo_out_", dir=BASE_DIR)
185
- os.makedirs(out_dir, exist_ok=True)
186
- ann_filename = os.path.splitext(os.path.basename(path))[0] + "_annotated.jpg"
187
- ann_path = os.path.join(out_dir, ann_filename)
188
- ann_pil.save(ann_path)
189
- except Exception:
190
- # Fallback to ultralytics save if plot() isn't available
191
- try:
192
- out_dir = tempfile.mkdtemp(prefix="yolo_out_", dir=BASE_DIR)
193
- res.save(save_dir=out_dir)
194
- saved_files = res.files if hasattr(res, 'files') else []
195
- ann_path = saved_files[0] if saved_files else None
196
- except Exception:
197
- ann_path = None
198
-
199
- # Build table rows from detections
200
- boxes = res.boxes if hasattr(res, 'boxes') else None
201
- if boxes is None or len(boxes) == 0:
202
- table_rows.append([os.path.basename(path), 0, "", "", ""])
203
- if ann_path and os.path.exists(ann_path):
204
- gallery.append((Image.open(ann_path).convert('RGB'), f"{os.path.basename(path)}\nNo detections"))
205
- else:
206
- gallery.append((Image.open(path).convert('RGB'), f"{os.path.basename(path)}\nNo detections"))
207
- continue
208
-
209
- det_labels = []
210
- det_scores = []
211
- det_boxes = []
212
- for box in boxes:
213
- # box.cls, box.conf, box.xyxy
214
- cls = int(box.cls.cpu().item()) if hasattr(box, 'cls') else None
215
- # use .item() to extract scalar and avoid numpy deprecation warnings
216
- if hasattr(box, 'conf'):
217
- try:
218
- confscore = float(box.conf.cpu().item())
219
- except Exception:
220
- try:
221
- confscore = float(box.conf.item())
222
- except Exception:
223
- confscore = None
224
- else:
225
- confscore = None
226
-
227
- # extract xyxy coords; box.xyxy may be shape (1,4) -> nested list after .tolist()
228
- coords = []
229
- if hasattr(box, 'xyxy'):
230
- try:
231
- arr = box.xyxy.cpu().numpy()
232
- # handle nested shape (1,4) or (4,)
233
- if getattr(arr, 'ndim', None) == 2 and arr.shape[0] == 1:
234
- coords = arr[0].tolist()
235
- elif getattr(arr, 'ndim', None) == 1:
236
- coords = arr.tolist()
237
- else:
238
- coords = arr.reshape(-1).tolist()
239
- except Exception:
240
- # fallback: try to call tolist()
241
- try:
242
- coords = box.xyxy.tolist()
243
- except Exception:
244
- coords = []
245
-
246
- # append detection info
247
- det_labels.append(ymodel.names.get(cls, str(cls)) if cls is not None else "")
248
- det_scores.append(round(confscore, 4) if confscore is not None else "")
249
- # round and store coords
250
- try:
251
- det_boxes.append([round(float(x), 2) for x in coords])
252
- except Exception:
253
- # fallback: store raw repr
254
- det_boxes.append([str(coords)])
255
-
256
- # create readable label:confidence pairs
257
- label_conf_pairs = [f"{l}:{s}" for l, s in zip(det_labels, det_scores)]
258
- boxes_repr = ["[" + ", ".join(map(str, b)) + "]" for b in det_boxes]
259
- table_rows.append([
260
- os.path.basename(path),
261
- len(det_labels),
262
- ", ".join(label_conf_pairs),
263
- ", ".join(boxes_repr),
264
- "; ".join([str(b) for b in det_boxes])
265
- ])
266
-
267
- # Use annotated image if exists
268
- if ann_path and os.path.exists(ann_path):
269
- try:
270
- gallery.append((Image.open(ann_path).convert('RGB'), f"{os.path.basename(path)}\n{len(det_labels)} detections"))
271
- except Exception:
272
- gallery.append((Image.open(path).convert('RGB'), f"{os.path.basename(path)}\n{len(det_labels)} detections"))
273
- else:
274
- gallery.append((Image.open(path).convert('RGB'), f"{os.path.basename(path)}\n{len(det_labels)} detections"))
275
-
276
- # write CSV
277
- csv_path = None
278
- try:
279
- tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", prefix="yolo_preds_", dir=BASE_DIR, mode="w", newline='', encoding='utf-8')
280
- writer = csv.writer(tmp)
281
- writer.writerow(["filename", "num_detections", "labels_with_conf", "boxes", "raw_boxes"])
282
- for r in table_rows:
283
- writer.writerow(r)
284
- tmp.flush()
285
- tmp.close()
286
- csv_path = tmp.name
287
- except Exception as e:
288
- print("Failed to write CSV:", e)
289
  csv_path = None
290
 
291
  return gallery, table_rows, csv_path
@@ -295,7 +340,7 @@ single = gr.Interface(
295
  fn=classify_image,
296
  inputs=gr.Image(type="pil", label="Upload Underwater Image"),
297
  outputs=gr.Label(num_top_classes=len(ID2LABEL), label="Species Classification"),
298
- title="🌊 BenthicAI - Single Image",
299
  description="Classify one image into one of 7 benthic species."
300
  )
301
 
@@ -308,32 +353,14 @@ batch = gr.Interface(
308
  headers=["filename", "top1_label", "top1_conf", "top3_labels", "top3_confs"],
309
  label="Predictions Table",
310
  wrap=True
311
- )
312
- , gr.File(label="Download CSV")
313
  ],
314
- title="🌊 BenthicAI - Batch (up to 10)",
315
- description="Upload multiple images (max 10). Outputs a gallery with captions and a table of top predictions.",
316
  )
317
 
318
  demo = gr.TabbedInterface([single, batch], ["Single", "Batch"])
319
- print(YOLO==None, flush=True)
320
- # Add Object Detection tab if ultralytics available
321
- if YOLO is not None:
322
- detection_iface = gr.Interface(
323
- fn=detect_objects_batch,
324
- inputs=[gr.Files(label="Upload images for detection (max 10)"), gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="conf threshold"), gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="IOU threshold")],
325
- outputs=[
326
- gr.Gallery(label="Detections (annotated)", height=500, rows=3),
327
- gr.Dataframe(headers=["filename", "num_detections", "labels_with_conf", "boxes", "raw_boxes"], label="Detection Table"),
328
- gr.File(label="Download CSV")
329
- ],
330
- title="🌊 BenthicAI - Object Detection",
331
- description="Run YOLO object detection on multiple images. Requires 'yolo11_best.pt' in project root."
332
- )
333
-
334
- # extend tabs
335
- demo = gr.TabbedInterface([single, batch, detection_iface], ["Single", "Batch", "Detection"])
336
 
337
  if __name__ == "__main__":
338
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
339
-
 
1
+ # # app.py β€” Object Detection only (multi-image YOLO, up to 10)
2
+ # import os
3
+ # import csv
4
+ # import tempfile
5
+ # from pathlib import Path
6
+ # from typing import List, Tuple
7
+
8
+ # import gradio as gr
9
+ # from PIL import Image
10
+
11
+ # # Try import ultralytics (ensure it's in requirements.txt)
12
+ # try:
13
+ # from ultralytics import YOLO
14
+ # except Exception:
15
+ # YOLO = None
16
+
17
+ # BASE_DIR = os.path.dirname(os.path.abspath(__file__))
18
+ # MAX_BATCH = 10
19
+
20
+ # # Option A: local file baked into Space (easiest if allowed)
21
+ # YOLO_WEIGHTS = os.path.join(BASE_DIR, "yolo11_best.pt")
22
+
23
+ # # Option B (optional): pull from a private HF model repo using a Space secret
24
+ # # Set these env vars in your Space if you want auto-download:
25
+ # # HF_TOKEN=<read token> YOLO_REPO_ID="yourname/yolo-detector"
26
+ # HF_TOKEN = os.environ.get("HF_TOKEN")
27
+ # YOLO_REPO_ID = os.environ.get("YOLO_REPO_ID")
28
+
29
+ # def _download_from_hub_if_needed() -> str | None:
30
+ # """If YOLO_REPO_ID is set, download weights with huggingface_hub; else return None."""
31
+ # if not YOLO_REPO_ID:
32
+ # return None
33
+ # try:
34
+ # from huggingface_hub import snapshot_download
35
+ # local_dir = snapshot_download(
36
+ # repo_id=YOLO_REPO_ID, repo_type="model", token=HF_TOKEN
37
+ # )
38
+ # # try common filenames
39
+ # for name in ("yolo11_best.pt", "best.pt", "yolo.pt", "weights.pt"):
40
+ # cand = Path(local_dir) / name
41
+ # if cand.exists():
42
+ # return str(cand)
43
+ # except Exception as e:
44
+ # print("[YOLO] Hub download failed:", e)
45
+ # return None
46
+
47
+ # _yolo_model = None
48
+ # def _load_yolo():
49
+ # """Load YOLO weights either from local file or HF Hub."""
50
+ # global _yolo_model
51
+ # if _yolo_model is not None:
52
+ # return _yolo_model
53
+ # if YOLO is None:
54
+ # raise RuntimeError("ultralytics package not installed. Add 'ultralytics' to requirements.txt")
55
+
56
+ # model_path = None
57
+ # if os.path.exists(YOLO_WEIGHTS):
58
+ # model_path = YOLO_WEIGHTS
59
+ # else:
60
+ # hub_path = _download_from_hub_if_needed()
61
+ # if hub_path:
62
+ # model_path = hub_path
63
+
64
+ # if not model_path:
65
+ # raise FileNotFoundError(
66
+ # "YOLO weights not found. Either include 'yolo11_best.pt' in the repo root, "
67
+ # "or set YOLO_REPO_ID (+ HF_TOKEN if private) to pull from the Hub."
68
+ # )
69
+
70
+ # _yolo_model = YOLO(model_path)
71
+ # return _yolo_model
72
+
73
+ # def detect_objects_batch(files, conf=0.25, iou=0.25):
74
+ # """
75
+ # Run YOLO detection on multiple images (up to 10).
76
+ # Returns: gallery of annotated images, rows table, csv filepath
77
+ # """
78
+ # if YOLO is None:
79
+ # return [], [], None
80
+ # if not files:
81
+ # return [], [], None
82
+
83
+ # try:
84
+ # ymodel = _load_yolo()
85
+ # except Exception as e:
86
+ # print("YOLO load error:", e)
87
+ # return [], [], None
88
+
89
+ # gallery, table_rows = [], []
90
+
91
+ # for f in files[:MAX_BATCH]:
92
+ # path = getattr(f, "name", None) or getattr(f, "path", None) or f
93
+ # try:
94
+ # results = ymodel.predict(source=path, conf=conf, iou=iou, imgsz=640, verbose=False)
95
+ # except Exception as e:
96
+ # print(f"Detection failed for {path}:", e)
97
+ # continue
98
+ # res = results[0]
99
+
100
+ # # annotated image
101
+ # ann_path = None
102
+ # try:
103
+ # ann_img = res.plot()
104
+ # ann_pil = Image.fromarray(ann_img)
105
+ # out_dir = tempfile.mkdtemp(prefix="yolo_out_", dir=BASE_DIR)
106
+ # os.makedirs(out_dir, exist_ok=True)
107
+ # ann_filename = Path(path).stem + "_annotated.jpg"
108
+ # ann_path = os.path.join(out_dir, ann_filename)
109
+ # ann_pil.save(ann_path)
110
+ # except Exception:
111
+ # try:
112
+ # out_dir = tempfile.mkdtemp(prefix="yolo_out_", dir=BASE_DIR)
113
+ # res.save(save_dir=out_dir)
114
+ # saved_files = getattr(res, "files", [])
115
+ # ann_path = saved_files[0] if saved_files else None
116
+ # except Exception:
117
+ # ann_path = None
118
+
119
+ # # extract detections
120
+ # boxes = getattr(res, "boxes", None)
121
+ # if boxes is None or len(boxes) == 0:
122
+ # table_rows.append([os.path.basename(path), 0, "", "", ""])
123
+ # img_for_gallery = Image.open(ann_path).convert("RGB") if ann_path and os.path.exists(ann_path) \
124
+ # else Image.open(path).convert("RGB")
125
+ # gallery.append((img_for_gallery, f"{os.path.basename(path)}\nNo detections"))
126
+ # continue
127
+
128
+ # det_labels, det_scores, det_boxes = [], [], []
129
+ # for box in boxes:
130
+ # cls = int(box.cls.cpu().item()) if hasattr(box, "cls") else None
131
+ # # conf
132
+ # try:
133
+ # confscore = float(box.conf.cpu().item()) if hasattr(box, "conf") else None
134
+ # except Exception:
135
+ # try:
136
+ # confscore = float(box.conf.item())
137
+ # except Exception:
138
+ # confscore = None
139
+ # # xyxy
140
+ # coords = []
141
+ # if hasattr(box, "xyxy"):
142
+ # try:
143
+ # arr = box.xyxy.cpu().numpy()
144
+ # if getattr(arr, "ndim", None) == 2 and arr.shape[0] == 1:
145
+ # coords = arr[0].tolist()
146
+ # elif getattr(arr, "ndim", None) == 1:
147
+ # coords = arr.tolist()
148
+ # else:
149
+ # coords = arr.reshape(-1).tolist()
150
+ # except Exception:
151
+ # try:
152
+ # coords = box.xyxy.tolist()
153
+ # except Exception:
154
+ # coords = []
155
+
156
+ # det_labels.append(ymodel.names.get(cls, str(cls)) if cls is not None else "")
157
+ # det_scores.append(round(confscore, 4) if confscore is not None else "")
158
+ # try:
159
+ # det_boxes.append([round(float(x), 2) for x in coords])
160
+ # except Exception:
161
+ # det_boxes.append([str(coords)])
162
+
163
+ # label_conf_pairs = [f"{l}:{s}" for l, s in zip(det_labels, det_scores)]
164
+ # boxes_repr = ["[" + ", ".join(map(str, b)) + "]" for b in det_boxes]
165
+ # table_rows.append([
166
+ # os.path.basename(path),
167
+ # len(det_labels),
168
+ # ", ".join(label_conf_pairs),
169
+ # ", ".join(boxes_repr),
170
+ # "; ".join([str(b) for b in det_boxes]),
171
+ # ])
172
+
173
+ # img_for_gallery = Image.open(ann_path).convert("RGB") if ann_path and os.path.exists(ann_path) \
174
+ # else Image.open(path).convert("RGB")
175
+ # gallery.append((img_for_gallery, f"{os.path.basename(path)}\n{len(det_labels)} detections"))
176
+
177
+ # # write CSV
178
+ # csv_path = None
179
+ # try:
180
+ # tmp = tempfile.NamedTemporaryFile(
181
+ # delete=False, suffix=".csv", prefix="yolo_preds_", dir=BASE_DIR,
182
+ # mode="w", newline='', encoding='utf-8'
183
+ # )
184
+ # writer = csv.writer(tmp)
185
+ # writer.writerow(["filename", "num_detections", "labels_with_conf", "boxes", "raw_boxes"])
186
+ # for r in table_rows:
187
+ # writer.writerow(r)
188
+ # tmp.flush(); tmp.close()
189
+ # csv_path = tmp.name
190
+ # except Exception as e:
191
+ # print("Failed to write CSV:", e)
192
+ # csv_path = None
193
+
194
+ # return gallery, table_rows, csv_path
195
+
196
+ # # ---------- UI ----------
197
+ # if YOLO is None:
198
+ # demo = gr.Interface(
199
+ # fn=lambda *a, **k: ("Ultralytics not installed; add 'ultralytics' to requirements.txt",),
200
+ # inputs=[],
201
+ # outputs="text",
202
+ # title="🌊 BenthicAI β€” Object Detection",
203
+ # description="Ultralytics is not installed."
204
+ # )
205
+ # else:
206
+ # demo = gr.Interface(
207
+ # fn=detect_objects_batch,
208
+ # inputs=[
209
+ # gr.Files(label="Upload images (max 10)"),
210
+ # gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="Conf threshold"),
211
+ # gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="IoU threshold"),
212
+ # ],
213
+ # outputs=[
214
+ # gr.Gallery(label="Detections (annotated)", height=500, rows=3),
215
+ # gr.Dataframe(headers=["filename", "num_detections", "labels_with_conf", "boxes", "raw_boxes"],
216
+ # label="Detection Table"),
217
+ # gr.File(label="Download CSV"),
218
+ # ],
219
+ # title="🌊 BenthicAI β€” Object Detection",
220
+ # description=(
221
+ # "Run YOLO object detection on multiple images. "
222
+ # "Place 'yolo11_best.pt' in the repo root, OR set YOLO_REPO_ID (+ HF_TOKEN if private) "
223
+ # "to fetch from the Hub."
224
+ # ),
225
+ # )
226
+
227
+ # if __name__ == "__main__":
228
+ # demo.launch(server_name="0.0.0.0", server_port=7860)
229
+ # app.py β€” Image Classification only (single + batch up to 10)
230
+ import os
231
+ import csv
232
+ import tempfile
233
+ from pathlib import Path
234
+ from typing import List, Tuple
235
+
236
  import gradio as gr
237
  import torch
238
  import torch.nn.functional as F
239
  from transformers import AutoImageProcessor, AutoModelForImageClassification
240
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
241
 
242
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
243
+ MODEL_ID = "dshi01/convnext-tiny-224-7clss" # your HF model repo id
244
+ PROCESSOR_ID = "facebook/convnext-tiny-224" # feature extractor
245
 
246
+ print(f"[IC] Loading model: {MODEL_ID}")
247
+ processor = AutoImageProcessor.from_pretrained(PROCESSOR_ID)
248
  model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
249
  model.eval()
250
 
251
+ # Build id2label list (stable order)
252
  ID2LABEL = [
253
  model.config.id2label.get(str(i), model.config.id2label.get(i, f"Label_{i}"))
254
  for i in range(model.config.num_labels)
255
  ]
256
+
257
  def classify_image(image):
258
+ """Single-image classification."""
259
  if not isinstance(image, Image.Image):
260
  image = Image.fromarray(image).convert("RGB")
261
 
 
266
 
267
  return {ID2LABEL[i]: float(p) for i, p in enumerate(probs)}
268
 
 
269
  MAX_BATCH = 10
270
 
271
  def classify_images_batch(files):
272
  """
273
+ Batch classification (up to 10).
274
+ Returns: gallery [(img, caption)], table rows, CSV filepath
 
 
275
  """
276
  if not files:
277
  return [], [], None
278
 
 
279
  files = files[:MAX_BATCH]
280
 
281
+ # Load PILs
282
  pil_images, names = [], []
283
  for f in files:
284
  path = getattr(f, "name", None) or getattr(f, "path", None) or f
 
287
  pil_images.append(img)
288
  names.append(os.path.basename(path))
289
  except Exception:
 
290
  continue
291
 
292
  if not pil_images:
293
  return [], [], None
294
 
 
295
  inputs = processor(images=pil_images, return_tensors="pt")
296
  with torch.no_grad():
297
  logits = model(**inputs).logits
298
  probs = F.softmax(logits, dim=1)
299
 
 
300
  gallery = []
301
  table_rows = [] # [filename, top1_label, top1_conf, top3_labels, top3_confs]
302
 
 
305
  top_idxs = sorted(range(len(p)), key=lambda i: p[i], reverse=True)[:3]
306
  top1 = top_idxs[0]
307
  caption = f"{ID2LABEL[top1]} ({p[top1]:.2%})"
 
308
  gallery.append((img, f"{fname}\n{caption}"))
309
 
310
  top3_labels = [ID2LABEL[i] for i in top_idxs]
 
320
  # Create CSV for download
321
  csv_path = None
322
  try:
323
+ tmp = tempfile.NamedTemporaryFile(
324
+ delete=False, suffix=".csv", prefix="predictions_", dir=BASE_DIR,
325
+ mode="w", newline='', encoding='utf-8'
326
+ )
327
  writer = csv.writer(tmp)
 
328
  writer.writerow(["filename", "top1_label", "top1_conf", "top3_labels", "top3_confs"])
329
  for row in table_rows:
330
  writer.writerow(row)
331
+ tmp.flush(); tmp.close()
 
332
  csv_path = tmp.name
333
  except Exception:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  csv_path = None
335
 
336
  return gallery, table_rows, csv_path
 
340
  fn=classify_image,
341
  inputs=gr.Image(type="pil", label="Upload Underwater Image"),
342
  outputs=gr.Label(num_top_classes=len(ID2LABEL), label="Species Classification"),
343
+ title="🌊 BenthicAI β€” Single Image",
344
  description="Classify one image into one of 7 benthic species."
345
  )
346
 
 
353
  headers=["filename", "top1_label", "top1_conf", "top3_labels", "top3_confs"],
354
  label="Predictions Table",
355
  wrap=True
356
+ ),
357
+ gr.File(label="Download CSV")
358
  ],
359
+ title="🌊 BenthicAI β€” Batch (up to 10)",
360
+ description="Upload multiple images (max 10)."
361
  )
362
 
363
  demo = gr.TabbedInterface([single, batch], ["Single", "Batch"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
  if __name__ == "__main__":
366
+ demo.launch(server_name="0.0.0.0", server_port=7860)