Spaces:
Build error
Build error
use pytorch3d to render, instead of nvdiffrast
Browse files- gradio_app/gradio_3dgen.py +0 -5
- mesh_reconstruction/recon.py +2 -2
- mesh_reconstruction/refine.py +2 -2
- mesh_reconstruction/render.py +118 -0
- scripts/project_mesh.py +14 -18
gradio_app/gradio_3dgen.py
CHANGED
|
@@ -10,13 +10,8 @@ from scripts.refine_lr_to_sr import run_sr_fast
|
|
| 10 |
from scripts.utils import save_glb_and_video
|
| 11 |
from scripts.multiview_inference import geo_reconstruct
|
| 12 |
|
| 13 |
-
|
| 14 |
-
import nvdiffrast.torch as dr
|
| 15 |
-
dr.RasterizeGLContext(output_db=False)
|
| 16 |
@spaces.GPU
|
| 17 |
def generate3dv2(preview_img, input_processing, seed, render_video=True, do_refine=True, expansion_weight=0.1, init_type="std"):
|
| 18 |
-
dr.RasterizeGLContext(output_db=False) # BUG: cuda_runtime_api.h: No such file or directory
|
| 19 |
-
|
| 20 |
if preview_img is None:
|
| 21 |
raise gr.Error("preview_img is none")
|
| 22 |
if isinstance(preview_img, str):
|
|
|
|
| 10 |
from scripts.utils import save_glb_and_video
|
| 11 |
from scripts.multiview_inference import geo_reconstruct
|
| 12 |
|
|
|
|
|
|
|
|
|
|
| 13 |
@spaces.GPU
|
| 14 |
def generate3dv2(preview_img, input_processing, seed, render_video=True, do_refine=True, expansion_weight=0.1, init_type="std"):
|
|
|
|
|
|
|
| 15 |
if preview_img is None:
|
| 16 |
raise gr.Error("preview_img is none")
|
| 17 |
if isinstance(preview_img, str):
|
mesh_reconstruction/recon.py
CHANGED
|
@@ -6,14 +6,14 @@ from typing import List
|
|
| 6 |
from mesh_reconstruction.remesh import calc_vertex_normals
|
| 7 |
from mesh_reconstruction.opt import MeshOptimizer
|
| 8 |
from mesh_reconstruction.func import make_star_cameras_orthographic
|
| 9 |
-
from mesh_reconstruction.render import NormalsRenderer
|
| 10 |
from scripts.utils import to_py3d_mesh, init_target
|
| 11 |
|
| 12 |
def reconstruct_stage1(pils: List[Image.Image], steps=100, vertices=None, faces=None, start_edge_len=0.15, end_edge_len=0.005, decay=0.995, return_mesh=True, loss_expansion_weight=0.1, gain=0.1):
|
| 13 |
vertices, faces = vertices.to("cuda"), faces.to("cuda")
|
| 14 |
assert len(pils) == 4
|
| 15 |
mv,proj = make_star_cameras_orthographic(4, 1)
|
| 16 |
-
renderer =
|
| 17 |
|
| 18 |
target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
|
| 19 |
# 1. no rotate
|
|
|
|
| 6 |
from mesh_reconstruction.remesh import calc_vertex_normals
|
| 7 |
from mesh_reconstruction.opt import MeshOptimizer
|
| 8 |
from mesh_reconstruction.func import make_star_cameras_orthographic
|
| 9 |
+
from mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer
|
| 10 |
from scripts.utils import to_py3d_mesh, init_target
|
| 11 |
|
| 12 |
def reconstruct_stage1(pils: List[Image.Image], steps=100, vertices=None, faces=None, start_edge_len=0.15, end_edge_len=0.005, decay=0.995, return_mesh=True, loss_expansion_weight=0.1, gain=0.1):
|
| 13 |
vertices, faces = vertices.to("cuda"), faces.to("cuda")
|
| 14 |
assert len(pils) == 4
|
| 15 |
mv,proj = make_star_cameras_orthographic(4, 1)
|
| 16 |
+
renderer = Pytorch3DNormalsRenderer(mv,proj,list(pils[0].size))
|
| 17 |
|
| 18 |
target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
|
| 19 |
# 1. no rotate
|
mesh_reconstruction/refine.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import List
|
|
| 5 |
from mesh_reconstruction.remesh import calc_vertex_normals
|
| 6 |
from mesh_reconstruction.opt import MeshOptimizer
|
| 7 |
from mesh_reconstruction.func import make_star_cameras_orthographic
|
| 8 |
-
from mesh_reconstruction.render import NormalsRenderer
|
| 9 |
from scripts.project_mesh import multiview_color_projection, get_cameras_list
|
| 10 |
from scripts.utils import to_py3d_mesh, from_py3d_mesh, init_target
|
| 11 |
|
|
@@ -18,7 +18,7 @@ def run_mesh_refine(vertices, faces, pils: List[Image.Image], steps=100, start_e
|
|
| 18 |
|
| 19 |
assert len(pils) == 4
|
| 20 |
mv,proj = make_star_cameras_orthographic(4, 1)
|
| 21 |
-
renderer =
|
| 22 |
|
| 23 |
target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
|
| 24 |
# 1. no rotate
|
|
|
|
| 5 |
from mesh_reconstruction.remesh import calc_vertex_normals
|
| 6 |
from mesh_reconstruction.opt import MeshOptimizer
|
| 7 |
from mesh_reconstruction.func import make_star_cameras_orthographic
|
| 8 |
+
from mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer
|
| 9 |
from scripts.project_mesh import multiview_color_projection, get_cameras_list
|
| 10 |
from scripts.utils import to_py3d_mesh, from_py3d_mesh, init_target
|
| 11 |
|
|
|
|
| 18 |
|
| 19 |
assert len(pils) == 4
|
| 20 |
mv,proj = make_star_cameras_orthographic(4, 1)
|
| 21 |
+
renderer = Pytorch3DNormalsRenderer(mv,proj,list(pils[0].size))
|
| 22 |
|
| 23 |
target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
|
| 24 |
# 1. no rotate
|
mesh_reconstruction/render.py
CHANGED
|
@@ -49,3 +49,121 @@ class NormalsRenderer:
|
|
| 49 |
col = torch.concat((col,alpha),dim=-1) #C,H,W,4
|
| 50 |
col = dr.antialias(col, rast_out, vertices_clip, faces) #C,H,W,4
|
| 51 |
return col #C,H,W,4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
col = torch.concat((col,alpha),dim=-1) #C,H,W,4
|
| 50 |
col = dr.antialias(col, rast_out, vertices_clip, faces) #C,H,W,4
|
| 51 |
return col #C,H,W,4
|
| 52 |
+
|
| 53 |
+
from pytorch3d.structures import Meshes
|
| 54 |
+
from pytorch3d.renderer.mesh.shader import ShaderBase
|
| 55 |
+
from pytorch3d.renderer import (
|
| 56 |
+
RasterizationSettings,
|
| 57 |
+
MeshRendererWithFragments,
|
| 58 |
+
TexturesVertex,
|
| 59 |
+
MeshRasterizer,
|
| 60 |
+
BlendParams,
|
| 61 |
+
FoVOrthographicCameras,
|
| 62 |
+
look_at_view_transform,
|
| 63 |
+
hard_rgb_blend,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
class VertexColorShader(ShaderBase):
|
| 67 |
+
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
| 68 |
+
blend_params = kwargs.get("blend_params", self.blend_params)
|
| 69 |
+
texels = meshes.sample_textures(fragments)
|
| 70 |
+
return hard_rgb_blend(texels, fragments, blend_params)
|
| 71 |
+
|
| 72 |
+
def render_mesh_vertex_color(mesh, cameras, H, W, blur_radius=0.0, faces_per_pixel=1, bkgd=(0., 0., 0.), dtype=torch.float32, device="cuda"):
|
| 73 |
+
if len(mesh) != len(cameras):
|
| 74 |
+
if len(cameras) % len(mesh) == 0:
|
| 75 |
+
mesh = mesh.extend(len(cameras))
|
| 76 |
+
else:
|
| 77 |
+
raise NotImplementedError()
|
| 78 |
+
|
| 79 |
+
# render requires everything in float16 or float32
|
| 80 |
+
input_dtype = dtype
|
| 81 |
+
blend_params = BlendParams(1e-4, 1e-4, bkgd)
|
| 82 |
+
|
| 83 |
+
# Define the settings for rasterization and shading
|
| 84 |
+
raster_settings = RasterizationSettings(
|
| 85 |
+
image_size=(H, W),
|
| 86 |
+
blur_radius=blur_radius,
|
| 87 |
+
faces_per_pixel=faces_per_pixel,
|
| 88 |
+
clip_barycentric_coords=True,
|
| 89 |
+
bin_size=None,
|
| 90 |
+
max_faces_per_bin=500000,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Create a renderer by composing a rasterizer and a shader
|
| 94 |
+
# We simply render vertex colors through the custom VertexColorShader (no lighting, materials are used)
|
| 95 |
+
renderer = MeshRendererWithFragments(
|
| 96 |
+
rasterizer=MeshRasterizer(
|
| 97 |
+
cameras=cameras,
|
| 98 |
+
raster_settings=raster_settings
|
| 99 |
+
),
|
| 100 |
+
shader=VertexColorShader(
|
| 101 |
+
device=device,
|
| 102 |
+
cameras=cameras,
|
| 103 |
+
blend_params=blend_params
|
| 104 |
+
)
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# render RGB and depth, get mask
|
| 108 |
+
with torch.autocast(dtype=input_dtype, device_type=torch.device(device).type):
|
| 109 |
+
images, _ = renderer(mesh)
|
| 110 |
+
return images # BHW4
|
| 111 |
+
|
| 112 |
+
class Pytorch3DNormalsRenderer:
|
| 113 |
+
def __init__(self, cameras, image_size, device):
|
| 114 |
+
self.cameras = cameras.to(device)
|
| 115 |
+
self._image_size = image_size
|
| 116 |
+
self.device = device
|
| 117 |
+
|
| 118 |
+
def render(self,
|
| 119 |
+
vertices: torch.Tensor, #V,3 float
|
| 120 |
+
normals: torch.Tensor, #V,3 float in [-1, 1]
|
| 121 |
+
faces: torch.Tensor, #F,3 long
|
| 122 |
+
) ->torch.Tensor: #C,H,W,4
|
| 123 |
+
mesh = Meshes(verts=[vertices], faces=[faces], textures=TexturesVertex(verts_features=[(normals + 1) / 2])).to(self.device)
|
| 124 |
+
return render_mesh_vertex_color(mesh, self.cameras, self._image_size[0], self._image_size[1], device=self.device)
|
| 125 |
+
|
| 126 |
+
def get_camera(R, T, focal_length=1 / (2**0.5)):
|
| 127 |
+
focal_length = 1 / focal_length
|
| 128 |
+
camera = FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length)
|
| 129 |
+
return camera
|
| 130 |
+
|
| 131 |
+
def make_star_cameras_orthographic_py3d(azim_list, device, focal=2/1.35, dist=1.1):
|
| 132 |
+
R, T = look_at_view_transform(dist, 0, azim_list)
|
| 133 |
+
focal_length = 1 / focal
|
| 134 |
+
return FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length).to(device)
|
| 135 |
+
|
| 136 |
+
def save_tensor_to_img(tensor, save_dir):
|
| 137 |
+
from PIL import Image
|
| 138 |
+
import numpy as np
|
| 139 |
+
for idx, img in enumerate(tensor):
|
| 140 |
+
img = img[..., :3].cpu().numpy()
|
| 141 |
+
img = (img * 255).astype(np.uint8)
|
| 142 |
+
img = Image.fromarray(img)
|
| 143 |
+
img.save(save_dir + f"{idx}.png")
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
import sys
|
| 147 |
+
import os
|
| 148 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 149 |
+
from mesh_reconstruction.func import make_star_cameras_orthographic
|
| 150 |
+
cameras = make_star_cameras_orthographic_py3d([0, 270, 180, 90], device="cuda", focal=1., dist=4.0)
|
| 151 |
+
mv,proj = make_star_cameras_orthographic(4, 1)
|
| 152 |
+
resolution = 1024
|
| 153 |
+
renderer1 = NormalsRenderer(mv,proj, [resolution,resolution], device="cuda")
|
| 154 |
+
renderer2 = Pytorch3DNormalsRenderer(cameras, [resolution,resolution], device="cuda")
|
| 155 |
+
vertices = torch.tensor([[0,0,0],[0,0,1],[0,1,0],[1,0,0]], device="cuda", dtype=torch.float32)
|
| 156 |
+
normals = torch.tensor([[-1,-1,-1],[1,-1,-1],[-1,-1,1],[-1,1,-1]], device="cuda", dtype=torch.float32)
|
| 157 |
+
faces = torch.tensor([[0,1,2],[0,1,3],[0,2,3],[1,2,3]], device="cuda", dtype=torch.long)
|
| 158 |
+
|
| 159 |
+
import time
|
| 160 |
+
t0 = time.time()
|
| 161 |
+
r1 = renderer1.render(vertices, normals, faces)
|
| 162 |
+
print("time r1:", time.time() - t0)
|
| 163 |
+
|
| 164 |
+
t0 = time.time()
|
| 165 |
+
r2 = renderer2.render(vertices, normals, faces)
|
| 166 |
+
print("time r2:", time.time() - t0)
|
| 167 |
+
|
| 168 |
+
for i in range(4):
|
| 169 |
+
print((r1[i]-r2[i]).abs().mean(), (r1[i]+r2[i]).abs().mean())
|
scripts/project_mesh.py
CHANGED
|
@@ -13,17 +13,6 @@ from pytorch3d.renderer import (
|
|
| 13 |
)
|
| 14 |
from pytorch3d.renderer import MeshRasterizer
|
| 15 |
|
| 16 |
-
def get_camera(world_to_cam, fov_in_degrees=60, focal_length=1 / (2**0.5), cam_type='fov'):
|
| 17 |
-
# pytorch3d expects transforms as row-vectors, so flip rotation: https://github.com/facebookresearch/pytorch3d/issues/1183
|
| 18 |
-
R = world_to_cam[:3, :3].t()[None, ...]
|
| 19 |
-
T = world_to_cam[:3, 3][None, ...]
|
| 20 |
-
if cam_type == 'fov':
|
| 21 |
-
camera = FoVPerspectiveCameras(device=world_to_cam.device, R=R, T=T, fov=fov_in_degrees, degrees=True)
|
| 22 |
-
else:
|
| 23 |
-
focal_length = 1 / focal_length
|
| 24 |
-
camera = FoVOrthographicCameras(device=world_to_cam.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length)
|
| 25 |
-
return camera
|
| 26 |
-
|
| 27 |
def render_pix2faces_py3d(meshes, cameras, H=512, W=512, blur_radius=0.0, faces_per_pixel=1):
|
| 28 |
"""
|
| 29 |
Renders pix2face of visible faces.
|
|
@@ -98,11 +87,11 @@ class Pix2FacesRenderer:
|
|
| 98 |
pix2faces_renderer = None
|
| 99 |
|
| 100 |
def get_visible_faces(meshes: Meshes, cameras: CamerasBase, resolution=1024):
|
| 101 |
-
global pix2faces_renderer
|
| 102 |
-
if pix2faces_renderer is None:
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
pix_to_face = pix2faces_renderer.render_pix2faces_nvdiff(meshes, cameras, H=resolution, W=resolution)
|
| 106 |
|
| 107 |
unique_faces = torch.unique(pix_to_face.flatten())
|
| 108 |
unique_faces = unique_faces[unique_faces != -1]
|
|
@@ -313,12 +302,19 @@ def multiview_color_projection(meshes: Meshes, image_list: List[Image.Image], ca
|
|
| 313 |
del meshes
|
| 314 |
return ret_mesh
|
| 315 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
def get_cameras_list(azim_list, device, focal=2/1.35, dist=1.1):
|
| 317 |
ret = []
|
| 318 |
for azim in azim_list:
|
| 319 |
R, T = look_at_view_transform(dist, 0, azim)
|
| 320 |
-
|
| 321 |
-
cameras: OrthographicCameras = get_camera(w2c, focal_length=focal, cam_type='orthogonal').to(device)
|
| 322 |
ret.append(cameras)
|
| 323 |
return ret
|
| 324 |
|
|
|
|
| 13 |
)
|
| 14 |
from pytorch3d.renderer import MeshRasterizer
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
def render_pix2faces_py3d(meshes, cameras, H=512, W=512, blur_radius=0.0, faces_per_pixel=1):
|
| 17 |
"""
|
| 18 |
Renders pix2face of visible faces.
|
|
|
|
| 87 |
pix2faces_renderer = None
|
| 88 |
|
| 89 |
def get_visible_faces(meshes: Meshes, cameras: CamerasBase, resolution=1024):
|
| 90 |
+
# global pix2faces_renderer
|
| 91 |
+
# if pix2faces_renderer is None:
|
| 92 |
+
# pix2faces_renderer = Pix2FacesRenderer()
|
| 93 |
+
pix_to_face = render_pix2faces_py3d(meshes, cameras, H=resolution, W=resolution)['pix_to_face']
|
| 94 |
+
# pix_to_face = pix2faces_renderer.render_pix2faces_nvdiff(meshes, cameras, H=resolution, W=resolution)
|
| 95 |
|
| 96 |
unique_faces = torch.unique(pix_to_face.flatten())
|
| 97 |
unique_faces = unique_faces[unique_faces != -1]
|
|
|
|
| 302 |
del meshes
|
| 303 |
return ret_mesh
|
| 304 |
|
| 305 |
+
def get_camera(R, T, fov_in_degrees=60, focal_length=1 / (2**0.5), cam_type='fov'):
|
| 306 |
+
if cam_type == 'fov':
|
| 307 |
+
camera = FoVPerspectiveCameras(device=R.device, R=R, T=T, fov=fov_in_degrees, degrees=True)
|
| 308 |
+
else:
|
| 309 |
+
focal_length = 1 / focal_length
|
| 310 |
+
camera = FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length)
|
| 311 |
+
return camera
|
| 312 |
+
|
| 313 |
def get_cameras_list(azim_list, device, focal=2/1.35, dist=1.1):
|
| 314 |
ret = []
|
| 315 |
for azim in azim_list:
|
| 316 |
R, T = look_at_view_transform(dist, 0, azim)
|
| 317 |
+
cameras: OrthographicCameras = get_camera(R, T, focal_length=focal, cam_type='orthogonal').to(device)
|
|
|
|
| 318 |
ret.append(cameras)
|
| 319 |
return ret
|
| 320 |
|