Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import pipeline | |
| from PIL import Image, ImageDraw, ImageFont | |
| from collections import defaultdict | |
| # 1) Zero-shot detector (works on CPU Spaces) | |
| # You can upgrade model to "google/owlv2-base-patch16-ensemble" for higher accuracy (slower). | |
| detector = pipeline( | |
| task="zero-shot-object-detection", | |
| model="google/owlvit-base-patch32" # fast & lightweight | |
| ) | |
| # Keep labels explicit so the model can choose the right class. | |
| # (You can add synonyms like "Bengal tiger", "African lion" if you want.) | |
| LABELS = ["tiger", "lion"] | |
| COLOR_BY_LABEL = { | |
| "tiger": "red", | |
| "lion": "blue", | |
| } | |
| def iou(box_a, box_b): | |
| xA = max(box_a["xmin"], box_b["xmin"]) | |
| yA = max(box_a["ymin"], box_b["ymin"]) | |
| xB = min(box_a["xmax"], box_b["xmax"]) | |
| yB = min(box_a["ymax"], box_b["ymax"]) | |
| inter_w = max(0.0, xB - xA) | |
| inter_h = max(0.0, yB - yA) | |
| inter = inter_w * inter_h | |
| area_a = (box_a["xmax"] - box_a["xmin"]) * (box_a["ymax"] - box_a["ymin"]) | |
| area_b = (box_b["xmax"] - box_b["xmin"]) * (box_b["ymax"] - box_b["ymin"]) | |
| denom = area_a + area_b - inter + 1e-9 | |
| return inter / denom | |
| def nms_single_class(dets, iou_thresh=0.5): | |
| # dets: list of dicts with keys {"box": {...}, "score": float, "label": str} | |
| dets = sorted(dets, key=lambda d: d["score"], reverse=True) | |
| kept = [] | |
| while dets: | |
| best = dets.pop(0) | |
| kept.append(best) | |
| dets = [d for d in dets if iou(best["box"], d["box"]) < iou_thresh] | |
| return kept | |
| def class_aware_nms(dets, iou_thresh=0.5): | |
| # Run NMS separately per class so lions don't suppress tigers (and vice versa) | |
| by_label = defaultdict(list) | |
| for d in dets: | |
| by_label[d["label"].lower()].append(d) | |
| merged = [] | |
| for label, per_class in by_label.items(): | |
| merged.extend(nms_single_class(per_class, iou_thresh=iou_thresh)) | |
| return merged | |
| def annotate(img, dets): | |
| draw = ImageDraw.Draw(img) | |
| try: | |
| font = ImageFont.truetype("DejaVuSans.ttf", 14) | |
| except: | |
| font = None | |
| for d in dets: | |
| b = d["box"] | |
| color = COLOR_BY_LABEL.get(d["label"].lower(), "red") | |
| draw.rectangle([(b["xmin"], b["ymin"]), (b["xmax"], b["ymax"])], outline=color, width=3) | |
| txt = f"{d['label']} {d['score']:.2f}" | |
| # Estimate text width | |
| try: | |
| txt_w = draw.textlength(txt, font=font) | |
| except AttributeError: | |
| txt_w = 8 * len(txt) | |
| pad = 3 | |
| draw.rectangle( | |
| [(b["xmin"], b["ymin"] - 18), (b["xmin"] + txt_w + 2 * pad, b["ymin"])], | |
| fill=color | |
| ) | |
| draw.text((b["xmin"] + pad, b["ymin"] - 16), txt, fill="white", font=font) | |
| return img | |
| def count_big_cats(img, score_threshold, iou_threshold): | |
| # 2) Run zero-shot detection with both labels | |
| preds = detector(img, candidate_labels=LABELS) | |
| # 3) Keep only our labels and apply score filter | |
| preds = [p for p in preds if p["label"].lower() in LABELS and p["score"] >= score_threshold] | |
| # 4) Class-aware NMS | |
| preds = class_aware_nms(preds, iou_thresh=iou_threshold) | |
| # 5) Prepare counts | |
| tiger_count = sum(1 for p in preds if p["label"].lower() == "tiger") | |
| lion_count = sum(1 for p in preds if p["label"].lower() == "lion") | |
| total_count = tiger_count + lion_count | |
| # 6) Draw boxes | |
| img_annotated = annotate(img.copy(), preds) | |
| return tiger_count, lion_count, total_count, img_annotated | |
| TEST_IMAGES = { | |
| "Tiger": "examples/tiger1.png", | |
| "Lion": "examples/tiger2.png", | |
| "Both": "examples/tiger3.png", | |
| } | |
| def load_test_image(choice): | |
| return Image.open(TEST_IMAGES[choice]) | |
| with gr.Blocks(title="Big Cat Counter") as demo: | |
| gr.Markdown("# 🐯🦁 Big Cat Counter\nUpload an image and I’ll count how many **tigers** and **lions** I see.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp = gr.Image(type="pil", label="Input image") | |
| test_selector = gr.Dropdown(list(TEST_IMAGES.keys()), label="Pick a test image") | |
| score_th = gr.Slider(0.05, 0.95, value=0.20, step=0.05, label="Score threshold") | |
| iou_th = gr.Slider(0.1, 0.9, value=0.50, step=0.05, label="IOU (NMS) threshold") | |
| btn = gr.Button("Count Big Cats") | |
| with gr.Column(): | |
| out_tiger = gr.Number(label="Tiger count", precision=0) | |
| out_lion = gr.Number(label="Lion count", precision=0) | |
| out_total = gr.Number(label="Total big cats", precision=0) | |
| out_img = gr.Image(label="Annotated output") | |
| test_selector.change(fn=load_test_image, inputs=test_selector, outputs=inp) | |
| btn.click(fn=count_big_cats, inputs=[inp, score_th, iou_th], outputs=[out_tiger, out_lion, out_total, out_img]) | |
| if __name__ == "__main__": | |
| demo.launch() | |