Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| from src.plot_utils import show_masks | |
| from gradio_image_annotation import image_annotator | |
| from sam2.build_sam import build_sam2 | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| choice_mapping = { | |
| "tiny": ["sam2_hiera_t.yaml", "assets/checkpoints/sam2_hiera_tiny.pt"], | |
| "small": ["sam2_hiera_s.yaml", "assets/checkpoints/sam2_hiera_small.pt"], | |
| "base_plus": ["sam2_hiera_b+.yaml", "assets/checkpoints/sam2_hiera_base_plus.pt"], | |
| "large": ["sam2_hiera_l.yaml", "assets/checkpoints/sam2_hiera_large.pt"], | |
| } | |
| def predict(model_choice: str, annotations, image): | |
| config_file, ckpt_path = choice_mapping[str(model_choice)] | |
| sam2_model = build_sam2(config_file, ckpt_path, device="cpu") | |
| predictor = SAM2ImagePredictor(sam2_model) | |
| predictor.set_image(image) | |
| coordinates = np.array( | |
| [ | |
| int(annotations["boxes"][0]["xmin"]), | |
| int(annotations["boxes"][0]["ymin"]), | |
| int(annotations["boxes"][0]["xmax"]), | |
| int(annotations["boxes"][0]["ymax"]), | |
| ] | |
| ) | |
| masks, scores, _ = predictor.predict( | |
| point_coords=None, | |
| point_labels=None, | |
| box=coordinates[None, :], | |
| multimask_output=False, | |
| ) | |
| mask = masks.transpose(1, 2, 0) | |
| mask_image = (mask * 255).astype(np.uint8) # Convert to uint8 format | |
| cv2.imwrite("mask.png", mask_image) | |
| return [ | |
| show_masks(image, masks, scores, box_coords=coordinates), | |
| gr.DownloadButton("Download Mask", value="mask.png", visible=True), | |
| ] | |
| with gr.Blocks(delete_cache=(30, 30)) as demo: | |
| gr.Markdown( | |
| """ | |
| # 1. Choose Model Checkpoint | |
| """ | |
| ) | |
| with gr.Row(): | |
| model = gr.Dropdown( | |
| choices=["tiny", "small", "base_plus", "large"], | |
| value="tiny", | |
| label="Model Checkpoint", | |
| info="Which model checkpoint to load?", | |
| ) | |
| gr.Markdown( | |
| """ | |
| # 2. Upload an Image | |
| """ | |
| ) | |
| with gr.Row(): | |
| img = gr.Image(value="./assets/img.png", type="numpy", label="Input Image") | |
| gr.Markdown( | |
| """ | |
| # 3. Draw Bounding Box | |
| """ | |
| ) | |
| annotator = image_annotator( | |
| value={"image": img.value["path"]}, | |
| disable_edit_boxes=True, | |
| single_box=True, | |
| label="Draw a bounding box", | |
| ) | |
| btn = gr.Button("Get Segmentation Mask") | |
| download_btn = gr.DownloadButton("Download Mask", value="mask.png", visible=False) | |
| btn.click( | |
| fn=predict, inputs=[model, annotator, img], outputs=[gr.Plot(), download_btn] | |
| ) | |
| demo.launch() | |