Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import gradio as gr | |
| import supervision as sv | |
| from scipy.ndimage import binary_fill_holes | |
| from ultralytics import YOLOE | |
| from ultralytics.utils.torch_utils import smart_inference_mode | |
| from ultralytics.models.yolo.yoloe.predict_vp import YOLOEVPSegPredictor | |
| from gradio_image_prompter import ImagePrompter | |
| from huggingface_hub import hf_hub_download | |
| import spaces | |
| def init_model(model_id, is_pf=False): | |
| filename = f"{model_id}-seg.pt" if not is_pf else f"{model_id}-seg-pf.pt" | |
| path = hf_hub_download(repo_id="jameslahm/yoloe", filename=filename) | |
| model = YOLOE(path) | |
| model.eval() | |
| model.to("cuda") | |
| return model | |
| def yoloe_inference(image, prompts, target_image, model_id, image_size, conf_thresh, iou_thresh, prompt_type): | |
| model = init_model(model_id) | |
| kwargs = {} | |
| if prompt_type == "Text": | |
| texts = prompts["texts"] | |
| model.set_classes(texts, model.get_text_pe(texts)) | |
| elif prompt_type == "Visual": | |
| kwargs = dict( | |
| prompts=prompts, | |
| predictor=YOLOEVPSegPredictor | |
| ) | |
| if target_image: | |
| model.predict(source=image, imgsz=image_size, conf=conf_thresh, iou=iou_thresh, return_vpe=True, **kwargs) | |
| model.set_classes(["object0"], model.predictor.vpe) | |
| model.predictor = None # unset VPPredictor | |
| image = target_image | |
| kwargs = {} | |
| elif prompt_type == "Prompt-free": | |
| vocab = model.get_vocab(prompts["texts"]) | |
| model = init_model(model_id, is_pf=True) | |
| model.set_vocab(vocab, names=prompts["texts"]) | |
| model.model.model[-1].is_fused = True | |
| model.model.model[-1].conf = 0.001 | |
| model.model.model[-1].max_det = 1000 | |
| results = model.predict(source=image, imgsz=image_size, conf=conf_thresh, iou=iou_thresh, **kwargs) | |
| detections = sv.Detections.from_ultralytics(results[0]) | |
| resolution_wh = image.size | |
| thickness = sv.calculate_optimal_line_thickness(resolution_wh=resolution_wh) | |
| text_scale = sv.calculate_optimal_text_scale(resolution_wh=resolution_wh) | |
| labels = [ | |
| f"{class_name} {confidence:.2f}" | |
| for class_name, confidence | |
| in zip(detections['class_name'], detections.confidence) | |
| ] | |
| annotated_image = image.copy() | |
| annotated_image = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX, opacity=0.4).annotate( | |
| scene=annotated_image, detections=detections) | |
| annotated_image = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=thickness).annotate( | |
| scene=annotated_image, detections=detections) | |
| annotated_image = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX, text_scale=text_scale, smart_position=True).annotate( | |
| scene=annotated_image, detections=detections, labels=labels) | |
| return annotated_image | |
| def app(): | |
| with gr.Blocks(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| raw_image = gr.Image(type="pil", label="Image", visible=True, interactive=True, height=600) | |
| box_image = ImagePrompter(type="pil", label="DrawBox", visible=False, interactive=True) | |
| mask_image = gr.ImageEditor(type="pil", label="DrawMask", visible=False, interactive=True, layers=False, canvas_size=(640, 640)) | |
| target_image = gr.Image(type="pil", label="Target Image", visible=False, interactive=True, height=600) | |
| yoloe_infer = gr.Button(value="Detect & Segment Objects") | |
| prompt_type = gr.Textbox(value="Text", visible=False) | |
| with gr.Tab("Text") as text_tab: | |
| texts = gr.Textbox(label="Input Texts", value='person,bus', placeholder='person,bus', visible=True, interactive=True) | |
| with gr.Tab("Visual") as visual_tab: | |
| with gr.Row(): | |
| visual_prompt_type = gr.Dropdown(choices=["bboxes", "masks"], value="bboxes", label="Visual Type", interactive=True) | |
| visual_usage_type = gr.Radio(choices=["Intra-Image", "Cross-Image"], value="Intra-Image", label="Intra/Cross Image", interactive=True) | |
| with gr.Tab("Prompt-Free") as prompt_free_tab: | |
| gr.HTML( | |
| """ | |
| <p style='text-align: center'> | |
| <b>Prompt-Free Mode is On</b> | |
| </p> | |
| """, show_label=False) | |
| model_id = gr.Dropdown( | |
| label="Model", | |
| choices=[ | |
| "yoloe-v8s", | |
| "yoloe-v8m", | |
| "yoloe-v8l", | |
| "yoloe-11s", | |
| "yoloe-11m", | |
| "yoloe-11l", | |
| ], | |
| value="yoloe-v8l", | |
| ) | |
| image_size = gr.Slider( | |
| label="Image Size", | |
| minimum=320, | |
| maximum=1280, | |
| step=32, | |
| value=640, | |
| ) | |
| conf_thresh = gr.Slider( | |
| label="Confidence Threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.25, | |
| ) | |
| iou_thresh = gr.Slider( | |
| label="IoU Threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.70, | |
| ) | |
| with gr.Column(): | |
| output_image = gr.Image(type="numpy", label="Annotated Image", visible=True, height=600) | |
| def update_text_image_visibility(): | |
| return gr.update(value="Text"), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) | |
| def update_visual_image_visiblity(visual_prompt_type, visual_usage_type): | |
| if visual_prompt_type == "bboxes": | |
| return gr.update(value="Visual"), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=(visual_usage_type == "Cross-Image")) | |
| elif visual_prompt_type == "masks": | |
| return gr.update(value="Visual"), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=(visual_usage_type == "Cross-Image")) | |
| def update_pf_image_visibility(): | |
| return gr.update(value="Prompt-free"), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) | |
| text_tab.select( | |
| fn=update_text_image_visibility, | |
| inputs=None, | |
| outputs=[prompt_type, raw_image, box_image, mask_image, target_image] | |
| ) | |
| visual_tab.select( | |
| fn=update_visual_image_visiblity, | |
| inputs=[visual_prompt_type, visual_usage_type], | |
| outputs=[prompt_type, raw_image, box_image, mask_image, target_image] | |
| ) | |
| prompt_free_tab.select( | |
| fn=update_pf_image_visibility, | |
| inputs=None, | |
| outputs=[prompt_type, raw_image, box_image, mask_image, target_image] | |
| ) | |
| def update_visual_prompt_type(visual_prompt_type): | |
| if visual_prompt_type == "bboxes": | |
| return gr.update(visible=True), gr.update(visible=False) | |
| if visual_prompt_type == "masks": | |
| return gr.update(visible=False), gr.update(visible=True) | |
| return gr.update(visible=False), gr.update(visible=False) | |
| def update_visual_usage_type(visual_usage_type): | |
| if visual_usage_type == "Intra-Image": | |
| return gr.update(visible=False) | |
| if visual_usage_type == "Cross-Image": | |
| return gr.update(visible=True) | |
| return gr.update(visible=False) | |
| visual_prompt_type.change( | |
| fn=update_visual_prompt_type, | |
| inputs=[visual_prompt_type], | |
| outputs=[box_image, mask_image] | |
| ) | |
| visual_usage_type.change( | |
| fn=update_visual_usage_type, | |
| inputs=[visual_usage_type], | |
| outputs=[target_image] | |
| ) | |
| def run_inference(raw_image, box_image, mask_image, target_image, texts, model_id, image_size, conf_thresh, iou_thresh, prompt_type, visual_prompt_type, visual_usage_type): | |
| # add text/built-in prompts | |
| if prompt_type == "Text" or prompt_type == "Prompt-free": | |
| target_image = None | |
| image = raw_image | |
| if prompt_type == "Prompt-free": | |
| with open('tools/ram_tag_list.txt', 'r') as f: | |
| texts = [x.strip() for x in f.readlines()] | |
| else: | |
| texts = [text.strip() for text in texts.split(',')] | |
| prompts = { | |
| "texts": texts | |
| } | |
| # add visual prompt | |
| elif prompt_type == "Visual": | |
| if visual_usage_type != "Cross-Image": | |
| target_image = None | |
| if visual_prompt_type == "bboxes": | |
| image, points = box_image["image"], box_image["points"] | |
| points = np.array(points) | |
| if len(points) == 0: | |
| gr.Warning("No boxes are provided. No image output.", visible=True) | |
| return gr.update(value=None) | |
| bboxes = np.array([p[[0, 1, 3, 4]] for p in points if p[2] == 2]) | |
| prompts = { | |
| "bboxes": bboxes, | |
| "cls": np.array([0] * len(bboxes)) | |
| } | |
| elif visual_prompt_type == "masks": | |
| image, masks = mask_image["background"], mask_image["layers"][0] | |
| # image = image.convert("RGB") | |
| masks = np.array(masks.convert("L")) | |
| masks = binary_fill_holes(masks).astype(np.uint8) | |
| masks[masks > 0] = 1 | |
| if masks.sum() == 0: | |
| gr.Warning("No masks are provided. No image output.", visible=True) | |
| return gr.update(value=None) | |
| prompts = { | |
| "masks": masks[None], | |
| "cls": np.array([0]) | |
| } | |
| return yoloe_inference(image, prompts, target_image, model_id, image_size, conf_thresh, iou_thresh, prompt_type) | |
| run_inference.zerogpu = True | |
| yoloe_infer.click( | |
| fn=run_inference, | |
| inputs=[raw_image, box_image, mask_image, target_image, texts, model_id, image_size, conf_thresh, iou_thresh, prompt_type, visual_prompt_type, visual_usage_type], | |
| outputs=[output_image], | |
| ) | |
| ###################### Examples ########################## | |
| text_examples = gr.Examples( | |
| examples=[[ | |
| "ultralytics/assets/bus.jpg", | |
| "person,bus", | |
| "yoloe-v8l", | |
| 640, | |
| 0.25, | |
| 0.7]], | |
| inputs=[raw_image, texts, model_id, image_size, conf_thresh, iou_thresh], | |
| visible=True, cache_examples=False, label="Text Prompt Examples") | |
| box_examples = gr.Examples( | |
| examples=[[ | |
| {"image": "ultralytics/assets/bus_box.jpg", "points": [[235, 408, 2, 342, 863, 3]]}, | |
| "ultralytics/assets/zidane.jpg", | |
| "yoloe-v8l", | |
| 640, | |
| 0.2, | |
| 0.7, | |
| ]], | |
| inputs=[box_image, target_image, model_id, image_size, conf_thresh, iou_thresh], | |
| visible=False, cache_examples=False, label="Box Visual Prompt Examples") | |
| mask_examples = gr.Examples( | |
| examples=[[ | |
| {"background": "ultralytics/assets/bus.jpg", "layers": ["ultralytics/assets/bus_mask.png"], "composite": "ultralytics/assets/bus_composite.jpg"}, | |
| "ultralytics/assets/zidane.jpg", | |
| "yoloe-v8l", | |
| 640, | |
| 0.15, | |
| 0.7, | |
| ]], | |
| inputs=[mask_image, target_image, model_id, image_size, conf_thresh, iou_thresh], | |
| visible=False, cache_examples=False, label="Mask Visual Prompt Examples") | |
| pf_examples = gr.Examples( | |
| examples=[[ | |
| "ultralytics/assets/bus.jpg", | |
| "yoloe-v8l", | |
| 640, | |
| 0.25, | |
| 0.7, | |
| ]], | |
| inputs=[raw_image, model_id, image_size, conf_thresh, iou_thresh], | |
| visible=False, cache_examples=False, label="Prompt-free Examples") | |
| # Components update | |
| def load_box_example(visual_usage_type): | |
| return (gr.update(visible=True, value={"image": "ultralytics/assets/bus_box.jpg", "points": [[235, 408, 2, 342, 863, 3]]}), | |
| gr.update(visible=(visual_usage_type=="Cross-Image"))) | |
| def load_mask_example(visual_usage_type): | |
| return gr.update(visible=True), gr.update(visible=(visual_usage_type=="Cross-Image")) | |
| box_examples.load_input_event.then( | |
| fn=load_box_example, | |
| inputs=visual_usage_type, | |
| outputs=[box_image, target_image] | |
| ) | |
| mask_examples.load_input_event.then( | |
| fn=load_mask_example, | |
| inputs=visual_usage_type, | |
| outputs=[mask_image, target_image] | |
| ) | |
| # Examples update | |
| def update_text_examples(): | |
| return gr.Dataset(visible=True), gr.Dataset(visible=False), gr.Dataset(visible=False), gr.Dataset(visible=False) | |
| def update_pf_examples(): | |
| return gr.Dataset(visible=False), gr.Dataset(visible=False), gr.Dataset(visible=False), gr.Dataset(visible=True) | |
| def update_visual_examples(visual_prompt_type): | |
| if visual_prompt_type == "bboxes": | |
| return gr.Dataset(visible=False), gr.Dataset(visible=True), gr.Dataset(visible=False), gr.Dataset(visible=False), | |
| elif visual_prompt_type == "masks": | |
| return gr.Dataset(visible=False), gr.Dataset(visible=False), gr.Dataset(visible=True), gr.Dataset(visible=False), | |
| text_tab.select( | |
| fn=update_text_examples, | |
| inputs=None, | |
| outputs=[text_examples.dataset, box_examples.dataset, mask_examples.dataset, pf_examples.dataset] | |
| ) | |
| visual_tab.select( | |
| fn=update_visual_examples, | |
| inputs=[visual_prompt_type], | |
| outputs=[text_examples.dataset, box_examples.dataset, mask_examples.dataset, pf_examples.dataset] | |
| ) | |
| prompt_free_tab.select( | |
| fn=update_pf_examples, | |
| inputs=None, | |
| outputs=[text_examples.dataset, box_examples.dataset, mask_examples.dataset, pf_examples.dataset] | |
| ) | |
| visual_prompt_type.change( | |
| fn=update_visual_examples, | |
| inputs=[visual_prompt_type], | |
| outputs=[text_examples.dataset, box_examples.dataset, mask_examples.dataset, pf_examples.dataset] | |
| ) | |
| visual_usage_type.change( | |
| fn=update_visual_examples, | |
| inputs=[visual_prompt_type], | |
| outputs=[text_examples.dataset, box_examples.dataset, mask_examples.dataset, pf_examples.dataset] | |
| ) | |
| gradio_app = gr.Blocks() | |
| with gradio_app: | |
| gr.HTML( | |
| """ | |
| <h1 style='text-align: center'> | |
| <img src="/file=figures/logo.png" width="2.5%" style="display:inline;padding-bottom:4px"> | |
| YOLOE: Real-Time Seeing Anything | |
| </h1> | |
| """) | |
| gr.HTML( | |
| """ | |
| <h3 style='text-align: center'> | |
| <a href='https://arxiv.org/abs/2503.07465' target='_blank'>arXiv</a> | <a href='https://github.com/THU-MIG/yoloe' target='_blank'>github</a> | |
| </h3> | |
| """) | |
| gr.Markdown( | |
| """ | |
| We introduce **YOLOE(ye)**, a highly **efficient**, **unified**, and **open** object detection and segmentation model, like human eye, under different prompt mechanisms, like *texts*, *visual inputs*, and *prompt-free paradigm*. | |
| """ | |
| ) | |
| gr.Markdown( | |
| """ | |
| If desired objects are not identified, pleaset set a **smaller** confidence threshold, e.g., for visual prompts with handcrafted shape or cross-image prompts. | |
| """ | |
| ) | |
| gr.Markdown( | |
| """ | |
| Drawing **multiple** boxes or handcrafted shapes as visual prompt in an image is also supported, which leads to more accurate prompt. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| app() | |
| if __name__ == '__main__': | |
| gradio_app.launch(allowed_paths=["figures"]) | |