Spaces:
Running
on
Zero
Running
on
Zero
| # Project EmbodiedGen | |
| # | |
| # Copyright (c) 2025 Horizon Robotics. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |
| # implied. See the License for the specific language governing | |
| # permissions and limitations under the License. | |
| # Part of the code comes from https://github.com/nerfstudio-project/gsplat | |
| # Both under the Apache License, Version 2.0. | |
| import math | |
| import random | |
| from io import BytesIO | |
| from typing import Dict, Literal, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| import trimesh | |
| from gsplat.optimizers import SelectiveAdam | |
| from scipy.spatial.transform import Rotation | |
| from sklearn.neighbors import NearestNeighbors | |
| from torch import Tensor | |
| from embodied_gen.models.gs_model import GaussianOperator | |
| __all__ = [ | |
| "set_random_seed", | |
| "export_splats", | |
| "create_splats_with_optimizers", | |
| "resize_pinhole_intrinsics", | |
| "restore_scene_scale_and_position", | |
| ] | |
| def knn(x: Tensor, K: int = 4) -> Tensor: | |
| x_np = x.cpu().numpy() | |
| model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np) | |
| distances, _ = model.kneighbors(x_np) | |
| return torch.from_numpy(distances).to(x) | |
| def rgb_to_sh(rgb: Tensor) -> Tensor: | |
| C0 = 0.28209479177387814 | |
| return (rgb - 0.5) / C0 | |
| def set_random_seed(seed: int): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| def splat2ply_bytes( | |
| means: torch.Tensor, | |
| scales: torch.Tensor, | |
| quats: torch.Tensor, | |
| opacities: torch.Tensor, | |
| sh0: torch.Tensor, | |
| shN: torch.Tensor, | |
| ) -> bytes: | |
| num_splats = means.shape[0] | |
| buffer = BytesIO() | |
| # Write PLY header | |
| buffer.write(b"ply\n") | |
| buffer.write(b"format binary_little_endian 1.0\n") | |
| buffer.write(f"element vertex {num_splats}\n".encode()) | |
| buffer.write(b"property float x\n") | |
| buffer.write(b"property float y\n") | |
| buffer.write(b"property float z\n") | |
| for i, data in enumerate([sh0, shN]): | |
| prefix = "f_dc" if i == 0 else "f_rest" | |
| for j in range(data.shape[1]): | |
| buffer.write(f"property float {prefix}_{j}\n".encode()) | |
| buffer.write(b"property float opacity\n") | |
| for i in range(scales.shape[1]): | |
| buffer.write(f"property float scale_{i}\n".encode()) | |
| for i in range(quats.shape[1]): | |
| buffer.write(f"property float rot_{i}\n".encode()) | |
| buffer.write(b"end_header\n") | |
| # Concatenate all tensors in the correct order | |
| splat_data = torch.cat( | |
| [means, sh0, shN, opacities.unsqueeze(1), scales, quats], dim=1 | |
| ) | |
| # Ensure correct dtype | |
| splat_data = splat_data.to(torch.float32) | |
| # Write binary data | |
| float_dtype = np.dtype(np.float32).newbyteorder("<") | |
| buffer.write( | |
| splat_data.detach().cpu().numpy().astype(float_dtype).tobytes() | |
| ) | |
| return buffer.getvalue() | |
| def export_splats( | |
| means: torch.Tensor, | |
| scales: torch.Tensor, | |
| quats: torch.Tensor, | |
| opacities: torch.Tensor, | |
| sh0: torch.Tensor, | |
| shN: torch.Tensor, | |
| format: Literal["ply"] = "ply", | |
| save_to: Optional[str] = None, | |
| ) -> bytes: | |
| """Export a Gaussian Splats model to bytes in PLY file format.""" | |
| total_splats = means.shape[0] | |
| assert means.shape == (total_splats, 3), "Means must be of shape (N, 3)" | |
| assert scales.shape == (total_splats, 3), "Scales must be of shape (N, 3)" | |
| assert quats.shape == ( | |
| total_splats, | |
| 4, | |
| ), "Quaternions must be of shape (N, 4)" | |
| assert opacities.shape == ( | |
| total_splats, | |
| ), "Opacities must be of shape (N,)" | |
| assert sh0.shape == (total_splats, 1, 3), "sh0 must be of shape (N, 1, 3)" | |
| assert ( | |
| shN.ndim == 3 and shN.shape[0] == total_splats and shN.shape[2] == 3 | |
| ), f"shN must be of shape (N, K, 3), got {shN.shape}" | |
| # Reshape spherical harmonics | |
| sh0 = sh0.squeeze(1) # Shape (N, 3) | |
| shN = shN.permute(0, 2, 1).reshape(means.shape[0], -1) # Shape (N, K * 3) | |
| # Check for NaN or Inf values | |
| invalid_mask = ( | |
| torch.isnan(means).any(dim=1) | |
| | torch.isinf(means).any(dim=1) | |
| | torch.isnan(scales).any(dim=1) | |
| | torch.isinf(scales).any(dim=1) | |
| | torch.isnan(quats).any(dim=1) | |
| | torch.isinf(quats).any(dim=1) | |
| | torch.isnan(opacities).any(dim=0) | |
| | torch.isinf(opacities).any(dim=0) | |
| | torch.isnan(sh0).any(dim=1) | |
| | torch.isinf(sh0).any(dim=1) | |
| | torch.isnan(shN).any(dim=1) | |
| | torch.isinf(shN).any(dim=1) | |
| ) | |
| # Filter out invalid entries | |
| valid_mask = ~invalid_mask | |
| means = means[valid_mask] | |
| scales = scales[valid_mask] | |
| quats = quats[valid_mask] | |
| opacities = opacities[valid_mask] | |
| sh0 = sh0[valid_mask] | |
| shN = shN[valid_mask] | |
| if format == "ply": | |
| data = splat2ply_bytes(means, scales, quats, opacities, sh0, shN) | |
| else: | |
| raise ValueError(f"Unsupported format: {format}") | |
| if save_to: | |
| with open(save_to, "wb") as binary_file: | |
| binary_file.write(data) | |
| return data | |
| def create_splats_with_optimizers( | |
| points: np.ndarray = None, | |
| points_rgb: np.ndarray = None, | |
| init_num_pts: int = 100_000, | |
| init_extent: float = 3.0, | |
| init_opacity: float = 0.1, | |
| init_scale: float = 1.0, | |
| means_lr: float = 1.6e-4, | |
| scales_lr: float = 5e-3, | |
| opacities_lr: float = 5e-2, | |
| quats_lr: float = 1e-3, | |
| sh0_lr: float = 2.5e-3, | |
| shN_lr: float = 2.5e-3 / 20, | |
| scene_scale: float = 1.0, | |
| sh_degree: int = 3, | |
| sparse_grad: bool = False, | |
| visible_adam: bool = False, | |
| batch_size: int = 1, | |
| feature_dim: Optional[int] = None, | |
| device: str = "cuda", | |
| world_rank: int = 0, | |
| world_size: int = 1, | |
| ) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]: | |
| if points is not None and points_rgb is not None: | |
| points = torch.from_numpy(points).float() | |
| rgbs = torch.from_numpy(points_rgb / 255.0).float() | |
| else: | |
| points = ( | |
| init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1) | |
| ) | |
| rgbs = torch.rand((init_num_pts, 3)) | |
| # Initialize the GS size to be the average dist of the 3 nearest neighbors | |
| dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] | |
| dist_avg = torch.sqrt(dist2_avg) | |
| scales = ( | |
| torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 3) | |
| ) # [N, 3] | |
| # Distribute the GSs to different ranks (also works for single rank) | |
| points = points[world_rank::world_size] | |
| rgbs = rgbs[world_rank::world_size] | |
| scales = scales[world_rank::world_size] | |
| N = points.shape[0] | |
| quats = torch.rand((N, 4)) # [N, 4] | |
| opacities = torch.logit(torch.full((N,), init_opacity)) # [N,] | |
| params = [ | |
| # name, value, lr | |
| ("means", torch.nn.Parameter(points), means_lr * scene_scale), | |
| ("scales", torch.nn.Parameter(scales), scales_lr), | |
| ("quats", torch.nn.Parameter(quats), quats_lr), | |
| ("opacities", torch.nn.Parameter(opacities), opacities_lr), | |
| ] | |
| if feature_dim is None: | |
| # color is SH coefficients. | |
| colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3] | |
| colors[:, 0, :] = rgb_to_sh(rgbs) | |
| params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), sh0_lr)) | |
| params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), shN_lr)) | |
| else: | |
| # features will be used for appearance and view-dependent shading | |
| features = torch.rand(N, feature_dim) # [N, feature_dim] | |
| params.append(("features", torch.nn.Parameter(features), sh0_lr)) | |
| colors = torch.logit(rgbs) # [N, 3] | |
| params.append(("colors", torch.nn.Parameter(colors), sh0_lr)) | |
| splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device) | |
| # Scale learning rate based on batch size, reference: | |
| # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ | |
| # Note that this would not make the training exactly equivalent, see | |
| # https://arxiv.org/pdf/2402.18824v1 | |
| BS = batch_size * world_size | |
| optimizer_class = None | |
| if sparse_grad: | |
| optimizer_class = torch.optim.SparseAdam | |
| elif visible_adam: | |
| optimizer_class = SelectiveAdam | |
| else: | |
| optimizer_class = torch.optim.Adam | |
| optimizers = { | |
| name: optimizer_class( | |
| [{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}], | |
| eps=1e-15 / math.sqrt(BS), | |
| # TODO: check betas logic when BS is larger than 10 betas[0] will be zero. | |
| betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), | |
| ) | |
| for name, _, lr in params | |
| } | |
| return splats, optimizers | |
| def compute_intrinsics_from_fovy( | |
| image_w: int, image_h: int, fovy_deg: float | |
| ) -> np.ndarray: | |
| fovy_rad = np.deg2rad(fovy_deg) | |
| fy = image_h / (2 * np.tan(fovy_rad / 2)) | |
| fx = fy * (image_w / image_h) | |
| cx = image_w / 2 | |
| cy = image_h / 2 | |
| K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) | |
| return K | |
| def resize_pinhole_intrinsics( | |
| raw_K: np.ndarray | torch.Tensor, | |
| raw_hw: tuple[int, int], | |
| new_hw: tuple[int, int], | |
| ) -> np.ndarray: | |
| raw_h, raw_w = raw_hw | |
| new_h, new_w = new_hw | |
| scale_x = new_w / raw_w | |
| scale_y = new_h / raw_h | |
| new_K = raw_K.copy() if isinstance(raw_K, np.ndarray) else raw_K.clone() | |
| new_K[0, 0] *= scale_x # fx | |
| new_K[0, 2] *= scale_x # cx | |
| new_K[1, 1] *= scale_y # fy | |
| new_K[1, 2] *= scale_y # cy | |
| return new_K | |
| def restore_scene_scale_and_position( | |
| real_height: float, mesh_path: str, gs_path: str | |
| ) -> None: | |
| """Scales a mesh and corresponding GS model to match a given real-world height. | |
| Uses the 1st and 99th percentile of mesh Z-axis to estimate height, | |
| applies scaling and vertical alignment, and updates both the mesh and GS model. | |
| Args: | |
| real_height (float): Target real-world height among Z axis. | |
| mesh_path (str): Path to the input mesh file. | |
| gs_path (str): Path to the Gaussian Splatting model file. | |
| """ | |
| mesh = trimesh.load(mesh_path) | |
| z_min = np.percentile(mesh.vertices[:, 1], 1) | |
| z_max = np.percentile(mesh.vertices[:, 1], 99) | |
| height = z_max - z_min | |
| scale = real_height / height | |
| rot = Rotation.from_quat([0, 1, 0, 0]) | |
| mesh.vertices = rot.apply(mesh.vertices) | |
| mesh.vertices[:, 1] -= z_min | |
| mesh.vertices *= scale | |
| mesh.export(mesh_path) | |
| gs_model: GaussianOperator = GaussianOperator.load_from_ply(gs_path) | |
| gs_model = gs_model.get_gaussians( | |
| instance_pose=torch.tensor([0.0, -z_min, 0, 0, 1, 0, 0]) | |
| ) | |
| gs_model.rescale(scale) | |
| gs_model.save_to_ply(gs_path) | |