Spaces:
Runtime error
Runtime error
| # import gradio as gr | |
| # from gradio_image_prompter import ImagePrompter | |
| # import os | |
| # import torch | |
| # def prompter(prompts): | |
| # image = prompts["image"] # Get the image from prompts | |
| # points = prompts["points"] # Get the points from prompts | |
| # # Print the collected inputs for debugging or logging | |
| # print("Image received:", image) | |
| # print("Points received:", points) | |
| # import torch | |
| # from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| # device = torch.device("cpu") | |
| # predictor = SAM2ImagePredictor.from_pretrained( | |
| # "facebook/sam2-hiera-base-plus", device=device | |
| # ) | |
| # with torch.inference_mode(): | |
| # predictor.set_image(image) | |
| # # masks, _, _ = predictor.predict([[point[0], point[1]] for point in points]) | |
| # input_point = [[point[0], point[1]] for point in points] | |
| # input_label = [1] | |
| # masks, _, _ = predictor.predict( | |
| # point_coords=input_point, point_labels=input_label | |
| # ) | |
| # print("Predicted Mask:", masks) | |
| # return image, points | |
| # # Define the Gradio interface | |
| # demo = gr.Interface( | |
| # fn=prompter, # Use the custom prompter function | |
| # inputs=ImagePrompter( | |
| # show_label=False | |
| # ), # ImagePrompter for image input and point selection | |
| # outputs=[ | |
| # gr.Image(show_label=False), # Display the image | |
| # gr.Dataframe(label="Points"), # Display the points in a DataFrame | |
| # ], | |
| # title="Image Point Collector", | |
| # description="Upload an image, click on it, and get the coordinates of the clicked points.", | |
| # ) | |
| # # Launch the Gradio app | |
| # demo.launch() | |
| # import gradio as gr | |
| # from gradio_image_prompter import ImagePrompter | |
| # import torch | |
| # from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| # def prompter(prompts): | |
| # image = prompts["image"] # Get the image from prompts | |
| # points = prompts["points"] # Get the points from prompts | |
| # # Print the collected inputs for debugging or logging | |
| # print("Image received:", image) | |
| # print("Points received:", points) | |
| # device = torch.device("cpu") | |
| # # Load the SAM2ImagePredictor model | |
| # predictor = SAM2ImagePredictor.from_pretrained( | |
| # "facebook/sam2-hiera-base-plus", device=device | |
| # ) | |
| # # Perform inference | |
| # with torch.inference_mode(): | |
| # predictor.set_image(image) | |
| # input_point = [[point[0], point[1]] for point in points] | |
| # input_label = [1] * len(points) # Assuming all points are foreground | |
| # masks, _, _ = predictor.predict( | |
| # point_coords=input_point, point_labels=input_label | |
| # ) | |
| # # The masks are returned as a list of numpy arrays | |
| # print("Predicted Mask:", masks) | |
| # # Assuming there's only one mask returned, you can adjust if there are multiple | |
| # predicted_mask = masks[0] | |
| # print(len(image)) | |
| # print(len(predicted_mask)) | |
| # # Create annotations for AnnotatedImage | |
| # annotations = [(predicted_mask, "Predicted Mask")] | |
| # return image, annotations | |
| # # Define the Gradio interface | |
| # demo = gr.Interface( | |
| # fn=prompter, # Use the custom prompter function | |
| # inputs=ImagePrompter( | |
| # show_label=False | |
| # ), # ImagePrompter for image input and point selection | |
| # outputs=gr.AnnotatedImage(), # Display the image with the predicted mask | |
| # title="Image Point Collector with Mask Overlay", | |
| # description="Upload an image, click on it, and get the predicted mask overlayed on the image.", | |
| # ) | |
| # # Launch the Gradio app | |
| # demo.launch() | |
| import gradio as gr | |
| from gradio_image_prompter import ImagePrompter | |
| import torch | |
| import numpy as np | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| from PIL import Image | |
| def prompter(prompts): | |
| image = np.array(prompts["image"]) # Convert the image to a numpy array | |
| points = prompts["points"] # Get the points from prompts | |
| # Print the collected inputs for debugging or logging | |
| print("Image received:", image) | |
| print("Points received:", points) | |
| device = torch.device("cpu") | |
| # Load the SAM2ImagePredictor model | |
| predictor = SAM2ImagePredictor.from_pretrained( | |
| "facebook/sam2-hiera-base-plus", device=device | |
| ) | |
| # Perform inference with multimask_output=True | |
| with torch.inference_mode(): | |
| predictor.set_image(image) | |
| input_point = [[point[0], point[1]] for point in points] | |
| input_label = [1] * len(points) # Assuming all points are foreground | |
| masks, _, _ = predictor.predict( | |
| point_coords=input_point, point_labels=input_label, multimask_output=True | |
| ) | |
| # Prepare individual images with separate overlays | |
| overlay_images = [] | |
| for i, mask in enumerate(masks): | |
| print(f"Predicted Mask {i+1}:", mask) | |
| red_mask = np.zeros_like(image) | |
| red_mask[:, :, 0] = mask.astype(np.uint8) * 255 # Apply the red channel | |
| red_mask = Image.fromarray(red_mask) | |
| # Convert the original image to a PIL image | |
| original_image = Image.fromarray(image) | |
| # Blend the original image with the red mask | |
| blended_image = Image.blend(original_image, red_mask, alpha=0.5) | |
| # Add the blended image to the list | |
| overlay_images.append(blended_image) | |
| return overlay_images | |
| # Define the Gradio interface | |
| demo = gr.Interface( | |
| fn=prompter, # Use the custom prompter function | |
| inputs=ImagePrompter( | |
| show_label=False | |
| ), # ImagePrompter for image input and point selection | |
| outputs=[ | |
| gr.Image(show_label=False) for _ in range(3) | |
| ], # Display up to 3 overlay images | |
| title="Image Point Collector with Multiple Separate Mask Overlays", | |
| description="Upload an image, click on it, and get each predicted mask overlaid separately in red on individual images.", | |
| ) | |
| # Launch the Gradio app | |
| demo.launch() | |