Spaces:
Running
Running
| import gradio as gr | |
| import numpy as np | |
| import os | |
| import json | |
| import subprocess | |
| from PIL import Image | |
| from functools import partial | |
| from datetime import datetime | |
| from sam_inference import get_sam_predictor, sam_seg | |
| from utils import blend_seg, blend_seg_pure | |
| import cv2 | |
| import uuid | |
| import torch | |
| import trimesh | |
| from huggingface_hub import snapshot_download | |
| from gradio_model3dcolor import Model3DColor | |
| # from gradio_model3dnormal import Model3DNormal | |
| code_dir = snapshot_download("sudo-ai/MeshFormer-API", token=os.environ['HF_TOKEN']) | |
| with open(f'{code_dir}/api.json', 'r') as file: | |
| api_dict = json.load(file) | |
| SEG_CMD = api_dict["SEG_CMD"] | |
| MESH_CMD = api_dict["MESH_CMD"] | |
| STYLE = """ | |
| <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet" integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN" crossorigin="anonymous"> | |
| <style> | |
| .alert, .alert div, .alert b { | |
| color: black !important; | |
| } | |
| </style> | |
| """ | |
| # info (info-circle-fill), cursor (hand-index-thumb), wait (hourglass-split), done (check-circle) | |
| ICONS = { | |
| "info": """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="#0d6efd" class="bi bi-info-circle-fill flex-shrink-0 me-2" viewBox="0 0 16 16"> | |
| <path d="M8 16A8 8 0 1 0 8 0a8 8 0 0 0 0 16zm.93-9.412-1 4.705c-.07.34.029.533.304.533.194 0 .487-.07.686-.246l-.088.416c-.287.346-.92.598-1.465.598-.703 0-1.002-.422-.808-1.319l.738-3.468c.064-.293.006-.399-.287-.47l-.451-.081.082-.381 2.29-.287zM8 5.5a1 1 0 1 1 0-2 1 1 0 0 1 0 2z"/> | |
| </svg>""", | |
| "cursor": """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="#0dcaf0" class="bi bi-hand-index-thumb-fill flex-shrink-0 me-2" viewBox="0 0 16 16"> | |
| <path d="M8.5 1.75v2.716l.047-.002c.312-.012.742-.016 1.051.046.28.056.543.18.738.288.273.152.456.385.56.642l.132-.012c.312-.024.794-.038 1.158.108.37.148.689.487.88.716.075.09.141.175.195.248h.582a2 2 0 0 1 1.99 2.199l-.272 2.715a3.5 3.5 0 0 1-.444 1.389l-1.395 2.441A1.5 1.5 0 0 1 12.42 16H6.118a1.5 1.5 0 0 1-1.342-.83l-1.215-2.43L1.07 8.589a1.517 1.517 0 0 1 2.373-1.852L5 8.293V1.75a1.75 1.75 0 0 1 3.5 0z"/> | |
| </svg>""", | |
| "wait": """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="#6c757d" class="bi bi-hourglass-split flex-shrink-0 me-2" viewBox="0 0 16 16"> | |
| <path d="M2.5 15a.5.5 0 1 1 0-1h1v-1a4.5 4.5 0 0 1 2.557-4.06c.29-.139.443-.377.443-.59v-.7c0-.213-.154-.451-.443-.59A4.5 4.5 0 0 1 3.5 3V2h-1a.5.5 0 0 1 0-1h11a.5.5 0 0 1 0 1h-1v1a4.5 4.5 0 0 1-2.557 4.06c-.29.139-.443.377-.443.59v.7c0 .213.154.451.443.59A4.5 4.5 0 0 1 12.5 13v1h1a.5.5 0 0 1 0 1h-11zm2-13v1c0 .537.12 1.045.337 1.5h6.326c.216-.455.337-.963.337-1.5V2h-7zm3 6.35c0 .701-.478 1.236-1.011 1.492A3.5 3.5 0 0 0 4.5 13s.866-1.299 3-1.48V8.35zm1 0v3.17c2.134.181 3 1.48 3 1.48a3.5 3.5 0 0 0-1.989-3.158C8.978 9.586 8.5 9.052 8.5 8.351z"/> | |
| </svg>""", | |
| "done": """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="#198754" class="bi bi-check-circle-fill flex-shrink-0 me-2" viewBox="0 0 16 16"> | |
| <path d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0zm-3.97-3.03a.75.75 0 0 0-1.08.022L7.477 9.417 5.384 7.323a.75.75 0 0 0-1.06 1.06L6.97 11.03a.75.75 0 0 0 1.079-.02l3.992-4.99a.75.75 0 0 0-.01-1.05z"/> | |
| </svg>""", | |
| } | |
| icons2alert = { | |
| "info": "primary", # blue | |
| "cursor": "info", # light blue | |
| "wait": "secondary", # gray | |
| "done": "success", # green | |
| } | |
| def message(text, icon_type="info"): | |
| return f"""{STYLE} <div class="alert alert-{icons2alert[icon_type]} d-flex align-items-center" role="alert"> {ICONS[icon_type]} | |
| <div> | |
| {text} | |
| </div> | |
| </div>""" | |
| def preprocess(tmp_dir, input_img, idx=None): | |
| if idx is not None: | |
| print("image idx:", int(idx)) | |
| input_img = Image.open(input_img[int(idx)]["name"]) | |
| input_img.save(f"{tmp_dir}/input.png") | |
| # print(SEG_CMD.format(tmp_dir=tmp_dir)) | |
| os.system(SEG_CMD.format(tmp_dir=tmp_dir)) | |
| processed_img = Image.open(f"{tmp_dir}/seg.png") | |
| return processed_img.resize((320, 320), Image.Resampling.LANCZOS) | |
| def ply_to_glb(ply_path): | |
| result = subprocess.run( | |
| ["python", "ply2glb.py", "--", ply_path], | |
| capture_output=True, | |
| text=True, | |
| ) | |
| print("Output of blender script:") | |
| print(result.stdout) | |
| glb_path = ply_path.replace(".ply", ".glb") | |
| return glb_path | |
| def mesh_gen(tmp_dir, simplify, num_inference_steps): | |
| # print(MESH_CMD.format(tmp_dir=tmp_dir, num_inference_steps=num_inference_steps)) | |
| os.system(MESH_CMD.format(tmp_dir=tmp_dir, num_inference_steps=num_inference_steps)) | |
| mesh = trimesh.load_mesh(f"{tmp_dir}/mesh.ply") | |
| vertex_normals = mesh.vertex_normals | |
| theta = np.radians(-90) # Rotation angle in radians | |
| # Create rotation matrix | |
| cos_theta = np.cos(theta) | |
| sin_theta = np.sin(theta) | |
| rotation_matrix = np.array([ | |
| [cos_theta, -sin_theta, 0], | |
| [sin_theta, cos_theta, 0], | |
| [0, 0, 1] | |
| ]) | |
| rotated_normal = np.dot(vertex_normals, rotation_matrix.T) | |
| # rotated_normal = rotated_normal / np.linalg.norm(rotated_normal) | |
| colors = (-rotated_normal + 1) / 2.0 | |
| # colors = (-vertex_normals + 1) / 2.0 | |
| colors = (colors * 255).clip(0, 255).astype(np.uint8) # Convert to 8-bit color | |
| # print(colors.shape) | |
| mesh.visual.vertex_colors = colors[..., [2, 1, 0]] # RGB -> BGR | |
| mesh.export(f"{tmp_dir}/mesh_normal.ply", file_type="ply") | |
| color_path = ply_to_glb(f"{tmp_dir}/mesh.ply") | |
| normal_path = ply_to_glb(f"{tmp_dir}/mesh_normal.ply") | |
| return color_path, normal_path | |
| def create_tmp_dir(): | |
| tmp_dir = ( | |
| "demo_exp/" | |
| + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
| + "_" | |
| + str(uuid.uuid4())[:4] | |
| ) | |
| os.makedirs(tmp_dir, exist_ok=True) | |
| print("create tmp_exp_dir", tmp_dir) | |
| return tmp_dir | |
| def vis_seg(checkbox): | |
| if checkbox: | |
| print("Show manual seg windows") | |
| return ( | |
| [gr.Image(value=None, visible=True)] * 2 | |
| + [gr.Radio(visible=True)] | |
| + [[], gr.Checkbox(visible=True)] | |
| ) | |
| else: | |
| print("Clear manual seg") | |
| return ( | |
| [gr.Image(visible=False)] * 2 | |
| + [gr.Radio(visible=False)] | |
| + [[], gr.Checkbox(visible=False)] | |
| ) | |
| def calc_feat(checkbox, predictor, input_image, idx=None): | |
| if checkbox: | |
| if idx is not None: | |
| print("image idx:", int(idx)) | |
| input_image = Image.open(input_image[int(idx)]["name"]) | |
| input_image.thumbnail([512, 512], Image.Resampling.LANCZOS) | |
| w, h = input_image.size | |
| print("image size:", w, h) | |
| side_len = np.max((w, h)) | |
| seg_in = Image.new(input_image.mode, (side_len, side_len), (255, 255, 255)) | |
| seg_in.paste( | |
| input_image, (np.max((0, (h - w) // 2)), np.max((0, (w - h) // 2))) | |
| ) | |
| print("Calculating image SAM feature...") | |
| predictor.set_image(np.array(seg_in.convert("RGB"))) | |
| torch.cuda.empty_cache() | |
| return gr.Image(value=seg_in, visible=True) | |
| else: | |
| print("Quit manual seg") | |
| raise ValueError("Quit manual seg") | |
| def manual_seg( | |
| predictor, | |
| seg_in, | |
| selected_points, | |
| fg_bg_radio, | |
| tmp_dir, | |
| seg_mask_opt, | |
| evt: gr.SelectData, | |
| ): | |
| print("Start segmentation") | |
| selected_points.append( | |
| {"coord": evt.index, "add_del": fg_bg_radio == "+ (add mask)"} | |
| ) | |
| input_points = np.array([point["coord"] for point in selected_points]) | |
| input_labels = np.array([point["add_del"] for point in selected_points]) | |
| out_image = sam_seg( | |
| predictor, np.array(seg_in.convert("RGB")), input_points, input_labels | |
| ) | |
| # seg_in.save(f"{tmp_dir}/in.png") | |
| # out_image.save(f"{tmp_dir}/out.png") | |
| if seg_mask_opt: | |
| segmentation = blend_seg_pure( | |
| seg_in.convert("RGB"), out_image, input_points, input_labels | |
| ) | |
| else: | |
| segmentation = blend_seg( | |
| seg_in.convert("RGB"), out_image, input_points, input_labels | |
| ) | |
| # recenter and rescale | |
| image_arr = np.array(out_image) | |
| ret, mask = cv2.threshold( | |
| np.array(out_image.split()[-1]), 0, 255, cv2.THRESH_BINARY | |
| ) | |
| x, y, w, h = cv2.boundingRect(mask) | |
| max_size = max(w, h) | |
| ratio = 0.75 | |
| side_len = int(max_size / ratio) | |
| padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) | |
| center = side_len // 2 | |
| padded_image[ | |
| center - h // 2 : center - h // 2 + h, center - w // 2 : center - w // 2 + w | |
| ] = image_arr[y : y + h, x : x + w] | |
| rgba = Image.fromarray(padded_image) | |
| rgba.save(f"{tmp_dir}/seg.png") | |
| torch.cuda.empty_cache() | |
| return segmentation.resize((380, 380), Image.Resampling.LANCZOS), rgba.resize( | |
| (320, 320), Image.Resampling.LANCZOS | |
| ) | |
| custom_theme = gr.themes.Soft(primary_hue="blue").set( | |
| button_secondary_background_fill="*neutral_100", | |
| button_secondary_background_fill_hover="*neutral_200", | |
| ) | |
| with gr.Blocks(title="MeshFormer Demo", css="style.css", theme=custom_theme) as demo: | |
| with gr.Row(): | |
| gr.Markdown( | |
| "# MeshFormer: High-Quality Mesh Generation with 3D-Guided Reconstruction Model" | |
| ) | |
| with gr.Row(): | |
| gr.Markdown( | |
| "[Project Page](https://meshformer3d.github.io/) | [arXiv](https://arxiv.org/pdf/2408.10198)" | |
| ) | |
| with gr.Row(): | |
| gr.Markdown( | |
| """ | |
| <div> | |
| <b><em>Check out <a href="https://www.sudo.ai/3dgen">Hillbot (sudoAI)</a> for more details and advanced features.</em></b> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| guide_text_i2m = gr.HTML(message("Please input an image!"), visible=True) | |
| tmp_dir_img = gr.State("./demo_exp/placeholder") | |
| tmp_dir_txt = gr.State("./demo_exp/placeholder") | |
| tmp_dir_3t3 = gr.State("./demo_exp/placeholder") | |
| example_folder = os.path.join(os.path.dirname(__file__), "demo_examples") | |
| example_fns = os.listdir(example_folder) | |
| example_fns.sort() | |
| img_examples = [ | |
| os.path.join(example_folder, x) for x in example_fns | |
| ] # if x.endswith('.png') or x.endswith('.') | |
| with gr.Row(variant="panel"): | |
| with gr.Row(): | |
| with gr.Column(scale=8): | |
| input_image = gr.Image( | |
| type="pil", | |
| image_mode="RGBA", | |
| height=320, | |
| label="Input Image", | |
| interactive=True, | |
| ) | |
| gr.Examples( | |
| examples=img_examples, | |
| inputs=[input_image], | |
| outputs=[input_image], | |
| cache_examples=False, | |
| label="Image Examples (Click one of the images below to start)", | |
| examples_per_page=27, | |
| ) | |
| with gr.Accordion("Options", open=False): | |
| img_simplify = gr.Checkbox( | |
| False, label="simplify the generated mesh", visible=False | |
| ) | |
| n_steps_img = gr.Slider( | |
| value=28, | |
| minimum=15, | |
| maximum=100, | |
| step=1, | |
| label="number of inference steps", | |
| ) | |
| # manual segmentation | |
| checkbox_manual_seg = gr.Checkbox(False, label="manual segmentation") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| seg_in = gr.Image( | |
| type="pil", | |
| image_mode="RGBA", | |
| label="Click to segment", | |
| visible=False, | |
| show_download_button=False, | |
| height=380, | |
| ) | |
| with gr.Column(scale=1): | |
| seg_out = gr.Image( | |
| type="pil", | |
| image_mode="RGBA", | |
| label="Segmentation", | |
| interactive=False, | |
| visible=False, | |
| show_download_button=False, | |
| height=380, | |
| elem_id="disp_image", | |
| ) | |
| fg_bg_radio = gr.Radio( | |
| ["+ (add mask)", "- (remove area)"], | |
| value="+ (add mask)", | |
| info="Select foreground (+) or background (-) point", | |
| label="Point label", | |
| visible=False, | |
| interactive=True, | |
| ) | |
| seg_mask_opt = gr.Checkbox( | |
| True, | |
| label="show foreground mask in manual segmentation", | |
| visible=False, | |
| ) | |
| # run | |
| img_run_btn = gr.Button( | |
| "Generate", variant="primary", interactive=False | |
| ) | |
| with gr.Column(scale=6): | |
| processed_image = gr.Image( | |
| type="pil", | |
| label="Processed Image", | |
| interactive=False, | |
| height=320, | |
| image_mode="RGBA", | |
| elem_id="disp_image", | |
| ) | |
| # with gr.Row(): | |
| # mesh_output = gr.Model3D(label="Generated Mesh", elem_id="model-3d-out") | |
| mesh_output_normal = Model3DColor( | |
| label="Generated Mesh (normal)", | |
| elem_id="mesh-normal-out", | |
| height=400, | |
| ) | |
| mesh_output = Model3DColor( | |
| label="Generated Mesh (color)", | |
| elem_id="mesh-out", | |
| height=400, | |
| ) | |
| predictor = gr.State(value=get_sam_predictor()) | |
| selected_points = gr.State(value=[]) | |
| selected_points_t2i = gr.State(value=[]) | |
| disable_checkbox = lambda: gr.Checkbox(value=False) | |
| disable_button = lambda: gr.Button(interactive=False) | |
| enable_button = lambda: gr.Button(interactive=True) | |
| update_guide = lambda GUIDE_TEXT, icon_type="info": gr.HTML( | |
| value=message(GUIDE_TEXT, icon_type) | |
| ) | |
| update_md = lambda GUIDE_TEXT: gr.Markdown(value=GUIDE_TEXT) | |
| def is_img_clear(input_image): | |
| if not input_image: | |
| raise ValueError("Input image cleared.") | |
| checkbox_manual_seg.change( | |
| vis_seg, | |
| inputs=[checkbox_manual_seg], | |
| outputs=[seg_in, seg_out, fg_bg_radio, selected_points, seg_mask_opt], | |
| queue=False, | |
| ).success( | |
| calc_feat, | |
| inputs=[checkbox_manual_seg, predictor, input_image], | |
| outputs=[seg_in], | |
| queue=True, | |
| ).success( | |
| fn=create_tmp_dir, outputs=[tmp_dir_img], queue=False | |
| ) | |
| seg_in.select( | |
| manual_seg, | |
| [predictor, seg_in, selected_points, fg_bg_radio, tmp_dir_img, seg_mask_opt], | |
| [seg_out, processed_image], | |
| queue=True, | |
| ) | |
| input_image.change(disable_button, outputs=img_run_btn, queue=False).success( | |
| disable_checkbox, outputs=checkbox_manual_seg, queue=False | |
| ).success(fn=is_img_clear, inputs=input_image, queue=False).success( | |
| fn=create_tmp_dir, outputs=tmp_dir_img, queue=False | |
| ).success( | |
| fn=partial(update_guide, "Preprocessing the image!", "wait"), | |
| outputs=[guide_text_i2m], | |
| queue=False, | |
| ).success( | |
| fn=preprocess, | |
| inputs=[tmp_dir_img, input_image], | |
| outputs=[processed_image], | |
| queue=True, | |
| ).success( | |
| fn=partial( | |
| update_guide, | |
| "Click <b>Generate</b> to generate mesh! If the input image was not segmented accurately, please adjust it using <b>manual segmentation</b>.", | |
| "cursor", | |
| ), | |
| outputs=[guide_text_i2m], | |
| queue=False, | |
| ).success( | |
| enable_button, outputs=img_run_btn, queue=False | |
| ) | |
| img_run_btn.click( | |
| fn=partial(update_guide, "Generating the mesh!", "wait"), | |
| outputs=[guide_text_i2m], | |
| queue=False, | |
| ).success( | |
| fn=mesh_gen, | |
| inputs=[tmp_dir_img, img_simplify, n_steps_img], | |
| outputs=[mesh_output, mesh_output_normal], | |
| queue=True, | |
| ).success( | |
| fn=partial( | |
| update_guide, | |
| "Successfully generated the mesh. (It might take a few seconds to load the mesh)", | |
| "done", | |
| ), | |
| outputs=[guide_text_i2m], | |
| queue=False, | |
| ) | |
| demo.queue().launch( | |
| debug=False, share=False, inline=False, show_api=False, server_name="0.0.0.0" | |
| ) | |