Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	
		Sergey Kolbin
		
	commited on
		
		
					Commit 
							
							·
						
						d384e64
	
1
								Parent(s):
							
							28dc509
								
mmma
Browse files- app.py +92 -41
- requirements.txt +2 -1
    	
        app.py
    CHANGED
    
    | @@ -1,24 +1,44 @@ | |
|  | |
|  | |
| 1 | 
             
            import gradio as gr
         | 
| 2 | 
             
            from transformers import pipeline
         | 
| 3 | 
             
            from PIL import Image, ImageDraw, ImageFont
         | 
| 4 | 
             
            from collections import defaultdict
         | 
| 5 | 
            -
             | 
| 6 | 
            -
             | 
| 7 | 
            -
            # | 
| 8 | 
            -
             | 
| 9 | 
            -
             | 
| 10 | 
            -
             | 
| 11 | 
            -
                 | 
| 12 | 
            -
             | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
             | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
            -
                "tiger" | 
| 20 | 
            -
                " | 
| 21 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 22 |  | 
| 23 | 
             
            def iou(box_a, box_b):
         | 
| 24 | 
             
                xA = max(box_a["xmin"], box_b["xmin"])
         | 
| @@ -34,7 +54,6 @@ def iou(box_a, box_b): | |
| 34 | 
             
                return inter / denom
         | 
| 35 |  | 
| 36 | 
             
            def nms_single_class(dets, iou_thresh=0.5):
         | 
| 37 | 
            -
                # dets: list of dicts with keys {"box": {...}, "score": float, "label": str}
         | 
| 38 | 
             
                dets = sorted(dets, key=lambda d: d["score"], reverse=True)
         | 
| 39 | 
             
                kept = []
         | 
| 40 | 
             
                while dets:
         | 
| @@ -44,12 +63,12 @@ def nms_single_class(dets, iou_thresh=0.5): | |
| 44 | 
             
                return kept
         | 
| 45 |  | 
| 46 | 
             
            def class_aware_nms(dets, iou_thresh=0.5):
         | 
| 47 | 
            -
                #  | 
| 48 | 
             
                by_label = defaultdict(list)
         | 
| 49 | 
             
                for d in dets:
         | 
| 50 | 
             
                    by_label[d["label"].lower()].append(d)
         | 
| 51 | 
             
                merged = []
         | 
| 52 | 
            -
                for  | 
| 53 | 
             
                    merged.extend(nms_single_class(per_class, iou_thresh=iou_thresh))
         | 
| 54 | 
             
                return merged
         | 
| 55 |  | 
| @@ -61,41 +80,65 @@ def annotate(img, dets): | |
| 61 | 
             
                    font = None
         | 
| 62 | 
             
                for d in dets:
         | 
| 63 | 
             
                    b = d["box"]
         | 
| 64 | 
            -
                    color = COLOR_BY_LABEL.get(d["label"] | 
| 65 | 
             
                    draw.rectangle([(b["xmin"], b["ymin"]), (b["xmax"], b["ymax"])], outline=color, width=3)
         | 
| 66 | 
             
                    txt = f"{d['label']} {d['score']:.2f}"
         | 
| 67 | 
            -
                    # Estimate text width
         | 
| 68 | 
             
                    try:
         | 
| 69 | 
             
                        txt_w = draw.textlength(txt, font=font)
         | 
| 70 | 
             
                    except AttributeError:
         | 
| 71 | 
             
                        txt_w = 8 * len(txt)
         | 
| 72 | 
             
                    pad = 3
         | 
| 73 | 
            -
                     | 
| 74 | 
            -
             | 
| 75 | 
            -
             | 
| 76 | 
            -
                    )
         | 
| 77 | 
            -
                    draw.text((b["xmin"] + pad, b["ymin"] - 16), txt, fill="white", font=font)
         | 
| 78 | 
             
                return img
         | 
| 79 |  | 
| 80 | 
            -
             | 
| 81 | 
            -
             | 
| 82 | 
            -
                 | 
| 83 | 
            -
             | 
| 84 | 
            -
             | 
| 85 | 
            -
             | 
| 86 | 
            -
             | 
| 87 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 88 | 
             
                preds = class_aware_nms(preds, iou_thresh=iou_threshold)
         | 
| 89 |  | 
| 90 | 
            -
                 | 
| 91 | 
            -
                 | 
| 92 | 
            -
                lion_count = sum(1 for p in preds if p["label"].lower() == "lion")
         | 
| 93 | 
             
                total_count = tiger_count + lion_count
         | 
| 94 |  | 
| 95 | 
            -
                # 6) Draw boxes
         | 
| 96 | 
             
                img_annotated = annotate(img.copy(), preds)
         | 
| 97 | 
             
                return tiger_count, lion_count, total_count, img_annotated
         | 
