Spaces:
Sleeping
Sleeping
| # # app.py β Object Detection only (multi-image YOLO, up to 10) | |
| # import os | |
| # import csv | |
| # import tempfile | |
| # from pathlib import Path | |
| # from typing import List, Tuple | |
| # import gradio as gr | |
| # from PIL import Image | |
| # # Try import ultralytics (ensure it's in requirements.txt) | |
| # try: | |
| # from ultralytics import YOLO | |
| # except Exception: | |
| # YOLO = None | |
| # BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| # MAX_BATCH = 10 | |
| # # Option A: local file baked into Space (easiest if allowed) | |
| # YOLO_WEIGHTS = os.path.join(BASE_DIR, "yolo11_best.pt") | |
| # # Option B (optional): pull from a private HF model repo using a Space secret | |
| # # Set these env vars in your Space if you want auto-download: | |
| # # HF_TOKEN=<read token> YOLO_REPO_ID="yourname/yolo-detector" | |
| # HF_TOKEN = os.environ.get("HF_TOKEN") | |
| # YOLO_REPO_ID = os.environ.get("YOLO_REPO_ID") | |
| # def _download_from_hub_if_needed() -> str | None: | |
| # """If YOLO_REPO_ID is set, download weights with huggingface_hub; else return None.""" | |
| # if not YOLO_REPO_ID: | |
| # return None | |
| # try: | |
| # from huggingface_hub import snapshot_download | |
| # local_dir = snapshot_download( | |
| # repo_id=YOLO_REPO_ID, repo_type="model", token=HF_TOKEN | |
| # ) | |
| # # try common filenames | |
| # for name in ("yolo11_best.pt", "best.pt", "yolo.pt", "weights.pt"): | |
| # cand = Path(local_dir) / name | |
| # if cand.exists(): | |
| # return str(cand) | |
| # except Exception as e: | |
| # print("[YOLO] Hub download failed:", e) | |
| # return None | |
| # _yolo_model = None | |
| # def _load_yolo(): | |
| # """Load YOLO weights either from local file or HF Hub.""" | |
| # global _yolo_model | |
| # if _yolo_model is not None: | |
| # return _yolo_model | |
| # if YOLO is None: | |
| # raise RuntimeError("ultralytics package not installed. Add 'ultralytics' to requirements.txt") | |
| # model_path = None | |
| # if os.path.exists(YOLO_WEIGHTS): | |
| # model_path = YOLO_WEIGHTS | |
| # else: | |
| # hub_path = _download_from_hub_if_needed() | |
| # if hub_path: | |
| # model_path = hub_path | |
| # if not model_path: | |
| # raise FileNotFoundError( | |
| # "YOLO weights not found. Either include 'yolo11_best.pt' in the repo root, " | |
| # "or set YOLO_REPO_ID (+ HF_TOKEN if private) to pull from the Hub." | |
| # ) | |
| # _yolo_model = YOLO(model_path) | |
| # return _yolo_model | |
| # def detect_objects_batch(files, conf=0.25, iou=0.25): | |
| # """ | |
| # Run YOLO detection on multiple images (up to 10). | |
| # Returns: gallery of annotated images, rows table, csv filepath | |
| # """ | |
| # if YOLO is None: | |
| # return [], [], None | |
| # if not files: | |
| # return [], [], None | |
| # try: | |
| # ymodel = _load_yolo() | |
| # except Exception as e: | |
| # print("YOLO load error:", e) | |
| # return [], [], None | |
| # gallery, table_rows = [], [] | |
| # for f in files[:MAX_BATCH]: | |
| # path = getattr(f, "name", None) or getattr(f, "path", None) or f | |
| # try: | |
| # results = ymodel.predict(source=path, conf=conf, iou=iou, imgsz=640, verbose=False) | |
| # except Exception as e: | |
| # print(f"Detection failed for {path}:", e) | |
| # continue | |
| # res = results[0] | |
| # # annotated image | |
| # ann_path = None | |
| # try: | |
| # ann_img = res.plot() | |
| # ann_pil = Image.fromarray(ann_img) | |
| # out_dir = tempfile.mkdtemp(prefix="yolo_out_", dir=BASE_DIR) | |
| # os.makedirs(out_dir, exist_ok=True) | |
| # ann_filename = Path(path).stem + "_annotated.jpg" | |
| # ann_path = os.path.join(out_dir, ann_filename) | |
| # ann_pil.save(ann_path) | |
| # except Exception: | |
| # try: | |
| # out_dir = tempfile.mkdtemp(prefix="yolo_out_", dir=BASE_DIR) | |
| # res.save(save_dir=out_dir) | |
| # saved_files = getattr(res, "files", []) | |
| # ann_path = saved_files[0] if saved_files else None | |
| # except Exception: | |
| # ann_path = None | |
| # # extract detections | |
| # boxes = getattr(res, "boxes", None) | |
| # if boxes is None or len(boxes) == 0: | |
| # table_rows.append([os.path.basename(path), 0, "", "", ""]) | |
| # img_for_gallery = Image.open(ann_path).convert("RGB") if ann_path and os.path.exists(ann_path) \ | |
| # else Image.open(path).convert("RGB") | |
| # gallery.append((img_for_gallery, f"{os.path.basename(path)}\nNo detections")) | |
| # continue | |
| # det_labels, det_scores, det_boxes = [], [], [] | |
| # for box in boxes: | |
| # cls = int(box.cls.cpu().item()) if hasattr(box, "cls") else None | |
| # # conf | |
| # try: | |
| # confscore = float(box.conf.cpu().item()) if hasattr(box, "conf") else None | |
| # except Exception: | |
| # try: | |
| # confscore = float(box.conf.item()) | |
| # except Exception: | |
| # confscore = None | |
| # # xyxy | |
| # coords = [] | |
| # if hasattr(box, "xyxy"): | |
| # try: | |
| # arr = box.xyxy.cpu().numpy() | |
| # if getattr(arr, "ndim", None) == 2 and arr.shape[0] == 1: | |
| # coords = arr[0].tolist() | |
| # elif getattr(arr, "ndim", None) == 1: | |
| # coords = arr.tolist() | |
| # else: | |
| # coords = arr.reshape(-1).tolist() | |
| # except Exception: | |
| # try: | |
| # coords = box.xyxy.tolist() | |
| # except Exception: | |
| # coords = [] | |
| # det_labels.append(ymodel.names.get(cls, str(cls)) if cls is not None else "") | |
| # det_scores.append(round(confscore, 4) if confscore is not None else "") | |
| # try: | |
| # det_boxes.append([round(float(x), 2) for x in coords]) | |
| # except Exception: | |
| # det_boxes.append([str(coords)]) | |
| # label_conf_pairs = [f"{l}:{s}" for l, s in zip(det_labels, det_scores)] | |
| # boxes_repr = ["[" + ", ".join(map(str, b)) + "]" for b in det_boxes] | |
| # table_rows.append([ | |
| # os.path.basename(path), | |
| # len(det_labels), | |
| # ", ".join(label_conf_pairs), | |
| # ", ".join(boxes_repr), | |
| # "; ".join([str(b) for b in det_boxes]), | |
| # ]) | |
| # img_for_gallery = Image.open(ann_path).convert("RGB") if ann_path and os.path.exists(ann_path) \ | |
| # else Image.open(path).convert("RGB") | |
| # gallery.append((img_for_gallery, f"{os.path.basename(path)}\n{len(det_labels)} detections")) | |
| # # write CSV | |
| # csv_path = None | |
| # try: | |
| # tmp = tempfile.NamedTemporaryFile( | |
| # delete=False, suffix=".csv", prefix="yolo_preds_", dir=BASE_DIR, | |
| # mode="w", newline='', encoding='utf-8' | |
| # ) | |
| # writer = csv.writer(tmp) | |
| # writer.writerow(["filename", "num_detections", "labels_with_conf", "boxes", "raw_boxes"]) | |
| # for r in table_rows: | |
| # writer.writerow(r) | |
| # tmp.flush(); tmp.close() | |
| # csv_path = tmp.name | |
| # except Exception as e: | |
| # print("Failed to write CSV:", e) | |
| # csv_path = None | |
| # return gallery, table_rows, csv_path | |
| # # ---------- UI ---------- | |
| # if YOLO is None: | |
| # demo = gr.Interface( | |
| # fn=lambda *a, **k: ("Ultralytics not installed; add 'ultralytics' to requirements.txt",), | |
| # inputs=[], | |
| # outputs="text", | |
| # title="π BenthicAI β Object Detection", | |
| # description="Ultralytics is not installed." | |
| # ) | |
| # else: | |
| # demo = gr.Interface( | |
| # fn=detect_objects_batch, | |
| # inputs=[ | |
| # gr.Files(label="Upload images (max 10)"), | |
| # gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="Conf threshold"), | |
| # gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.01, label="IoU threshold"), | |
| # ], | |
| # outputs=[ | |
| # gr.Gallery(label="Detections (annotated)", height=500, rows=3), | |
| # gr.Dataframe(headers=["filename", "num_detections", "labels_with_conf", "boxes", "raw_boxes"], | |
| # label="Detection Table"), | |
| # gr.File(label="Download CSV"), | |
| # ], | |
| # title="π BenthicAI β Object Detection", | |
| # description=( | |
| # "Run YOLO object detection on multiple images. " | |
| # "Place 'yolo11_best.pt' in the repo root, OR set YOLO_REPO_ID (+ HF_TOKEN if private) " | |
| # "to fetch from the Hub." | |
| # ), | |
| # ) | |
| # if __name__ == "__main__": | |
| # demo.launch(server_name="0.0.0.0", server_port=7860) | |
| # app.py β Image Classification only (single + batch up to 10) | |
| import os | |
| import csv | |
| import tempfile | |
| from pathlib import Path | |
| from typing import List, Tuple | |
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| from PIL import Image | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| MODEL_ID = "dshi01/convnext-tiny-224-7clss" # your HF model repo id | |
| PROCESSOR_ID = "facebook/convnext-tiny-224" # feature extractor | |
| print(f"[IC] Loading model: {MODEL_ID}") | |
| processor = AutoImageProcessor.from_pretrained(PROCESSOR_ID) | |
| model = AutoModelForImageClassification.from_pretrained(MODEL_ID) | |
| model.eval() | |
| # Build id2label list (stable order) | |
| ID2LABEL = [ | |
| model.config.id2label.get(str(i), model.config.id2label.get(i, f"Label_{i}")) | |
| for i in range(model.config.num_labels) | |
| ] | |
| def classify_image(image): | |
| """Single-image classification.""" | |
| if not isinstance(image, Image.Image): | |
| image = Image.fromarray(image).convert("RGB") | |
| inputs = processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| probs = F.softmax(logits, dim=1)[0].tolist() | |
| return {ID2LABEL[i]: float(p) for i, p in enumerate(probs)} | |
| MAX_BATCH = 10 | |
| def classify_images_batch(files): | |
| """ | |
| Batch classification (up to 10). | |
| Returns: gallery [(img, caption)], table rows, CSV filepath | |
| """ | |
| if not files: | |
| return [], [], None | |
| files = files[:MAX_BATCH] | |
| # Load PILs | |
| pil_images, names = [], [] | |
| for f in files: | |
| path = getattr(f, "name", None) or getattr(f, "path", None) or f | |
| try: | |
| img = Image.open(path).convert("RGB") | |
| pil_images.append(img) | |
| names.append(os.path.basename(path)) | |
| except Exception: | |
| continue | |
| if not pil_images: | |
| return [], [], None | |
| inputs = processor(images=pil_images, return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| probs = F.softmax(logits, dim=1) | |
| gallery = [] | |
| table_rows = [] # [filename, top1_label, top1_conf, top3_labels, top3_confs] | |
| for idx, (img, fname) in enumerate(zip(pil_images, names)): | |
| p = probs[idx].tolist() | |
| top_idxs = sorted(range(len(p)), key=lambda i: p[i], reverse=True)[:3] | |
| top1 = top_idxs[0] | |
| caption = f"{ID2LABEL[top1]} ({p[top1]:.2%})" | |
| gallery.append((img, f"{fname}\n{caption}")) | |
| top3_labels = [ID2LABEL[i] for i in top_idxs] | |
| top3_scores = [round(p[i], 4) for i in top_idxs] | |
| table_rows.append([ | |
| fname, | |
| ID2LABEL[top1], | |
| round(p[top1], 4), | |
| ", ".join(top3_labels), | |
| ", ".join(map(str, top3_scores)), | |
| ]) | |
| # Create CSV for download | |
| csv_path = None | |
| try: | |
| tmp = tempfile.NamedTemporaryFile( | |
| delete=False, suffix=".csv", prefix="predictions_", dir=BASE_DIR, | |
| mode="w", newline='', encoding='utf-8' | |
| ) | |
| writer = csv.writer(tmp) | |
| writer.writerow(["filename", "top1_label", "top1_conf", "top3_labels", "top3_confs"]) | |
| for row in table_rows: | |
| writer.writerow(row) | |
| tmp.flush(); tmp.close() | |
| csv_path = tmp.name | |
| except Exception: | |
| csv_path = None | |
| return gallery, table_rows, csv_path | |
| # ---------- UI ---------- | |
| single = gr.Interface( | |
| fn=classify_image, | |
| inputs=gr.Image(type="pil", label="Upload Underwater Image"), | |
| outputs=gr.Label(num_top_classes=len(ID2LABEL), label="Species Classification"), | |
| title="π BenthicAI β Single Image", | |
| description="Classify one image into one of 7 benthic species." | |
| ) | |
| batch = gr.Interface( | |
| fn=classify_images_batch, | |
| inputs=gr.Files(label="Upload up to 10 images"), | |
| outputs=[ | |
| gr.Gallery(label="Results (Top-1 in caption)", height=500, rows=3), | |
| gr.Dataframe( | |
| headers=["filename", "top1_label", "top1_conf", "top3_labels", "top3_confs"], | |
| label="Predictions Table", | |
| wrap=True | |
| ), | |
| gr.File(label="Download CSV") | |
| ], | |
| title="π BenthicAI β Batch (up to 10)", | |
| description="Upload multiple images (max 10)." | |
| ) | |
| demo = gr.TabbedInterface([single, batch], ["Single", "Batch"]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |