Spaces:
Paused
Paused
| import math | |
| from collections import defaultdict | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import step1x3d_geometry | |
| from step1x3d_geometry.utils.typing import * | |
| def dot(x, y): | |
| return torch.sum(x * y, -1, keepdim=True) | |
| def reflect(x, n): | |
| return 2 * dot(x, n) * n - x | |
| ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]] | |
| def scale_tensor( | |
| dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale | |
| ): | |
| if inp_scale is None: | |
| inp_scale = (0, 1) | |
| if tgt_scale is None: | |
| tgt_scale = (0, 1) | |
| if isinstance(tgt_scale, Tensor): | |
| assert dat.shape[-1] == tgt_scale.shape[-1] | |
| dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) | |
| dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] | |
| return dat | |
| def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any: | |
| if chunk_size <= 0: | |
| return func(*args, **kwargs) | |
| B = None | |
| for arg in list(args) + list(kwargs.values()): | |
| if isinstance(arg, torch.Tensor): | |
| B = arg.shape[0] | |
| break | |
| assert ( | |
| B is not None | |
| ), "No tensor found in args or kwargs, cannot determine batch size." | |
| out = defaultdict(list) | |
| out_type = None | |
| # max(1, B) to support B == 0 | |
| for i in range(0, max(1, B), chunk_size): | |
| out_chunk = func( | |
| *[ | |
| arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg | |
| for arg in args | |
| ], | |
| **{ | |
| k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg | |
| for k, arg in kwargs.items() | |
| }, | |
| ) | |
| if out_chunk is None: | |
| continue | |
| out_type = type(out_chunk) | |
| if isinstance(out_chunk, torch.Tensor): | |
| out_chunk = {0: out_chunk} | |
| elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list): | |
| chunk_length = len(out_chunk) | |
| out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)} | |
| elif isinstance(out_chunk, dict): | |
| pass | |
| else: | |
| print( | |
| f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}." | |
| ) | |
| exit(1) | |
| for k, v in out_chunk.items(): | |
| v = v if torch.is_grad_enabled() else v.detach() | |
| out[k].append(v) | |
| if out_type is None: | |
| return None | |
| out_merged: Dict[Any, Optional[torch.Tensor]] = {} | |
| for k, v in out.items(): | |
| if all([vv is None for vv in v]): | |
| # allow None in return value | |
| out_merged[k] = None | |
| elif all([isinstance(vv, torch.Tensor) for vv in v]): | |
| out_merged[k] = torch.cat(v, dim=0) | |
| else: | |
| raise TypeError( | |
| f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}" | |
| ) | |
| if out_type is torch.Tensor: | |
| return out_merged[0] | |
| elif out_type in [tuple, list]: | |
| return out_type([out_merged[i] for i in range(chunk_length)]) | |
| elif out_type is dict: | |
| return out_merged | |
| def randn_tensor( | |
| shape: Union[Tuple, List], | |
| generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, | |
| device: Optional["torch.device"] = None, | |
| dtype: Optional["torch.dtype"] = None, | |
| layout: Optional["torch.layout"] = None, | |
| ): | |
| """A helper function to create random tensors on the desired `device` with the desired `dtype`. When | |
| passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor | |
| is always created on the CPU. | |
| """ | |
| # device on which tensor is created defaults to device | |
| rand_device = device | |
| batch_size = shape[0] | |
| layout = layout or torch.strided | |
| device = device or torch.device("cpu") | |
| if generator is not None: | |
| gen_device_type = ( | |
| generator.device.type | |
| if not isinstance(generator, list) | |
| else generator[0].device.type | |
| ) | |
| if gen_device_type != device.type and gen_device_type == "cpu": | |
| rand_device = "cpu" | |
| if device != "mps": | |
| print( | |
| f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." | |
| f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" | |
| f" slighly speed up this function by passing a generator that was created on the {device} device." | |
| ) | |
| elif gen_device_type != device.type and gen_device_type == "cuda": | |
| raise ValueError( | |
| f"Cannot generate a {device} tensor from a generator of type {gen_device_type}." | |
| ) | |
| # make sure generator list of length 1 is treated like a non-list | |
| if isinstance(generator, list) and len(generator) == 1: | |
| generator = generator[0] | |
| if isinstance(generator, list): | |
| shape = (1,) + shape[1:] | |
| latents = [ | |
| torch.randn( | |
| shape, | |
| generator=generator[i], | |
| device=rand_device, | |
| dtype=dtype, | |
| layout=layout, | |
| ) | |
| for i in range(batch_size) | |
| ] | |
| latents = torch.cat(latents, dim=0).to(device) | |
| else: | |
| latents = torch.randn( | |
| shape, generator=generator, device=rand_device, dtype=dtype, layout=layout | |
| ).to(device) | |
| return latents | |
| def generate_dense_grid_points( | |
| bbox_min: np.ndarray, bbox_max: np.ndarray, octree_depth: int, indexing: str = "ij" | |
| ): | |
| length = bbox_max - bbox_min | |
| num_cells = np.exp2(octree_depth) | |
| x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) | |
| y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) | |
| z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) | |
| [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) | |
| xyz = np.stack((xs, ys, zs), axis=-1) | |
| xyz = xyz.reshape(-1, 3) | |
| grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] | |
| return xyz, grid_size, length | |