| 98 |  | 
|  | |
| 99 | 
             
            TEST_IMAGES = {
         | 
| 100 | 
             
                "Tigers": "examples/tiger1.png",
         | 
| 101 | 
             
                "More Tigers": "examples/tiger2.png",
         | 
| @@ -106,13 +149,16 @@ TEST_IMAGES = { | |
| 106 | 
             
            def load_test_image(choice):
         | 
| 107 | 
             
                return Image.open(TEST_IMAGES[choice])
         | 
| 108 |  | 
|  | |
|  | |
| 109 |  | 
| 110 | 
             
            with gr.Blocks(title="Big Cat Counter") as demo:
         | 
| 111 | 
            -
                gr.Markdown("# 🐯🦁 Big Cat Counter\nUpload an image and I’ll count how many **tigers** and **lions** I see.")
         | 
| 112 | 
             
                with gr.Row():
         | 
| 113 | 
             
                    with gr.Column():
         | 
| 114 | 
             
                        inp = gr.Image(type="pil", label="Input image")
         | 
| 115 | 
             
                        test_selector = gr.Dropdown(list(TEST_IMAGES.keys()), label="Pick a test image")
         | 
|  | |
| 116 | 
             
                        score_th = gr.Slider(0.05, 0.95, value=0.20, step=0.05, label="Score threshold")
         | 
| 117 | 
             
                        iou_th = gr.Slider(0.1, 0.9, value=0.50, step=0.05, label="IOU (NMS) threshold")
         | 
| 118 | 
             
                        btn = gr.Button("Count Big Cats")
         | 
| @@ -121,8 +167,13 @@ with gr.Blocks(title="Big Cat Counter") as demo: | |
| 121 | 
             
                        out_lion = gr.Number(label="Lion count", precision=0)
         | 
| 122 | 
             
                        out_total = gr.Number(label="Total big cats", precision=0)
         | 
| 123 | 
             
                        out_img = gr.Image(label="Annotated output")
         | 
|  | |
| 124 | 
             
                test_selector.change(fn=load_test_image, inputs=test_selector, outputs=inp)
         | 
| 125 | 
            -
                btn.click( | 
|  | |
|  | |
|  | |
|  | |
| 126 |  | 
| 127 | 
             
            if __name__ == "__main__":
         | 
| 128 | 
             
                demo.launch()
         | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
             
            import gradio as gr
         | 
| 4 | 
             
            from transformers import pipeline
         | 
| 5 | 
             
            from PIL import Image, ImageDraw, ImageFont
         | 
| 6 | 
             
            from collections import defaultdict
         | 
| 7 | 
            +
            from functools import lru_cache
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # ---------- Config ----------
         | 
| 10 | 
            +
            DEFAULT_MODEL = os.getenv("MODEL_ID", "google/owlvit-large-patch14")
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            MODEL_CHOICES = [
         | 
| 13 | 
            +
                "google/owlvit-large-patch14",              # default
         | 
| 14 | 
            +
                "google/owlv2-large-patch14-ensemble",
         | 
| 15 | 
            +
                "google/owlv2-base-patch16-ensemble",
         | 
| 16 | 
            +
                "google/owlvit-base-patch32",
         | 
| 17 | 
            +
            ]
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            # Candidate labels include toys/plush/lego/figurines/cartoons/etc.
         | 
| 20 | 
            +
            TIGER_SYNS = [
         | 
| 21 | 
            +
                "tiger", "tiger cub", "tiger toy", "toy tiger", "plush tiger",
         | 
| 22 | 
            +
                "stuffed tiger", "stuffed animal tiger", "lego tiger", "tiger figurine",
         | 
| 23 | 
            +
                "tiger statue", "cartoon tiger", "tiger drawing"
         | 
| 24 | 
            +
            ]
         | 
| 25 | 
            +
            LION_SYNS = [
         | 
| 26 | 
            +
                "lion", "lioness", "lion cub", "lion toy", "toy lion", "plush lion",
         | 
| 27 | 
            +
                "stuffed lion", "stuffed animal lion", "lego lion", "lion figurine",
         | 
| 28 | 
            +
                "lion statue", "cartoon lion", "lion drawing"
         | 
| 29 | 
            +
            ]
         | 
| 30 | 
            +
            CANDIDATE_LABELS = TIGER_SYNS + LION_SYNS
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            COLOR_BY_LABEL = {"tiger": "red", "lion": "blue"}
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            # ---------- Utils ----------
         | 
| 35 | 
            +
            def canonicalize(label: str):
         | 
| 36 | 
            +
                l = label.lower()
         | 
| 37 | 
            +
                if "tiger" in l:
         | 
| 38 | 
            +
                    return "tiger"
         | 
| 39 | 
            +
                if "lion" in l:
         | 
| 40 | 
            +
                    return "lion"
         | 
| 41 | 
            +
                return None
         | 
| 42 |  | 
| 43 | 
             
            def iou(box_a, box_b):
         | 
| 44 | 
             
                xA = max(box_a["xmin"], box_b["xmin"])
         | 
|  | |
| 54 | 
             
                return inter / denom
         | 
| 55 |  | 
| 56 | 
             
            def nms_single_class(dets, iou_thresh=0.5):
         | 
|  | |
| 57 | 
             
                dets = sorted(dets, key=lambda d: d["score"], reverse=True)
         | 
| 58 | 
             
                kept = []
         | 
| 59 | 
             
                while dets:
         | 
|  | |
| 63 | 
             
                return kept
         | 
| 64 |  | 
| 65 | 
             
            def class_aware_nms(dets, iou_thresh=0.5):
         | 
| 66 | 
            +
                # NMS per class so synonyms don't suppress each other across classes
         | 
| 67 | 
             
                by_label = defaultdict(list)
         | 
| 68 | 
             
                for d in dets:
         | 
| 69 | 
             
                    by_label[d["label"].lower()].append(d)
         | 
| 70 | 
             
                merged = []
         | 
| 71 | 
            +
                for per_class in by_label.values():
         | 
| 72 | 
             
                    merged.extend(nms_single_class(per_class, iou_thresh=iou_thresh))
         | 
| 73 | 
             
                return merged
         | 
| 74 |  | 
|  | |
| 80 | 
             
                    font = None
         | 
| 81 | 
             
                for d in dets:
         | 
| 82 | 
             
                    b = d["box"]
         | 
| 83 | 
            +
                    color = COLOR_BY_LABEL.get(d["label"], "red")
         | 
| 84 | 
             
                    draw.rectangle([(b["xmin"], b["ymin"]), (b["xmax"], b["ymax"])], outline=color, width=3)
         | 
| 85 | 
             
                    txt = f"{d['label']} {d['score']:.2f}"
         | 
|  | |
| 86 | 
             
                    try:
         | 
| 87 | 
             
                        txt_w = draw.textlength(txt, font=font)
         | 
| 88 | 
             
                    except AttributeError:
         | 
| 89 | 
             
                        txt_w = 8 * len(txt)
         | 
| 90 | 
             
                    pad = 3
         | 
| 91 | 
            +
                    top = max(0, b["ymin"] - 18)
         | 
| 92 | 
            +
                    draw.rectangle([(b["xmin"], top), (b["xmin"] + txt_w + 2 * pad, top + 18)], fill=color)
         | 
| 93 | 
            +
                    draw.text((b["xmin"] + pad, top + 2), txt, fill="white", font=font)
         | 
|  | |
|  | |
| 94 | 
             
                return img
         | 
| 95 |  | 
| 96 | 
            +
            @lru_cache(maxsize=4)
         | 
| 97 | 
            +
            def get_detector(model_id: str):
         | 
| 98 | 
            +
                return pipeline(
         | 
| 99 | 
            +
                    "zero-shot-object-detection",
         | 
| 100 | 
            +
                    model=model_id,
         | 
| 101 | 
            +
                    device=0 if torch.cuda.is_available() else -1,
         | 
| 102 | 
            +
                    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
         | 
| 103 | 
            +
                )
         | 
| 104 | 
            +
             | 
| 105 | 
            +
            # ---------- Inference ----------
         | 
| 106 | 
            +
            def count_big_cats(img, score_threshold, iou_threshold, model_id):
         | 
| 107 | 
            +
                if img is None:
         | 
| 108 | 
            +
                    return 0, 0, 0, None
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                # Keep memory in check for huge uploads
         | 
| 111 | 
            +
                if img.width * img.height > 4_000_000:  # ~4MP
         | 
| 112 | 
            +
                    img = img.copy()
         | 
| 113 | 
            +
                    img.thumbnail((2048, 2048))
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                detector = get_detector(model_id)
         | 
| 116 | 
            +
                raw = detector(img, candidate_labels=CANDIDATE_LABELS)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                # Canonicalize labels BEFORE NMS so synonyms don't double-count
         | 
| 119 | 
            +
                preds = class_aware_nms(raw, iou_thresh=iou_threshold)
         | 
| 120 | 
            +
                for p in raw:
         | 
| 121 | 
            +
                    if p["score"] < score_threshold:
         | 
| 122 | 
            +
                        continue
         | 
| 123 | 
            +
                    canon = canonicalize(p["label"])
         | 
| 124 | 
            +
                    if canon is None:
         | 
| 125 | 
            +
                        continue
         | 
| 126 | 
            +
                    q = dict(p)
         | 
| 127 | 
            +
                    q["label"] = canon           # overwrite with canonical ('tiger'/'lion')
         | 
| 128 | 
            +
                    preds.append(q)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                # NMS per canonical class
         | 
| 131 | 
            +
                iou_threshold = iou_threshold
         | 
| 132 | 
             
                preds = class_aware_nms(preds, iou_thresh=iou_threshold)
         | 
| 133 |  | 
| 134 | 
            +
                tiger_count = sum(1 for p in preds if p["label"] == "tiger")
         | 
| 135 | 
            +
                lion_count  = sum(1 for p in preds if p["label"] == "lion")
         | 
|  | |
| 136 | 
             
                total_count = tiger_count + lion_count
         | 
| 137 |  | 
|  | |
| 138 | 
             
                img_annotated = annotate(img.copy(), preds)
         | 
| 139 | 
             
                return tiger_count, lion_count, total_count, img_annotated
         | 
| 140 |  | 
| 141 | 
            +
            # ---------- Demo ----------
         | 
| 142 | 
             
            TEST_IMAGES = {
         | 
| 143 | 
             
                "Tigers": "examples/tiger1.png",
         | 
| 144 | 
             
                "More Tigers": "examples/tiger2.png",
         | 
|  | |
| 149 | 
             
            def load_test_image(choice):
         | 
| 150 | 
             
                return Image.open(TEST_IMAGES[choice])
         | 
| 151 |  | 
| 152 | 
            +
            # Default dropdown value (env override supported)
         | 
| 153 | 
            +
            default_choice = DEFAULT_MODEL if DEFAULT_MODEL in MODEL_CHOICES else MODEL_CHOICES[0]
         | 
| 154 |  | 
| 155 | 
             
            with gr.Blocks(title="Big Cat Counter") as demo:
         | 
| 156 | 
            +
                gr.Markdown("# 🐯🦁 Big Cat Counter\nUpload an image and I’ll count how many **tigers** and **lions** I see (including toys, plush, LEGO, etc.).")
         | 
| 157 | 
             
                with gr.Row():
         | 
| 158 | 
             
                    with gr.Column():
         | 
| 159 | 
             
                        inp = gr.Image(type="pil", label="Input image")
         | 
| 160 | 
             
                        test_selector = gr.Dropdown(list(TEST_IMAGES.keys()), label="Pick a test image")
         | 
| 161 | 
            +
                        model_dd = gr.Dropdown(MODEL_CHOICES, value=default_choice, label="Model")
         | 
| 162 | 
             
                        score_th = gr.Slider(0.05, 0.95, value=0.20, step=0.05, label="Score threshold")
         | 
| 163 | 
             
                        iou_th = gr.Slider(0.1, 0.9, value=0.50, step=0.05, label="IOU (NMS) threshold")
         | 
| 164 | 
             
                        btn = gr.Button("Count Big Cats")
         | 
|  | |
| 167 | 
             
                        out_lion = gr.Number(label="Lion count", precision=0)
         | 
| 168 | 
             
                        out_total = gr.Number(label="Total big cats", precision=0)
         | 
| 169 | 
             
                        out_img = gr.Image(label="Annotated output")
         | 
| 170 | 
            +
             | 
| 171 | 
             
                test_selector.change(fn=load_test_image, inputs=test_selector, outputs=inp)
         | 
| 172 | 
            +
                btn.click(
         | 
| 173 | 
            +
                    fn=count_big_cats,
         | 
| 174 | 
            +
                    inputs=[inp, score_th, iou_th, model_dd],
         | 
| 175 | 
            +
                    outputs=[out_tiger, out_lion, out_total, out_img],
         | 
| 176 | 
            +
                )
         | 
| 177 |  | 
| 178 | 
             
            if __name__ == "__main__":
         | 
| 179 | 
             
                demo.launch()
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -1,4 +1,4 @@ | |
| 1 | 
            -
            transformers>=4. | 
| 2 | 
             
            huggingface_hub>=0.23.0
         | 
| 3 | 
             
            torch
         | 
| 4 | 
             
            scipy
         | 
| @@ -6,3 +6,4 @@ gradio>=4.0.0 | |
| 6 | 
             
            pillow
         | 
| 7 | 
             
            safetensors
         | 
| 8 | 
             
            accelerate
         | 
|  | 
|  | |
| 1 | 
            +
            transformers>=4.43
         | 
| 2 | 
             
            huggingface_hub>=0.23.0
         | 
| 3 | 
             
            torch
         | 
| 4 | 
             
            scipy
         | 
|  | |
| 6 | 
             
            pillow
         | 
| 7 | 
             
            safetensors
         | 
| 8 | 
             
            accelerate
         | 
| 9 | 
            +
            numpy
         | 
