Spaces:
Runtime error
Runtime error
| from typing import List | |
| import gradio as gr | |
| import numpy as np | |
| import supervision as sv | |
| import torch | |
| from PIL import Image | |
| from transformers import pipeline, CLIPProcessor, CLIPModel | |
| MARKDOWN = """ | |
| # Segment Anything Model + MetaCLIP | |
| This is the demo for a Open Vocabulary Image Segmentation using | |
| [Segment Anything Model](https://github.com/facebookresearch/segment-anything) and | |
| [MetaCLIP](https://github.com/facebookresearch/MetaCLIP) combo. | |
| """ | |
| EXAMPLES = [ | |
| ["https://media.roboflow.com/notebooks/examples/dog.jpeg", "dog", 0.5], | |
| ["https://media.roboflow.com/notebooks/examples/dog.jpeg", "building", 0.5], | |
| ["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "jacket", 0.5], | |
| ["https://media.roboflow.com/notebooks/examples/dog-3.jpeg", "coffee", 0.6], | |
| ] | |
| MIN_AREA_THRESHOLD = 0.01 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| SAM_GENERATOR = pipeline( | |
| task="mask-generation", | |
| model="facebook/sam-vit-large", | |
| device=DEVICE) | |
| CLIP_MODEL = CLIPModel.from_pretrained("facebook/metaclip-b32-400m").to(DEVICE) | |
| CLIP_PROCESSOR = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m") | |
| SEMITRANSPARENT_MASK_ANNOTATOR = sv.MaskAnnotator( | |
| color=sv.Color.RED, | |
| color_lookup=sv.ColorLookup.INDEX) | |
| SOLID_MASK_ANNOTATOR = sv.MaskAnnotator( | |
| color=sv.Color.WHITE, | |
| color_lookup=sv.ColorLookup.INDEX, | |
| opacity=1) | |
| def run_sam(image_rgb_pil: Image.Image) -> sv.Detections: | |
| outputs = SAM_GENERATOR(image_rgb_pil, points_per_batch=32) | |
| mask = np.array(outputs['masks']) | |
| return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask) | |
| def run_clip(image_rgb_pil: Image.Image, text: List[str]) -> np.ndarray: | |
| inputs = CLIP_PROCESSOR( | |
| text=text, | |
| images=image_rgb_pil, | |
| return_tensors="pt", | |
| padding=True | |
| ).to(DEVICE) | |
| outputs = CLIP_MODEL(**inputs) | |
| probs = outputs.logits_per_image.softmax(dim=1) | |
| return probs.detach().cpu().numpy() | |
| def reverse_mask_image(image: np.ndarray, mask: np.ndarray, gray_value=128): | |
| gray_color = np.array([gray_value, gray_value, gray_value], dtype=np.uint8) | |
| return np.where(mask[..., None], image, gray_color) | |
| def annotate( | |
| image_rgb_pil: Image.Image, | |
| detections: sv.Detections, | |
| annotator: sv.MaskAnnotator | |
| ) -> Image.Image: | |
| img_bgr_numpy = np.array(image_rgb_pil)[:, :, ::-1] | |
| annotated_bgr_image = annotator.annotate( | |
| scene=img_bgr_numpy, detections=detections) | |
| return Image.fromarray(annotated_bgr_image[:, :, ::-1]) | |
| def filter_detections( | |
| image_rgb_pil: Image.Image, | |
| detections: sv.Detections, | |
| prompt: str, | |
| confidence: float | |
| ) -> sv.Detections: | |
| img_rgb_numpy = np.array(image_rgb_pil) | |
| text = [f"a picture of {prompt}", "a picture of background"] | |
| filtering_mask = [] | |
| for xyxy, mask in zip(detections.xyxy, detections.mask): | |
| crop = sv.crop_image(image=img_rgb_numpy, xyxy=xyxy) | |
| mask_crop = sv.crop_image(image=mask, xyxy=xyxy) | |
| masked_crop = reverse_mask_image(image=crop, mask=mask_crop) | |
| masked_crop_pil = Image.fromarray(masked_crop) | |
| probs = run_clip(image_rgb_pil=masked_crop_pil, text=text) | |
| filtering_mask.append(probs[0][0] > confidence) | |
| filtering_mask = np.array(filtering_mask) | |
| return detections[filtering_mask] | |
| def inference( | |
| image_rgb_pil: Image.Image, | |
| prompt: str, | |
| confidence: float | |
| ) -> List[Image.Image]: | |
| width, height = image_rgb_pil.size | |
| area = width * height | |
| detections = run_sam(image_rgb_pil) | |
| detections = detections[detections.area / area > MIN_AREA_THRESHOLD] | |
| detections = filter_detections( | |
| image_rgb_pil=image_rgb_pil, | |
| detections=detections, | |
| prompt=prompt, | |
| confidence=confidence) | |
| blank_image = Image.new("RGB", (width, height), "black") | |
| return [ | |
| annotate( | |
| image_rgb_pil=image_rgb_pil, | |
| detections=detections, | |
| annotator=SEMITRANSPARENT_MASK_ANNOTATOR), | |
| annotate( | |
| image_rgb_pil=blank_image, | |
| detections=detections, | |
| annotator=SOLID_MASK_ANNOTATOR) | |
| ] | |
| with gr.Blocks() as demo: | |
| gr.Markdown(MARKDOWN) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image( | |
| image_mode='RGB', type='pil', height=500) | |
| prompt_text = gr.Textbox( | |
| label="Prompt", value="dog") | |
| confidence_slider = gr.Slider( | |
| label="Confidence", minimum=0.5, maximum=1.0, step=0.05, value=0.6) | |
| submit_button = gr.Button("Submit") | |
| gallery = gr.Gallery(label="Result", object_fit="scale-down", preview=True) | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| fn=inference, | |
| inputs=[input_image, prompt_text, confidence_slider], | |
| outputs=[gallery], | |
| cache_examples=True, | |
| run_on_click=True | |
| ) | |
| submit_button.click( | |
| inference, | |
| inputs=[input_image, prompt_text, confidence_slider], | |
| outputs=gallery) | |
| demo.launch(debug=False, show_error=True) | |