tiger_counter / app.py
Sergey Kolbin
mn
98de2f1
raw
history blame
6.38 kB
import os
import torch
import gradio as gr
from transformers import pipeline
from PIL import Image, ImageDraw, ImageFont
from collections import defaultdict
from functools import lru_cache
# ---------- Config ----------
DEFAULT_MODEL = os.getenv("MODEL_ID", "google/owlvit-large-patch14")
MODEL_CHOICES = [
"google/owlvit-large-patch14", # default
"google/owlv2-large-patch14-ensemble",
"google/owlv2-base-patch16-ensemble",
"google/owlvit-base-patch32",
]
# Candidate labels include toys/plush/lego/figurines/cartoons/etc.
TIGER_SYNS = [
"tiger", "tiger cub", "tiger toy", "toy tiger", "plush tiger",
"stuffed tiger", "stuffed animal tiger", "lego tiger", "tiger figurine",
"tiger statue", "cartoon tiger", "tiger drawing"
]
LION_SYNS = [
"lion", "lioness", "lion cub", "lion toy", "toy lion", "plush lion",
"stuffed lion", "stuffed animal lion", "lego lion", "lion figurine",
"lion statue", "cartoon lion", "lion drawing"
]
CANDIDATE_LABELS = TIGER_SYNS + LION_SYNS
COLOR_BY_LABEL = {"tiger": "red", "lion": "blue"}
# ---------- Utils ----------
def canonicalize(label: str):
l = label.lower()
if "tiger" in l:
return "tiger"
if "lion" in l:
return "lion"
return None
def iou(box_a, box_b):
xA = max(box_a["xmin"], box_b["xmin"])
yA = max(box_a["ymin"], box_b["ymin"])
xB = min(box_a["xmax"], box_b["xmax"])
yB = min(box_a["ymax"], box_b["ymax"])
inter_w = max(0.0, xB - xA)
inter_h = max(0.0, yB - yA)
inter = inter_w * inter_h
area_a = (box_a["xmax"] - box_a["xmin"]) * (box_a["ymax"] - box_a["ymin"])
area_b = (box_b["xmax"] - box_b["xmin"]) * (box_b["ymax"] - box_b["ymin"])
denom = area_a + area_b - inter + 1e-9
return inter / denom
def nms_single_class(dets, iou_thresh=0.5):
dets = sorted(dets, key=lambda d: d["score"], reverse=True)
kept = []
while dets:
best = dets.pop(0)
kept.append(best)
dets = [d for d in dets if iou(best["box"], d["box"]) < iou_thresh]
return kept
def class_aware_nms(dets, iou_thresh=0.5):
# NMS per class so synonyms don't suppress each other across classes
by_label = defaultdict(list)
for d in dets:
by_label[d["label"].lower()].append(d)
merged = []
for per_class in by_label.values():
merged.extend(nms_single_class(per_class, iou_thresh=iou_thresh))
return merged
def annotate(img, dets):
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("DejaVuSans.ttf", 14)
except:
font = None
for d in dets:
b = d["box"]
color = COLOR_BY_LABEL.get(d["label"], "red")
draw.rectangle([(b["xmin"], b["ymin"]), (b["xmax"], b["ymax"])], outline=color, width=3)
txt = f"{d['label']} {d['score']:.2f}"
try:
txt_w = draw.textlength(txt, font=font)
except AttributeError:
txt_w = 8 * len(txt)
pad = 3
top = max(0, b["ymin"] - 18)
draw.rectangle([(b["xmin"], top), (b["xmin"] + txt_w + 2 * pad, top + 18)], fill=color)
draw.text((b["xmin"] + pad, top + 2), txt, fill="white", font=font)
return img
@lru_cache(maxsize=4)
def get_detector(model_id: str):
return pipeline(
"zero-shot-object-detection",
model=model_id,
device=0 if torch.cuda.is_available() else -1,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)
# ---------- Inference ----------
def count_big_cats(img, score_threshold, iou_threshold, model_id):
if img is None:
return 0, 0, 0, None
# Keep memory in check for huge uploads
if img.width * img.height > 4_000_000: # ~4MP
img = img.copy()
img.thumbnail((2048, 2048))
detector = get_detector(model_id)
raw = detector(img, candidate_labels=CANDIDATE_LABELS)
# 1) Filter by score and canonicalize labels to {"tiger","lion"}
preds = []
for p in raw:
if p["score"] < score_threshold:
continue
canon = canonicalize(p["label"])
if canon is None:
continue
q = dict(p)
q["label"] = canon # overwrite with canonical
preds.append(q)
# 2) NMS per canonical class (avoids double-counting synonyms like "toy tiger")
preds = class_aware_nms(preds, iou_thresh=iou_threshold)
tiger_count = sum(1 for p in preds if p["label"] == "tiger")
lion_count = sum(1 for p in preds if p["label"] == "lion")
total_count = tiger_count + lion_count
img_annotated = annotate(img.copy(), preds)
return tiger_count, lion_count, total_count, img_annotated
# ---------- Demo ----------
TEST_IMAGES = {
"Tigers": "examples/tiger1.png",
"More Tigers": "examples/tiger2.png",
"Funny Tigers": "examples/tiger3.png",
"Lions": "examples/tigers_and_lions_2.png",
}
def load_test_image(choice):
return Image.open(TEST_IMAGES[choice])
# Default dropdown value (env override supported)
default_choice = DEFAULT_MODEL if DEFAULT_MODEL in MODEL_CHOICES else MODEL_CHOICES[0]
with gr.Blocks(title="Big Cat Counter") as demo:
gr.Markdown("# 🐯🦁 Big Cat Counter\nUpload an image and I’ll count how many **tigers** and **lions** I see (including toys, plush, LEGO, etc.).")
with gr.Row():
with gr.Column():
inp = gr.Image(type="pil", label="Input image")
test_selector = gr.Dropdown(list(TEST_IMAGES.keys()), label="Pick a test image")
model_dd = gr.Dropdown(MODEL_CHOICES, value=default_choice, label="Model")
score_th = gr.Slider(0.05, 0.95, value=0.20, step=0.05, label="Score threshold")
iou_th = gr.Slider(0.1, 0.9, value=0.50, step=0.05, label="IOU (NMS) threshold")
btn = gr.Button("Count Big Cats")
with gr.Column():
out_tiger = gr.Number(label="Tiger count", precision=0)
out_lion = gr.Number(label="Lion count", precision=0)
out_total = gr.Number(label="Total big cats", precision=0)
out_img = gr.Image(label="Annotated output")
test_selector.change(fn=load_test_image, inputs=test_selector, outputs=inp)
btn.click(
fn=count_big_cats,
inputs=[inp, score_th, iou_th, model_dd],
outputs=[out_tiger, out_lion, out_total, out_img],
)
if __name__ == "__main__":
demo.launch()