# # app.py — Object Detection only (multi-image YOLO, up to 10) # import os # import csv # import tempfile # from pathlib import Path # from typing import List, Tuple # import gradio as gr # from PIL import Image # # Try import ultralytics (ensure it's in requirements.txt) # try: # from ultralytics import YOLO # except Exception: # YOLO = None # BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # MAX_BATCH = 10 # # Option A: local file baked into Space (easiest if allowed) # YOLO_WEIGHTS = os.path.join(BASE_DIR, "yolo11_best.pt") # # Option B (optional): pull from a private HF model repo using a Space secret # # Set these env vars in your Space if you want auto-download: # # HF_TOKEN= YOLO_REPO_ID="yourname/yolo-detector" # HF_TOKEN = os.environ.get("HF_TOKEN") # YOLO_REPO_ID = os.environ.get("YOLO_REPO_ID") # def _download_from_hub_if_needed() -> str | None: # """If YOLO_REPO_ID is set, download weights with huggingface_hub; else return None.""" # if not YOLO_REPO_ID: # return None # try: # from huggingface_hub import snapshot_download # local_dir = snapshot_download( # repo_id=YOLO_REPO_ID, repo_type="model", token=HF_TOKEN # ) # # try common filenames # for name in ("yolo11_best.pt", "best.pt", "yolo.pt", "weights.pt"): # cand = Path(local_dir) / name # if cand.exists(): # return str(cand) # except Exception as e: # print("[YOLO] Hub download failed:", e) # return None # _yolo_model = None # def _load_yolo(): # """Load YOLO weights either from local file or HF Hub.""" # global _yolo_model # if _yolo_model is not None: # return _yolo_model # if YOLO is None: # raise RuntimeError("ultralytics package not installed. Add 'ultralytics' to requirements.txt") # model_path = None # if os.path.exists(YOLO_WEIGHTS): # model_path = YOLO_WEIGHTS # else: # hub_path = _download_from_hub_if_needed() # if hub_path: # model_path = hub_path # if not model_path: # raise FileNotFoundError( # "YOLO weights not found. Either include 'yolo11_best.pt' in the repo root, " # "or set YOLO_REPO_ID (+ HF_TOKEN if private) to pull from the Hub." # ) # _yolo_model = YOLO(model_path) # return _yolo_model # def detect_objects_batch(files, conf=0.25, iou=0.25): # """ # Run YOLO detection on multiple images (up to 10). # Returns: gallery of annotated images, rows table, csv filepath # """ # if YOLO is None: # return [], [], None # if not files: # return [], [], None # try: # ymodel = _load_yolo() # except Exception as e: # print("YOLO load error:", e) # return [], [], None # gallery, table_rows = [], [] # for f in files[:MAX_BATCH]: # path = getattr(f, "name", None) or getattr(f, "path", None) or f # try: # results = ymodel.predict(source=path, conf=conf, iou=iou, imgsz=640, verbose=False) # except Exception as e: # print(f"Detection failed for {path}:", e) # continue # res = results[0] # # annotated image # ann_path = None # try: # ann_img = res.plot() # ann_pil = Image.fromarray(ann_img) # out_dir = tempfile.mkdtemp(prefix="yolo_out_", dir=BASE_DIR) # os.makedirs(out_dir, exist_ok=True) # ann_filename = Path(path).stem + "_annotated.jpg" # ann_path = os.path.join(out_dir, ann_filename) # ann_pil.save(ann_path) # except Exception: # try: # out_dir = tempfile.mkdtemp(prefix="yolo_out_", dir=BASE_DIR) # res.save(save_dir=out_dir) # saved_files = getattr(res, "files", []) # ann_path = saved_files[0] if saved_files else None # except Exception: # ann_path = None # # extract detections # boxes = getattr(res, "boxes", None) # if boxes is None or len(boxes) == 0: # table_rows.append([os.path.basename(path), 0, "", "", ""]) # img_for_gallery = Image.open(ann_path).convert("RGB") if ann_path and os.path.exists(ann_path) \ # else Image.open(path).convert("RGB") # gallery.append((img_for_gallery, f"{os.path.basename(path)}\nNo detections")) # continue # det_labels, det_scores, det_boxes = [], [], [] # for box in boxes: # cls = int(box.cls.cpu().item()) if hasattr(box, "cls") else None # # conf # try: # confscore = float(box.conf.cpu().item()) if hasattr(box, "conf") else None # except Exception: # try: # confscore = float(box.conf.item()) # except Exception: # confscore = None # # xyxy # coords = [] # if hasattr(box, "xyxy"): # try: # arr = box.xyxy.cpu().numpy() # if getattr(arr, "ndim", None) == 2 and arr.shape[0] == 1: # coords = arr[0].tolist() # elif getattr(arr, "ndim", None) == 1: # coords = arr.tolist() # else: # coords = arr.reshape(-1).tolist() # except Exception: # try: # coords = box.xyxy.tolist() # except Exception: # coords = [] # det_labels.append(ymodel.names.get(cls, str(cls)) if cls is not None else "") # det_scores.append(round(confscore, 4) if confscore is not None else "") # try: # det_boxes.append([round(float(x), 2) for x in coords]) # except Exception: # det_boxes.append([str(coords)]) # label_conf_pairs = [f"{l}:{s}" for l, s in zip(det_labels, det_scores)] # boxes_repr = ["[" + ", ".join(map(str, b)) + "]" for b in det_boxes] # table_rows.append([ # os.path.basename(path), # len(det_labels), # ", ".join(label_conf_pairs), # ", ".join(boxes_repr), # "; ".join([str(b) for b in det_boxes]), # ]) # img_for_gallery = Image.open(ann_path).convert("RGB") if ann_path and os.path.exists(ann_path) \ # else Image.open(path).convert("RGB") # gallery.append((img_for_gallery, f"{os.path.basename(path)}\n{len(det_labels)} detections")) # # write CSV # csv_path = None # try: # tmp = tempfile.NamedTemporaryFile( # delete=False, suffix=".csv", prefix="yolo_preds_", dir=BASE_DIR, # mode="w", newline='', encoding='utf-8' # ) # writer = csv.writer(tmp) # writer.writerow(["filename", "num_detections", "labels_with_conf", "boxes", "raw_boxes"]) # for r in table_rows: # writer.writerow(r) # tmp.flush(); tmp.close() # csv_path = tmp.name # except Exception as e: # print("Failed to write CSV:", e) # csv_path = None # return gallery, table_rows, csv_path # # ---------- UI ---------- # if YOLO is None: # demo = gr.Interface( # fn=lambda *a, **k: ("Ultralytics not installed; add 'ultralytics' to requirements.txt",), # inputs=[], # outputs="text", # title="🌊 BenthicAI — Object Detection", # description="Ultralytics is not installed." # ) # else: # demo = gr.Interface( # fn=detect_objects_batch, # inputs=[ # gr.Files(label="Upload images (max 10)"), # gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="Conf threshold"), # gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="IoU threshold"), # ], # outputs=[ # gr.Gallery(label="Detections (annotated)", height=500, rows=3), # gr.Dataframe(headers=["filename", "num_detections", "labels_with_conf", "boxes", "raw_boxes"], # label="Detection Table"), # gr.File(label="Download CSV"), # ], # title="🌊 BenthicAI — Object Detection", # description=( # "Run YOLO object detection on multiple images. " # "Place 'yolo11_best.pt' in the repo root, OR set YOLO_REPO_ID (+ HF_TOKEN if private) " # "to fetch from the Hub." # ), # ) # if __name__ == "__main__": # demo.launch(server_name="0.0.0.0", server_port=7860) # app.py — Image Classification only (single + batch up to 10) import os import csv import tempfile from pathlib import Path from typing import List, Tuple import gradio as gr import torch import torch.nn.functional as F from transformers import AutoImageProcessor, AutoModelForImageClassification from PIL import Image BASE_DIR = os.path.dirname(os.path.abspath(__file__)) MODEL_ID = "dshi01/convnext-tiny-224-7clss" # your HF model repo id PROCESSOR_ID = "facebook/convnext-tiny-224" # feature extractor print(f"[IC] Loading model: {MODEL_ID}") processor = AutoImageProcessor.from_pretrained(PROCESSOR_ID) model = AutoModelForImageClassification.from_pretrained(MODEL_ID) model.eval() # Build id2label list (stable order) ID2LABEL = [ model.config.id2label.get(str(i), model.config.id2label.get(i, f"Label_{i}")) for i in range(model.config.num_labels) ] def classify_image(image): """Single-image classification.""" if not isinstance(image, Image.Image): image = Image.fromarray(image).convert("RGB") inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits probs = F.softmax(logits, dim=1)[0].tolist() return {ID2LABEL[i]: float(p) for i, p in enumerate(probs)} MAX_BATCH = 10 def classify_images_batch(files): """ Batch classification (up to 10). Returns: gallery [(img, caption)], table rows, CSV filepath """ if not files: return [], [], None files = files[:MAX_BATCH] # Load PILs pil_images, names = [], [] for f in files: path = getattr(f, "name", None) or getattr(f, "path", None) or f try: img = Image.open(path).convert("RGB") pil_images.append(img) names.append(os.path.basename(path)) except Exception: continue if not pil_images: return [], [], None inputs = processor(images=pil_images, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits probs = F.softmax(logits, dim=1) gallery = [] table_rows = [] # [filename, top1_label, top1_conf, top3_labels, top3_confs] for idx, (img, fname) in enumerate(zip(pil_images, names)): p = probs[idx].tolist() top_idxs = sorted(range(len(p)), key=lambda i: p[i], reverse=True)[:3] top1 = top_idxs[0] caption = f"{ID2LABEL[top1]} ({p[top1]:.2%})" gallery.append((img, f"{fname}\n{caption}")) top3_labels = [ID2LABEL[i] for i in top_idxs] top3_scores = [round(p[i], 4) for i in top_idxs] table_rows.append([ fname, ID2LABEL[top1], round(p[top1], 4), ", ".join(top3_labels), ", ".join(map(str, top3_scores)), ]) # Create CSV for download csv_path = None try: tmp = tempfile.NamedTemporaryFile( delete=False, suffix=".csv", prefix="predictions_", dir=BASE_DIR, mode="w", newline='', encoding='utf-8' ) writer = csv.writer(tmp) writer.writerow(["filename", "top1_label", "top1_conf", "top3_labels", "top3_confs"]) for row in table_rows: writer.writerow(row) tmp.flush(); tmp.close() csv_path = tmp.name except Exception: csv_path = None return gallery, table_rows, csv_path # ---------- UI ---------- single = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil", label="Upload Underwater Image"), outputs=gr.Label(num_top_classes=len(ID2LABEL), label="Species Classification"), title="🌊 BenthicAI — Single Image", description="Classify one image into one of 7 benthic species." ) batch = gr.Interface( fn=classify_images_batch, inputs=gr.Files(label="Upload up to 10 images"), outputs=[ gr.Gallery(label="Results (Top-1 in caption)", height=500, rows=3), gr.Dataframe( headers=["filename", "top1_label", "top1_conf", "top3_labels", "top3_confs"], label="Predictions Table", wrap=True ), gr.File(label="Download CSV") ], title="🌊 BenthicAI — Batch (up to 10)", description="Upload multiple images (max 10)." ) demo = gr.TabbedInterface([single, batch], ["Single", "Batch"]) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)