Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # Copyright (c) Meta Platforms, Inc. All Rights Reserved | |
| import os | |
| import ast | |
| import time | |
| import random | |
| from PIL import Image | |
| import torch | |
| import numpy as np | |
| from tqdm import tqdm | |
| import matplotlib.pyplot as plt | |
| from plyfile import PlyData | |
| import gradio as gr | |
| import plotly.graph_objs as go | |
| from sam_3d import SAM3DDemo | |
| def pc_to_plot(pc): | |
| return go.Figure( | |
| data=[ | |
| go.Scatter3d( | |
| x=pc['x'], y=pc['y'], z=pc['z'], | |
| mode='markers', | |
| marker=dict( | |
| size=2, | |
| color=['rgb({},{},{})'.format(r,g,b) for r,g,b in zip(pc['red'], pc['green'], pc['blue'])], | |
| ) | |
| ) | |
| ], | |
| layout=dict( | |
| scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False)) | |
| ), | |
| ) | |
| def inference(scene_name, granularity, coords, plot): | |
| print(scene_name, coords) | |
| sam_3d = SAM3DDemo('vit_b', 'sam_vit_b_01ec64.pth', scene_name) | |
| coords = ast.literal_eval(coords) | |
| data_point_select, rgb_img_w_points, rgb_img_w_masks, data_final = sam_3d.run_with_coord(coords, int(granularity)) | |
| return pc_to_plot(data_point_select), Image.fromarray(rgb_img_w_points), Image.fromarray(rgb_img_w_masks), pc_to_plot(data_final) | |
| plydatas = [] | |
| for scene_name in ['scene0000_00', 'scene0001_00', 'scene0002_00']: | |
| plydata = PlyData.read(f"./scannet_data/{scene_name}/{scene_name}.ply") | |
| data = plydata.elements[0].data | |
| plydatas.append(data) | |
| examples = [['scene0000_00', 0, [0, -2.5, 0.7], pc_to_plot(plydatas[0])], | |
| ['scene0001_00', 0, [0, -2.5, 1], pc_to_plot(plydatas[1])], | |
| ['scene0002_00', 0, [0, -2.5, 1], pc_to_plot(plydatas[2])],] | |
| title = 'Segment_Anything on 3D in-door point clouds' | |
| description = """ | |
| Gradio Demo for Segment Anything on 3D indoor scenes (ScanNet supported). \n | |
| The logic is straighforward: 1) Find a point in 3D; 2) project the 3D point to valid images; 3) perform 2D SAM on valid images; 4) reproject 2D results back to 3D; 5) Visualization. | |
| Unfortunatly, it does not support click the point cloud to generate coordinates automatically. You may want to write down the coordinates and put it manually. \n | |
| """ | |
| article = """ | |
| <p style='text-align: center'> | |
| <a href='https://arxiv.org/abs/2210.04150' target='_blank'> | |
| Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP | |
| </a> | |
| | | |
| <a href='https://github.com/facebookresearch/ov-seg' target='_blank'>Github Repo</a></p> | |
| """ | |
| gr.Interface( | |
| inference, | |
| inputs=[ | |
| gr.Dropdown(choices=['scene0000_00', 'scene0001_00', 'scene0002_00'], label="Scannet scene name (limited scenes supported)"), | |
| gr.Dropdown(choices=[0, 1, 2], label="Mask granularity from 0 (most coarse) to 2 (most precise)"), | |
| gr.Textbox(lines=1, label='Coordinates'), | |
| gr.Plot(label="Input Point cloud (For visualization and point finding only, click responce not supported yet.)"), | |
| ], | |
| outputs=[gr.Plot(label='Selected point(s): red points show the top 10 cloest points for your input anchor point'), | |
| gr.Image(label='Selected image with projected points'), | |
| gr.Image(label='Selected image processed after SAM'), | |
| gr.Plot(label='Output Point cloud: blue points represent the mask')], | |
| title=title, | |
| description=description, | |
| article=article, | |
| examples=examples).launch() |