Spaces:
Sleeping
Sleeping
| from typing import Any, Dict | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| from gradio_image_annotation import image_annotator | |
| from sam2 import load_model | |
| from sam2.utils.visualization import show_masks | |
| from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| # @spaces.GPU() | |
| def predict(model_choice, annotations: Dict[str, Any]): | |
| sam2_model = load_model( | |
| variant=model_choice, | |
| ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt", | |
| device="cpu", | |
| ) | |
| if annotations["boxes"]: | |
| predictor = SAM2ImagePredictor(sam2_model) # type:ignore | |
| predictor.set_image(annotations["image"]) | |
| coordinates = [] | |
| for i in range(len(annotations["boxes"])): | |
| coordinate = [ | |
| int(annotations["boxes"][i]["xmin"]), | |
| int(annotations["boxes"][i]["ymin"]), | |
| int(annotations["boxes"][i]["xmax"]), | |
| int(annotations["boxes"][i]["ymax"]), | |
| ] | |
| coordinates.append(coordinate) | |
| masks, scores, _ = predictor.predict( | |
| point_coords=None, | |
| point_labels=None, | |
| box=np.array(coordinates), | |
| multimask_output=False, | |
| ) | |
| multi_box = len(scores) > 1 | |
| return show_masks( | |
| image=annotations["image"], | |
| masks=masks, | |
| scores=scores if len(scores) == 1 else None, | |
| only_best=not multi_box, | |
| ) | |
| else: | |
| mask_generator = SAM2AutomaticMaskGenerator(sam2_model) # type: ignore | |
| masks = mask_generator.generate(annotations["image"]) | |
| return show_masks( | |
| image=annotations["image"], | |
| masks=masks, # type: ignore | |
| scores=None, | |
| only_best=False, | |
| autogenerated_mask=True | |
| ) | |
| with gr.Blocks(delete_cache=(30, 30)) as demo: | |
| gr.Markdown( | |
| """ | |
| ## To read more about the Segment Anything Project please refer to the [Lightly AI blogpost](https://www.lightly.ai/post/segment-anything-model-and-friends) | |
| """ | |
| ) | |
| gr.Markdown( | |
| """ | |
| # 1. Choose Model Checkpoint | |
| """ | |
| ) | |
| with gr.Row(): | |
| model = gr.Dropdown( | |
| choices=["tiny", "small", "base_plus", "large"], | |
| value="tiny", | |
| label="Model Checkpoint", | |
| info="Which model checkpoint to load?", | |
| ) | |
| gr.Markdown( | |
| """ | |
| # 2. Upload your Image and draw bounding box(es) | |
| """ | |
| ) | |
| annotator = image_annotator( | |
| value={"image": cv2.imread("assets/example.png")}, | |
| disable_edit_boxes=True, | |
| label="Draw a bounding box", | |
| ) | |
| btn = gr.Button("Get Segmentation Mask(s)") | |
| btn.click( | |
| fn=predict, inputs=[model, annotator], outputs=[gr.Image(label="Mask(s)")] | |
| ) | |
| demo.launch() | |