Spaces:
Running
on
Zero
Running
on
Zero
| # Project EmbodiedGen | |
| # | |
| # Copyright (c) 2025 Horizon Robotics. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |
| # implied. See the License for the specific language governing | |
| # permissions and limitations under the License. | |
| from embodied_gen.utils.monkey_patches import monkey_patch_pano2room | |
| monkey_patch_pano2room() | |
| import os | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import trimesh | |
| from equilib import cube2equi, equi2pers | |
| from kornia.morphology import dilation | |
| from PIL import Image | |
| from embodied_gen.models.sr_model import ImageRealESRGAN | |
| from embodied_gen.utils.config import Pano2MeshSRConfig | |
| from embodied_gen.utils.geometry import compute_pinhole_intrinsics | |
| from embodied_gen.utils.log import logger | |
| from thirdparty.pano2room.modules.geo_predictors import PanoJointPredictor | |
| from thirdparty.pano2room.modules.geo_predictors.PanoFusionDistancePredictor import ( | |
| PanoFusionDistancePredictor, | |
| ) | |
| from thirdparty.pano2room.modules.inpainters import PanoPersFusionInpainter | |
| from thirdparty.pano2room.modules.mesh_fusion.render import ( | |
| features_to_world_space_mesh, | |
| render_mesh, | |
| ) | |
| from thirdparty.pano2room.modules.mesh_fusion.sup_info import SupInfoPool | |
| from thirdparty.pano2room.utils.camera_utils import gen_pano_rays | |
| from thirdparty.pano2room.utils.functions import ( | |
| depth_to_distance, | |
| get_cubemap_views_world_to_cam, | |
| resize_image_with_aspect_ratio, | |
| rot_z_world_to_cam, | |
| tensor_to_pil, | |
| ) | |
| class Pano2MeshSRPipeline: | |
| """Converting panoramic RGB image into 3D mesh representations, followed by inpainting and mesh refinement. | |
| This class integrates several key components including: | |
| - Depth estimation from RGB panorama | |
| - Inpainting of missing regions under offsets | |
| - RGB-D to mesh conversion | |
| - Multi-view mesh repair | |
| - 3D Gaussian Splatting (3DGS) dataset generation | |
| Args: | |
| config (Pano2MeshSRConfig): Configuration object containing model and pipeline parameters. | |
| Example: | |
| ```python | |
| pipeline = Pano2MeshSRPipeline(config) | |
| pipeline(pano_image='example.png', output_dir='./output') | |
| ``` | |
| """ | |
| def __init__(self, config: Pano2MeshSRConfig) -> None: | |
| self.cfg = config | |
| self.device = config.device | |
| # Init models. | |
| self.inpainter = PanoPersFusionInpainter(save_path=None) | |
| self.geo_predictor = PanoJointPredictor(save_path=None) | |
| self.pano_fusion_distance_predictor = PanoFusionDistancePredictor() | |
| self.super_model = ImageRealESRGAN(outscale=self.cfg.upscale_factor) | |
| # Init poses. | |
| cubemap_w2cs = get_cubemap_views_world_to_cam() | |
| self.cubemap_w2cs = [p.to(self.device) for p in cubemap_w2cs] | |
| self.camera_poses = self.load_camera_poses(self.cfg.trajectory_dir) | |
| kernel = cv2.getStructuringElement( | |
| cv2.MORPH_ELLIPSE, self.cfg.kernel_size | |
| ) | |
| self.kernel = torch.from_numpy(kernel).float().to(self.device) | |
| def init_mesh_params(self) -> None: | |
| torch.set_default_device(self.device) | |
| self.inpaint_mask = torch.ones( | |
| (self.cfg.cubemap_h, self.cfg.cubemap_w), dtype=torch.bool | |
| ) | |
| self.vertices = torch.empty((3, 0), requires_grad=False) | |
| self.colors = torch.empty((3, 0), requires_grad=False) | |
| self.faces = torch.empty((3, 0), dtype=torch.long, requires_grad=False) | |
| def read_camera_pose_file(filepath: str) -> np.ndarray: | |
| with open(filepath, "r") as f: | |
| values = [float(num) for line in f for num in line.split()] | |
| return np.array(values).reshape(4, 4) | |
| def load_camera_poses( | |
| self, trajectory_dir: str | |
| ) -> tuple[np.ndarray, list[torch.Tensor]]: | |
| pose_filenames = sorted( | |
| [ | |
| fname | |
| for fname in os.listdir(trajectory_dir) | |
| if fname.startswith("camera_pose") | |
| ] | |
| ) | |
| pano_pose_world = None | |
| relative_poses = [] | |
| for idx, filename in enumerate(pose_filenames): | |
| pose_path = os.path.join(trajectory_dir, filename) | |
| pose_matrix = self.read_camera_pose_file(pose_path) | |
| if pano_pose_world is None: | |
| pano_pose_world = pose_matrix.copy() | |
| pano_pose_world[0, 3] += self.cfg.pano_center_offset[0] | |
| pano_pose_world[2, 3] += self.cfg.pano_center_offset[1] | |
| # Use different reference for the first 6 cubemap views | |
| reference_pose = pose_matrix if idx < 6 else pano_pose_world | |
| relative_matrix = pose_matrix @ np.linalg.inv(reference_pose) | |
| relative_matrix[0:2, :] *= -1 # flip_xy | |
| relative_matrix = ( | |
| relative_matrix @ rot_z_world_to_cam(180).cpu().numpy() | |
| ) | |
| relative_matrix[:3, 3] *= self.cfg.pose_scale | |
| relative_matrix = torch.tensor( | |
| relative_matrix, dtype=torch.float32 | |
| ) | |
| relative_poses.append(relative_matrix) | |
| return relative_poses | |
| def load_inpaint_poses( | |
| self, poses: torch.Tensor | |
| ) -> dict[int, torch.Tensor]: | |
| inpaint_poses = dict() | |
| sampled_views = poses[:: self.cfg.inpaint_frame_stride] | |
| init_pose = torch.eye(4) | |
| for idx, w2c_tensor in enumerate(sampled_views): | |
| w2c = w2c_tensor.cpu().numpy().astype(np.float32) | |
| c2w = np.linalg.inv(w2c) | |
| pose_tensor = init_pose.clone() | |
| pose_tensor[:3, 3] = torch.from_numpy(c2w[:3, 3]) | |
| pose_tensor[:3, 3] *= -1 | |
| inpaint_poses[idx] = pose_tensor.to(self.device) | |
| return inpaint_poses | |
| def project(self, world_to_cam: torch.Tensor): | |
| ( | |
| project_image, | |
| project_depth, | |
| inpaint_mask, | |
| _, | |
| z_buf, | |
| mesh, | |
| ) = render_mesh( | |
| vertices=self.vertices, | |
| faces=self.faces, | |
| vertex_features=self.colors, | |
| H=self.cfg.cubemap_h, | |
| W=self.cfg.cubemap_w, | |
| fov_in_degrees=self.cfg.fov, | |
| RT=world_to_cam, | |
| blur_radius=self.cfg.blur_radius, | |
| faces_per_pixel=self.cfg.faces_per_pixel, | |
| ) | |
| project_image = project_image * ~inpaint_mask | |
| return project_image[:3, ...], inpaint_mask, project_depth | |
| def render_pano(self, pose: torch.Tensor): | |
| cubemap_list = [] | |
| for cubemap_pose in self.cubemap_w2cs: | |
| project_pose = cubemap_pose @ pose | |
| rgb, inpaint_mask, depth = self.project(project_pose) | |
| distance_map = depth_to_distance(depth[None, ...]) | |
| mask = inpaint_mask[None, ...] | |
| cubemap_list.append(torch.cat([rgb, distance_map, mask], dim=0)) | |
| # Set default tensor type for CPU operation in cube2equi | |
| with torch.device("cpu"): | |
| pano_rgbd = cube2equi( | |
| cubemap_list, "list", self.cfg.pano_h, self.cfg.pano_w | |
| ) | |
| pano_rgb = pano_rgbd[:3, :, :] | |
| pano_depth = pano_rgbd[3:4, :, :].squeeze(0) | |
| pano_mask = pano_rgbd[4:, :, :].squeeze(0) | |
| return pano_rgb, pano_depth, pano_mask | |
| def rgbd_to_mesh( | |
| self, | |
| rgb: torch.Tensor, | |
| depth: torch.Tensor, | |
| inpaint_mask: torch.Tensor, | |
| world_to_cam: torch.Tensor = None, | |
| using_distance_map: bool = True, | |
| ) -> None: | |
| if world_to_cam is None: | |
| world_to_cam = torch.eye(4, dtype=torch.float32).to(self.device) | |
| if inpaint_mask.sum() == 0: | |
| return | |
| vertices, faces, colors = features_to_world_space_mesh( | |
| colors=rgb.squeeze(0), | |
| depth=depth, | |
| fov_in_degrees=self.cfg.fov, | |
| world_to_cam=world_to_cam, | |
| mask=inpaint_mask, | |
| faces=self.faces, | |
| vertices=self.vertices, | |
| using_distance_map=using_distance_map, | |
| edge_threshold=0.05, | |
| ) | |
| faces += self.vertices.shape[1] | |
| self.vertices = torch.cat([self.vertices, vertices], dim=1) | |
| self.colors = torch.cat([self.colors, colors], dim=1) | |
| self.faces = torch.cat([self.faces, faces], dim=1) | |
| def get_edge_image_by_depth( | |
| self, depth: torch.Tensor, dilate_iter: int = 1 | |
| ) -> np.ndarray: | |
| if isinstance(depth, torch.Tensor): | |
| depth = depth.cpu().detach().numpy() | |
| gray = (depth / depth.max() * 255).astype(np.uint8) | |
| edges = cv2.Canny(gray, 60, 150) | |
| if dilate_iter > 0: | |
| kernel = np.ones((3, 3), np.uint8) | |
| edges = cv2.dilate(edges, kernel, iterations=dilate_iter) | |
| return edges | |
| def mesh_repair_by_greedy_view_selection( | |
| self, pose_dict: dict[str, torch.Tensor], output_dir: str | |
| ) -> list: | |
| inpainted_panos_w_pose = [] | |
| while len(pose_dict) > 0: | |
| logger.info(f"Repairing mesh left rounds {len(pose_dict)}") | |
| sampled_views = [] | |
| for key, pose in pose_dict.items(): | |
| pano_rgb, pano_distance, pano_mask = self.render_pano(pose) | |
| completeness = torch.sum(1 - pano_mask) / (pano_mask.numel()) | |
| sampled_views.append((key, completeness.item(), pose)) | |
| if len(sampled_views) == 0: | |
| break | |
| # Find inpainting with least view completeness. | |
| sampled_views = sorted(sampled_views, key=lambda x: x[1]) | |
| key, _, pose = sampled_views[len(sampled_views) * 2 // 3] | |
| pose_dict.pop(key) | |
| pano_rgb, pano_distance, pano_mask = self.render_pano(pose) | |
| colors = pano_rgb.permute(1, 2, 0).clone() | |
| distances = pano_distance.unsqueeze(-1).clone() | |
| pano_inpaint_mask = pano_mask.clone() | |
| init_pose = pose.clone() | |
| normals = None | |
| if pano_inpaint_mask.min().item() < 0.5: | |
| colors, distances, normals = self.inpaint_panorama( | |
| idx=key, | |
| colors=colors, | |
| distances=distances, | |
| pano_mask=pano_inpaint_mask, | |
| ) | |
| init_pose[0, 3], init_pose[1, 3], init_pose[2, 3] = ( | |
| -pose[0, 3], | |
| pose[2, 3], | |
| 0, | |
| ) | |
| rays = gen_pano_rays( | |
| init_pose, self.cfg.pano_h, self.cfg.pano_w | |
| ) | |
| conflict_mask = self.sup_pool.geo_check( | |
| rays, distances.unsqueeze(-1) | |
| ) # 0 is conflict, 1 not conflict | |
| pano_inpaint_mask *= conflict_mask | |
| self.rgbd_to_mesh( | |
| colors.permute(2, 0, 1), | |
| distances, | |
| pano_inpaint_mask, | |
| world_to_cam=pose, | |
| ) | |
| self.sup_pool.register_sup_info( | |
| pose=init_pose, | |
| mask=pano_inpaint_mask.clone(), | |
| rgb=colors, | |
| distance=distances.unsqueeze(-1), | |
| normal=normals, | |
| ) | |
| colors = colors.permute(2, 0, 1).unsqueeze(0) | |
| inpainted_panos_w_pose.append([colors, pose]) | |
| if self.cfg.visualize: | |
| from embodied_gen.data.utils import DiffrastRender | |
| tensor_to_pil(pano_rgb.unsqueeze(0)).save( | |
| f"{output_dir}/rendered_pano_{key}.jpg" | |
| ) | |
| tensor_to_pil(colors).save( | |
| f"{output_dir}/inpainted_pano_{key}.jpg" | |
| ) | |
| norm_depth = DiffrastRender.normalize_map_by_mask( | |
| distances, torch.ones_like(distances) | |
| ) | |
| heatmap = (norm_depth.cpu().numpy() * 255).astype(np.uint8) | |
| heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) | |
| Image.fromarray(heatmap).save( | |
| f"{output_dir}/inpainted_depth_{key}.png" | |
| ) | |
| return inpainted_panos_w_pose | |
| def inpaint_panorama( | |
| self, | |
| idx: int, | |
| colors: torch.Tensor, | |
| distances: torch.Tensor, | |
| pano_mask: torch.Tensor, | |
| ) -> tuple[torch.Tensor]: | |
| mask = (pano_mask[None, ..., None] > 0.5).float() | |
| mask = mask.permute(0, 3, 1, 2) | |
| mask = dilation(mask, kernel=self.kernel) | |
| mask = mask[0, 0, ..., None] # hwc | |
| inpainted_img = self.inpainter.inpaint(idx, colors, mask) | |
| inpainted_img = colors * (1 - mask) + inpainted_img * mask | |
| inpainted_distances, inpainted_normals = self.geo_predictor( | |
| idx, | |
| inpainted_img, | |
| distances[..., None], | |
| mask=mask, | |
| reg_loss_weight=0.0, | |
| normal_loss_weight=5e-2, | |
| normal_tv_loss_weight=5e-2, | |
| ) | |
| return inpainted_img, inpainted_distances.squeeze(), inpainted_normals | |
| def preprocess_pano( | |
| self, image: Image.Image | str | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| if isinstance(image, str): | |
| image = Image.open(image) | |
| image = image.convert("RGB") | |
| if image.size[0] < image.size[1]: | |
| image = image.transpose(Image.TRANSPOSE) | |
| image = resize_image_with_aspect_ratio(image, self.cfg.pano_w) | |
| image_rgb = torch.tensor(np.array(image)).permute(2, 0, 1) / 255 | |
| image_rgb = image_rgb.to(self.device) | |
| image_depth = self.pano_fusion_distance_predictor.predict( | |
| image_rgb.permute(1, 2, 0) | |
| ) | |
| image_depth = ( | |
| image_depth / image_depth.max() * self.cfg.depth_scale_factor | |
| ) | |
| return image_rgb, image_depth | |
| def pano_to_perpective( | |
| self, pano_image: torch.Tensor, pitch: float, yaw: float, fov: float | |
| ) -> torch.Tensor: | |
| rots = dict( | |
| roll=0, | |
| pitch=pitch, | |
| yaw=yaw, | |
| ) | |
| perspective = equi2pers( | |
| equi=pano_image.squeeze(0), | |
| rots=rots, | |
| height=self.cfg.cubemap_h, | |
| width=self.cfg.cubemap_w, | |
| fov_x=fov, | |
| mode="bilinear", | |
| ).unsqueeze(0) | |
| return perspective | |
| def pano_to_cubemap(self, pano_rgb: torch.Tensor): | |
| # Define six canonical cube directions in (pitch, yaw) | |
| directions = [ | |
| (0, 0), | |
| (0, 1.5 * np.pi), | |
| (0, 1.0 * np.pi), | |
| (0, 0.5 * np.pi), | |
| (-0.5 * np.pi, 0), | |
| (0.5 * np.pi, 0), | |
| ] | |
| cubemaps_rgb = [] | |
| for pitch, yaw in directions: | |
| rgb_view = self.pano_to_perpective( | |
| pano_rgb, pitch, yaw, fov=self.cfg.fov | |
| ) | |
| cubemaps_rgb.append(rgb_view.cpu()) | |
| return cubemaps_rgb | |
| def save_mesh(self, output_path: str) -> None: | |
| vertices_np = self.vertices.T.cpu().numpy() | |
| colors_np = self.colors.T.cpu().numpy() | |
| faces_np = self.faces.T.cpu().numpy() | |
| mesh = trimesh.Trimesh( | |
| vertices=vertices_np, faces=faces_np, vertex_colors=colors_np | |
| ) | |
| mesh.export(output_path) | |
| def mesh_pose_to_gs_pose(self, mesh_pose: torch.Tensor) -> np.ndarray: | |
| pose = mesh_pose.clone() | |
| pose[0, :] *= -1 | |
| pose[1, :] *= -1 | |
| Rw2c = pose[:3, :3].cpu().numpy() | |
| Tw2c = pose[:3, 3:].cpu().numpy() | |
| yz_reverse = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) | |
| Rc2w = (yz_reverse @ Rw2c).T | |
| Tc2w = -(Rc2w @ yz_reverse @ Tw2c) | |
| c2w = np.concatenate((Rc2w, Tc2w), axis=1) | |
| c2w = np.concatenate((c2w, np.array([[0, 0, 0, 1]])), axis=0) | |
| return c2w | |
| def __call__(self, pano_image: Image.Image | str, output_dir: str): | |
| self.init_mesh_params() | |
| pano_rgb, pano_depth = self.preprocess_pano(pano_image) | |
| self.sup_pool = SupInfoPool() | |
| self.sup_pool.register_sup_info( | |
| pose=torch.eye(4).to(self.device), | |
| mask=torch.ones([self.cfg.pano_h, self.cfg.pano_w]), | |
| rgb=pano_rgb.permute(1, 2, 0), | |
| distance=pano_depth[..., None], | |
| ) | |
| self.sup_pool.gen_occ_grid(res=256) | |
| logger.info("Init mesh from pano RGBD image...") | |
| depth_edge = self.get_edge_image_by_depth(pano_depth) | |
| inpaint_edge_mask = ( | |
| ~torch.from_numpy(depth_edge).to(self.device).bool() | |
| ) | |
| self.rgbd_to_mesh(pano_rgb, pano_depth, inpaint_edge_mask) | |
| repair_poses = self.load_inpaint_poses(self.camera_poses) | |
| inpainted_panos_w_poses = self.mesh_repair_by_greedy_view_selection( | |
| repair_poses, output_dir | |
| ) | |
| torch.cuda.empty_cache() | |
| torch.set_default_device("cpu") | |
| if self.cfg.mesh_file is not None: | |
| mesh_path = os.path.join(output_dir, self.cfg.mesh_file) | |
| self.save_mesh(mesh_path) | |
| if self.cfg.gs_data_file is None: | |
| return | |
| logger.info(f"Dump data for 3DGS training...") | |
| points_rgb = (self.colors.clip(0, 1) * 255).to(torch.uint8) | |
| data = { | |
| "points": self.vertices.permute(1, 0).cpu().numpy(), # (N, 3) | |
| "points_rgb": points_rgb.permute(1, 0).cpu().numpy(), # (N, 3) | |
| "train": [], | |
| "eval": [], | |
| } | |
| image_h = self.cfg.cubemap_h * self.cfg.upscale_factor | |
| image_w = self.cfg.cubemap_w * self.cfg.upscale_factor | |
| Ks = compute_pinhole_intrinsics(image_w, image_h, self.cfg.fov) | |
| for idx, (pano_img, pano_pose) in enumerate(inpainted_panos_w_poses): | |
| cubemaps = self.pano_to_cubemap(pano_img) | |
| for i in range(len(cubemaps)): | |
| cubemap = tensor_to_pil(cubemaps[i]) | |
| cubemap = self.super_model(cubemap) | |
| mesh_pose = self.cubemap_w2cs[i] @ pano_pose | |
| c2w = self.mesh_pose_to_gs_pose(mesh_pose) | |
| data["train"].append( | |
| { | |
| "camtoworld": c2w.astype(np.float32), | |
| "K": Ks.astype(np.float32), | |
| "image": np.array(cubemap), | |
| "image_h": image_h, | |
| "image_w": image_w, | |
| "image_id": len(cubemaps) * idx + i, | |
| } | |
| ) | |
| # Camera poses for evaluation. | |
| for idx in range(len(self.camera_poses)): | |
| c2w = self.mesh_pose_to_gs_pose(self.camera_poses[idx]) | |
| data["eval"].append( | |
| { | |
| "camtoworld": c2w.astype(np.float32), | |
| "K": Ks.astype(np.float32), | |
| "image_h": image_h, | |
| "image_w": image_w, | |
| "image_id": idx, | |
| } | |
| ) | |
| data_path = os.path.join(output_dir, self.cfg.gs_data_file) | |
| torch.save(data, data_path) | |
| return | |
| if __name__ == "__main__": | |
| output_dir = "outputs/bg_v2/test3" | |
| input_pano = "apps/assets/example_scene/result_pano.png" | |
| config = Pano2MeshSRConfig() | |
| pipeline = Pano2MeshSRPipeline(config) | |
| pipeline(input_pano, output_dir) | |