danielhshi8224 commited on
Commit
879e1cd
Β·
1 Parent(s): c458c3e

add object detection

Browse files
Files changed (2) hide show
  1. app.py +198 -112
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,115 +1,4 @@
1
- # import gradio as gr
2
- # import torch
3
- # from transformers import AutoImageProcessor, AutoModelForImageClassification
4
- # from PIL import Image
5
- # import os
6
-
7
- # # Get model path (Windows compatible)
8
- # BASE_DIR = os.path.dirname(os.path.abspath(__file__))
9
- # MODEL_ID = "dshi01/convnext-tiny-224-7clss"
10
-
11
- # # Try different possible filenames
12
- # # possible_names = ['ConvNextmodel.pth', 'convnextmodel.pth', 'ConvNext_model.pth']
13
- # # model_path = None
14
-
15
- # # for name in possible_names:
16
- # # test_path = os.path.join(BASE_DIR, name)
17
- # # if os.path.exists(test_path):
18
- # # model_path = test_path
19
- # # print(f"βœ“ Found model: {name}")
20
- # # break
21
-
22
- # # if model_path is None:
23
- # # raise FileNotFoundError(f"Could not find model file. Tried: {possible_names}")
24
-
25
- # # Species categories (7 classes)
26
- # SPECIES_CATEGORIES = [
27
- # 'Eel',
28
- # 'Scallop',
29
- # 'Crab',
30
- # 'Flatfish',
31
- # 'Roundfish',
32
- # 'Skate',
33
- # 'Whelk'
34
- # ]
35
-
36
- # # Load model
37
- # print(f"Loading model from: {MODEL_ID}")
38
- # # model = AutoModelForImageClassification.from_pretrained(
39
- # # 'facebook/convnext-tiny-224',
40
- # # num_labels=7,
41
- # # ignore_mismatched_sizes=True
42
- # # )
43
- # processor=AutoImageProcessor.from_pretrained('facebook/convnext-tiny-224')
44
- # model = AutoModelForImageClassification.from_pretrained(MODEL_ID)
45
-
46
- # # Load weights
47
- # # checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
48
- # # if isinstance(checkpoint, dict):
49
- # # if 'model' in checkpoint:
50
- # # checkpoint = checkpoint['model']
51
- # # elif 'state_dict' in checkpoint:
52
- # # checkpoint = checkpoint['state_dict']
53
-
54
- # # model.load_state_dict(checkpoint, strict=False)
55
- # # model.eval()
56
-
57
- # # Load processor
58
- # # processor = AutoImageProcessor.from_pretrained('facebook/convnext-tiny-224')
59
- # # print("βœ“ Model loaded successfully!")
60
-
61
- # def classify_image(image):
62
- # """
63
- # Classify a benthic species image.
64
-
65
- # Args:
66
- # image: PIL Image or numpy array
67
-
68
- # Returns:
69
- # dict: Predictions with species names and confidence scores
70
- # """
71
- # # Convert to PIL if needed
72
- # if not isinstance(image, Image.Image):
73
- # image = Image.fromarray(image).convert('RGB')
74
-
75
- # # Preprocess
76
- # inputs = processor(images=image, return_tensors="pt")
77
-
78
- # # Predict
79
- # with torch.no_grad():
80
- # outputs = model(**inputs)
81
- # logits = outputs.logits
82
- # probabilities = torch.nn.functional.softmax(logits, dim=1)
83
-
84
- # # Create results dictionary for Gradio
85
- # results = {}
86
- # for idx, prob in enumerate(probabilities[0]):
87
- # results[SPECIES_CATEGORIES[idx]] = float(prob)
88
-
89
- # return results
90
-
91
- # # Create Gradio interface
92
- # demo = gr.Interface(
93
- # fn=classify_image,
94
- # inputs=gr.Image(type="pil", label="Upload Underwater Image"),
95
- # outputs=gr.Label(num_top_classes=7, label="Species Classification"),
96
- # title="🌊 BenthicAI - Benthic Species Classifier",
97
- # description="Upload an image of a benthic organism to classify it into one of 7 species categories. Built with ConvNeXT transformer model.",
98
- # examples=[
99
- # [os.path.join("examples", "eel.jpg")],
100
- # [os.path.join("examples", "scallop.jpg")],
101
- # [os.path.join("examples", "crab.jpg")],
102
- # ] if os.path.exists("examples") else None,
103
- # theme=gr.themes.Soft(),
104
- # allow_flagging="never"
105
- # )
106
-
107
- # if __name__ == "__main__":
108
- # demo.launch(
109
- # server_name="0.0.0.0",
110
- # server_port=7860,
111
- # share=True # Set to True to get a public URL
112
- # )
113
  import gradio as gr
114
  import torch
115
  import torch.nn.functional as F
@@ -118,6 +7,13 @@ from PIL import Image
118
  import os
119
  import csv
120
  import tempfile
 
 
 
 
 
 
 
121
 
122
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
123
  MODEL_ID = "dshi01/convnext-tiny-224-7clss"
@@ -221,6 +117,179 @@ def classify_images_batch(files):
221
 
