Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import open3d_zerogpu_fix | |
| import spaces | |
| import re | |
| from utils import read_pcd, render_point_cloud, render_pcd_file, set_seed | |
| from inference.utils import get_legend | |
| from inference.inference import segment_obj, get_heatmap | |
| from huggingface_hub import login | |
| import os | |
| os.chdir("Pointcept/libs/pointops") | |
| os.system("python setup.py install") | |
| os.chdir("../../../") | |
| login(token=os.getenv('hfkey')) | |
| parts_dict = { | |
| "fireplug": "bonnet of a fireplug,side cap of a fireplug,barrel of a fireplug,base of a fireplug", | |
| "mickey": "ear,head,arms,hands,body,legs", | |
| "motorvehicle": "wheel of a motor vehicle,seat of a motor vehicle,handle of a motor vehicle", | |
| "teddy": "head,body,arms,legs", | |
| "lamppost": "lighting of a lamppost,pole of a lamppost", | |
| "shirt": "sleeve of a shirt,collar of a shirt,body of a shirt", | |
| "capybara": "hat worn by a capybara,head,body,feet", | |
| "corgi": "head,leg,body,ear", | |
| "pushcar": "wheel,body,handle", | |
| "plant": "pot,plant", | |
| "chair": "back of chair,leg,seat", | |
| "objpart_redblack": "head,arm,foot,body,leg,hand,knapsack,neck", | |
| "objpart_dragon": "body,head,leg,wing,tail,foot", | |
| "objpart_catgirl": "Body,Leg,Head,Hair,Foot,Hand,Ear,Shorts,Tail,Arm" | |
| } | |
| source_dict = { | |
| "fireplug":"objaverse", | |
| "mickey":"objaverse", | |
| "motorvehicle":"objaverse", | |
| "teddy":"objaverse", | |
| "lamppost":"objaverse", | |
| "shirt":"objaverse", | |
| "capybara": "wild", | |
| "corgi": "wild", | |
| "pushcar": "wild", | |
| "plant": "wild", | |
| "chair": "wild", | |
| "objpart_redblack": "wild", | |
| "objpart_catgirl": "wild", | |
| "objpart_dragon": "wild" | |
| } | |
| def run_predict(*args): | |
| yield from predict(*args) | |
| def predict(pcd_path, inference_mode, part_queries): | |
| set_seed() | |
| xyz, rgb, normal = read_pcd(pcd_path) | |
| if inference_mode == "Segmentation": | |
| parts = [part.strip(" ") for part in re.split(r'[,;.|]', part_queries)] | |
| if len(parts)< 2: | |
| raise gr.Error("For segmentation mode, please provide 2 or more parts", duration=5) | |
| seg_rgb = segment_obj(xyz, rgb, normal, parts).cpu().numpy() | |
| legend = get_legend(parts) | |
| yield render_point_cloud(xyz, seg_rgb, legend=legend) | |
| elif inference_mode == "Localization": | |
| if "," in part_queries or ";" in part_queries or "." in part_queries: | |
| raise gr.Error("For localization mode, please provide only one part", duration=5) | |
| heatmap_rgb = get_heatmap(xyz, rgb, normal, part_queries).cpu().numpy() | |
| yield render_point_cloud(xyz, heatmap_rgb) | |
| else: | |
| yield None | |
| def on_select(evt: gr.SelectData): | |
| obj_name = evt.value['image']['orig_name'][:-4] | |
| src = source_dict[obj_name] | |
| return [f"examples/{src}/{obj_name}.pcd", parts_dict[obj_name]] | |
| with gr.Blocks(theme=gr.themes.Default(text_size="lg", radius_size="none")) as demo: | |
| gr.HTML( | |
| '''<h1 text-align="center">Find Any Part in 3D</h1> | |
| <p style='font-size: 16px;'>This is a demo for Find3D: Find Any Part in 3D! Two modes are supported: <b>segmentation</b> and <b>localization</b>. | |
| <br> | |
| For <b>segmentation mode</b>, please provide multiple part queries in the "queries" text box, in the format of comma-separated string, such as "part1,part2,part3". | |
| After hitting "Run", the model will segment the object into the provided parts. | |
| <br> | |
| For <b>localization mode</b>, please only provide <b>one query string</b> in the "queries" text box. After hitting "Run", the model will generate a heatmap for the provided query text. | |
| Please click on the buttons below "Objaverse" and "In the Wild" for some examples. You can also upload your own .pcd files.</p> | |
| <p style='font-size: 16px;'>Hint: | |
| When uploading your own point cloud, please first close the existing point cloud by clicking on the "x" button. | |
| <br> | |
| We show some sample queries for the provided examples. When working with your own point cloud, feel free to rephrase the query (e.g. "part" vs "part of a object") to achieve better performance!</p> | |
| ''' | |
| ) | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=4): | |
| file_upload = gr.File( | |
| label="Upload Point Cloud File", | |
| type="filepath", | |
| file_types=[".pcd"], | |
| value="examples/objaverse/lamppost.pcd" | |
| ) | |
| inference_mode = gr.Radio( | |
| choices=["Segmentation", "Localization"], | |
| label="Inference Mode", | |
| value="Segmentation", | |
| ) | |
| part_queries = gr.Textbox( | |
| label="Part Queries", | |
| value="lighting of a lamppost,pole of a lamppost", | |
| ) | |
| run_button = gr.Button( | |
| value="Run", | |
| variant="primary", | |
| ) | |
| with gr.Column(scale=4): | |
| input_image = gr.Image(label="Input Image", visible=False, type='pil', image_mode='RGBA', height=290) | |
| input_point_cloud = gr.Plot(label="Input Point Cloud") | |
| with gr.Column(scale=4): | |
| output_point_cloud = gr.Plot(label="Output Result") | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=6): | |
| title = gr.HTML('''<h1 text-align="center">Objaverse</h1> | |
| <p style='font-size: 16px;'>Online 3D assets from Objaverse!</p> | |
| ''') | |
| gallery_objaverse = gr.Gallery([("examples/objaverse/lamppost.jpg", "lamppost"), | |
| ("examples/objaverse/fireplug.jpg", "fireplug"), | |
| ("examples/objaverse/mickey.jpg", "Mickey"), | |
| ("examples/objaverse/motorvehicle.jpg", "motor vehicle"), | |
| ("examples/objaverse/teddy.jpg", "teddy bear"), | |
| ("examples/objaverse/shirt.jpg", "shirt")], | |
| columns=3, | |
| allow_preview=False) | |
| gallery_objaverse.select(fn=on_select, | |
| inputs=None, | |
| outputs=[file_upload, part_queries]) | |
| with gr.Column(scale=6): | |
| title = gr.HTML("""<h1 text-align="center">In the Wild & PartObjaverse-Tiny</h1> | |
| <p style='font-size: 16px;'>Challenging examples in-the-wild reconstructions from iPhone photos and PartObjaverseTiny!</p> | |
| """) | |
| gallery_wild = gr.Gallery([("examples/wild/pushcar.jpg", "iPhone-pushcar"), | |
| ("examples/wild/plant.jpg", "iPhone-plant"), | |
| ("examples/wild/objpart_catgirl.jpg", "Cat girl character"), | |
| ("examples/wild/objpart_dragon.jpg", "Dragon character"), | |
| ("examples/wild/objpart_redblack.jpg", "Cartoon character"),], | |
| columns=3, | |
| allow_preview=False) | |
| gallery_wild.select(fn=on_select, | |
| inputs=None, | |
| outputs=[file_upload, part_queries]) | |
| file_upload.change( | |
| fn=render_pcd_file, | |
| inputs=[file_upload], | |
| outputs=[input_point_cloud], | |
| ) | |
| run_button.click( | |
| fn=run_predict, | |
| inputs=[file_upload, inference_mode, part_queries], | |
| outputs=[output_point_cloud], | |
| ) | |
| demo.load( | |
| fn=render_pcd_file, | |
| inputs=[file_upload], | |
| outputs=[input_point_cloud]) # initialize | |
| demo.launch() | |