Spaces:
Running
on
Zero
Running
on
Zero
| from abc import abstractmethod | |
| from typing import Callable, Dict, List, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| from shap_e.models.nn.camera import ( | |
| DifferentiableCamera, | |
| DifferentiableProjectiveCamera, | |
| get_image_coords, | |
| projective_camera_frame, | |
| ) | |
| from shap_e.models.nn.meta import MetaModule | |
| from shap_e.util.collections import AttrDict | |
| class Renderer(MetaModule): | |
| """ | |
| A rendering abstraction that can render rays and views by calling the | |
| appropriate models. The models are instantiated outside but registered in | |
| this module. | |
| """ | |
| def render_views( | |
| self, | |
| batch: AttrDict, | |
| params: Optional[Dict] = None, | |
| options: Optional[Dict] = None, | |
| ) -> AttrDict: | |
| """ | |
| Returns a backproppable rendering of a view | |
| :param batch: contains | |
| - height: Optional[int] | |
| - width: Optional[int] | |
| - inner_batch_size or ray_batch_size: Optional[int] defaults to 4096 rays | |
| And additionally, to specify poses with a default up direction: | |
| - poses: [batch_size x *shape x 2 x 3] where poses[:, ..., 0, :] are the camera | |
| positions, and poses[:, ..., 1, :] are the z-axis (toward the object) of | |
| the camera frame. | |
| - camera: DifferentiableCamera. Assumes the same camera position | |
| across batch for simplicity. Could eventually support | |
| batched cameras. | |
| or to specify a batch of arbitrary poses: | |
| - cameras: DifferentiableCameraBatch of shape [batch_size x *shape]. | |
| :param params: Meta parameters | |
| :param options: Optional[Dict] | |
| """ | |
| class RayRenderer(Renderer): | |
| """ | |
| A rendering abstraction that can render rays and views by calling the | |
| appropriate models. The models are instantiated outside but registered in | |
| this module. | |
| """ | |
| def render_rays( | |
| self, | |
| batch: AttrDict, | |
| params: Optional[Dict] = None, | |
| options: Optional[Dict] = None, | |
| ) -> AttrDict: | |
| """ | |
| :param batch: has | |
| - rays: [batch_size x ... x 2 x 3] specify the origin and direction of each ray. | |
| - radii (optional): [batch_size x ... x 1] the "thickness" of each ray. | |
| :param options: Optional[Dict] | |
| """ | |
| def render_views( | |
| self, | |
| batch: AttrDict, | |
| params: Optional[Dict] = None, | |
| options: Optional[Dict] = None, | |
| device: torch.device = torch.device("cuda"), | |
| ) -> AttrDict: | |
| output = render_views_from_rays( | |
| self.render_rays, | |
| batch, | |
| params=params, | |
| options=options, | |
| device=self.device, | |
| ) | |
| return output | |
| def forward( | |
| self, | |
| batch: AttrDict, | |
| params: Optional[Dict] = None, | |
| options: Optional[Dict] = None, | |
| ) -> AttrDict: | |
| """ | |
| :param batch: must contain either | |
| - rays: [batch_size x ... x 2 x 3] specify the origin and direction of each ray. | |
| or | |
| - poses: [batch_size x 2 x 3] where poses[:, 0] are the camera | |
| positions, and poses[:, 1] are the z-axis (toward the object) of | |
| the camera frame. | |
| - camera: an instance of Camera that implements camera_rays | |
| or | |
| - cameras: DifferentiableCameraBatch of shape [batch_size x *shape]. | |
| For both of the above two options, these may be specified. | |
| - height: Optional[int] | |
| - width: Optional[int] | |
| - ray_batch_size or inner_batch_size: Optional[int] defaults to 4096 rays | |
| :param params: a dictionary of optional meta parameters. | |
| :param options: A Dict of other hyperparameters that could be | |
| related to rendering or debugging | |
| :return: a dictionary containing | |
| - channels: [batch_size, *shape, n_channels] | |
| - distances: [batch_size, *shape, 1] | |
| - transmittance: [batch_size, *shape, 1] | |
| - aux_losses: Dict[str, torch.Tensor] | |
| """ | |
| if "rays" in batch: | |
| for key in ["poses", "camera", "height", "width"]: | |
| assert key not in batch | |
| return self.render_rays(batch, params=params, options=options) | |
| elif "poses" in batch or "cameras" in batch: | |
| assert "rays" not in batch | |
| if "poses" in batch: | |
| assert "camera" in batch | |
| else: | |
| assert "camera" not in batch | |
| return self.render_views(batch, params=params, options=options) | |
| raise NotImplementedError | |
| def get_camera_from_batch(batch: AttrDict) -> Tuple[DifferentiableCamera, int, Tuple[int]]: | |
| if "poses" in batch: | |
| assert not "cameras" in batch | |
| batch_size, *inner_shape, n_vecs, spatial_dim = batch.poses.shape | |
| assert n_vecs == 2 and spatial_dim == 3 | |
| inner_batch_size = int(np.prod(inner_shape)) | |
| poses = batch.poses.view(batch_size * inner_batch_size, 2, 3) | |
| position, direction = poses[:, 0], poses[:, 1] | |
| camera = projective_camera_frame(position, direction, batch.camera) | |
| elif "cameras" in batch: | |
| assert not "camera" in batch | |
| batch_size, *inner_shape = batch.cameras.shape | |
| camera = batch.cameras.flat_camera | |
| else: | |
| raise ValueError(f'neither "poses" nor "cameras" found in keys: {batch.keys()}') | |
| if "height" in batch and "width" in batch: | |
| camera = camera.resize_image(batch.width, batch.height) | |
| return camera, batch_size, inner_shape | |
| def append_tensor(val_list: Optional[List[torch.Tensor]], output: Optional[torch.Tensor]): | |
| if val_list is None: | |
| return [output] | |
| return val_list + [output] | |
| def render_views_from_rays( | |
| render_rays: Callable[[AttrDict, AttrDict, AttrDict], AttrDict], | |
| batch: AttrDict, | |
| params: Optional[Dict] = None, | |
| options: Optional[Dict] = None, | |
| device: torch.device = torch.device("cuda"), | |
| ) -> AttrDict: | |
| camera, batch_size, inner_shape = get_camera_from_batch(batch) | |
| inner_batch_size = int(np.prod(inner_shape)) | |
| coords = get_image_coords(camera.width, camera.height).to(device) | |
| coords = torch.broadcast_to(coords.unsqueeze(0), [batch_size * inner_batch_size, *coords.shape]) | |
| rays = camera.camera_rays(coords) | |
| # mip-NeRF radii calculation from: https://github.com/google/mipnerf/blob/84c969e0a623edd183b75693aed72a7e7c22902d/internal/datasets.py#L193-L200 | |
| directions = rays.view(batch_size, inner_batch_size, camera.height, camera.width, 2, 3)[ | |
| ..., 1, : | |
| ] | |
| neighbor_dists = torch.linalg.norm(directions[:, :, :, 1:] - directions[:, :, :, :-1], dim=-1) | |
| neighbor_dists = torch.cat([neighbor_dists, neighbor_dists[:, :, :, -2:-1]], dim=3) | |
| radii = (neighbor_dists * 2 / np.sqrt(12)).view(batch_size, -1, 1) | |
| rays = rays.view(batch_size, inner_batch_size * camera.height * camera.width, 2, 3) | |
| if isinstance(camera, DifferentiableProjectiveCamera): | |
| # Compute the camera z direction corresponding to every ray's pixel. | |
| # Used for depth computations below. | |
| z_directions = ( | |
| (camera.z / torch.linalg.norm(camera.z, dim=-1, keepdim=True)) | |
| .reshape([batch_size, inner_batch_size, 1, 3]) | |
| .repeat(1, 1, camera.width * camera.height, 1) | |
| .reshape(1, inner_batch_size * camera.height * camera.width, 3) | |
| ) | |
| ray_batch_size = batch.get("ray_batch_size", batch.get("inner_batch_size", 4096)) | |
| assert rays.shape[1] % ray_batch_size == 0 | |
| n_batches = rays.shape[1] // ray_batch_size | |
| output_list = AttrDict(aux_losses=dict()) | |
| for idx in range(n_batches): | |
| rays_batch = AttrDict( | |
| rays=rays[:, idx * ray_batch_size : (idx + 1) * ray_batch_size], | |
| radii=radii[:, idx * ray_batch_size : (idx + 1) * ray_batch_size], | |
| ) | |
| output = render_rays(rays_batch, params=params, options=options) | |
| if isinstance(camera, DifferentiableProjectiveCamera): | |
| z_batch = z_directions[:, idx * ray_batch_size : (idx + 1) * ray_batch_size] | |
| ray_directions = rays_batch.rays[:, :, 1] | |
| z_dots = (ray_directions * z_batch).sum(-1, keepdim=True) | |
| output.depth = output.distances * z_dots | |
| output_list = output_list.combine(output, append_tensor) | |
| def _resize(val_list: List[torch.Tensor]): | |
| val = torch.cat(val_list, dim=1) | |
| assert val.shape[1] == inner_batch_size * camera.height * camera.width | |
| return val.view(batch_size, *inner_shape, camera.height, camera.width, -1) | |
| def _avg(_key: str, loss_list: List[torch.Tensor]): | |
| return sum(loss_list) / n_batches | |
| output = AttrDict( | |
| {name: _resize(val_list) for name, val_list in output_list.items() if name != "aux_losses"} | |
| ) | |
| output.aux_losses = output_list.aux_losses.map(_avg) | |
| return output | |