222
  return gallery, table_rows, csv_path
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  # ---------- UI ----------
225
  single = gr.Interface(
226
  fn=classify_image,
@@ -247,6 +316,23 @@ batch = gr.Interface(
247
  )
248
 
249
  demo = gr.TabbedInterface([single, batch], ["Single", "Batch"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
  if __name__ == "__main__":
252
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
1
+ #Main Gradio app ith image classification and object detection tabs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
  import torch
4
  import torch.nn.functional as F
 
7
  import os
8
  import csv
9
  import tempfile
10
+ from pathlib import Path
11
+ from ultralytics import YOLO
12
+ # ultralytics YOLO import (for object detection)
13
+ try:
14
+ from ultralytics import YOLO
15
+ except Exception:
16
+ YOLO = None
17
 
18
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
19
  MODEL_ID = "dshi01/convnext-tiny-224-7clss"
 
117
 
118
  return gallery, table_rows, csv_path
119
 
120
+
121
+ # ---------- NEW: YOLO object detection for multi-image upload ----------
122
+ YOLO_WEIGHTS = os.path.join(BASE_DIR, "yolo11_best.pt")
123
+ _yolo_model = None
124
+ def _load_yolo():
125
+ global _yolo_model
126
+ if _yolo_model is not None:
127
+ return _yolo_model
128
+ if YOLO is None:
129
+ raise RuntimeError("ultralytics package not installed. Please install 'ultralytics'.")
130
+ if not os.path.exists(YOLO_WEIGHTS):
131
+ # Try current directory too
132
+ alt = Path.cwd() / "yolo11_best.pt"
133
+ if alt.exists():
134
+ model_path = str(alt)
135
+ else:
136
+ raise FileNotFoundError(f"YOLO weights not found at {YOLO_WEIGHTS}. Place yolo11_best.pt in project root.")
137
+ else:
138
+ model_path = YOLO_WEIGHTS
139
+
140
+ _yolo_model = YOLO(model_path)
141
+ return _yolo_model
142
+
143
+
144
+ def detect_objects_batch(files, iou=0.25, conf=0.25):
145
+ """
146
+ Run YOLO detection on multiple images.
147
+ Returns: gallery of annotated images, dataframe rows, csv file path
148
+ """
149
+ if YOLO is None:
150
+ return [], [], None
151
+
152
+ if not files:
153
+ return [], [], None
154
+
155
+ # Load model
156
+ try:
157
+ ymodel = _load_yolo()
158
+ except Exception as e:
159
+ print("YOLO load error:", e)
160
+ return [], [], None
161
+
162
+ annotated_paths = []
163
+ table_rows = []
164
+ gallery = []
165
+
166
+ for f in files[:MAX_BATCH]:
167
+ path = getattr(f, "name", None) or getattr(f, "path", None) or f
168
+ try:
169
+ # Run predict; returns a Results object list
170
+ results = ymodel.predict(source=path, conf=conf, iou=iou, imgsz=640, verbose=False)
171
+ except Exception as e:
172
+ print(f"Detection failed for {path}:", e)
173
+ continue
174
+
175
+ # results is list-like; take first
176
+ res = results[0]
177
+
178
+ # Prepare annotation image using res.plot() so boxes+confidences are drawn
179
+ ann_path = None
180
+ try:
181
+ ann_img = res.plot() # returns numpy array with annotations
182
+ from PIL import Image as PILImage
183
+ ann_pil = PILImage.fromarray(ann_img)
184
+ out_dir = tempfile.mkdtemp(prefix="yolo_out_", dir=BASE_DIR)
185
+ os.makedirs(out_dir, exist_ok=True)
186
+ ann_filename = os.path.splitext(os.path.basename(path))[0] + "_annotated.jpg"
187
+ ann_path = os.path.join(out_dir, ann_filename)
188
+ ann_pil.save(ann_path)
189
+ except Exception:
190
+ # Fallback to ultralytics save if plot() isn't available
191
+ try:
192
+ out_dir = tempfile.mkdtemp(prefix="yolo_out_", dir=BASE_DIR)
193
+ res.save(save_dir=out_dir)
194
+ saved_files = res.files if hasattr(res, 'files') else []
195
+ ann_path = saved_files[0] if saved_files else None
196
+ except Exception:
197
+ ann_path = None
198
+
199
+ # Build table rows from detections
200
+ boxes = res.boxes if hasattr(res, 'boxes') else None
201
+ if boxes is None or len(boxes) == 0:
202
+ table_rows.append([os.path.basename(path), 0, "", "", ""])
203
+ if ann_path and os.path.exists(ann_path):
204
+ gallery.append((Image.open(ann_path).convert('RGB'), f"{os.path.basename(path)}\nNo detections"))
205
+ else:
206
+ gallery.append((Image.open(path).convert('RGB'), f"{os.path.basename(path)}\nNo detections"))
207
+ continue
208
+
209
+ det_labels = []
210
+ det_scores = []
211
+ det_boxes = []
212
+ for box in boxes:
213
+ # box.cls, box.conf, box.xyxy
214
+ cls = int(box.cls.cpu().item()) if hasattr(box, 'cls') else None
215
+ # use .item() to extract scalar and avoid numpy deprecation warnings
216
+ if hasattr(box, 'conf'):
217
+ try:
218
+ confscore = float(box.conf.cpu().item())
219
+ except Exception:
220
+ try:
221
+ confscore = float(box.conf.item())
222
+ except Exception:
223
+ confscore = None
224
+ else:
225
+ confscore = None
226
+
227
+ # extract xyxy coords; box.xyxy may be shape (1,4) -> nested list after .tolist()
228
+ coords = []
229
+ if hasattr(box, 'xyxy'):
230
+ try:
231
+ arr = box.xyxy.cpu().numpy()
232
+ # handle nested shape (1,4) or (4,)
233
+ if getattr(arr, 'ndim', None) == 2 and arr.shape[0] == 1:
234
+ coords = arr[0].tolist()
235
+ elif getattr(arr, 'ndim', None) == 1:
236
+ coords = arr.tolist()
237
+ else:
238
+ coords = arr.reshape(-1).tolist()
239
+ except Exception:
240
+ # fallback: try to call tolist()
241
+ try:
242
+ coords = box.xyxy.tolist()
243
+ except Exception:
244
+ coords = []
245
+
246
+ # append detection info
247
+ det_labels.append(ymodel.names.get(cls, str(cls)) if cls is not None else "")
248
+ det_scores.append(round(confscore, 4) if confscore is not None else "")
249
+ # round and store coords
250
+ try:
251
+ det_boxes.append([round(float(x), 2) for x in coords])
252
+ except Exception:
253
+ # fallback: store raw repr
254
+ det_boxes.append([str(coords)])
255
+
256
+ # create readable label:confidence pairs
257
+ label_conf_pairs = [f"{l}:{s}" for l, s in zip(det_labels, det_scores)]
258
+ boxes_repr = ["[" + ", ".join(map(str, b)) + "]" for b in det_boxes]
259
+ table_rows.append([
260
+ os.path.basename(path),
261
+ len(det_labels),
262
+ ", ".join(label_conf_pairs),
263
+ ", ".join(boxes_repr),
264
+ "; ".join([str(b) for b in det_boxes])
265
+ ])
266
+
267
+ # Use annotated image if exists
268
+ if ann_path and os.path.exists(ann_path):
269
+ try:
270
+ gallery.append((Image.open(ann_path).convert('RGB'), f"{os.path.basename(path)}\n{len(det_labels)} detections"))
271
+ except Exception:
272
+ gallery.append((Image.open(path).convert('RGB'), f"{os.path.basename(path)}\n{len(det_labels)} detections"))
273
+ else:
274
+ gallery.append((Image.open(path).convert('RGB'), f"{os.path.basename(path)}\n{len(det_labels)} detections"))
275
+
276
+ # write CSV
277
+ csv_path = None
278
+ try:
279
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv", prefix="yolo_preds_", dir=BASE_DIR, mode="w", newline='', encoding='utf-8')
280
+ writer = csv.writer(tmp)
281
+ writer.writerow(["filename", "num_detections", "labels_with_conf", "boxes", "raw_boxes"])
282
+ for r in table_rows:
283
+ writer.writerow(r)
284
+ tmp.flush()
285
+ tmp.close()
286
+ csv_path = tmp.name
287
+ except Exception as e:
288
+ print("Failed to write CSV:", e)
289
+ csv_path = None
290
+
291
+ return gallery, table_rows, csv_path
292
+
293
  # ---------- UI ----------
294
  single = gr.Interface(
295
  fn=classify_image,
 
316
  )
317
 
318
  demo = gr.TabbedInterface([single, batch], ["Single", "Batch"])
319
+ print(YOLO==None, flush=True)
320
+ # Add Object Detection tab if ultralytics available
321
+ if YOLO is not None:
322
+ detection_iface = gr.Interface(
323
+ fn=detect_objects_batch,
324
+ inputs=[gr.Files(label="Upload images for detection (max 10)"), gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="conf threshold"), gr.Slider(minimum=0.0, maximum=1.0, value=0.25, label="IOU threshold")],
325
+ outputs=[
326
+ gr.Gallery(label="Detections (annotated)", height=500, rows=3),
327
+ gr.Dataframe(headers=["filename", "num_detections", "labels_with_conf", "boxes", "raw_boxes"], label="Detection Table"),
328
+ gr.File(label="Download CSV")
329
+ ],
330
+ title="🌊 BenthicAI - Object Detection",
331
+ description="Run YOLO object detection on multiple images. Requires 'yolo11_best.pt' in project root."
332
+ )
333
+
334
+ # extend tabs
335
+ demo = gr.TabbedInterface([single, batch, detection_iface], ["Single", "Batch", "Detection"])
336
 
337
  if __name__ == "__main__":
338
  demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
requirements.txt CHANGED
@@ -2,4 +2,5 @@ torch
2
  torchvision
3
  transformers
4
  gradio
5
- Pillow
 
 
2
  torchvision
3
  transformers
4
  gradio
5
+ Pillow
6
+ ultralytics