Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| from typing import Optional | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import supervision as sv | |
| import torch | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from gradio_image_prompter import ImagePrompter | |
| from utils.models import load_models, CHECKPOINT_NAMES, MODE_NAMES, \ | |
| MASK_GENERATION_MODE, BOX_PROMPT_MODE, VIDEO_SEGMENTATION_MODE | |
| from utils.video import create_directory, generate_unique_name | |
| from sam2.build_sam import build_sam2_video_predictor | |
| MARKDOWN = """ | |
| # Segment Anything Model 2 🔥 | |
| <div> | |
| <a href="https://github.com/facebookresearch/segment-anything-2"> | |
| <img src="https://badges.aleen42.com/src/github.svg" alt="GitHub" style="display:inline-block;"> | |
| </a> | |
| <a href="https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-segment-images-with-sam-2.ipynb"> | |
| <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Colab" style="display:inline-block;"> | |
| </a> | |
| <a href="https://blog.roboflow.com/what-is-segment-anything-2/"> | |
| <img src="https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg" alt="Roboflow" style="display:inline-block;"> | |
| </a> | |
| <a href="https://www.youtube.com/watch?v=Dv003fTyO-Y"> | |
| <img src="https://badges.aleen42.com/src/youtube.svg" alt="YouTube" style="display:inline-block;"> | |
| </a> | |
| </div> | |
| Segment Anything Model 2 (SAM 2) is a foundation model designed to address promptable | |
| visual segmentation in both images and videos. **Video segmentation will be available | |
| soon.** | |
| """ | |
| EXAMPLES = [ | |
| ["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg", None], | |
| ["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-3.jpeg", None], | |
| ["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-4.jpeg", None], | |
| ] | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX) | |
| IMAGE_PREDICTORS, MASK_GENERATORS = load_models(device=DEVICE) | |
| SCALE_FACTOR = 0.5 | |
| TARGET_DIRECTORY = "tmp" | |
| # creating video results directory | |
| create_directory(directory_path=TARGET_DIRECTORY) | |
| def on_mode_dropdown_change(text): | |
| return [ | |
| gr.Image(visible=text == MASK_GENERATION_MODE), | |
| ImagePrompter(visible=text == BOX_PROMPT_MODE), | |
| gr.Video(visible=text == VIDEO_SEGMENTATION_MODE), | |
| ImagePrompter(visible=text == VIDEO_SEGMENTATION_MODE), | |
| gr.Button(visible=text != VIDEO_SEGMENTATION_MODE), | |
| gr.Button(visible=text == VIDEO_SEGMENTATION_MODE), | |
| gr.Image(visible=text != VIDEO_SEGMENTATION_MODE), | |
| gr.Video(visible=text == VIDEO_SEGMENTATION_MODE) | |
| ] | |
| def on_video_input_change(video_input): | |
| if not video_input: | |
| return None | |
| frames_generator = sv.get_video_frames_generator(video_input) | |
| frame = next(frames_generator) | |
| frame = sv.scale_image(frame, SCALE_FACTOR) | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frame = Image.fromarray(frame) | |
| return {'image': frame, 'points': []} | |
| def process_image( | |
| checkpoint_dropdown, | |
| mode_dropdown, | |
| image_input, | |
| image_prompter_input | |
| ) -> Optional[Image.Image]: | |
| if mode_dropdown == BOX_PROMPT_MODE: | |
| image_input = image_prompter_input["image"] | |
| prompt = image_prompter_input["points"] | |
| if len(prompt) == 0: | |
| return image_input | |
| model = IMAGE_PREDICTORS[checkpoint_dropdown] | |
| image = np.array(image_input.convert("RGB")) | |
| box = np.array([[x1, y1, x2, y2] for x1, y1, _, x2, y2, _ in prompt]) | |
| model.set_image(image) | |
| masks, _, _ = model.predict(box=box, multimask_output=False) | |
| # dirty fix; remove this later | |
| if len(masks.shape) == 4: | |
| masks = np.squeeze(masks) | |
| detections = sv.Detections( | |
| xyxy=sv.mask_to_xyxy(masks=masks), | |
| mask=masks.astype(bool) | |
| ) | |
| return MASK_ANNOTATOR.annotate(image_input, detections) | |
| if mode_dropdown == MASK_GENERATION_MODE: | |
| model = MASK_GENERATORS[checkpoint_dropdown] | |
| image = np.array(image_input.convert("RGB")) | |
| result = model.generate(image) | |
| detections = sv.Detections.from_sam(result) | |
| return MASK_ANNOTATOR.annotate(image_input, detections) | |
| def process_video( | |
| checkpoint_dropdown, | |
| mode_dropdown, | |
| video_input, | |
| video_prompter_input, | |
| progress=gr.Progress(track_tqdm=True) | |
| ) -> str: | |
| if mode_dropdown != VIDEO_SEGMENTATION_MODE: | |
| return str(video_input) | |
| name = generate_unique_name() | |
| frame_directory_path = os.path.join(TARGET_DIRECTORY, name) | |
| frames_sink = sv.ImageSink( | |
| target_dir_path=frame_directory_path, | |
| image_name_pattern="{:05d}.jpeg" | |
| ) | |
| video_info = sv.VideoInfo.from_video_path(video_input) | |
| frames_generator = sv.get_video_frames_generator(video_input) | |
| with frames_sink: | |
| for frame in tqdm( | |
| frames_generator, | |
| total=video_info.total_frames, | |
| desc="splitting video into frames" | |
| ): | |
| frame = sv.scale_image(frame, SCALE_FACTOR) | |
| frames_sink.save_image(frame) | |
| model = build_sam2_video_predictor( | |
| "sam2_hiera_t.yaml", | |
| "checkpoints/sam2_hiera_tiny.pt", | |
| device=DEVICE | |
| ) | |
| inference_state = model.init_state( | |
| video_path=frame_directory_path, | |
| offload_video_to_cpu=DEVICE == torch.device('cpu'), | |
| offload_state_to_cpu=DEVICE == torch.device('cpu'), | |
| ) | |
| prompt = video_prompter_input["points"] | |
| points = np.array([[x1, y1] for x1, y1, _, _, _, _ in prompt]) | |
| labels = np.ones(len(points)) | |
| _, object_ids, mask_logits = model.add_new_points( | |
| inference_state=inference_state, | |
| frame_idx=0, | |
| obj_id=1, | |
| points=points, | |
| labels=labels, | |
| ) | |
| del inference_state | |
| del model | |
| video_path = os.path.join(TARGET_DIRECTORY, f"{name}.mp4") | |
| return str(video_input) | |
| with gr.Blocks() as demo: | |
| gr.Markdown(MARKDOWN) | |
| with gr.Row(): | |
| checkpoint_dropdown_component = gr.Dropdown( | |
| choices=CHECKPOINT_NAMES, | |
| value=CHECKPOINT_NAMES[0], | |
| label="Checkpoint", info="Select a SAM2 checkpoint to use.", | |
| interactive=True | |
| ) | |
| mode_dropdown_component = gr.Dropdown( | |
| choices=MODE_NAMES, | |
| value=MODE_NAMES[0], | |
| label="Mode", | |
| info="Select a mode to use. `box prompt` if you want to generate masks for " | |
| "selected objects, `mask generation` if you want to generate masks " | |
| "for the whole image, and `video segmentation` if you want to track " | |
| "object on video.", | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input_component = gr.Image( | |
| type='pil', label='Upload image', visible=False) | |
| image_prompter_input_component = ImagePrompter( | |
| type='pil', label='Prompt image') | |
| video_input_component = gr.Video( | |
| label='Step 1: Upload video', visible=False) | |
| video_prompter_input_component = ImagePrompter( | |
| type='pil', label='Step 2: Prompt frame', visible=False) | |
| submit_image_button_component = gr.Button( | |
| value='Submit', variant='primary') | |
| submit_video_button_component = gr.Button( | |
| value='Submit', variant='primary', visible=False) | |
| with gr.Column(): | |
| image_output_component = gr.Image(type='pil', label='Image output') | |
| video_output_component = gr.Video( | |
| label='Step 2: Video output', visible=False) | |
| with gr.Row(): | |
| gr.Examples( | |
| fn=process_image, | |
| examples=EXAMPLES, | |
| inputs=[ | |
| checkpoint_dropdown_component, | |
| mode_dropdown_component, | |
| image_input_component, | |
| image_prompter_input_component, | |
| ], | |
| outputs=[image_output_component], | |
| run_on_click=True | |
| ) | |
| mode_dropdown_component.change( | |
| on_mode_dropdown_change, | |
| inputs=[mode_dropdown_component], | |
| outputs=[ | |
| image_input_component, | |
| image_prompter_input_component, | |
| video_input_component, | |
| video_prompter_input_component, | |
| submit_image_button_component, | |
| submit_video_button_component, | |
| image_output_component, | |
| video_output_component | |
| ] | |
| ) | |
| video_input_component.change( | |
| fn=on_video_input_change, | |
| inputs=[video_input_component], | |
| outputs=[video_prompter_input_component] | |
| ) | |
| submit_image_button_component.click( | |
| fn=process_image, | |
| inputs=[ | |
| checkpoint_dropdown_component, | |
| mode_dropdown_component, | |
| image_input_component, | |
| image_prompter_input_component, | |
| ], | |
| outputs=[image_output_component] | |
| ) | |
| submit_video_button_component.click( | |
| fn=process_video, | |
| inputs=[ | |
| checkpoint_dropdown_component, | |
| mode_dropdown_component, | |
| video_input_component, | |
| video_prompter_input_component, | |
| ], | |
| outputs=[video_output_component] | |
| ) | |
| demo.launch(debug=False, show_error=True, max_threads=1) | |