Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import spaces | |
| import sahi.utils | |
| from sahi import AutoDetectionModel | |
| import sahi.predict | |
| import sahi.slicing | |
| from PIL import Image | |
| import numpy | |
| from ultralytics import YOLO | |
| import sys | |
| import types | |
| if 'huggingface_hub.utils._errors' not in sys.modules: | |
| mock_errors = types.ModuleType('_errors') | |
| mock_errors.RepositoryNotFoundError = Exception | |
| sys.modules['huggingface_hub.utils._errors'] = mock_errors | |
| IMAGE_SIZE = 640 | |
| # Images | |
| sahi.utils.file.download_from_url( | |
| "https://user-images.githubusercontent.com/34196005/142730935-2ace3999-a47b-49bb-83e0-2bdd509f1c90.jpg", | |
| "apple_tree.jpg", | |
| ) | |
| sahi.utils.file.download_from_url( | |
| "https://user-images.githubusercontent.com/34196005/142730936-1b397756-52e5-43be-a949-42ec0134d5d8.jpg", | |
| "highway.jpg", | |
| ) | |
| sahi.utils.file.download_from_url( | |
| "https://user-images.githubusercontent.com/34196005/142742871-bf485f84-0355-43a3-be86-96b44e63c3a2.jpg", | |
| "highway2.jpg", | |
| ) | |
| sahi.utils.file.download_from_url( | |
| "https://user-images.githubusercontent.com/34196005/142742872-1fefcc4d-d7e6-4c43-bbb7-6b5982f7e4ba.jpg", | |
| "highway3.jpg", | |
| ) | |
| # Global model variable | |
| model = None | |
| def load_yolo_model(model_name, confidence_threshold=0.5): | |
| """ | |
| Loads a YOLOv11 detection model. | |
| Args: | |
| model_name (str): The name of the YOLOv11 model to load (e.g., "yolo11n.pt"). | |
| confidence_threshold (float): The confidence threshold for object detection. | |
| Returns: | |
| AutoDetectionModel: The loaded SAHI AutoDetectionModel. | |
| """ | |
| global model | |
| model_path = model_name | |
| model = AutoDetectionModel.from_pretrained( | |
| model_type="ultralytics", model_path=model_path, device=None, # auto device selection | |
| confidence_threshold=confidence_threshold, image_size=IMAGE_SIZE | |
| ) | |
| return model | |
| def sahi_yolo_inference( | |
| image, | |
| yolo_model_name, | |
| confidence_threshold, | |
| max_detections, | |
| slice_height=512, | |
| slice_width=512, | |
| overlap_height_ratio=0.2, | |
| overlap_width_ratio=0.2, | |
| postprocess_type="NMS", | |
| postprocess_match_metric="IOU", | |
| postprocess_match_threshold=0.5, | |
| postprocess_class_agnostic=False, | |
| ): | |
| """ | |
| Performs object detection using SAHI with a specified YOLOv11 model. | |
| Args: | |
| image (PIL.Image.Image): The input image for detection. | |
| yolo_model_name (str): The name of the YOLOv11 model to use for inference. | |
| confidence_threshold (float): The confidence threshold for object detection. | |
| max_detections (int): The maximum number of detections to return. | |
| slice_height (int): The height of each slice for sliced inference. | |
| slice_width (int): The width of each slice for sliced inference. | |
| overlap_height_ratio (float): The overlap ratio for slice height. | |
| overlap_width_ratio (float): The overlap ratio for slice width. | |
| postprocess_type (str): The type of postprocessing to apply ("NMS" or "GREEDYNMM"). | |
| postprocess_match_metric (str): The metric for postprocessing matching ("IOU" or "IOS"). | |
| postprocess_match_threshold (float): The threshold for postprocessing matching. | |
| postprocess_class_agnostic (bool): Whether postprocessing should be class agnostic. | |
| Returns: | |
| tuple: A tuple containing two PIL.Image.Image objects: | |
| - The image with standard YOLO inference results. | |
| - The image with SAHI sliced YOLO inference results. | |
| """ | |
| load_yolo_model(yolo_model_name, confidence_threshold) | |
| image_width, image_height = image.size | |
| sliced_bboxes = sahi.slicing.get_slice_bboxes( | |
| image_height, | |
| image_width, | |
| slice_height, | |
| slice_width, | |
| False, | |
| overlap_height_ratio, | |
| overlap_width_ratio, | |
| ) | |
| if len(sliced_bboxes) > 60: | |
| raise ValueError( | |
| f"{len(sliced_bboxes)} slices are too much for huggingface spaces, try smaller slice size." | |
| ) | |
| # Standard inference | |
| prediction_result_1 = sahi.predict.get_prediction( | |
| image=image, detection_model=model, | |
| ) | |
| # Filter by max_detections for standard inference | |
| if max_detections is not None and len(prediction_result_1.object_prediction_list) > max_detections: | |
| prediction_result_1.object_prediction_list = sorted( | |
| prediction_result_1.object_prediction_list, key=lambda x: x.score.value, reverse=True | |
| )[:max_detections] | |
| visual_result_1 = sahi.utils.cv.visualize_object_predictions( | |
| image=numpy.array(image), | |
| object_prediction_list=prediction_result_1.object_prediction_list, | |
| ) | |
| output_1 = Image.fromarray(visual_result_1["image"]) | |
| # Sliced inference | |
| prediction_result_2 = sahi.predict.get_sliced_prediction( | |
| image=image, | |
| detection_model=model, | |
| slice_height=int(slice_height), | |
| slice_width=int(slice_width), | |
| overlap_height_ratio=overlap_height_ratio, | |
| overlap_width_ratio=overlap_width_ratio, | |
| postprocess_type=postprocess_type, | |
| postprocess_match_metric=postprocess_match_metric, | |
| postprocess_match_threshold=postprocess_match_threshold, | |
| postprocess_class_agnostic=postprocess_class_agnostic, | |
| ) | |
| # Filter by max_detections for sliced inference | |
| if max_detections is not None and len(prediction_result_2.object_prediction_list) > max_detections: | |
| prediction_result_2.object_prediction_list = sorted( | |
| prediction_result_2.object_prediction_list, key=lambda x: x.score.value, reverse=True | |
| )[:max_detections] | |
| visual_result_2 = sahi.utils.cv.visualize_object_predictions( | |
| image=numpy.array(image), | |
| object_prediction_list=prediction_result_2.object_prediction_list, | |
| ) | |
| output_2 = Image.fromarray(visual_result_2["image"]) | |
| return output_1, output_2 | |
| with gr.Blocks() as app: | |
| gr.Markdown("# Small Object Detection with SAHI + YOLOv11") | |
| gr.Markdown( | |
| "SAHI + YOLOv11 demo for small object detection. " | |
| "Upload your own image or click an example image to use." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| original_image_input = gr.Image(type="pil", label="Original Image") | |
| yolo_model_dropdown = gr.Dropdown( | |
| choices=["yolo11n.pt", "yolo11s.pt", "yolo11m.pt", "yolo11l.pt", "yolo11x.pt"], | |
| value="yolo11s.pt", | |
| label="YOLOv11 Model", | |
| ) | |
| confidence_threshold_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.01, | |
| value=0.5, | |
| label="Confidence Threshold", | |
| ) | |
| max_detections_slider = gr.Slider( | |
| minimum=1, | |
| maximum=500, | |
| step=1, | |
| value=300, | |
| label="Max Detections", | |
| ) | |
| slice_height_input = gr.Number(value=512, label="Slice Height") | |
| slice_width_input = gr.Number(value=512, label="Slice Width") | |
| overlap_height_ratio_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.01, | |
| value=0.2, | |
| label="Overlap Height Ratio", | |
| ) | |
| overlap_width_ratio_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.01, | |
| value=0.2, | |
| label="Overlap Width Ratio", | |
| ) | |
| postprocess_type_dropdown = gr.Dropdown( | |
| ["NMS", "GREEDYNMM"], | |
| type="value", | |
| value="NMS", | |
| label="Postprocess Type", | |
| ) | |
| postprocess_match_metric_dropdown = gr.Dropdown( | |
| ["IOU", "IOS"], type="value", value="IOU", label="Postprocess Match Metric" | |
| ) | |
| postprocess_match_threshold_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.01, | |
| value=0.5, | |
| label="Postprocess Match Threshold", | |
| ) | |
| postprocess_class_agnostic_checkbox = gr.Checkbox(value=True, label="Postprocess Class Agnostic") | |
| submit_button = gr.Button("Run Inference") | |
| with gr.Column(): | |
| output_standard = gr.Image(type="pil", label="YOLOv11 Standard") | |
| output_sahi_sliced = gr.Image(type="pil", label="YOLOv11 + SAHI Sliced") | |
| gr.Examples( | |
| examples=[ | |
| ["apple_tree.jpg", "yolo11s.pt", 0.5, 300, 256, 256, 0.2, 0.2, "NMS", "IOU", 0.4, True], | |
| ["highway.jpg", "yolo11s.pt", 0.5, 300, 256, 256, 0.2, 0.2, "NMS", "IOU", 0.4, True], | |
| ["highway2.jpg", "yolo11s.pt", 0.5, 300, 512, 512, 0.2, 0.2, "NMS", "IOU", 0.4, True], | |
| ["highway3.jpg", "yolo11s.pt", 0.5, 300, 512, 512, 0.2, 0.2, "NMS", "IOU", 0.4, True], | |
| ], | |
| inputs=[ | |
| original_image_input, | |
| yolo_model_dropdown, | |
| confidence_threshold_slider, | |
| max_detections_slider, | |
| slice_height_input, | |
| slice_width_input, | |
| overlap_height_ratio_slider, | |
| overlap_width_ratio_slider, | |
| postprocess_type_dropdown, | |
| postprocess_match_metric_dropdown, | |
| postprocess_match_threshold_slider, | |
| postprocess_class_agnostic_checkbox, | |
| ], | |
| outputs=[output_standard, output_sahi_sliced], | |
| fn=sahi_yolo_inference, | |
| cache_examples=True, | |
| ) | |
| submit_button.click( | |
| fn=sahi_yolo_inference, | |
| inputs=[ | |
| original_image_input, | |
| yolo_model_dropdown, | |
| confidence_threshold_slider, | |
| max_detections_slider, | |
| slice_height_input, | |
| slice_width_input, | |
| overlap_height_ratio_slider, | |
| overlap_width_ratio_slider, | |
| postprocess_type_dropdown, | |
| postprocess_match_metric_dropdown, | |
| postprocess_match_threshold_slider, | |
| postprocess_class_agnostic_checkbox, | |
| ], | |
| outputs=[output_standard, output_sahi_sliced], | |
| ) | |
| app.launch(mcp_server=True) |