Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from gradio_image_prompter import ImagePrompter | |
| import torch | |
| import numpy as np | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| from PIL import Image | |
| from uuid import uuid4 | |
| import os | |
| from huggingface_hub import upload_folder | |
| import shutil | |
| MODEL = "facebook/sam2-hiera-large" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| PREDICTOR = SAM2ImagePredictor.from_pretrained(MODEL, device=DEVICE) | |
| GLOBALS = {} | |
| IMAGE = None | |
| MASKS = None | |
| INDEX = None | |
| def prompter(prompts): | |
| image = np.array(prompts["image"]) # Convert the image to a numpy array | |
| points = prompts["points"] # Get the points from prompts | |
| # Perform inference with multimask_output=True | |
| with torch.inference_mode(): | |
| PREDICTOR.set_image(image) | |
| input_point = [[point[0], point[1]] for point in points] | |
| input_label = [1] * len(points) # Assuming all points are foreground | |
| masks, _, _ = PREDICTOR.predict( | |
| point_coords=input_point, point_labels=input_label, multimask_output=True | |
| ) | |
| # Prepare individual images with separate overlays | |
| overlay_images = [] | |
| for i, mask in enumerate(masks): | |
| print(f"Predicted Mask {i+1}:", mask.shape) | |
| red_mask = np.zeros_like(image) | |
| red_mask[:, :, 0] = mask.astype(np.uint8) * 255 # Apply the red channel | |
| red_mask = Image.fromarray(red_mask) | |
| # Convert the original image to a PIL image | |
| original_image = Image.fromarray(image) | |
| # Blend the original image with the red mask | |
| blended_image = Image.blend(original_image, red_mask, alpha=0.5) | |
| # Add the blended image to the list | |
| overlay_images.append(blended_image) | |
| global IMAGE, MASKS | |
| IMAGE, MASKS = image, masks | |
| return overlay_images[0], overlay_images[1], overlay_images[2], masks | |
| def select_mask( | |
| selected_mask_index, | |
| mask1, | |
| mask2, | |
| mask3, | |
| ): | |
| masks = [mask1, mask2, mask3] | |
| global INDEX | |
| INDEX = selected_mask_index | |
| return masks[selected_mask_index] | |
| def save_selected_mask(image, mask, output_dir="output"): | |
| output_dir = os.path.join(os.getcwd(), output_dir) | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Generate a unique UUID for the folder name | |
| folder_id = str(uuid4()) | |
| # Create a path for the new folder | |
| folder_path = os.path.join(output_dir, folder_id) | |
| # Ensure the folder is created | |
| os.makedirs(folder_path, exist_ok=True) | |
| # Define the paths for saving the image and mask | |
| image_path = os.path.join(folder_path, "image.npy") | |
| mask_path = os.path.join(folder_path, "mask.npy") | |
| # Save the image and mask to the respective paths | |
| with open(image_path, "wb") as f: | |
| np.save(f, IMAGE) | |
| with open(mask_path, "wb") as f: | |
| np.save(f, MASKS[INDEX]) | |
| # Upload the folder to the Hugging Face Hub | |
| upload_folder( | |
| folder_path=output_dir, | |
| # path_in_repo=path_in_repo, | |
| repo_id="amaye15/object-segmentation", | |
| repo_type="dataset", | |
| # ignore_patterns="**/logs/*.txt", # Adjust this if needed | |
| ) | |
| shutil.rmtree(folder_path) | |
| return f"Image and mask saved to {folder_path}." | |
| def save_dataset_name(key, dataset_name): | |
| global GLOBALS | |
| GLOBALS[key] = dataset_name | |
| iframe_code = f""" | |
| <iframe | |
| src="https://huggingface.co/datasets/{dataset_name}/embed/viewer/default/train" | |
| frameborder="0" | |
| width="100%" | |
| height="560px" | |
| ></iframe> | |
| """ | |
| return f"Huggingface Dataset: {dataset_name}", iframe_code | |
| # Define the Gradio Blocks app | |
| with gr.Blocks() as demo: | |
| with gr.Tab("Setup"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| source = gr.Textbox(label="Source Dataset") | |
| source_display = gr.Markdown() | |
| iframe_display = gr.HTML() | |
| source.change( | |
| save_dataset_name, | |
| inputs=(gr.State("source_dataset"), source), | |
| outputs=(source_display, iframe_display), | |
| ) | |
| with gr.Column(): | |
| destination = gr.Textbox(label="Destination Dataset") | |
| destination_display = gr.Markdown() | |
| destination.change( | |
| save_dataset_name, | |
| inputs=(gr.State("destination_dataset"), destination), | |
| outputs=destination_display, | |
| ) | |
| with gr.Tab("Object Mask - Point Prompt"): | |
| gr.Markdown("# Image Point Collector with Multiple Separate Mask Overlays") | |
| gr.Markdown( | |
| "Upload an image, click on it, and get each predicted mask overlaid separately in red on individual images." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Input: ImagePrompter | |
| image_input = ImagePrompter(show_label=False) | |
| submit_button = gr.Button("Submit") | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Outputs: Up to 3 overlay images | |
| image_output_1 = gr.Image(show_label=False) | |
| with gr.Column(): | |
| image_output_2 = gr.Image(show_label=False) | |
| with gr.Column(): | |
| image_output_3 = gr.Image(show_label=False) | |
| # Dropdown for selecting the correct mask | |
| with gr.Row(): | |
| mask_selector = gr.Radio( | |
| label="Select the correct mask", | |
| choices=["Mask 1", "Mask 2", "Mask 3"], | |
| type="index", | |
| ) | |
| # selected_mask_output = gr.Image(show_label=False) | |
| save_button = gr.Button("Save Selected Mask and Image") | |
| save_message = gr.Textbox(visible=False) | |
| # Define the action triggered by the submit button | |
| submit_button.click( | |
| fn=prompter, | |
| inputs=image_input, | |
| outputs=[image_output_1, image_output_2, image_output_3, gr.State()], | |
| ) | |
| # Define the action triggered by mask selection | |
| mask_selector.change( | |
| fn=select_mask, | |
| inputs=[mask_selector, image_output_1, image_output_2, image_output_3], | |
| outputs=gr.State(), | |
| ) | |
| # Define the action triggered by the save button | |
| save_button.click( | |
| fn=save_selected_mask, | |
| inputs=[gr.State(), gr.State()], | |
| outputs=save_message, | |
| show_progress=True, | |
| ) | |
| # Launch the Gradio app | |
| demo.launch() | |