Spaces:
Running
on
Zero
Running
on
Zero
| import copy | |
| import inspect | |
| from typing import Any, Callable, List, Sequence, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from pytorch3d.renderer import ( | |
| BlendParams, | |
| DirectionalLights, | |
| FoVPerspectiveCameras, | |
| MeshRasterizer, | |
| MeshRenderer, | |
| RasterizationSettings, | |
| SoftPhongShader, | |
| TexturesVertex, | |
| ) | |
| from pytorch3d.renderer.utils import TensorProperties | |
| from pytorch3d.structures import Meshes | |
| from shap_e.models.nn.checkpoint import checkpoint | |
| from .blender.constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR, UNIFORM_LIGHT_DIRECTION | |
| from .torch_mesh import TorchMesh | |
| from .view_data import ProjectiveCamera | |
| # Using a lower value like 1e-4 seems to result in weird issues | |
| # for our high-poly meshes. | |
| DEFAULT_RENDER_SIGMA = 1e-5 | |
| DEFAULT_RENDER_GAMMA = 1e-4 | |
| def render_images( | |
| image_size: int, | |
| meshes: Meshes, | |
| cameras: Any, | |
| lights: Any, | |
| sigma: float = DEFAULT_RENDER_SIGMA, | |
| gamma: float = DEFAULT_RENDER_GAMMA, | |
| max_faces_per_bin=100000, | |
| faces_per_pixel=50, | |
| bin_size=None, | |
| use_checkpoint: bool = False, | |
| ) -> torch.Tensor: | |
| if use_checkpoint: | |
| # Decompose all of our arguments into a bunch of tensor lists | |
| # so that autograd can keep track of what the op depends on. | |
| verts_list = meshes.verts_list() | |
| faces_list = meshes.faces_list() | |
| assert isinstance(meshes.textures, TexturesVertex) | |
| assert isinstance(lights, BidirectionalLights) | |
| textures = meshes.textures.verts_features_padded() | |
| light_vecs, light_fn = _deconstruct_tensor_props(lights) | |
| camera_vecs, camera_fn = _deconstruct_tensor_props(cameras) | |
| def ckpt_fn( | |
| *args: torch.Tensor, | |
| num_verts=len(verts_list), | |
| num_light_vecs=len(light_vecs), | |
| num_camera_vecs=len(camera_vecs), | |
| light_fn=light_fn, | |
| camera_fn=camera_fn, | |
| faces_list=faces_list | |
| ): | |
| args = list(args) | |
| verts_list = args[:num_verts] | |
| del args[:num_verts] | |
| light_vecs = args[:num_light_vecs] | |
| del args[:num_light_vecs] | |
| camera_vecs = args[:num_camera_vecs] | |
| del args[:num_camera_vecs] | |
| textures = args.pop(0) | |
| meshes = Meshes(verts=verts_list, faces=faces_list, textures=TexturesVertex(textures)) | |
| lights = light_fn(light_vecs) | |
| cameras = camera_fn(camera_vecs) | |
| return render_images( | |
| image_size=image_size, | |
| meshes=meshes, | |
| cameras=cameras, | |
| lights=lights, | |
| sigma=sigma, | |
| gamma=gamma, | |
| max_faces_per_bin=max_faces_per_bin, | |
| faces_per_pixel=faces_per_pixel, | |
| bin_size=bin_size, | |
| use_checkpoint=False, | |
| ) | |
| result = checkpoint(ckpt_fn, (*verts_list, *light_vecs, *camera_vecs, textures), (), True) | |
| else: | |
| raster_settings_soft = RasterizationSettings( | |
| image_size=image_size, | |
| blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma, | |
| faces_per_pixel=faces_per_pixel, | |
| max_faces_per_bin=max_faces_per_bin, | |
| bin_size=bin_size, | |
| perspective_correct=False, | |
| ) | |
| renderer = MeshRenderer( | |
| rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings_soft), | |
| shader=SoftPhongShader( | |
| device=meshes.device, | |
| cameras=cameras, | |
| lights=lights, | |
| blend_params=BlendParams(sigma=sigma, gamma=gamma, background_color=(0, 0, 0)), | |
| ), | |
| ) | |
| result = renderer(meshes) | |
| return result | |
| def _deconstruct_tensor_props( | |
| props: TensorProperties, | |
| ) -> Tuple[List[torch.Tensor], Callable[[List[torch.Tensor]], TensorProperties]]: | |
| vecs = [] | |
| names = [] | |
| other_props = {} | |
| for k in dir(props): | |
| if k.startswith("__"): | |
| continue | |
| v = getattr(props, k) | |
| if inspect.ismethod(v): | |
| continue | |
| if torch.is_tensor(v): | |
| vecs.append(v) | |
| names.append(k) | |
| else: | |
| other_props[k] = v | |
| def recreate_fn(vecs_arg): | |
| other = type(props)(device=props.device) | |
| for k, v in other_props.items(): | |
| setattr(other, k, copy.deepcopy(v)) | |
| for name, vec in zip(names, vecs_arg): | |
| setattr(other, name, vec) | |
| return other | |
| return vecs, recreate_fn | |
| def convert_meshes(raw_meshes: Sequence[TorchMesh], default_brightness=0.8) -> Meshes: | |
| meshes = Meshes( | |
| verts=[mesh.verts for mesh in raw_meshes], faces=[mesh.faces for mesh in raw_meshes] | |
| ) | |
| rgbs = [] | |
| for mesh in raw_meshes: | |
| if mesh.vertex_channels and all(k in mesh.vertex_channels for k in "RGB"): | |
| rgbs.append(torch.stack([mesh.vertex_channels[k] for k in "RGB"], axis=-1)) | |
| else: | |
| rgbs.append( | |
| torch.ones( | |
| len(mesh.verts) * default_brightness, | |
| 3, | |
| device=mesh.verts.device, | |
| dtype=mesh.verts.dtype, | |
| ) | |
| ) | |
| meshes.textures = TexturesVertex(verts_features=rgbs) | |
| return meshes | |
| def convert_cameras( | |
| cameras: Sequence[ProjectiveCamera], device: torch.device | |
| ) -> FoVPerspectiveCameras: | |
| Rs = [] | |
| Ts = [] | |
| for camera in cameras: | |
| assert ( | |
| camera.width == camera.height and camera.x_fov == camera.y_fov | |
| ), "viewports must be square" | |
| assert camera.x_fov == cameras[0].x_fov, "all cameras must have same field-of-view" | |
| R = np.stack([-camera.x, -camera.y, camera.z], axis=0).T | |
| T = -R.T @ camera.origin | |
| Rs.append(R) | |
| Ts.append(T) | |
| return FoVPerspectiveCameras( | |
| R=np.stack(Rs, axis=0), | |
| T=np.stack(Ts, axis=0), | |
| fov=cameras[0].x_fov, | |
| degrees=False, | |
| device=device, | |
| ) | |
| def convert_cameras_torch( | |
| origins: torch.Tensor, xs: torch.Tensor, ys: torch.Tensor, zs: torch.Tensor, fov: float | |
| ) -> FoVPerspectiveCameras: | |
| Rs = [] | |
| Ts = [] | |
| for origin, x, y, z in zip(origins, xs, ys, zs): | |
| R = torch.stack([-x, -y, z], axis=0).T | |
| T = -R.T @ origin | |
| Rs.append(R) | |
| Ts.append(T) | |
| return FoVPerspectiveCameras( | |
| R=torch.stack(Rs, dim=0), | |
| T=torch.stack(Ts, dim=0), | |
| fov=fov, | |
| degrees=False, | |
| device=origins.device, | |
| ) | |
| def blender_uniform_lights( | |
| batch_size: int, | |
| device: torch.device, | |
| ambient_color: Union[float, Tuple[float]] = BASIC_AMBIENT_COLOR, | |
| diffuse_color: Union[float, Tuple[float]] = BASIC_DIFFUSE_COLOR, | |
| specular_color: Union[float, Tuple[float]] = 0.0, | |
| ) -> "BidirectionalLights": | |
| """ | |
| Create a light that attempts to match the light used by the Blender | |
| renderer when run with `--light_mode basic`. | |
| """ | |
| if isinstance(ambient_color, float): | |
| ambient_color = (ambient_color,) * 3 | |
| if isinstance(diffuse_color, float): | |
| diffuse_color = (diffuse_color,) * 3 | |
| if isinstance(specular_color, float): | |
| specular_color = (specular_color,) * 3 | |
| return BidirectionalLights( | |
| ambient_color=(ambient_color,) * batch_size, | |
| diffuse_color=(diffuse_color,) * batch_size, | |
| specular_color=(specular_color,) * batch_size, | |
| direction=(UNIFORM_LIGHT_DIRECTION,) * batch_size, | |
| device=device, | |
| ) | |
| class BidirectionalLights(DirectionalLights): | |
| """ | |
| Adapted from here, but effectively shines the light in both positive and negative directions: | |
| https://github.com/facebookresearch/pytorch3d/blob/efea540bbcab56fccde6f4bc729d640a403dac56/pytorch3d/renderer/lighting.py#L159 | |
| """ | |
| def diffuse(self, normals, points=None) -> torch.Tensor: | |
| return torch.maximum( | |
| super().diffuse(normals, points=points), super().diffuse(-normals, points=points) | |
| ) | |
| def specular(self, normals, points, camera_position, shininess) -> torch.Tensor: | |
| return torch.maximum( | |
| super().specular(normals, points, camera_position, shininess), | |
| super().specular(-normals, points, camera_position, shininess), | |
| ) | |