Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Optional, Sequence | |
| import torch | |
| from shap_e.rendering.blender.constants import ( | |
| BASIC_AMBIENT_COLOR, | |
| BASIC_DIFFUSE_COLOR, | |
| UNIFORM_LIGHT_DIRECTION, | |
| ) | |
| from shap_e.rendering.view_data import ProjectiveCamera | |
| from .cast import cast_camera | |
| from .types import RayCollisions, TriMesh | |
| def render_diffuse_mesh( | |
| camera: ProjectiveCamera, | |
| mesh: TriMesh, | |
| light_direction: Sequence[float] = tuple(UNIFORM_LIGHT_DIRECTION), | |
| diffuse: float = BASIC_DIFFUSE_COLOR, | |
| ambient: float = BASIC_AMBIENT_COLOR, | |
| ray_batch_size: Optional[int] = None, | |
| checkpoint: Optional[bool] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Return an [H x W x 4] RGBA tensor of the rendered image. | |
| The pixels are floating points, with alpha in the range [0, 1] and the | |
| other colors matching the scale used by the mesh's vertex colors. | |
| """ | |
| light_direction = torch.tensor( | |
| light_direction, device=mesh.vertices.device, dtype=mesh.vertices.dtype | |
| ) | |
| all_collisions = RayCollisions.collect( | |
| cast_camera( | |
| camera=camera, | |
| mesh=mesh, | |
| ray_batch_size=ray_batch_size, | |
| checkpoint=checkpoint, | |
| ) | |
| ) | |
| num_rays = len(all_collisions.normals) | |
| if mesh.vertex_colors is None: | |
| vertex_colors = torch.tensor([[0.8, 0.8, 0.8]]).to(mesh.vertices).repeat(num_rays, 1) | |
| else: | |
| vertex_colors = mesh.vertex_colors | |
| light_coeffs = ambient + ( | |
| diffuse * torch.sum(all_collisions.normals * light_direction, dim=-1).abs() | |
| ) | |
| vertex_colors = mesh.vertex_colors[mesh.faces[all_collisions.tri_indices]] | |
| bary_products = torch.sum(vertex_colors * all_collisions.barycentric[..., None], axis=-2) | |
| out_colors = bary_products * light_coeffs[..., None] | |
| res = torch.where(all_collisions.collides[:, None], out_colors, torch.zeros_like(out_colors)) | |
| return torch.cat([res, all_collisions.collides[:, None].float()], dim=-1).view( | |
| camera.height, camera.width, 4 | |
| ) | |