Spaces:
Runtime error
Runtime error
| import os | |
| from tqdm import tqdm | |
| import kiui | |
| from kiui.op import recenter | |
| import kornia | |
| import collections | |
| import math | |
| import time | |
| import itertools | |
| import pickle | |
| from typing import Any | |
| import lmdb | |
| import cv2 | |
| import trimesh | |
| cv2.setNumThreads(0) # disable multiprocess | |
| # import imageio | |
| import imageio.v3 as imageio | |
| import numpy as np | |
| from PIL import Image | |
| import Imath | |
| import OpenEXR | |
| from pdb import set_trace as st | |
| from pathlib import Path | |
| import torchvision | |
| from torchvision.transforms import v2 | |
| from einops import rearrange, repeat | |
| from functools import partial | |
| import io | |
| from scipy.stats import special_ortho_group | |
| import gzip | |
| import random | |
| import torch | |
| import torch as th | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader, Dataset | |
| from torchvision import transforms | |
| from torch.utils.data.distributed import DistributedSampler | |
| from pathlib import Path | |
| import lz4.frame | |
| from nsr.volumetric_rendering.ray_sampler import RaySampler | |
| import point_cloud_utils as pcu | |
| import torch.multiprocessing | |
| # torch.multiprocessing.set_sharing_strategy('file_system') | |
| from utils.general_utils import PILtoTorch, matrix_to_quaternion | |
| from guided_diffusion import logger | |
| import json | |
| import webdataset as wds | |
| from webdataset.shardlists import expand_source | |
| # st() | |
| from .shapenet import LMDBDataset, LMDBDataset_MV_Compressed, decompress_and_open_image_gzip, decompress_array | |
| from kiui.op import safe_normalize | |
| from utils.gs_utils.graphics_utils import getWorld2View2, getProjectionMatrix, getView2World | |
| from nsr.camera_utils import generate_input_camera | |
| def random_rotation_matrix(): | |
| # Generate a random rotation matrix in 3D | |
| random_rotation_3d = special_ortho_group.rvs(3) | |
| # Embed the 3x3 rotation matrix into a 4x4 matrix | |
| rotation_matrix_4x4 = np.eye(4) | |
| rotation_matrix_4x4[:3, :3] = random_rotation_3d | |
| return rotation_matrix_4x4 | |
| def fov2focal(fov, pixels): | |
| return pixels / (2 * math.tan(fov / 2)) | |
| def focal2fov(focal, pixels): | |
| return 2 * math.atan(pixels / (2 * focal)) | |
| def resize_depth_mask(depth_to_resize, resolution): | |
| depth_resized = cv2.resize(depth_to_resize, (resolution, resolution), | |
| interpolation=cv2.INTER_LANCZOS4) | |
| # interpolation=cv2.INTER_AREA) | |
| return depth_resized, depth_resized > 0 # type: ignore | |
| def resize_depth_mask_Tensor(depth_to_resize, resolution): | |
| if depth_to_resize.shape[-1] != resolution: | |
| depth_resized = torch.nn.functional.interpolate( | |
| input=depth_to_resize.unsqueeze(1), | |
| size=(resolution, resolution), | |
| # mode='bilinear', | |
| mode='nearest', | |
| # align_corners=False, | |
| ).squeeze(1) | |
| else: | |
| depth_resized = depth_to_resize | |
| return depth_resized.float(), depth_resized > 0 # type: ignore | |
| class PostProcess: | |
| def __init__( | |
| self, | |
| reso, | |
| reso_encoder, | |
| imgnet_normalize, | |
| plucker_embedding, | |
| decode_encode_img_only, | |
| mv_input, | |
| split_chunk_input, | |
| duplicate_sample, | |
| append_depth, | |
| gs_cam_format, | |
| orthog_duplicate, | |
| frame_0_as_canonical, | |
| pcd_path=None, | |
| load_pcd=False, | |
| split_chunk_size=8, | |
| append_xyz=False, | |
| ) -> None: | |
| self.load_pcd = load_pcd | |
| if pcd_path is None: # hard-coded | |
| pcd_path = '/cpfs01/user/lanyushi.p/data/FPS_PCD/pcd-V=6_256_again/fps-pcd/' | |
| self.pcd_path = Path(pcd_path) | |
| self.append_xyz = append_xyz | |
| if append_xyz: | |
| assert append_depth is False | |
| self.frame_0_as_canonical = frame_0_as_canonical | |
| self.gs_cam_format = gs_cam_format | |
| self.append_depth = append_depth | |
| self.plucker_embedding = plucker_embedding | |
| self.decode_encode_img_only = decode_encode_img_only | |
| self.duplicate_sample = duplicate_sample | |
| self.orthog_duplicate = orthog_duplicate | |
| self.zfar = 100.0 | |
| self.znear = 0.01 | |
| transformations = [] | |
| if not split_chunk_input: | |
| transformations.append(transforms.ToTensor()) | |
| if imgnet_normalize: | |
| transformations.append( | |
| transforms.Normalize((0.485, 0.456, 0.406), | |
| (0.229, 0.224, 0.225)) # type: ignore | |
| ) | |
| else: | |
| transformations.append( | |
| transforms.Normalize((0.5, 0.5, 0.5), | |
| (0.5, 0.5, 0.5))) # type: ignore | |
| self.normalize = transforms.Compose(transformations) | |
| self.reso_encoder = reso_encoder | |
| self.reso = reso | |
| self.instance_data_length = 40 | |
| # self.pair_per_instance = 1 # compat | |
| self.mv_input = mv_input | |
| self.split_chunk_input = split_chunk_input # 8 | |
| self.chunk_size = split_chunk_size if split_chunk_input else 40 | |
| # assert self.chunk_size in [8, 10] | |
| self.V = self.chunk_size // 2 # 4 views as input | |
| # else: | |
| # assert self.chunk_size == 20 | |
| # self.V = 12 # 6 + 6 here | |
| # st() | |
| assert split_chunk_input | |
| self.pair_per_instance = 1 | |
| # else: | |
| # self.pair_per_instance = 4 if mv_input else 2 # check whether improves IO | |
| self.ray_sampler = RaySampler() # load xyz | |
| def gen_rays(self, c): | |
| # Generate rays | |
| intrinsics, c2w = c[16:], c[:16].reshape(4, 4) | |
| self.h = self.reso_encoder | |
| self.w = self.reso_encoder | |
| yy, xx = torch.meshgrid( | |
| torch.arange(self.h, dtype=torch.float32) + 0.5, | |
| torch.arange(self.w, dtype=torch.float32) + 0.5, | |
| indexing='ij') | |
| # normalize to 0-1 pixel range | |
| yy = yy / self.h | |
| xx = xx / self.w | |
| # K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3) | |
| cx, cy, fx, fy = intrinsics[2], intrinsics[5], intrinsics[ | |
| 0], intrinsics[4] | |
| # cx *= self.w | |
| # cy *= self.h | |
| # f_x = f_y = fx * h / res_raw | |
| c2w = torch.from_numpy(c2w).float() | |
| xx = (xx - cx) / fx | |
| yy = (yy - cy) / fy | |
| zz = torch.ones_like(xx) | |
| dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention | |
| dirs /= torch.norm(dirs, dim=-1, keepdim=True) | |
| dirs = dirs.reshape(-1, 3, 1) | |
| del xx, yy, zz | |
| # st() | |
| dirs = (c2w[None, :3, :3] @ dirs)[..., 0] | |
| origins = c2w[None, :3, 3].expand(self.h * self.w, -1).contiguous() | |
| origins = origins.view(self.h, self.w, 3) | |
| dirs = dirs.view(self.h, self.w, 3) | |
| return origins, dirs | |
| def _post_process_batch_sample(self, | |
| sample): # sample is an instance batch here | |
| caption, ins = sample[-2:] | |
| instance_samples = [] | |
| for instance_idx in range(sample[0].shape[0]): | |
| instance_samples.append( | |
| self._post_process_sample(item[instance_idx] | |
| for item in sample[:-2])) | |
| return (*instance_samples, caption, ins) | |
| def _post_process_sample(self, data_sample): | |
| # raw_img, depth, c, bbox, caption, ins = data_sample | |
| # st() | |
| raw_img, depth, c, bbox = data_sample | |
| bbox = (bbox * (self.reso / 256)).astype( | |
| np.uint8) # normalize bbox to the reso range | |
| if raw_img.shape[-2] != self.reso_encoder: | |
| img_to_encoder = cv2.resize(raw_img, | |
| (self.reso_encoder, self.reso_encoder), | |
| interpolation=cv2.INTER_LANCZOS4) | |
| else: | |
| img_to_encoder = raw_img | |
| img_to_encoder = self.normalize(img_to_encoder) | |
| if self.plucker_embedding: | |
| rays_o, rays_d = self.gen_rays(c) | |
| rays_plucker = torch.cat( | |
| [torch.cross(rays_o, rays_d, dim=-1), rays_d], | |
| dim=-1).permute(2, 0, 1) # [h, w, 6] -> 6,h,w | |
| img_to_encoder = torch.cat([img_to_encoder, rays_plucker], 0) | |
| img = cv2.resize(raw_img, (self.reso, self.reso), | |
| interpolation=cv2.INTER_LANCZOS4) | |
| img = torch.from_numpy(img).permute(2, 0, 1) / 127.5 - 1 | |
| if self.decode_encode_img_only: | |
| depth_reso, fg_mask_reso = depth, depth | |
| else: | |
| depth_reso, fg_mask_reso = resize_depth_mask(depth, self.reso) | |
| # return { | |
| # # **sample, | |
| # 'img_to_encoder': img_to_encoder, | |
| # 'img': img, | |
| # 'depth_mask': fg_mask_reso, | |
| # # 'img_sr': img_sr, | |
| # 'depth': depth_reso, | |
| # 'c': c, | |
| # 'bbox': bbox, | |
| # 'caption': caption, | |
| # 'ins': ins | |
| # # ! no need to load img_sr for now | |
| # } | |
| # if len(data_sample) == 4: | |
| return (img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox) | |
| # else: | |
| # return (img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox, data_sample[-2], data_sample[-1]) | |
| def canonicalize_pts(self, c, pcd, for_encoder=True, canonical_idx=0): | |
| # pcd: sampled in world space | |
| assert c.shape[0] == self.chunk_size | |
| assert for_encoder | |
| # st() | |
| B = c.shape[0] | |
| camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4 | |
| cam_radius = np.linalg.norm( | |
| c[[0, self.V]][:, :16].reshape(2, 4, 4)[:, :3, 3], | |
| axis=-1, | |
| keepdims=False) # since g-buffer adopts dynamic radius here. | |
| frame1_fixed_pos = np.repeat(np.eye(4)[None], 2, axis=0) | |
| frame1_fixed_pos[:, 2, -1] = -cam_radius | |
| transform = frame1_fixed_pos @ np.linalg.inv(camera_poses[[0, self.V | |
| ]]) # B 4 4 | |
| transform = np.expand_dims(transform, axis=1) # B 1 4 4 | |
| # from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127 | |
| # transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]]) | |
| repeated_homo_pcd = np.repeat(np.concatenate( | |
| [pcd, np.ones_like(pcd[..., 0:1])], -1)[None], | |
| 2, | |
| axis=0)[..., None] # B N 4 1 | |
| new_pcd = (transform @ repeated_homo_pcd)[..., :3, 0] # 2 N 3 | |
| return new_pcd | |
| def canonicalize_pts_v6(self, c, pcd, for_encoder=True, canonical_idx=0): | |
| exit() # deprecated function | |
| # pcd: sampled in world space | |
| assert c.shape[0] == self.chunk_size | |
| assert for_encoder | |
| encoder_canonical_idx = [0, 6, 12, 18] | |
| B = c.shape[0] | |
| camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4 | |
| cam_radius = np.linalg.norm( | |
| c[encoder_canonical_idx][:, :16].reshape(4, 4, 4)[:, :3, 3], | |
| axis=-1, | |
| keepdims=False) # since g-buffer adopts dynamic radius here. | |
| frame1_fixed_pos = np.repeat(np.eye(4)[None], 4, axis=0) | |
| frame1_fixed_pos[:, 2, -1] = -cam_radius | |
| transform = frame1_fixed_pos @ np.linalg.inv( | |
| camera_poses[encoder_canonical_idx]) # B 4 4 | |
| transform = np.expand_dims(transform, axis=1) # B 1 4 4 | |
| # from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127 | |
| # transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]]) | |
| repeated_homo_pcd = np.repeat(np.concatenate( | |
| [pcd, np.ones_like(pcd[..., 0:1])], -1)[None], | |
| 4, | |
| axis=0)[..., None] # B N 4 1 | |
| new_pcd = (transform @ repeated_homo_pcd)[..., :3, 0] # 2 N 3 | |
| return new_pcd | |
| def normalize_camera(self, c, for_encoder=True, canonical_idx=0): | |
| assert c.shape[0] == self.chunk_size # 8 o r10 | |
| B = c.shape[0] | |
| camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4 | |
| if for_encoder: | |
| encoder_canonical_idx = [0, self.V] | |
| # st() | |
| cam_radius = np.linalg.norm( | |
| c[encoder_canonical_idx][:, :16].reshape(2, 4, 4)[:, :3, 3], | |
| axis=-1, | |
| keepdims=False) # since g-buffer adopts dynamic radius here. | |
| frame1_fixed_pos = np.repeat(np.eye(4)[None], 2, axis=0) | |
| frame1_fixed_pos[:, 2, -1] = -cam_radius | |
| transform = frame1_fixed_pos @ np.linalg.inv( | |
| camera_poses[encoder_canonical_idx]) | |
| # from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127 | |
| # transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]]) | |
| new_camera_poses = np.repeat( | |
| transform, self.V, axis=0 | |
| ) @ camera_poses # [V, 4, 4]. np.repeat() is th.repeat_interleave() | |
| else: | |
| cam_radius = np.linalg.norm( | |
| c[canonical_idx][:16].reshape(4, 4)[:3, 3], | |
| axis=-1, | |
| keepdims=False) # since g-buffer adopts dynamic radius here. | |
| frame1_fixed_pos = np.eye(4) | |
| frame1_fixed_pos[2, -1] = -cam_radius | |
| transform = frame1_fixed_pos @ np.linalg.inv( | |
| camera_poses[canonical_idx]) # 4,4 | |
| # from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127 | |
| # transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]]) | |
| new_camera_poses = np.repeat(transform[None], | |
| self.chunk_size, | |
| axis=0) @ camera_poses # [V, 4, 4] | |
| c = np.concatenate([new_camera_poses.reshape(B, 16), c[:, 16:]], | |
| axis=-1) | |
| return c | |
| def normalize_camera_v6(self, c, for_encoder=True, canonical_idx=0): | |
| B = c.shape[0] | |
| camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4 | |
| if for_encoder: | |
| assert c.shape[0] == 24 | |
| encoder_canonical_idx = [0, 6, 12, 18] | |
| cam_radius = np.linalg.norm( | |
| c[encoder_canonical_idx][:, :16].reshape(4, 4, 4)[:, :3, 3], | |
| axis=-1, | |
| keepdims=False) # since g-buffer adopts dynamic radius here. | |
| frame1_fixed_pos = np.repeat(np.eye(4)[None], 4, axis=0) | |
| frame1_fixed_pos[:, 2, -1] = -cam_radius | |
| transform = frame1_fixed_pos @ np.linalg.inv( | |
| camera_poses[encoder_canonical_idx]) | |
| # from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127 | |
| # transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]]) | |
| new_camera_poses = np.repeat(transform, 6, | |
| axis=0) @ camera_poses # [V, 4, 4] | |
| else: | |
| assert c.shape[0] == 12 | |
| cam_radius = np.linalg.norm( | |
| c[canonical_idx][:16].reshape(4, 4)[:3, 3], | |
| axis=-1, | |
| keepdims=False) # since g-buffer adopts dynamic radius here. | |
| frame1_fixed_pos = np.eye(4) | |
| frame1_fixed_pos[2, -1] = -cam_radius | |
| transform = frame1_fixed_pos @ np.linalg.inv( | |
| camera_poses[canonical_idx]) # 4,4 | |
| # from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127 | |
| # transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]]) | |
| new_camera_poses = np.repeat(transform[None], 12, | |
| axis=0) @ camera_poses # [V, 4, 4] | |
| c = np.concatenate([new_camera_poses.reshape(B, 16), c[:, 16:]], | |
| axis=-1) | |
| return c | |
| def get_plucker_ray(self, c): | |
| rays_plucker = [] | |
| for idx in range(c.shape[0]): | |
| rays_o, rays_d = self.gen_rays(c[idx]) | |
| rays_plucker.append( | |
| torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], | |
| dim=-1).permute(2, 0, 1)) # [h, w, 6] -> 6,h,w | |
| rays_plucker = torch.stack(rays_plucker, 0) | |
| return rays_plucker | |
| def _unproj_depth_given_c(self, c, depth): | |
| # get xyz hxw for each pixel, like MCC | |
| # img_size = self.reso | |
| img_size = depth.shape[-1] | |
| B = c.shape[0] | |
| cam2world_matrix = c[:, :16].reshape(B, 4, 4) | |
| intrinsics = c[:, 16:25].reshape(B, 3, 3) | |
| ray_origins, ray_directions = self.ray_sampler( # shape: | |
| cam2world_matrix, intrinsics, img_size)[:2] | |
| depth = depth.reshape(B, -1).unsqueeze(-1) | |
| xyz = ray_origins + depth * ray_directions # BV HW 3, already in the world space | |
| xyz = xyz.reshape(B, img_size, img_size, 3).permute(0, 3, 1, | |
| 2) # B 3 H W | |
| xyz = xyz.clip( | |
| -0.45, 0.45) # g-buffer saves depth with anti-alias = True ..... | |
| xyz = torch.where(xyz.abs() == 0.45, 0, xyz) # no boundary here? Yes. | |
| return xyz | |
| def _post_process_sample_batch(self, data_sample): | |
| # raw_img, depth, c, bbox, caption, ins = data_sample | |
| alpha = None | |
| if len(data_sample) == 4: | |
| raw_img, depth, c, bbox = data_sample | |
| else: | |
| raw_img, depth, c, alpha, bbox = data_sample # put c to position 2 | |
| if isinstance(depth, tuple): | |
| self.append_normal = True | |
| depth, normal = depth | |
| else: | |
| self.append_normal = False | |
| normal = None | |
| # if raw_img.shape[-1] == 4: | |
| # depth_reso, _ = resize_depth_mask_Tensor( | |
| # torch.from_numpy(depth), self.reso) | |
| # raw_img, fg_mask_reso = raw_img[..., :3], raw_img[..., -1] | |
| # # st() # ! check has 1 dim in alpha? | |
| # else: | |
| if not isinstance(depth, torch.Tensor): | |
| depth = torch.from_numpy(depth).float() | |
| else: | |
| depth = depth.float() | |
| depth_reso, fg_mask_reso = resize_depth_mask_Tensor(depth, self.reso) | |
| if alpha is None: | |
| alpha = fg_mask_reso | |
| else: | |
| # ! resize first | |
| # st() | |
| alpha = torch.from_numpy(alpha / 255.0).float() | |
| if alpha.shape[-1] != self.reso: # bilinear inteprolate reshape | |
| alpha = torch.nn.functional.interpolate( | |
| input=alpha.unsqueeze(1), | |
| size=(self.reso, self.reso), | |
| mode='bilinear', | |
| align_corners=False, | |
| ).squeeze(1) | |
| if self.reso < 256: | |
| bbox = (bbox * (self.reso / 256)).astype( | |
| np.uint8) # normalize bbox to the reso range | |
| else: # 3dgs | |
| bbox = bbox.astype(np.uint8) | |
| # st() # ! shall compat with 320 input | |
| # assert raw_img.shape[-2] == self.reso_encoder | |
| # img_to_encoder = cv2.resize( | |
| # raw_img, (self.reso_encoder, self.reso_encoder), | |
| # interpolation=cv2.INTER_LANCZOS4) | |
| # else: | |
| # img_to_encoder = raw_img | |
| raw_img = torch.from_numpy(raw_img).permute(0, 3, 1, | |
| 2) / 255.0 # [0,1] | |
| if normal is not None: | |
| normal = torch.from_numpy(normal).permute(0,3,1,2) | |
| # if raw_img.shape[-1] != self.reso: | |
| if raw_img.shape[1] != self.reso_encoder: | |
| img_to_encoder = torch.nn.functional.interpolate( | |
| input=raw_img, | |
| size=(self.reso_encoder, self.reso_encoder), | |
| mode='bilinear', | |
| align_corners=False,) | |
| img_to_encoder = self.normalize(img_to_encoder) | |
| if normal is not None: | |
| normal_for_encoder = torch.nn.functional.interpolate( | |
| input=normal, | |
| size=(self.reso_encoder, self.reso_encoder), | |
| # mode='bilinear', | |
| mode='nearest', | |
| # align_corners=False, | |
| ) | |
| else: | |
| img_to_encoder = self.normalize(raw_img) | |
| normal_for_encoder = normal | |
| if raw_img.shape[-1] != self.reso: | |
| img = torch.nn.functional.interpolate( | |
| input=raw_img, | |
| size=(self.reso, self.reso), | |
| mode='bilinear', | |
| align_corners=False, | |
| ) # [-1,1] range | |
| img = img * 2 - 1 # as gt | |
| if normal is not None: | |
| normal = torch.nn.functional.interpolate( | |
| input=normal, | |
| size=(self.reso, self.reso), | |
| # mode='bilinear', | |
| mode='nearest', | |
| # align_corners=False, | |
| ) | |
| else: | |
| img = raw_img * 2 - 1 | |
| # fg_mask_reso = depth[..., -1:] # ! use | |
| pad_v6_fn = lambda x: torch.concat([x, x[:4]], 0) if isinstance( | |
| x, torch.Tensor) else np.concatenate([x, x[:4]], 0) | |
| # ! processing encoder input image. | |
| # ! normalize camera feats | |
| if self.frame_0_as_canonical: # 4 views as input per batch | |
| # if self.chunk_size in [8, 10]: | |
| if True: | |
| # encoder_canonical_idx = [0, 4] | |
| # encoder_canonical_idx = [0, self.chunk_size//2] | |
| encoder_canonical_idx = [0, self.V] | |
| c_for_encoder = self.normalize_camera(c, for_encoder=True) | |
| c_for_render = self.normalize_camera( | |
| c, | |
| for_encoder=False, | |
| canonical_idx=encoder_canonical_idx[0] | |
| ) # allocated to nv_c, frame0 (in 8 views) as the canonical | |
| c_for_render_nv = self.normalize_camera( | |
| c, | |
| for_encoder=False, | |
| canonical_idx=encoder_canonical_idx[1] | |
| ) # allocated to nv_c, frame0 (in 8 views) as the canonical | |
| c_for_render = np.concatenate([c_for_render, c_for_render_nv], | |
| axis=-1) # for compat | |
| # st() | |
| else: | |
| assert self.chunk_size == 20 | |
| c_for_encoder = self.normalize_camera_v6(c, | |
| for_encoder=True) # | |
| paired_c_0 = np.concatenate([c[0:6], c[12:18]]) | |
| paired_c_1 = np.concatenate([c[6:12], c[18:24]]) | |
| def process_paired_camera(paired_c): | |
| c_for_render = self.normalize_camera_v6( | |
| paired_c, for_encoder=False, canonical_idx=0 | |
| ) # allocated to nv_c, frame0 (in 8 views) as the canonical | |
| c_for_render_nv = self.normalize_camera_v6( | |
| paired_c, for_encoder=False, canonical_idx=6 | |
| ) # allocated to nv_c, frame0 (in 8 views) as the canonical | |
| c_for_render = np.concatenate( | |
| [c_for_render, c_for_render_nv], axis=-1) # for compat | |
| return c_for_render | |
| paired_c_for_render_0 = process_paired_camera(paired_c_0) | |
| paired_c_for_render_1 = process_paired_camera(paired_c_1) | |
| c_for_render = np.empty(shape=(24, 50)) | |
| c_for_render[list(range(6)) + | |
| list(range(12, 18))] = paired_c_for_render_0 | |
| c_for_render[list(range(6, 12)) + | |
| list(range(18, 24))] = paired_c_for_render_1 | |
| else: # use g-buffer canonical c | |
| c_for_encoder, c_for_render = c, c | |
| if self.append_normal and normal is not None: | |
| img_to_encoder = torch.cat([img_to_encoder, normal_for_encoder], | |
| # img_to_encoder = torch.cat([img_to_encoder, normal], | |
| 1) # concat in C dim | |
| if self.plucker_embedding: | |
| # rays_plucker = self.get_plucker_ray(c) | |
| rays_plucker = self.get_plucker_ray(c_for_encoder) | |
| img_to_encoder = torch.cat([img_to_encoder, rays_plucker], | |
| 1) # concat in C dim | |
| # torchvision.utils.save_image(raw_img, 'tmp/inp.png', normalize=True, value_range=(0,1), nrow=1, padding=0) | |
| # torchvision.utils.save_image(rays_plucker[:,:3], 'tmp/plucker.png', normalize=True, value_range=(-1,1), nrow=1, padding=0) | |
| # torchvision.utils.save_image(depth_reso.unsqueeze(1), 'tmp/depth.png', normalize=True, nrow=1, padding=0) | |
| c = torch.from_numpy(c_for_render).to(torch.float32) | |
| if self.append_depth: | |
| normalized_depth = torch.from_numpy(depth_reso).clone().unsqueeze( | |
| 1) # min=0 | |
| # normalized_depth -= torch.min(normalized_depth) # always 0 here | |
| # normalized_depth /= torch.max(normalized_depth) | |
| # normalized_depth = normalized_depth.unsqueeze(1) * 2 - 1 # normalize to [-1,1] | |
| # st() | |
| img_to_encoder = torch.cat([img_to_encoder, normalized_depth], | |
| 1) # concat in C dim | |
| elif self.append_xyz: | |
| depth_for_unproj = depth.clone() | |
| depth_for_unproj[depth_for_unproj == | |
| 0] = 1e10 # so that rays_o will not appear in the final pcd. | |
| xyz = self._unproj_depth_given_c(c.float(), depth) | |
| # pcu.save_mesh_v(f'unproj_xyz_before_Nearest.ply', xyz[0:9].float().detach().permute(0,2,3,1).reshape(-1,3).cpu().numpy(),) | |
| if xyz.shape[-1] != self.reso_encoder: | |
| xyz = torch.nn.functional.interpolate( | |
| input=xyz, # [-1,1] | |
| # size=(self.reso, self.reso), | |
| size=(self.reso_encoder, self.reso_encoder), | |
| mode='nearest', | |
| ) | |
| # pcu.save_mesh_v(f'unproj_xyz_afterNearest.ply', xyz[0:9].float().detach().permute(0,2,3,1).reshape(-1,3).cpu().numpy(),) | |
| # st() | |
| img_to_encoder = torch.cat([img_to_encoder, xyz], 1) | |
| return (img_to_encoder, img, alpha, depth_reso, c, | |
| torch.from_numpy(bbox)) | |
| def rand_sample_idx(self): | |
| return random.randint(0, self.instance_data_length - 1) | |
| def rand_pair(self): | |
| return (self.rand_sample_idx() for _ in range(2)) | |
| def paired_post_process(self, sample): | |
| # repeat n times? | |
| all_inp_list = [] | |
| all_nv_list = [] | |
| caption, ins = sample[-2:] | |
| # expanded_return = [] | |
| for _ in range(self.pair_per_instance): | |
| cano_idx, nv_idx = self.rand_pair() | |
| cano_sample = self._post_process_sample(item[cano_idx] | |
| for item in sample[:-2]) | |
| nv_sample = self._post_process_sample(item[nv_idx] | |
| for item in sample[:-2]) | |
| all_inp_list.extend(cano_sample) | |
| all_nv_list.extend(nv_sample) | |
| return (*all_inp_list, *all_nv_list, caption, ins) | |
| # return [cano_sample, nv_sample, caption, ins] | |
| # return (*cano_sample, *nv_sample, caption, ins) | |
| def get_source_cw2wT(self, source_cameras_view_to_world): | |
| return matrix_to_quaternion( | |
| source_cameras_view_to_world[:3, :3].transpose(0, 1)) | |
| def c_to_3dgs_format(self, pose): | |
| # TODO, switch to torch version (batched later) | |
| c2w = pose[:16].reshape(4, 4) # 3x4 | |
| # ! load cam | |
| w2c = np.linalg.inv(c2w) | |
| R = np.transpose( | |
| w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code | |
| T = w2c[:3, 3] | |
| fx = pose[16] | |
| FovX = focal2fov(fx, 1) | |
| FovY = focal2fov(fx, 1) | |
| tanfovx = math.tan(FovX * 0.5) | |
| tanfovy = math.tan(FovY * 0.5) | |
| assert tanfovx == tanfovy | |
| trans = np.array([0.0, 0.0, 0.0]) | |
| scale = 1.0 | |
| view_world_transform = torch.tensor(getView2World(R, T, trans, | |
| scale)).transpose( | |
| 0, 1) | |
| world_view_transform = torch.tensor(getWorld2View2(R, T, trans, | |
| scale)).transpose( | |
| 0, 1) | |
| projection_matrix = getProjectionMatrix(znear=self.znear, | |
| zfar=self.zfar, | |
| fovX=FovX, | |
| fovY=FovY).transpose(0, 1) | |
| full_proj_transform = (world_view_transform.unsqueeze(0).bmm( | |
| projection_matrix.unsqueeze(0))).squeeze(0) | |
| camera_center = world_view_transform.inverse()[3, :3] | |
| # ! check pytorch3d camera system alignment. | |
| # item.update(viewpoint_cam=[viewpoint_cam]) | |
| c = {} | |
| # | |
| c["source_cv2wT_quat"] = self.get_source_cw2wT(view_world_transform) | |
| c.update( | |
| # projection_matrix=projection_matrix, # K | |
| R=torch.from_numpy(R), | |
| T=torch.from_numpy(T), | |
| cam_view=world_view_transform, # world_view_transform | |
| cam_view_proj=full_proj_transform, # full_proj_transform | |
| cam_pos=camera_center, | |
| tanfov=tanfovx, # TODO, fix in the renderer | |
| orig_pose=torch.from_numpy(pose), | |
| orig_c2w=torch.from_numpy(c2w), | |
| orig_w2c=torch.from_numpy(w2c), | |
| orig_intrin=torch.from_numpy(pose[16:]).reshape(3,3), | |
| # tanfovy=tanfovy, | |
| ) | |
| return c # dict for gs rendering | |
| def paired_post_process_chunk(self, sample): | |
| # st() | |
| # sample_npz, ins, caption = sample_pyd # three items | |
| # sample = *(sample[0][k] for k in ['raw_img', 'depth', 'c', 'bbox']), sample[-1], sample[-2] | |
| # repeat n times? | |
| all_inp_list = [] | |
| all_nv_list = [] | |
| auxiliary_sample = list(sample[-2:]) | |
| # caption, ins = sample[-2:] | |
| ins = sample[-1] | |
| assert sample[0].shape[0] == self.chunk_size # random chunks | |
| # expanded_return = [] | |
| if self.load_pcd: | |
| # fps_pcd = pcu.load_mesh_v( | |
| # # str(self.pcd_path / ins / 'fps-24576.ply')) # N, 3 | |
| # str(self.pcd_path / ins / 'fps-4096.ply')) # N, 3 | |
| # # 'fps-4096.ply')) # N, 3 | |
| fps_pcd = trimesh.load(str(self.pcd_path / ins / 'fps-4096.ply')).vertices | |
| auxiliary_sample += [fps_pcd] | |
| assert self.duplicate_sample | |
| # st() | |
| if self.duplicate_sample: | |
| # ! shuffle before process, since frame_0_as_canonical fixed c. | |
| if self.chunk_size in [20, 18, 16, 12]: | |
| shuffle_sample = sample[:-2] # no order shuffle required | |
| else: | |
| shuffle_sample = [] | |
| # indices = torch.randperm(self.chunk_size) | |
| indices = np.random.permutation(self.chunk_size) | |
| for _, item in enumerate(sample[:-2]): | |
| shuffle_sample.append(item[indices]) # random shuffle | |
| processed_sample = self._post_process_sample_batch(shuffle_sample) | |
| # ! process pcd if frmae_0 alignment | |
| if self.load_pcd: | |
| if self.frame_0_as_canonical: | |
| # ! normalize camera feats | |
| # normalized camera feats as in paper (transform the first pose to a fixed position) | |
| # if self.chunk_size == 20: | |
| # auxiliary_sample[-1] = self.canonicalize_pts_v6( | |
| # c=shuffle_sample[2], | |
| # pcd=auxiliary_sample[-1], | |
| # for_encoder=True) # B N 3 | |
| # else: | |
| auxiliary_sample[-1] = self.canonicalize_pts( | |
| c=shuffle_sample[2], | |
| pcd=auxiliary_sample[-1], | |
| for_encoder=True) # B N 3 | |
| else: | |
| auxiliary_sample[-1] = np.repeat( | |
| auxiliary_sample[-1][None], 2, | |
| axis=0) # share the same camera syste, just repeat | |
| assert not self.orthog_duplicate | |
| # if self.chunk_size == 8: | |
| all_inp_list.extend(item[:self.V] for item in processed_sample) | |
| all_nv_list.extend(item[self.V:] for item in processed_sample) | |
| # elif self.chunk_size == 20: # V=6 | |
| # # indices_v6 = [np.random.permutation(self.chunk_size)[:12] for _ in range(2)] # random sample 6 views from chunks | |
| # all_inp_list.extend(item[:12] for item in processed_sample) | |
| # # indices_v6 = np.concatenate([np.arange(12, 20), np.arange(0,4)]) | |
| # all_nv_list.extend( | |
| # item[12:] for item in | |
| # processed_sample) # already repeated inside batch fn | |
| # else: | |
| # raise NotImplementedError(self.chunk_size) | |
| # else: | |
| # all_inp_list.extend(item[:8] for item in processed_sample) | |
| # all_nv_list.extend(item[8:] for item in processed_sample) | |
| # st() | |
| return (*all_inp_list, *all_nv_list, *auxiliary_sample) | |
| else: | |
| processed_sample = self._post_process_sample_batch( # avoid shuffle shorten processing time | |
| item[:4] for item in sample[:-2]) | |
| all_inp_list.extend(item for item in processed_sample) | |
| all_nv_list.extend(item | |
| for item in processed_sample) # ! placeholder | |
| # return (*all_inp_list, *all_nv_list, caption, ins) | |
| return (*all_inp_list, *all_nv_list, *auxiliary_sample) | |
| # randomly shuffle 8 views, avoid overfitting | |
| def single_sample_create_dict_noBatch(self, sample, prefix=''): | |
| # if len(sample) == 1: | |
| # sample = sample[0] | |
| # assert len(sample) == 6 | |
| img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample | |
| if self.gs_cam_format: | |
| # TODO, can optimize later after model converges | |
| B, V, _ = c.shape # B 4 25 | |
| c = rearrange(c, 'B V C -> (B V) C').cpu().numpy() | |
| # c = c.cpu().numpy() | |
| all_gs_c = [self.c_to_3dgs_format(pose) for pose in c] | |
| # st() | |
| # all_gs_c = self.c_to_3dgs_format(c.cpu().numpy()) | |
| c = { | |
| k: | |
| rearrange(torch.stack([gs_c[k] for gs_c in all_gs_c]), | |
| '(B V) ... -> B V ...', | |
| B=B, | |
| V=V) | |
| # torch.stack([gs_c[k] for gs_c in all_gs_c]) | |
| if isinstance(all_gs_c[0][k], torch.Tensor) else all_gs_c[0][k] | |
| for k in all_gs_c[0].keys() | |
| } | |
| # c = collate_gs_c | |
| return { | |
| # **sample, | |
| f'{prefix}img_to_encoder': img_to_encoder, | |
| f'{prefix}img': img, | |
| f'{prefix}depth_mask': fg_mask_reso, | |
| f'{prefix}depth': depth_reso, | |
| f'{prefix}c': c, | |
| f'{prefix}bbox': bbox, | |
| } | |
| def single_sample_create_dict(self, sample, prefix=''): | |
| # if len(sample) == 1: | |
| # sample = sample[0] | |
| # assert len(sample) == 6 | |
| img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample | |
| if self.gs_cam_format: | |
| # TODO, can optimize later after model converges | |
| B, V, _ = c.shape # B 4 25 | |
| c = rearrange(c, 'B V C -> (B V) C').cpu().numpy() | |
| all_gs_c = [self.c_to_3dgs_format(pose) for pose in c] | |
| c = { | |
| k: | |
| rearrange(torch.stack([gs_c[k] for gs_c in all_gs_c]), | |
| '(B V) ... -> B V ...', | |
| B=B, | |
| V=V) | |
| if isinstance(all_gs_c[0][k], torch.Tensor) else all_gs_c[0][k] | |
| for k in all_gs_c[0].keys() | |
| } | |
| # c = collate_gs_c | |
| return { | |
| # **sample, | |
| f'{prefix}img_to_encoder': img_to_encoder, | |
| f'{prefix}img': img, | |
| f'{prefix}depth_mask': fg_mask_reso, | |
| f'{prefix}depth': depth_reso, | |
| f'{prefix}c': c, | |
| f'{prefix}bbox': bbox, | |
| } | |
| def single_instance_sample_create_dict(self, sample, prfix=''): | |
| assert len(sample) == 42 | |
| inp_sample_list = [[] for _ in range(6)] | |
| for item in sample[:40]: | |
| for item_idx in range(6): | |
| inp_sample_list[item_idx].append(item[0][item_idx]) | |
| inp_sample = self.single_sample_create_dict( | |
| (torch.stack(item_list) for item_list in inp_sample_list), | |
| prefix='') | |
| return { | |
| **inp_sample, # | |
| 'caption': sample[-2], | |
| 'ins': sample[-1] | |
| } | |
| def decode_gzip(self, sample_pyd, shape=(256, 256)): | |
| # sample_npz, ins, caption = sample_pyd # three items | |
| # c, bbox, depth, ins, caption, raw_img = sample_pyd[:5], sample_pyd[5:] | |
| # wds.to_tuple('raw_img.jpeg', 'depth.jpeg', | |
| # 'd_near.npy', | |
| # 'd_far.npy', | |
| # "c.npy", 'bbox.npy', 'ins.txt', 'caption.txt'), | |
| # raw_img, depth, alpha_mask, d_near, d_far, c, bbox, ins, caption = sample_pyd | |
| raw_img, depth_alpha, = sample_pyd | |
| # return raw_img, depth_alpha | |
| # raw_img, caption = sample_pyd | |
| # return raw_img, caption | |
| # st() | |
| raw_img = rearrange(raw_img, 'h (b w) c -> b h w c', b=self.chunk_size) | |
| depth = rearrange(depth, 'h (b w) c -> b h w c', b=self.chunk_size) | |
| alpha_mask = rearrange( | |
| alpha_mask, 'h (b w) c -> b h w c', b=self.chunk_size) / 255.0 | |
| d_far = d_far.reshape(self.chunk_size, 1, 1, 1) | |
| d_near = d_near.reshape(self.chunk_size, 1, 1, 1) | |
| # d = 1 / ( (d_normalized / 255) * (far-near) + near) | |
| depth = 1 / ((depth / 255) * (d_far - d_near) + d_near) | |
| depth = depth[..., 0] # decoded from jpeg | |
| # depth = decompress_array(depth['depth'], (self.chunk_size, *shape), | |
| # np.float32, | |
| # decompress=True, | |
| # decompress_fn=lz4.frame.decompress) | |
| # return raw_img, depth, d_near, d_far, c, bbox, caption, ins | |
| raw_img = np.concatenate([raw_img, alpha_mask[..., 0:1]], -1) | |
| return raw_img, depth, c, bbox, caption, ins | |
| def decode_zip( | |
| self, | |
| sample_pyd, | |
| ): | |
| shape = (self.reso_encoder, self.reso_encoder) | |
| if isinstance(sample_pyd, tuple): | |
| sample_pyd = sample_pyd[0] | |
| assert isinstance(sample_pyd, dict) | |
| raw_img = decompress_and_open_image_gzip( | |
| sample_pyd['raw_img'], | |
| is_img=True, | |
| decompress=True, | |
| decompress_fn=lz4.frame.decompress) | |
| caption = sample_pyd['caption'].decode('utf-8') | |
| ins = sample_pyd['ins'].decode('utf-8') | |
| c = decompress_array(sample_pyd['c'], ( | |
| self.chunk_size, | |
| 25, | |
| ), | |
| np.float32, | |
| decompress=True, | |
| decompress_fn=lz4.frame.decompress) | |
| bbox = decompress_array( | |
| sample_pyd['bbox'], | |
| ( | |
| self.chunk_size, | |
| 4, | |
| ), | |
| np.float32, | |
| # decompress=False) | |
| decompress=True, | |
| decompress_fn=lz4.frame.decompress) | |
| if self.decode_encode_img_only: | |
| depth = np.zeros(shape=(self.chunk_size, | |
| *shape)) # save loading time | |
| else: | |
| depth = decompress_array(sample_pyd['depth'], | |
| (self.chunk_size, *shape), | |
| np.float32, | |
| decompress=True, | |
| decompress_fn=lz4.frame.decompress) | |
| # return {'raw_img': raw_img, 'depth': depth, 'bbox': bbox, 'caption': caption, 'ins': ins, 'c': c} | |
| # return raw_img, depth, c, bbox, caption, ins | |
| # return raw_img, bbox, caption, ins | |
| # return bbox, caption, ins | |
| return raw_img, depth, c, bbox, caption, ins | |
| # ! run single-instance pipeline first | |
| # return raw_img[0], depth[0], c[0], bbox[0], caption, ins | |
| def create_dict_nobatch(self, sample): | |
| # sample = [item[0] for item in sample] # wds wrap items in [] | |
| sample_length = 6 | |
| # if self.load_pcd: | |
| # sample_length += 1 | |
| cano_sample_list = [[] for _ in range(sample_length)] | |
| nv_sample_list = [[] for _ in range(sample_length)] | |
| # st() | |
| # bs = (len(sample)-2) // 6 | |
| for idx in range(0, self.pair_per_instance): | |
| cano_sample = sample[sample_length * idx:sample_length * (idx + 1)] | |
| nv_sample = sample[sample_length * self.pair_per_instance + | |
| sample_length * idx:sample_length * | |
| self.pair_per_instance + sample_length * | |
| (idx + 1)] | |
| for item_idx in range(sample_length): | |
| if self.frame_0_as_canonical: | |
| # ! cycle input/output view for more pairs | |
| if item_idx == 4: | |
| cano_sample_list[item_idx].append( | |
| cano_sample[item_idx][..., :25]) | |
| nv_sample_list[item_idx].append( | |
| nv_sample[item_idx][..., :25]) | |
| cano_sample_list[item_idx].append( | |
| nv_sample[item_idx][..., 25:]) | |
| nv_sample_list[item_idx].append( | |
| cano_sample[item_idx][..., 25:]) | |
| else: | |
| cano_sample_list[item_idx].append( | |
| cano_sample[item_idx]) | |
| nv_sample_list[item_idx].append(nv_sample[item_idx]) | |
| cano_sample_list[item_idx].append(nv_sample[item_idx]) | |
| nv_sample_list[item_idx].append(cano_sample[item_idx]) | |
| else: | |
| cano_sample_list[item_idx].append(cano_sample[item_idx]) | |
| nv_sample_list[item_idx].append(nv_sample[item_idx]) | |
| cano_sample_list[item_idx].append(nv_sample[item_idx]) | |
| nv_sample_list[item_idx].append(cano_sample[item_idx]) | |
| cano_sample = self.single_sample_create_dict_noBatch( | |
| (torch.stack(item_list, 0) for item_list in cano_sample_list), | |
| prefix='' | |
| ) # torch.Size([5, 10, 256, 256]). Since no batch dim here for now. | |
| nv_sample = self.single_sample_create_dict_noBatch( | |
| (torch.stack(item_list, 0) for item_list in nv_sample_list), | |
| prefix='nv_') | |
| ret_dict = { | |
| **cano_sample, | |
| **nv_sample, | |
| } | |
| if not self.load_pcd: | |
| ret_dict.update({'caption': sample[-2], 'ins': sample[-1]}) | |
| else: | |
| # if self.frame_0_as_canonical: | |
| # # fps_pcd = rearrange( sample[-1], 'B V ... -> (B V) ...') # ! wrong order. | |
| # # if self.chunk_size == 8: | |
| # fps_pcd = rearrange( | |
| # sample[-1], 'B V ... -> (V B) ...') # mimic torch.repeat | |
| # # else: | |
| # # fps_pcd = rearrange( sample[-1], 'B V ... -> (B V) ...') # ugly code to match the input format... | |
| # else: | |
| # fps_pcd = sample[-1].repeat( | |
| # 2, 1, | |
| # 1) # mimic torch.cat(), from torch.Size([3, 4096, 3]) | |
| # ! TODO, check fps_pcd order | |
| ret_dict.update({ | |
| 'caption': sample[-3], | |
| 'ins': sample[-2], | |
| 'fps_pcd': sample[-1] | |
| }) | |
| return ret_dict | |
| def create_dict(self, sample): | |
| # sample = [item[0] for item in sample] # wds wrap items in [] | |
| # st() | |
| sample_length = 6 | |
| # if self.load_pcd: | |
| # sample_length += 1 | |
| cano_sample_list = [[] for _ in range(sample_length)] | |
| nv_sample_list = [[] for _ in range(sample_length)] | |
| # st() | |
| # bs = (len(sample)-2) // 6 | |
| for idx in range(0, self.pair_per_instance): | |
| cano_sample = sample[sample_length * idx:sample_length * (idx + 1)] | |
| nv_sample = sample[sample_length * self.pair_per_instance + | |
| sample_length * idx:sample_length * | |
| self.pair_per_instance + sample_length * | |
| (idx + 1)] | |
| for item_idx in range(sample_length): | |
| if self.frame_0_as_canonical: | |
| # ! cycle input/output view for more pairs | |
| if item_idx == 4: | |
| cano_sample_list[item_idx].append( | |
| cano_sample[item_idx][..., :25]) | |
| nv_sample_list[item_idx].append( | |
| nv_sample[item_idx][..., :25]) | |
| cano_sample_list[item_idx].append( | |
| nv_sample[item_idx][..., 25:]) | |
| nv_sample_list[item_idx].append( | |
| cano_sample[item_idx][..., 25:]) | |
| else: | |
| cano_sample_list[item_idx].append( | |
| cano_sample[item_idx]) | |
| nv_sample_list[item_idx].append(nv_sample[item_idx]) | |
| cano_sample_list[item_idx].append(nv_sample[item_idx]) | |
| nv_sample_list[item_idx].append(cano_sample[item_idx]) | |
| else: | |
| cano_sample_list[item_idx].append(cano_sample[item_idx]) | |
| nv_sample_list[item_idx].append(nv_sample[item_idx]) | |
| cano_sample_list[item_idx].append(nv_sample[item_idx]) | |
| nv_sample_list[item_idx].append(cano_sample[item_idx]) | |
| # if self.split_chunk_input: | |
| # cano_sample = self.single_sample_create_dict( | |
| # (torch.cat(item_list, 0) for item_list in cano_sample_list), | |
| # prefix='') | |
| # nv_sample = self.single_sample_create_dict( | |
| # (torch.cat(item_list, 0) for item_list in nv_sample_list), | |
| # prefix='nv_') | |
| # else: | |
| # st() | |
| cano_sample = self.single_sample_create_dict( | |
| (torch.cat(item_list, 0) for item_list in cano_sample_list), | |
| prefix='') # torch.Size([4, 4, 10, 256, 256]) | |
| nv_sample = self.single_sample_create_dict( | |
| (torch.cat(item_list, 0) for item_list in nv_sample_list), | |
| prefix='nv_') | |
| ret_dict = { | |
| **cano_sample, | |
| **nv_sample, | |
| } | |
| if not self.load_pcd: | |
| ret_dict.update({'caption': sample[-2], 'ins': sample[-1]}) | |
| else: | |
| if self.frame_0_as_canonical: | |
| # fps_pcd = rearrange( sample[-1], 'B V ... -> (B V) ...') # ! wrong order. | |
| # if self.chunk_size == 8: | |
| fps_pcd = rearrange( | |
| sample[-1], 'B V ... -> (V B) ...') # mimic torch.repeat | |
| # else: | |
| # fps_pcd = rearrange( sample[-1], 'B V ... -> (B V) ...') # ugly code to match the input format... | |
| else: | |
| fps_pcd = sample[-1].repeat( | |
| 2, 1, | |
| 1) # mimic torch.cat(), from torch.Size([3, 4096, 3]) | |
| ret_dict.update({ | |
| 'caption': sample[-3], | |
| 'ins': sample[-2], | |
| 'fps_pcd': fps_pcd | |
| }) | |
| return ret_dict | |
| def prepare_mv_input(self, sample): | |
| # sample = [item[0] for item in sample] # wds wrap items in [] | |
| bs = len(sample['caption']) # number of instances | |
| chunk_size = sample['img'].shape[0] // bs | |
| assert self.split_chunk_input | |
| for k, v in sample.items(): | |
| if isinstance(v, torch.Tensor) and k != 'fps_pcd': | |
| sample[k] = rearrange(v, "b f c ... -> (b f) c ...", | |
| f=self.V).contiguous() | |
| # # ! shift nv | |
| # else: | |
| # for k, v in sample.items(): | |
| # if k not in ['ins', 'caption']: | |
| # rolled_idx = torch.LongTensor( | |
| # list( | |
| # itertools.chain.from_iterable( | |
| # list(range(i, sample['img'].shape[0], bs)) | |
| # for i in range(bs)))) | |
| # v = torch.index_select(v, dim=0, index=rolled_idx) | |
| # sample[k] = v | |
| # # img = sample['img'] | |
| # # gt = sample['nv_img'] | |
| # # torchvision.utils.save_image(img[0], 'inp.jpg', normalize=True) | |
| # # torchvision.utils.save_image(gt[0], 'nv.jpg', normalize=True) | |
| # for k, v in sample.items(): | |
| # if 'nv' in k: | |
| # rolled_idx = torch.LongTensor( | |
| # list( | |
| # itertools.chain.from_iterable( | |
| # list( | |
| # np.roll( | |
| # np.arange(i * chunk_size, (i + 1) * | |
| # chunk_size), 4) | |
| # for i in range(bs))))) | |
| # v = torch.index_select(v, dim=0, index=rolled_idx) | |
| # sample[k] = v | |
| # torchvision.utils.save_image(sample['nv_img'], 'nv.png', normalize=True) | |
| # torchvision.utils.save_image(sample['img'], 'inp.png', normalize=True) | |
| return sample | |
| def load_dataset( | |
| file_path="", | |
| reso=64, | |
| reso_encoder=224, | |
| batch_size=1, | |
| # shuffle=True, | |
| num_workers=6, | |
| load_depth=False, | |
| preprocess=None, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| trainer_name='input_rec', | |
| use_lmdb=False, | |
| use_wds=False, | |
| use_chunk=False, | |
| use_lmdb_compressed=False, | |
| infi_sampler=True): | |
| # st() | |
| # dataset_cls = { | |
| # 'input_rec': MultiViewDataset, | |
| # 'nv': NovelViewDataset, | |
| # }[trainer_name] | |
| # st() | |
| if use_wds: | |
| return load_wds_data(file_path, reso, reso_encoder, batch_size, | |
| num_workers) | |
| if use_lmdb: | |
| logger.log('using LMDB dataset') | |
| # dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later. | |
| if use_lmdb_compressed: | |
| if 'nv' in trainer_name: | |
| dataset_cls = Objv_LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
| else: | |
| dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
| else: | |
| if 'nv' in trainer_name: | |
| dataset_cls = Objv_LMDBDataset_NV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
| else: | |
| dataset_cls = Objv_LMDBDataset_MV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
| # dataset = dataset_cls(file_path) | |
| elif use_chunk: | |
| dataset_cls = ChunkObjaverseDataset | |
| else: | |
| if 'nv' in trainer_name: | |
| dataset_cls = NovelViewObjverseDataset | |
| else: | |
| dataset_cls = MultiViewObjverseDataset # 1.5-2iter/s | |
| dataset = dataset_cls(file_path, | |
| reso, | |
| reso_encoder, | |
| test=False, | |
| preprocess=preprocess, | |
| load_depth=load_depth, | |
| imgnet_normalize=imgnet_normalize, | |
| dataset_size=dataset_size) | |
| logger.log('dataset_cls: {}, dataset size: {}'.format( | |
| trainer_name, len(dataset))) | |
| if use_chunk: | |
| def chunk_collate_fn(sample): | |
| # st() | |
| default_collate_sample = torch.utils.data.default_collate( | |
| sample[0]) | |
| st() | |
| return default_collate_sample | |
| collate_fn = chunk_collate_fn | |
| else: | |
| collate_fn = None | |
| loader = DataLoader(dataset, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| drop_last=False, | |
| pin_memory=True, | |
| persistent_workers=num_workers > 0, | |
| shuffle=use_chunk, | |
| collate_fn=collate_fn) | |
| return loader | |
| def chunk_collate_fn(sample): | |
| sample = torch.utils.data.default_collate(sample) | |
| # ! change from stack to cat | |
| # sample = self.post_process.prepare_mv_input(sample) | |
| bs = len(sample['caption']) # number of instances | |
| # chunk_size = sample['img'].shape[0] // bs | |
| def merge_internal_batch(sample, merge_b_only=False): | |
| for k, v in sample.items(): | |
| if isinstance(v, torch.Tensor): | |
| if v.ndim > 1: | |
| if k == 'fps_pcd' or merge_b_only: | |
| sample[k] = rearrange( | |
| v, | |
| "b1 b2 ... -> (b1 b2) ...").float().contiguous() | |
| else: | |
| sample[k] = rearrange( | |
| v, "b1 b2 f c ... -> (b1 b2 f) c ...").float( | |
| ).contiguous() | |
| elif k == 'tanfov': | |
| sample[k] = v[0].float().item() # tanfov. | |
| if isinstance(sample['c'], dict): # 3dgs | |
| merge_internal_batch(sample['c'], merge_b_only=True) | |
| merge_internal_batch(sample['nv_c'], merge_b_only=True) | |
| merge_internal_batch(sample) | |
| return sample | |
| def chunk_ddpm_collate_fn(sample): | |
| sample = torch.utils.data.default_collate(sample) | |
| # ! change from stack to cat | |
| # sample = self.post_process.prepare_mv_input(sample) | |
| # bs = len(sample['caption']) # number of instances | |
| # chunk_size = sample['img'].shape[0] // bs | |
| def merge_internal_batch(sample, merge_b_only=False): | |
| for k, v in sample.items(): | |
| if isinstance(v, torch.Tensor): | |
| if v.ndim > 1: | |
| # if k in ['c', 'latent']: | |
| sample[k] = rearrange( | |
| v, | |
| "b1 b2 ... -> (b1 b2) ...").float().contiguous() | |
| # else: # img | |
| # sample[k] = rearrange( | |
| # v, "b1 b2 f ... -> (b1 b2 f) ...").float( | |
| # ).contiguous() | |
| else: # caption & ins | |
| v = [v[i][0] for i in range(len(v))] | |
| merge_internal_batch(sample) | |
| # if 'caption' in sample: | |
| # sample['caption'] = sample['caption'][0] + sample['caption'][1] | |
| return sample | |
| def load_data_cls( | |
| file_path="", | |
| reso=64, | |
| reso_encoder=224, | |
| batch_size=1, | |
| # shuffle=True, | |
| num_workers=6, | |
| load_depth=False, | |
| preprocess=None, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| trainer_name='input_rec', | |
| use_lmdb=False, | |
| use_wds=False, | |
| use_chunk=False, | |
| use_lmdb_compressed=False, | |
| # plucker_embedding=False, | |
| # frame_0_as_canonical=False, | |
| infi_sampler=True, | |
| load_latent=False, | |
| return_dataset=False, | |
| load_caption_dataset=False, | |
| load_mv_dataset=False, | |
| **kwargs): | |
| # st() | |
| # dataset_cls = { | |
| # 'input_rec': MultiViewDataset, | |
| # 'nv': NovelViewDataset, | |
| # }[trainer_name] | |
| # st() | |
| # if use_lmdb: | |
| # logger.log('using LMDB dataset') | |
| # # dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later. | |
| # if 'nv' in trainer_name: | |
| # dataset_cls = Objv_LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
| # else: | |
| # dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
| # # dataset = dataset_cls(file_path) | |
| collate_fn = None | |
| if use_lmdb: | |
| logger.log('using LMDB dataset') | |
| # dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later. | |
| if use_lmdb_compressed: | |
| if 'nv' in trainer_name: | |
| dataset_cls = Objv_LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
| else: | |
| dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
| else: | |
| if 'nv' in trainer_name: | |
| dataset_cls = Objv_LMDBDataset_NV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
| else: | |
| dataset_cls = Objv_LMDBDataset_MV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
| elif use_chunk: | |
| if load_latent: | |
| # if 'gs_cam_format' in kwargs: | |
| if kwargs['gs_cam_format']: | |
| if load_caption_dataset: | |
| dataset_cls = ChunkObjaverseDatasetDDPMgsT23D | |
| collate_fn = chunk_ddpm_collate_fn | |
| else: | |
| if load_mv_dataset: | |
| # dataset_cls = ChunkObjaverseDatasetDDPMgsMV23D # ! if multi-view | |
| dataset_cls = ChunkObjaverseDatasetDDPMgsMV23DSynthetic # ! if multi-view | |
| # collate_fn = chunk_ddpm_collate_fn | |
| collate_fn = None | |
| else: | |
| dataset_cls = ChunkObjaverseDatasetDDPMgsI23D | |
| collate_fn = None | |
| else: | |
| dataset_cls = ChunkObjaverseDatasetDDPM | |
| collate_fn = chunk_ddpm_collate_fn | |
| else: | |
| dataset_cls = ChunkObjaverseDataset | |
| collate_fn = chunk_collate_fn | |
| else: | |
| if 'nv' in trainer_name: | |
| dataset_cls = NovelViewObjverseDataset # 1.5-2iter/s | |
| else: | |
| dataset_cls = MultiViewObjverseDataset | |
| dataset = dataset_cls(file_path, | |
| reso, | |
| reso_encoder, | |
| test=False, | |
| preprocess=preprocess, | |
| load_depth=load_depth, | |
| imgnet_normalize=imgnet_normalize, | |
| dataset_size=dataset_size, | |
| **kwargs | |
| # plucker_embedding=plucker_embedding | |
| ) | |
| logger.log('dataset_cls: {}, dataset size: {}'.format( | |
| trainer_name, len(dataset))) | |
| # st() | |
| return dataset | |
| def load_data( | |
| file_path="", | |
| reso=64, | |
| reso_encoder=224, | |
| batch_size=1, | |
| # shuffle=True, | |
| num_workers=6, | |
| load_depth=False, | |
| preprocess=None, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| trainer_name='input_rec', | |
| use_lmdb=False, | |
| use_wds=False, | |
| use_chunk=False, | |
| use_lmdb_compressed=False, | |
| # plucker_embedding=False, | |
| # frame_0_as_canonical=False, | |
| infi_sampler=True, | |
| load_latent=False, | |
| return_dataset=False, | |
| load_caption_dataset=False, | |
| load_mv_dataset=False, | |
| **kwargs): | |
| # st() | |
| # dataset_cls = { | |
| # 'input_rec': MultiViewDataset, | |
| # 'nv': NovelViewDataset, | |
| # }[trainer_name] | |
| # st() | |
| # if use_lmdb: | |
| # logger.log('using LMDB dataset') | |
| # # dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later. | |
| # if 'nv' in trainer_name: | |
| # dataset_cls = Objv_LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
| # else: | |
| # dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
| # # dataset = dataset_cls(file_path) | |
| collate_fn = None | |
| if use_lmdb: | |
| logger.log('using LMDB dataset') | |
| # dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later. | |
| if use_lmdb_compressed: | |
| if 'nv' in trainer_name: | |
| dataset_cls = Objv_LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
| else: | |
| dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
| else: | |
| if 'nv' in trainer_name: | |
| dataset_cls = Objv_LMDBDataset_NV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
| else: | |
| dataset_cls = Objv_LMDBDataset_MV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
| elif use_chunk: | |
| # st() | |
| if load_latent: | |
| if kwargs['gs_cam_format']: | |
| if load_caption_dataset: | |
| dataset_cls = ChunkObjaverseDatasetDDPMgsT23D | |
| # collate_fn = chunk_ddpm_collate_fn | |
| collate_fn = None | |
| else: | |
| if load_mv_dataset: | |
| # dataset_cls = ChunkObjaverseDatasetDDPMgsMV23D | |
| dataset_cls = ChunkObjaverseDatasetDDPMgsMV23DSynthetic # ! if multi-view | |
| # collate_fn = chunk_ddpm_collate_fn | |
| collate_fn = None | |
| else: | |
| # dataset_cls = ChunkObjaverseDatasetDDPMgsI23D # load i23d | |
| # collate_fn = None | |
| # load mv dataset for i23d | |
| dataset_cls = ChunkObjaverseDatasetDDPMgsI23D_loadMV | |
| collate_fn = chunk_ddpm_collate_fn | |
| else: | |
| dataset_cls = ChunkObjaverseDatasetDDPM | |
| collate_fn = chunk_ddpm_collate_fn | |
| else: | |
| dataset_cls = ChunkObjaverseDataset | |
| collate_fn = chunk_collate_fn | |
| else: | |
| if 'nv' in trainer_name: | |
| dataset_cls = NovelViewObjverseDataset # 1.5-2iter/s | |
| else: | |
| dataset_cls = MultiViewObjverseDataset | |
| dataset = dataset_cls(file_path, | |
| reso, | |
| reso_encoder, | |
| test=False, | |
| preprocess=preprocess, | |
| load_depth=load_depth, | |
| imgnet_normalize=imgnet_normalize, | |
| dataset_size=dataset_size, | |
| **kwargs | |
| # plucker_embedding=plucker_embedding | |
| ) | |
| logger.log('dataset_cls: {}, dataset size: {}'.format( | |
| trainer_name, len(dataset))) | |
| # st() | |
| if return_dataset: | |
| return dataset | |
| assert infi_sampler | |
| if infi_sampler: | |
| train_sampler = DistributedSampler(dataset=dataset, | |
| shuffle=True, | |
| drop_last=True) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| drop_last=True, | |
| pin_memory=True, | |
| persistent_workers=num_workers > 0, | |
| sampler=train_sampler, | |
| collate_fn=collate_fn, | |
| # prefetch_factor=3 if num_workers>0 else None, | |
| ) | |
| while True: | |
| yield from loader | |
| # else: | |
| # # loader = DataLoader(dataset, | |
| # # batch_size=batch_size, | |
| # # num_workers=num_workers, | |
| # # drop_last=False, | |
| # # pin_memory=True, | |
| # # persistent_workers=num_workers > 0, | |
| # # shuffle=False) | |
| # st() | |
| # return dataset | |
| def load_eval_data( | |
| file_path="", | |
| reso=64, | |
| reso_encoder=224, | |
| batch_size=1, | |
| num_workers=1, | |
| load_depth=False, | |
| preprocess=None, | |
| imgnet_normalize=True, | |
| interval=1, | |
| use_lmdb=False, | |
| plucker_embedding=False, | |
| load_real=False, | |
| load_mv_real=False, | |
| load_gso=False, | |
| four_view_for_latent=False, | |
| shuffle_across_cls=False, | |
| load_extra_36_view=False, | |
| gs_cam_format=False, | |
| single_view_for_i23d=False, | |
| use_chunk=False, | |
| **kwargs, | |
| ): | |
| collate_fn = None | |
| if use_lmdb: | |
| logger.log('using LMDB dataset') | |
| dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
| dataset = dataset_cls(file_path, | |
| reso, | |
| reso_encoder, | |
| test=True, | |
| preprocess=preprocess, | |
| load_depth=load_depth, | |
| imgnet_normalize=imgnet_normalize, | |
| interval=interval) | |
| elif use_chunk: | |
| dataset = ChunkObjaverseDataset( | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| test=False, | |
| preprocess=preprocess, | |
| load_depth=load_depth, | |
| imgnet_normalize=imgnet_normalize, | |
| # dataset_size=dataset_size, | |
| gs_cam_format=gs_cam_format, | |
| plucker_embedding=plucker_embedding, | |
| wds_split_all=2, | |
| # frame_0_as_canonical=frame_0_as_canonical, | |
| **kwargs) | |
| collate_fn = chunk_collate_fn | |
| elif load_real: | |
| if load_mv_real: | |
| dataset_cls = RealMVDataset | |
| elif load_gso: | |
| # st() | |
| dataset_cls = RealDataset_GSO | |
| else: # single-view i23d | |
| dataset_cls = RealDataset | |
| dataset = dataset_cls(file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=preprocess, | |
| load_depth=load_depth, | |
| test=True, | |
| imgnet_normalize=imgnet_normalize, | |
| interval=interval, | |
| plucker_embedding=plucker_embedding) | |
| else: | |
| dataset = MultiViewObjverseDataset( | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=preprocess, | |
| load_depth=load_depth, | |
| test=True, | |
| imgnet_normalize=imgnet_normalize, | |
| interval=interval, | |
| plucker_embedding=plucker_embedding, | |
| four_view_for_latent=four_view_for_latent, | |
| load_extra_36_view=load_extra_36_view, | |
| shuffle_across_cls=shuffle_across_cls, | |
| gs_cam_format=gs_cam_format, | |
| single_view_for_i23d=single_view_for_i23d, | |
| **kwargs) | |
| print('eval dataset size: {}'.format(len(dataset))) | |
| # train_sampler = DistributedSampler(dataset=dataset) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| drop_last=False, | |
| shuffle=False, | |
| collate_fn=collate_fn, | |
| ) | |
| # sampler=train_sampler) | |
| # return loader | |
| return iter(loader) | |
| def load_data_for_lmdb( | |
| file_path="", | |
| reso=64, | |
| reso_encoder=224, | |
| batch_size=1, | |
| # shuffle=True, | |
| num_workers=6, | |
| load_depth=False, | |
| preprocess=None, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| trainer_name='input_rec', | |
| shuffle_across_cls=False, | |
| four_view_for_latent=False, | |
| wds_split=1): | |
| # st() | |
| # dataset_cls = { | |
| # 'input_rec': MultiViewDataset, | |
| # 'nv': NovelViewDataset, | |
| # }[trainer_name] | |
| # if 'nv' in trainer_name: | |
| # dataset_cls = NovelViewDataset | |
| # else: | |
| # dataset_cls = MultiViewDataset | |
| # st() | |
| # dataset_cls = MultiViewObjverseDatasetforLMDB | |
| dataset_cls = MultiViewObjverseDatasetforLMDB_nocaption | |
| dataset = dataset_cls(file_path, | |
| reso, | |
| reso_encoder, | |
| test=False, | |
| preprocess=preprocess, | |
| load_depth=load_depth, | |
| imgnet_normalize=imgnet_normalize, | |
| dataset_size=dataset_size, | |
| shuffle_across_cls=shuffle_across_cls, | |
| wds_split=wds_split, | |
| four_view_for_latent=four_view_for_latent) | |
| logger.log('dataset_cls: {}, dataset size: {}'.format( | |
| trainer_name, len(dataset))) | |
| # train_sampler = DistributedSampler(dataset=dataset, shuffle=True, drop_last=True) | |
| loader = DataLoader( | |
| dataset, | |
| shuffle=False, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| drop_last=False, | |
| # prefetch_factor=2, | |
| # prefetch_factor=3, | |
| pin_memory=True, | |
| persistent_workers=num_workers > 0, | |
| ) | |
| # sampler=train_sampler) | |
| # while True: | |
| # yield from loader | |
| return loader, dataset.dataset_name, len(dataset) | |
| def load_lmdb_for_lmdb( | |
| file_path="", | |
| reso=64, | |
| reso_encoder=224, | |
| batch_size=1, | |
| # shuffle=True, | |
| num_workers=6, | |
| load_depth=False, | |
| preprocess=None, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| trainer_name='input_rec'): | |
| # st() | |
| # dataset_cls = { | |
| # 'input_rec': MultiViewDataset, | |
| # 'nv': NovelViewDataset, | |
| # }[trainer_name] | |
| # if 'nv' in trainer_name: | |
| # dataset_cls = NovelViewDataset | |
| # else: | |
| # dataset_cls = MultiViewDataset | |
| # st() | |
| dataset_cls = Objv_LMDBDataset_MV_Compressed_for_lmdb | |
| dataset = dataset_cls(file_path, | |
| reso, | |
| reso_encoder, | |
| test=False, | |
| preprocess=preprocess, | |
| load_depth=load_depth, | |
| imgnet_normalize=imgnet_normalize, | |
| dataset_size=dataset_size) | |
| logger.log('dataset_cls: {}, dataset size: {}'.format( | |
| trainer_name, len(dataset))) | |
| # train_sampler = DistributedSampler(dataset=dataset, shuffle=True, drop_last=True) | |
| loader = DataLoader( | |
| dataset, | |
| shuffle=False, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| drop_last=False, | |
| # prefetch_factor=2, | |
| # prefetch_factor=3, | |
| pin_memory=True, | |
| persistent_workers=True, | |
| ) | |
| # sampler=train_sampler) | |
| # while True: | |
| # yield from loader | |
| return loader, len(dataset) | |
| def load_memory_data( | |
| file_path="", | |
| reso=64, | |
| reso_encoder=224, | |
| batch_size=1, | |
| num_workers=1, | |
| # load_depth=True, | |
| preprocess=None, | |
| imgnet_normalize=True, | |
| use_chunk=True, | |
| **kwargs): | |
| # load a single-instance into the memory to speed up training IO | |
| # dataset = MultiViewObjverseDataset(file_path, | |
| collate_fn = None | |
| if use_chunk: | |
| dataset_cls = ChunkObjaverseDataset | |
| collate_fn = chunk_collate_fn | |
| else: | |
| dataset_cls = NovelViewObjverseDataset | |
| dataset = dataset_cls(file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=preprocess, | |
| load_depth=True, | |
| test=False, | |
| overfitting=True, | |
| imgnet_normalize=imgnet_normalize, | |
| overfitting_bs=batch_size, | |
| **kwargs) | |
| logger.log('!!!!!!! memory dataset size: {} !!!!!!'.format(len(dataset))) | |
| # train_sampler = DistributedSampler(dataset=dataset) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=len(dataset), | |
| num_workers=num_workers, | |
| drop_last=False, | |
| shuffle=False, | |
| collate_fn = collate_fn | |
| ) | |
| all_data: dict = next( | |
| iter(loader) | |
| ) # torchvision.utils.save_image(all_data['img'], 'gt.jpg', normalize=True, value_range=(-1,1)) | |
| # st() | |
| if kwargs.get('gs_cam_format', False): # gs rendering pipeline | |
| # ! load V=4 images for training in a batch. | |
| while True: | |
| # st() | |
| # indices = torch.randperm(len(dataset))[:4] | |
| indices = torch.randperm( | |
| len(dataset) * 2)[:batch_size] # all instances | |
| # indices2 = torch.randperm(len(dataset))[:] # all instances | |
| batch_c = collections.defaultdict(dict) | |
| V = all_data['c']['source_cv2wT_quat'].shape[1] | |
| for k in ['c', 'nv_c']: | |
| for k_c, v_c in all_data[k].items(): | |
| if k_c == 'tanfov': | |
| continue | |
| try: | |
| batch_c[k][ | |
| k_c] = torch.index_select( # ! chunk data reading pipeline | |
| v_c, | |
| dim=0, | |
| index=indices | |
| ).reshape(batch_size, V, *v_c.shape[2:]).float( | |
| ) if isinstance( | |
| v_c, | |
| torch.Tensor) else v_c # float | |
| except Exception as e: | |
| st() | |
| print(e) | |
| # ! read chunk not required, already float | |
| batch_c['c']['tanfov'] = all_data['c']['tanfov'] | |
| batch_c['nv_c']['tanfov'] = all_data['nv_c']['tanfov'] | |
| indices_range = torch.arange(indices[0]*V, (indices[0]+1)*V) | |
| batch_data = {} | |
| for k, v in all_data.items(): | |
| if k not in ['c', 'nv_c']: | |
| try: | |
| if k == 'fps_pcd': | |
| batch_data[k] = torch.index_select( | |
| v, dim=0, index=indices).float() if isinstance( | |
| v, torch.Tensor) else v # float | |
| else: | |
| batch_data[k] = torch.index_select( | |
| v, dim=0, index=indices_range).float() if isinstance( | |
| v, torch.Tensor) else v # float | |
| except: | |
| st() | |
| print(e) | |
| memory_batch_data = { | |
| **batch_data, | |
| **batch_c, | |
| } | |
| yield memory_batch_data | |
| else: | |
| while True: | |
| start_idx = np.random.randint(0, len(dataset) - batch_size + 1) | |
| yield { | |
| k: v[start_idx:start_idx + batch_size] | |
| for k, v in all_data.items() | |
| } | |
| def read_dnormal(normald_path, cond_pos, h=None, w=None): | |
| cond_cam_dis = np.linalg.norm(cond_pos, 2) | |
| near = 0.867 #sqrt(3) * 0.5 | |
| near_distance = cond_cam_dis - near | |
| normald = cv2.imread(normald_path, cv2.IMREAD_UNCHANGED).astype(np.float32) | |
| normal, depth = normald[..., :3], normald[..., 3:] | |
| depth[depth < near_distance] = 0 | |
| if h is not None: | |
| assert w is not None | |
| if depth.shape[1] != h: | |
| depth = cv2.resize(depth, (h, w), interpolation=cv2.INTER_NEAREST | |
| ) # 512,512, 1 -> self.reso, self.reso | |
| # depth = cv2.resize(depth, (h, w), interpolation=cv2.INTER_LANCZOS4 | |
| # ) # ! may fail if nearest. dirty data. | |
| # st() | |
| else: | |
| depth = depth[..., 0] | |
| if normal.shape[1] != h: | |
| normal = cv2.resize(normal, (h, w), | |
| interpolation=cv2.INTER_NEAREST | |
| ) # 512,512, 1 -> self.reso, self.reso | |
| else: | |
| depth = depth[..., 0] | |
| return torch.from_numpy(depth).float(), torch.from_numpy(normal).float() | |
| def get_intri(target_im=None, h=None, w=None, normalize=False): | |
| if target_im is None: | |
| assert (h is not None and w is not None) | |
| else: | |
| h, w = target_im.shape[:2] | |
| fx = fy = 1422.222 | |
| res_raw = 1024 | |
| f_x = f_y = fx * h / res_raw | |
| K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3) | |
| if normalize: # center is [0.5, 0.5], eg3d renderer tradition | |
| K[:6] /= h | |
| # print("intr: ", K) | |
| return K | |
| def convert_pose(C2W): | |
| # https://github.com/modelscope/richdreamer/blob/c3d9a77fa15fc42dbae12c2d41d64aaec14efd37/dataset/gobjaverse/depth_warp_example.py#L402 | |
| flip_yz = np.eye(4) | |
| flip_yz[1, 1] = -1 | |
| flip_yz[2, 2] = -1 | |
| C2W = np.matmul(C2W, flip_yz) | |
| return torch.from_numpy(C2W) | |
| def read_camera_matrix_single(json_file): | |
| with open(json_file, 'r', encoding='utf8') as reader: | |
| json_content = json.load(reader) | |
| ''' | |
| # NOTE that different from unity2blender experiments. | |
| camera_matrix = np.eye(4) | |
| camera_matrix[:3, 0] = np.array(json_content['x']) | |
| camera_matrix[:3, 1] = -np.array(json_content['y']) | |
| camera_matrix[:3, 2] = -np.array(json_content['z']) | |
| camera_matrix[:3, 3] = np.array(json_content['origin']) | |
| ''' | |
| camera_matrix = np.eye(4) # blender-based | |
| camera_matrix[:3, 0] = np.array(json_content['x']) | |
| camera_matrix[:3, 1] = np.array(json_content['y']) | |
| camera_matrix[:3, 2] = np.array(json_content['z']) | |
| camera_matrix[:3, 3] = np.array(json_content['origin']) | |
| # print(camera_matrix) | |
| # ''' | |
| # return convert_pose(camera_matrix) | |
| return camera_matrix | |
| def unity2blender(normal): | |
| normal_clone = normal.copy() | |
| normal_clone[..., 0] = -normal[..., -1] | |
| normal_clone[..., 1] = -normal[..., 0] | |
| normal_clone[..., 2] = normal[..., 1] | |
| return normal_clone | |
| def unity2blender_fix(normal): # up blue, left green, front (towards inside) red | |
| normal_clone = normal.copy() | |
| # normal_clone[..., 0] = -normal[..., 2] | |
| # normal_clone[..., 1] = -normal[..., 0] | |
| normal_clone[..., 0] = -normal[..., 0] # swap r and g | |
| normal_clone[..., 1] = -normal[..., 2] | |
| normal_clone[..., 2] = normal[..., 1] | |
| return normal_clone | |
| def unity2blender_th(normal): | |
| assert normal.shape[1] == 3 # B 3 H W... | |
| normal_clone = normal.clone() | |
| normal_clone[:, 0, ...] = -normal[:, -1, ...] | |
| normal_clone[:, 1, ...] = -normal[:, 0, ...] | |
| normal_clone[:, 2, ...] = normal[:, 1, ...] | |
| return normal_clone | |
| def blender2midas(img): | |
| '''Blender: rub | |
| midas: lub | |
| ''' | |
| img[..., 0] = -img[..., 0] | |
| img[..., 1] = -img[..., 1] | |
| img[..., -1] = -img[..., -1] | |
| return img | |
| def current_milli_time(): | |
| return round(time.time() * 1000) | |
| # modified from ShapeNet class | |
| class MultiViewObjverseDataset(Dataset): | |
| def __init__( | |
| self, | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=None, | |
| classes=False, | |
| load_depth=False, | |
| test=False, | |
| scene_scale=1, | |
| overfitting=False, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| overfitting_bs=-1, | |
| interval=1, | |
| plucker_embedding=False, | |
| shuffle_across_cls=False, | |
| wds_split=1, # 4 splits to accelerate preprocessing | |
| four_view_for_latent=False, | |
| single_view_for_i23d=False, | |
| load_extra_36_view=False, | |
| gs_cam_format=False, | |
| frame_0_as_canonical=False, | |
| **kwargs): | |
| self.load_extra_36_view = load_extra_36_view | |
| # st() | |
| self.gs_cam_format = gs_cam_format | |
| self.frame_0_as_canonical = frame_0_as_canonical | |
| self.four_view_for_latent = four_view_for_latent # export 0 12 30 36, 4 views for reconstruction | |
| self.single_view_for_i23d = single_view_for_i23d | |
| self.file_path = file_path | |
| self.overfitting = overfitting | |
| self.scene_scale = scene_scale | |
| self.reso = reso | |
| self.reso_encoder = reso_encoder | |
| self.classes = False | |
| self.load_depth = load_depth | |
| self.preprocess = preprocess | |
| self.plucker_embedding = plucker_embedding | |
| self.intrinsics = get_intri(h=self.reso, w=self.reso, | |
| normalize=True).reshape(9) | |
| assert not self.classes, "Not support class condition now." | |
| dataset_name = Path(self.file_path).stem.split('_')[0] | |
| self.dataset_name = dataset_name | |
| self.zfar = 100.0 | |
| self.znear = 0.01 | |
| # if test: | |
| # self.ins_list = sorted(os.listdir(self.file_path))[0:1] # the first 1 instance for evaluation reference. | |
| # else: | |
| # ! TODO, read from list? | |
| def load_single_cls_instances(file_path): | |
| ins_list = [] # the first 1 instance for evaluation reference. | |
| # ''' | |
| # for dict_dir in os.listdir(file_path)[:]: | |
| # for dict_dir in os.listdir(file_path)[:]: | |
| for dict_dir in os.listdir(file_path): | |
| # for dict_dir in os.listdir(file_path)[:2]: | |
| for ins_dir in os.listdir(os.path.join(file_path, dict_dir)): | |
| # self.ins_list.append(os.path.join(self.file_path, dict_dir, ins_dir,)) | |
| # /nas/shared/V2V/yslan/logs/nips23/Reconstruction/final/objav/vae/MV/170K/infer-latents/189w/v=6-rotate/latent_dir | |
| # st() # check latent whether saved | |
| # root = '/nas/shared/V2V/yslan/logs/nips23/Reconstruction/final/objav/vae/MV/170K/infer-latents/189w/v=6-rotate/latent_dir' | |
| # if os.path.exists(os.path.join(root,file_path.split('/')[-1], dict_dir, ins_dir, 'latent.npy') ): | |
| # continue | |
| # pcd_root = '/nas/shared/V2V/yslan/logs/nips23/Reconstruction/pcd-V=8_24576_polish' | |
| # pcd_root = '/nas/shared/V2V/yslan/logs/nips23/Reconstruction/pcd-V=10_4096_polish' | |
| # if os.path.exists( | |
| # os.path.join(pcd_root, 'fps-pcd', | |
| # file_path.split('/')[-1], dict_dir, | |
| # ins_dir, 'fps-4096.ply')): | |
| # continue | |
| # ! split=8 has some missing instances | |
| # root = '/cpfs01/user/lanyushi.p/data/chunk-jpeg-normal/bs_16_fixsave3/170K/384/' | |
| # if os.path.exists(os.path.join(root,file_path.split('/')[-1], dict_dir, ins_dir,) ): | |
| # continue | |
| # else: | |
| # ins_list.append( | |
| # os.path.join(file_path, dict_dir, ins_dir, | |
| # 'campos_512_v4')) | |
| # filter out some data | |
| if not os.path.exists(os.path.join(file_path, dict_dir, ins_dir, 'campos_512_v2')): | |
| continue | |
| if not os.path.exists(os.path.join(file_path, dict_dir, ins_dir, 'campos_512_v2', '00025', '00025.png')): | |
| continue | |
| if len(os.listdir(os.path.join(file_path, dict_dir, ins_dir, 'campos_512_v2'))) != 40: | |
| continue | |
| ins_list.append( | |
| os.path.join(file_path, dict_dir, ins_dir, | |
| 'campos_512_v2')) | |
| # ''' | |
| # check pcd performnace | |
| # ins_list.append( | |
| # os.path.join(file_path, '0', '10634', | |
| # 'campos_512_v4')) | |
| return ins_list | |
| # st() | |
| self.ins_list = [] | |
| # for subset in ['Animals', 'Transportations_tar', 'Furnitures']: | |
| # for subset in ['Furnitures']: | |
| # selected subset for training | |
| # if False: | |
| if True: | |
| for subset in [ # ! around 17W instances in total. | |
| # 'Animals', | |
| # 'BuildingsOutdoor', | |
| # 'daily-used', | |
| # 'Furnitures', | |
| # 'Food', | |
| # 'Plants', | |
| # 'Electronics', | |
| # 'Transportations_tar', | |
| # 'Human-Shape', | |
| 'gobjaverse_alignment_unzip', | |
| ]: # selected subset for training | |
| # if os.path.exists(f'{self.file_path}/{subset}.txt'): | |
| # dataset_list = f'{self.file_path}/{subset}_filtered.txt' | |
| dataset_list = f'{self.file_path}/{subset}_filtered_more.txt' | |
| assert os.path.exists(dataset_list) | |
| if os.path.exists(dataset_list): | |
| with open(dataset_list, 'r') as f: | |
| self.ins_list += [os.path.join(self.file_path, item.strip()) for item in f.readlines()] | |
| else: | |
| self.ins_list += load_single_cls_instances( | |
| os.path.join(self.file_path, subset)) | |
| # st() | |
| # current_time = int(current_milli_time() | |
| # ) # randomly shuffle given current time | |
| # random.seed(current_time) | |
| # random.shuffle(self.ins_list) | |
| else: # preprocess single class | |
| self.ins_list = load_single_cls_instances(self.file_path) | |
| self.ins_list = sorted(self.ins_list) | |
| if overfitting: | |
| self.ins_list = self.ins_list[:1] | |
| self.rgb_list = [] | |
| self.frame0_pose_list = [] | |
| self.pose_list = [] | |
| self.depth_list = [] | |
| self.data_ins_list = [] | |
| self.instance_data_length = -1 | |
| # self.pcd_path = Path('/cpfs01/shared/V2V/V2V_hdd/yslan/logs/nips23/Reconstruction/pcd-V=6/fps-pcd') | |
| self.pcd_path = Path( | |
| '/nas/shared/V2V/yslan/logs/nips23/Reconstruction/pcd-V=6/fps-pcd') | |
| with open( | |
| '/nas/shared/public/yslan/data/text_captions_cap3d.json') as f: | |
| # '/nas/shared/V2V/yslan/aigc3d/text_captions_cap3d.json') as f: | |
| self.caption_data = json.load(f) | |
| self.shuffle_across_cls = shuffle_across_cls | |
| # for ins in self.ins_list[47000:]: | |
| if four_view_for_latent: # also saving dense pcd | |
| # self.wds_split_all = 1 # ! when dumping latent | |
| # self.wds_split_all = 2 # ! when dumping latent | |
| # self.wds_split_all = 4 | |
| # self.wds_split_all = 6 | |
| # self.wds_split_all = 4 | |
| # self.wds_split_all = 5 | |
| # self.wds_split_all = 6 | |
| # self.wds_split_all = 7 | |
| # self.wds_split_all = 1 | |
| self.wds_split_all = 8 | |
| # self.wds_split_all = 2 | |
| # ins_list_to_process = self.ins_list | |
| all_ins_size = len(self.ins_list) | |
| ratio_size = all_ins_size // self.wds_split_all + 1 | |
| # ratio_size = int(all_ins_size / self.wds_split_all) + 1 | |
| ins_list_to_process = self.ins_list[ratio_size * | |
| (wds_split):ratio_size * | |
| (wds_split + 1)] | |
| else: # ! create shards dataset | |
| # self.wds_split_all = 4 | |
| self.wds_split_all = 8 | |
| # self.wds_split_all = 1 | |
| all_ins_size = len(self.ins_list) | |
| random.seed(0) | |
| random.shuffle(self.ins_list) # avoid same category appears in the same shard | |
| ratio_size = all_ins_size // self.wds_split_all + 1 | |
| ins_list_to_process = self.ins_list[ratio_size * # 1 - 8 | |
| (wds_split - 1):ratio_size * | |
| wds_split] | |
| # uniform_sample = False | |
| uniform_sample = True | |
| # st() | |
| for ins in tqdm(ins_list_to_process): | |
| # ins = os.path.join( | |
| # # self.file_path, ins , 'campos_512_v4' | |
| # self.file_path, ins , | |
| # # 'compos_512_v4' | |
| # ) | |
| # cur_rgb_path = os.path.join(self.file_path, ins, 'compos_512_v4') | |
| # cur_pose_path = os.path.join(self.file_path, ins, 'pose') | |
| # st() | |
| # ][:27]) | |
| if self.four_view_for_latent: | |
| # cur_all_fname = [t.split('.')[0] for t in os.listdir(ins) | |
| # ] # use full set for training | |
| # cur_all_fname = [f'{idx:05d}' for idx in [0, 12, 30, 36] | |
| # cur_all_fname = [f'{idx:05d}' for idx in [6,12,18,24] | |
| # cur_all_fname = [f'{idx:05d}' for idx in [7,16,24,25] | |
| # cur_all_fname = [f'{idx:05d}' for idx in [25,26,0,9,18,27,33,39]] | |
| cur_all_fname = [ | |
| f'{idx:05d}' | |
| for idx in [25, 26, 6, 12, 18, 24, 27, 31, 35, 39] # ! for extracting PCD | |
| ] | |
| # cur_all_fname = [f'{idx:05d}' for idx in [25,26,0,9,18,27,30,33,36,39]] # more down side for better bottom coverage. | |
| # cur_all_fname = [f'{idx:05d}' for idx in [25,0, 7,15]] | |
| # cur_all_fname = [f'{idx:05d}' for idx in [4,12,20,25,26] | |
| # cur_all_fname = [f'{idx:05d}' for idx in [6,12,18,24,25,26] | |
| # cur_all_fname = [f'{idx:05d}' for idx in [6,12,18,24,25,26, 39, 33, 27] | |
| # cur_all_fname = [f'{idx:05d}' for idx in [6,12,18,24,25,26, 39, 33, 27] | |
| # cur_all_fname = [ | |
| # f'{idx:05d}' for idx in [25, 26, 27, 30, 33, 36] | |
| # ] # for pcd unprojection | |
| # cur_all_fname = [ | |
| # f'{idx:05d}' for idx in [25, 26, 27, 30] # ! for infer latents | |
| # ] # | |
| # cur_all_fname = [ | |
| # f'{idx:05d}' for idx in [25, 27, 29, 31, 33, 35, 37 | |
| # ] # ! for infer latents | |
| # ] # | |
| # cur_all_fname = [ | |
| # f'{idx:05d}' for idx in [25, 27, 31, 35 | |
| # ] # ! for infer latents | |
| # ] # | |
| # cur_all_fname += [f'{idx:05d}' for idx in range(40) if idx not in [0,12,30,36]] # ! four views for inference | |
| elif self.single_view_for_i23d: | |
| # cur_all_fname = [f'{idx:05d}' | |
| # for idx in [16]] # 20 is also fine | |
| cur_all_fname = [f'{idx:05d}' | |
| for idx in [2]] # ! furniture side view | |
| else: | |
| cur_all_fname = [t.split('.')[0] for t in os.listdir(ins) | |
| ] # use full set for training | |
| if shuffle_across_cls: | |
| if uniform_sample: | |
| cur_all_fname = sorted(cur_all_fname) | |
| # 0-24, 25 views | |
| # 25,26, 2 views | |
| # 27-39, 13 views | |
| uniform_all_fname = [] | |
| # !!!! if bs=9 or 8 | |
| for idx in range(6): | |
| if idx % 2 == 0: | |
| chunk_all_fname = [25] | |
| else: | |
| chunk_all_fname = [26] | |
| # chunk_all_fname = [25] # no bottom view required as input | |
| # start_1 = np.random.randint(0,5) # for first 24 views | |
| # chunk_all_fname += [start_1+uniform_idx for uniform_idx in range(0,25,5)] | |
| start_1 = np.random.randint(0,4) # for first 24 views, v=8 | |
| chunk_all_fname += [start_1+uniform_idx for uniform_idx in range(0,25,7)] # [0-21] | |
| start_2 = np.random.randint(0,5) + 27 # for first 24 views | |
| chunk_all_fname += [start_2, start_2 + 4, start_2 + 8] | |
| assert len(chunk_all_fname) == 8, len(chunk_all_fname) | |
| uniform_all_fname += [cur_all_fname[fname] for fname in chunk_all_fname] | |
| # ! if bs=6 | |
| # for idx in range(8): | |
| # if idx % 2 == 0: | |
| # chunk_all_fname = [ | |
| # 25 | |
| # ] # no bottom view required as input | |
| # else: | |
| # chunk_all_fname = [ | |
| # 26 | |
| # ] # no bottom view required as input | |
| # start_1 = np.random.randint( | |
| # 0, 7) # for first 24 views | |
| # # chunk_all_fname += [start_1+uniform_idx for uniform_idx in range(0,25,5)] | |
| # chunk_all_fname += [ | |
| # start_1 + uniform_idx | |
| # for uniform_idx in range(0, 25, 9) | |
| # ] # 0 9 18 | |
| # start_2 = np.random.randint( | |
| # 0, 7) + 27 # for first 24 views | |
| # # chunk_all_fname += [start_2, start_2 + 4, start_2 + 8] | |
| # chunk_all_fname += [start_2, | |
| # start_2 + 6] # 2 frames | |
| # assert len(chunk_all_fname) == 6 | |
| # uniform_all_fname += [ | |
| # cur_all_fname[fname] | |
| # for fname in chunk_all_fname | |
| # ] | |
| cur_all_fname = uniform_all_fname | |
| else: | |
| current_time = int(current_milli_time( | |
| )) # randomly shuffle given current time | |
| random.seed(current_time) | |
| random.shuffle(cur_all_fname) | |
| else: | |
| cur_all_fname = sorted(cur_all_fname) | |
| # ! skip the check | |
| # if self.instance_data_length == -1: | |
| # self.instance_data_length = len(cur_all_fname) | |
| # else: | |
| # try: # data missing? | |
| # assert len(cur_all_fname) == self.instance_data_length | |
| # except: | |
| # # with open('error_log.txt', 'a') as f: | |
| # # f.write(str(e) + '\n') | |
| # with open('missing_ins_new2.txt', 'a') as f: | |
| # f.write(str(Path(ins.parent)) + | |
| # '\n') # remove the "campos_512_v4" | |
| # continue | |
| # if test: # use middle image as the novel view model input | |
| # mid_index = len(cur_all_fname) // 3 * 2 | |
| # cur_all_fname.insert(0, cur_all_fname[mid_index]) | |
| self.frame0_pose_list += ([ | |
| os.path.join(ins, fname, fname + '.json') | |
| for fname in [cur_all_fname[0]] | |
| ] * len(cur_all_fname)) | |
| self.pose_list += ([ | |
| os.path.join(ins, fname, fname + '.json') | |
| for fname in cur_all_fname | |
| ]) | |
| self.rgb_list += ([ | |
| os.path.join(ins, fname, fname + '.png') | |
| for fname in cur_all_fname | |
| ]) | |
| self.depth_list += ([ | |
| os.path.join(ins, fname, fname + '_nd.exr') | |
| for fname in cur_all_fname | |
| ]) | |
| self.data_ins_list += ([ins] * len(cur_all_fname)) | |
| # check | |
| # ! setup normalizataion | |
| transformations = [ | |
| transforms.ToTensor(), # [0,1] range | |
| ] | |
| if imgnet_normalize: | |
| transformations.append( | |
| transforms.Normalize((0.485, 0.456, 0.406), | |
| (0.229, 0.224, 0.225)) # type: ignore | |
| ) | |
| else: | |
| transformations.append( | |
| transforms.Normalize((0.5, 0.5, 0.5), | |
| (0.5, 0.5, 0.5))) # type: ignore | |
| # st() | |
| self.normalize = transforms.Compose(transformations) | |
| def get_source_cw2wT(self, source_cameras_view_to_world): | |
| return matrix_to_quaternion( | |
| source_cameras_view_to_world[:3, :3].transpose(0, 1)) | |
| def c_to_3dgs_format(self, pose): | |
| # TODO, switch to torch version (batched later) | |
| c2w = pose[:16].reshape(4, 4) # 3x4 | |
| # ! load cam | |
| w2c = np.linalg.inv(c2w) | |
| R = np.transpose( | |
| w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code | |
| T = w2c[:3, 3] | |
| fx = pose[16] | |
| FovX = focal2fov(fx, 1) | |
| FovY = focal2fov(fx, 1) | |
| tanfovx = math.tan(FovX * 0.5) | |
| tanfovy = math.tan(FovY * 0.5) | |
| assert tanfovx == tanfovy | |
| trans = np.array([0.0, 0.0, 0.0]) | |
| scale = 1.0 | |
| world_view_transform = torch.tensor(getWorld2View2(R, T, trans, | |
| scale)).transpose( | |
| 0, 1) | |
| projection_matrix = getProjectionMatrix(znear=self.znear, | |
| zfar=self.zfar, | |
| fovX=FovX, | |
| fovY=FovY).transpose(0, 1) | |
| full_proj_transform = (world_view_transform.unsqueeze(0).bmm( | |
| projection_matrix.unsqueeze(0))).squeeze(0) | |
| camera_center = world_view_transform.inverse()[3, :3] | |
| view_world_transform = torch.tensor(getView2World(R, T, trans, | |
| scale)).transpose( | |
| 0, 1) | |
| # item.update(viewpoint_cam=[viewpoint_cam]) | |
| c = {} | |
| c["source_cv2wT_quat"] = self.get_source_cw2wT(view_world_transform) | |
| c.update( | |
| # projection_matrix=projection_matrix, # K | |
| cam_view=world_view_transform, # world_view_transform | |
| cam_view_proj=full_proj_transform, # full_proj_transform | |
| cam_pos=camera_center, | |
| tanfov=tanfovx, # TODO, fix in the renderer | |
| # orig_c2w=c2w, | |
| # orig_w2c=w2c, | |
| orig_pose=torch.from_numpy(pose), | |
| orig_c2w=torch.from_numpy(c2w), | |
| orig_w2c=torch.from_numpy(w2c), | |
| # tanfovy=tanfovy, | |
| ) | |
| return c # dict for gs rendering | |
| def __len__(self): | |
| return len(self.rgb_list) | |
| def load_bbox(self, mask): | |
| # st() | |
| nonzero_value = torch.nonzero(mask) | |
| height, width = nonzero_value.max(dim=0)[0] | |
| top, left = nonzero_value.min(dim=0)[0] | |
| bbox = torch.tensor([top, left, height, width], dtype=torch.float32) | |
| return bbox | |
| def __getitem__(self, idx): | |
| # try: | |
| data = self._read_data(idx) | |
| return data | |
| # except Exception as e: | |
| # # with open('error_log_pcd.txt', 'a') as f: | |
| # with open('error_log_pcd.txt', 'a') as f: | |
| # f.write(str(e) + '\n') | |
| # with open('error_idx_pcd.txt', 'a') as f: | |
| # f.write(str(self.data_ins_list[idx]) + '\n') | |
| # print(e, flush=True) | |
| # return {} | |
| def gen_rays(self, c2w): | |
| # Generate rays | |
| self.h = self.reso_encoder | |
| self.w = self.reso_encoder | |
| yy, xx = torch.meshgrid( | |
| torch.arange(self.h, dtype=torch.float32) + 0.5, | |
| torch.arange(self.w, dtype=torch.float32) + 0.5, | |
| indexing='ij') | |
| # normalize to 0-1 pixel range | |
| yy = yy / self.h | |
| xx = xx / self.w | |
| # K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3) | |
| cx, cy, fx, fy = self.intrinsics[2], self.intrinsics[ | |
| 5], self.intrinsics[0], self.intrinsics[4] | |
| # cx *= self.w | |
| # cy *= self.h | |
| # f_x = f_y = fx * h / res_raw | |
| c2w = torch.from_numpy(c2w).float() | |
| xx = (xx - cx) / fx | |
| yy = (yy - cy) / fy | |
| zz = torch.ones_like(xx) | |
| dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention | |
| dirs /= torch.norm(dirs, dim=-1, keepdim=True) | |
| dirs = dirs.reshape(-1, 3, 1) | |
| del xx, yy, zz | |
| # st() | |
| dirs = (c2w[None, :3, :3] @ dirs)[..., 0] | |
| origins = c2w[None, :3, 3].expand(self.h * self.w, -1).contiguous() | |
| origins = origins.view(self.h, self.w, 3) | |
| dirs = dirs.view(self.h, self.w, 3) | |
| return origins, dirs | |
| def normalize_camera(self, c, c_frame0): | |
| # assert c.shape[0] == self.chunk_size # 8 o r10 | |
| B = c.shape[0] | |
| camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4 | |
| canonical_camera_poses = c_frame0[:, :16].reshape(B, 4, 4) | |
| # if for_encoder: | |
| # encoder_canonical_idx = [0, self.V] | |
| # st() | |
| cam_radius = np.linalg.norm( | |
| c_frame0[:, :16].reshape(1, 4, 4)[:, :3, 3], | |
| axis=-1, | |
| keepdims=False) # since g-buffer adopts dynamic radius here. | |
| frame1_fixed_pos = np.repeat(np.eye(4)[None], 1, axis=0) | |
| frame1_fixed_pos[:, 2, -1] = -cam_radius | |
| transform = frame1_fixed_pos @ np.linalg.inv(canonical_camera_poses) | |
| # from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127 | |
| # transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]]) | |
| new_camera_poses = np.repeat( | |
| transform, 1, axis=0 | |
| ) @ camera_poses # [V, 4, 4]. np.repeat() is th.repeat_interleave() | |
| # else: | |
| # cam_radius = np.linalg.norm( | |
| # c[canonical_idx][:16].reshape(4, 4)[:3, 3], | |
| # axis=-1, | |
| # keepdims=False | |
| # ) # since g-buffer adopts dynamic radius here. | |
| # frame1_fixed_pos = np.eye(4) | |
| # frame1_fixed_pos[2, -1] = -cam_radius | |
| # transform = frame1_fixed_pos @ np.linalg.inv( | |
| # camera_poses[canonical_idx]) # 4,4 | |
| # # from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127 | |
| # # transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]]) | |
| # new_camera_poses = np.repeat( | |
| # transform[None], self.chunk_size, | |
| # axis=0) @ camera_poses # [V, 4, 4] | |
| # st() | |
| c = np.concatenate([new_camera_poses.reshape(B, 16), c[:, 16:]], | |
| axis=-1) | |
| # st() | |
| return c | |
| def _read_data( | |
| self, | |
| idx, | |
| ): | |
| rgb_fname = self.rgb_list[idx] | |
| pose_fname = self.pose_list[idx] | |
| raw_img = imageio.imread(rgb_fname) | |
| # ! RGBD | |
| alpha_mask = raw_img[..., -1:] / 255 | |
| raw_img = alpha_mask * raw_img[..., :3] + ( | |
| 1 - alpha_mask) * np.ones_like(raw_img[..., :3]) * 255 | |
| raw_img = raw_img.astype( | |
| np.uint8) # otherwise, float64 won't call ToTensor() | |
| # return raw_img | |
| # st() | |
| if self.preprocess is None: | |
| img_to_encoder = cv2.resize(raw_img, | |
| (self.reso_encoder, self.reso_encoder), | |
| interpolation=cv2.INTER_LANCZOS4) | |
| # interpolation=cv2.INTER_AREA) | |
| img_to_encoder = img_to_encoder[ | |
| ..., :3] #[3, reso_encoder, reso_encoder] | |
| img_to_encoder = self.normalize(img_to_encoder) | |
| else: | |
| img_to_encoder = self.preprocess(Image.open(rgb_fname)) # clip | |
| # return img_to_encoder | |
| img = cv2.resize(raw_img, (self.reso, self.reso), | |
| interpolation=cv2.INTER_LANCZOS4) | |
| # interpolation=cv2.INTER_AREA) | |
| # img_sr = cv2.resize(raw_img, (512, 512), interpolation=cv2.INTER_AREA) | |
| # img_sr = cv2.resize(raw_img, (256, 256), interpolation=cv2.INTER_AREA) # just as refinement, since eg3d uses 64->128 final resolution | |
| # img_sr = cv2.resize(raw_img, (128, 128), interpolation=cv2.INTER_AREA) # just as refinement, since eg3d uses 64->128 final resolution | |
| # img_sr = cv2.resize( | |
| # raw_img, (128, 128), interpolation=cv2.INTER_LANCZOS4 | |
| # ) # just as refinement, since eg3d uses 64->128 final resolution | |
| # img = torch.from_numpy(img)[..., :3].permute( | |
| # 2, 0, 1) / 255.0 #[3, reso, reso] | |
| img = torch.from_numpy(img)[..., :3].permute( | |
| 2, 0, 1 | |
| ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range | |
| # img_sr = torch.from_numpy(img_sr)[..., :3].permute( | |
| # 2, 0, 1 | |
| # ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range | |
| c2w = read_camera_matrix_single(pose_fname) #[1, 4, 4] -> [1, 16] | |
| # c = np.concatenate([c2w, self.intrinsics], axis=0).reshape(25) # 25, no '1' dim needed. | |
| # return c2w | |
| # if self.load_depth: | |
| # depth, depth_mask, depth_mask_sr = read_dnormal(self.depth_list[idx], | |
| # try: | |
| depth, normal = read_dnormal(self.depth_list[idx], c2w[:3, 3:], | |
| self.reso, self.reso) | |
| # ! frame0 alignment | |
| # if self.frame_0_as_canonical: | |
| # return depth | |
| # except: | |
| # # print(self.depth_list[idx]) | |
| # raise NotImplementedError(self.depth_list[idx]) | |
| # if depth | |
| try: | |
| bbox = self.load_bbox(depth > 0) | |
| except: | |
| print(rgb_fname, flush=True) | |
| with open('error_log.txt', 'a') as f: | |
| f.write(str(rgb_fname + '\n')) | |
| bbox = self.load_bbox(torch.ones_like(depth)) | |
| # plucker | |
| # ! normalize camera | |
| c = np.concatenate([c2w.reshape(16), self.intrinsics], | |
| axis=0).reshape(25).astype( | |
| np.float32) # 25, no '1' dim needed. | |
| if self.frame_0_as_canonical: # 4 views as input per batch | |
| frame0_pose_name = self.frame0_pose_list[idx] | |
| c2w_frame0 = read_camera_matrix_single( | |
| frame0_pose_name) #[1, 4, 4] -> [1, 16] | |
| c = self.normalize_camera(c[None], c2w_frame0[None])[0] | |
| c2w = c[:16].reshape(4, 4) # ! | |
| # st() | |
| # pass | |
| rays_o, rays_d = self.gen_rays(c2w) | |
| rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], | |
| dim=-1) # [h, w, 6] | |
| img_to_encoder = torch.cat( | |
| [img_to_encoder, rays_plucker.permute(2, 0, 1)], | |
| 0).float() # concat in C dim | |
| # ! add depth as input | |
| depth, normal = read_dnormal(self.depth_list[idx], c2w[:3, 3:], | |
| self.reso_encoder, self.reso_encoder) | |
| normalized_depth = depth.unsqueeze(0) # min=0 | |
| img_to_encoder = torch.cat([img_to_encoder, normalized_depth], | |
| 0) # concat in C dim | |
| if self.gs_cam_format: | |
| c = self.c_to_3dgs_format(c) | |
| else: | |
| c = torch.from_numpy(c) | |
| ret_dict = { | |
| # 'rgb_fname': rgb_fname, | |
| 'img_to_encoder': img_to_encoder, | |
| 'img': img, | |
| 'c': c, | |
| # 'img_sr': img_sr, | |
| # 'ins_name': self.data_ins_list[idx] | |
| } | |
| # ins = str( | |
| # (Path(self.data_ins_list[idx]).relative_to(self.file_path)).parent) | |
| pcd_ins = Path(self.data_ins_list[idx]).relative_to( | |
| Path(self.file_path).parent).parent | |
| # load pcd | |
| # fps_pcd = pcu.load_mesh_v( | |
| # str(self.pcd_path / pcd_ins / 'fps-10000.ply')) | |
| ins = str( # for compat | |
| (Path(self.data_ins_list[idx]).relative_to(self.file_path)).parent) | |
| # if self.shuffle_across_cls: | |
| caption = self.caption_data['/'.join(ins.split('/')[1:])] | |
| # else: | |
| # caption = self.caption_data[ins] | |
| ret_dict.update({ | |
| 'depth': depth, | |
| 'normal': normal, | |
| 'alpha_mask': alpha_mask, | |
| 'depth_mask': depth > 0, | |
| # 'depth_mask_sr': depth_mask_sr, | |
| 'bbox': bbox, | |
| 'caption': caption, | |
| 'rays_plucker': rays_plucker, # cam embedding used in lgm | |
| 'ins': ins, # placeholder | |
| # 'fps_pcd': fps_pcd, | |
| }) | |
| return ret_dict | |
| # class MultiViewObjverseDatasetChunk(MultiViewObjverseDataset): | |
| # def __init__(self, | |
| # file_path, | |
| # reso, | |
| # reso_encoder, | |
| # preprocess=None, | |
| # classes=False, | |
| # load_depth=False, | |
| # test=False, | |
| # scene_scale=1, | |
| # overfitting=False, | |
| # imgnet_normalize=True, | |
| # dataset_size=-1, | |
| # overfitting_bs=-1, | |
| # interval=1, | |
| # plucker_embedding=False, | |
| # shuffle_across_cls=False, | |
| # wds_split=1, | |
| # four_view_for_latent=False, | |
| # single_view_for_i23d=False, | |
| # load_extra_36_view=False, | |
| # gs_cam_format=False, | |
| # **kwargs): | |
| # super().__init__(file_path, reso, reso_encoder, preprocess, classes, | |
| # load_depth, test, scene_scale, overfitting, | |
| # imgnet_normalize, dataset_size, overfitting_bs, | |
| # interval, plucker_embedding, shuffle_across_cls, | |
| # wds_split, four_view_for_latent, single_view_for_i23d, | |
| # load_extra_36_view, gs_cam_format, **kwargs) | |
| # # load 40 views at a time, for inferring latents. | |
| # TODO merge all the useful APIs together | |
| class ChunkObjaverseDataset(Dataset): | |
| def __init__( | |
| self, | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=None, | |
| classes=False, | |
| load_depth=False, | |
| test=False, | |
| scene_scale=1, | |
| overfitting=False, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| overfitting_bs=-1, | |
| interval=1, | |
| plucker_embedding=False, | |
| shuffle_across_cls=False, | |
| wds_split=1, # 4 splits to accelerate preprocessing | |
| four_view_for_latent=False, | |
| single_view_for_i23d=False, | |
| load_extra_36_view=False, | |
| gs_cam_format=False, | |
| frame_0_as_canonical=True, | |
| split_chunk_size=10, | |
| mv_input=True, | |
| append_depth=False, | |
| append_xyz=False, | |
| wds_split_all=1, | |
| pcd_path=None, | |
| load_pcd=False, | |
| read_normal=False, | |
| load_raw=False, | |
| load_instance_only=False, | |
| mv_latent_dir='', | |
| perturb_pcd_scale=0.0, | |
| # shards_folder_num=4, | |
| # eval=False, | |
| **kwargs): | |
| super().__init__() | |
| # st() | |
| self.mv_latent_dir = mv_latent_dir | |
| self.load_raw = load_raw | |
| self.load_instance_only = load_instance_only | |
| self.read_normal = read_normal | |
| self.file_path = file_path | |
| self.chunk_size = split_chunk_size | |
| self.gs_cam_format = gs_cam_format | |
| self.frame_0_as_canonical = frame_0_as_canonical | |
| self.four_view_for_latent = four_view_for_latent # export 0 12 30 36, 4 views for reconstruction | |
| self.overfitting = overfitting | |
| self.scene_scale = scene_scale | |
| self.reso = reso | |
| self.reso_encoder = reso_encoder | |
| self.classes = False | |
| self.load_depth = load_depth | |
| self.preprocess = preprocess | |
| self.plucker_embedding = plucker_embedding | |
| self.intrinsics = get_intri(h=self.reso, w=self.reso, | |
| normalize=True).reshape(9) | |
| self.perturb_pcd_scale = perturb_pcd_scale | |
| assert not self.classes, "Not support class condition now." | |
| dataset_name = Path(self.file_path).stem.split('_')[0] | |
| self.dataset_name = dataset_name | |
| self.ray_sampler = RaySampler() | |
| self.zfar = 100.0 | |
| self.znear = 0.01 | |
| # ! load all chunk paths | |
| self.chunk_list = [] | |
| # if dataset_size != -1: # predefined instance | |
| # self.chunk_list = self.fetch_chunk_list(os.path.join(self.file_path, 'debug')) | |
| # else: | |
| # # for shard_idx in range(1, 5): # shard_dir 1-4 by default | |
| # for shard_idx in os.listdir(self.file_path): | |
| # self.chunk_list += self.fetch_chunk_list(os.path.join(self.file_path, shard_idx)) | |
| def load_single_cls_instances(file_path): | |
| ins_list = [] # the first 1 instance for evaluation reference. | |
| for dict_dir in os.listdir(file_path)[:]: # ! for debugging | |
| for ins_dir in os.listdir(os.path.join(file_path, dict_dir)): | |
| ins_list.append( | |
| os.path.join(file_path, dict_dir, ins_dir, | |
| 'campos_512_v4')) | |
| return ins_list | |
| # st() | |
| if self.load_raw: | |
| with open( | |
| # '/nas/shared/V2V/yslan/aigc3d/text_captions_cap3d.json') as f: | |
| # '/nas/shared/public/yslan//data/text_captions_cap3d.json') as f: | |
| './dataset/text_captions_3dtopia.json') as f: | |
| self.caption_data = json.load(f) | |
| # with open | |
| # # '/nas/shared/V2V/yslan/aigc3d/text_captions_cap3d.json') as f: | |
| # '/nas/shared/public/yslan//data/text_captions_cap3d.json') as f: | |
| # # '/cpfs01/shared/public/yhluo/Projects/threed/3D-Enhancer/develop/text_captions_3dtopia.json') as f: | |
| # self.old_caption_data = json.load(f) | |
| for subset in [ # ! around 17.6 W instances in total. | |
| 'Animals', | |
| # 'daily-used', | |
| # 'BuildingsOutdoor', | |
| # 'Furnitures', | |
| # 'Food', | |
| # 'Plants', | |
| # 'Electronics', | |
| # 'Transportations_tar', | |
| # 'Human-Shape', | |
| ]: # selected subset for training | |
| # self.chunk_list += load_single_cls_instances( | |
| # os.path.join(self.file_path, subset)) | |
| with open(f'shell_scripts/raw_img_list/{subset}.txt', 'r') as f: | |
| self.chunk_list += [os.path.join(subset, item.strip()) for item in f.readlines()] | |
| # st() # save to local | |
| # with open('/cpfs01/user/lanyushi.p/Repo/diffusion-3d/shell_scripts/shards_list/chunk_list.txt', 'w') as f: | |
| # f.writelines(self.chunk_list) | |
| # load raw g-objv dataset | |
| # self.img_ext = 'png' # ln3diff | |
| # for k, v in dataset_json.items(): # directly load from folders instead | |
| # self.chunk_list.extend(v) | |
| else: | |
| # ! direclty load from json | |
| with open(f'{self.file_path}/dataset.json', 'r') as f: | |
| dataset_json = json.load(f) | |
| # dataset_json = {'Animals': ['Animals/0/10017/1']} | |
| if self.chunk_size == 12: | |
| self.img_ext = 'png' # ln3diff | |
| for k, v in dataset_json.items(): | |
| self.chunk_list.extend(v) | |
| else: | |
| # extract latent | |
| assert self.chunk_size in [16,18, 20] | |
| self.img_ext = 'jpg' # more views | |
| for k, v in dataset_json.items(): | |
| # if k != 'BuildingsOutdoor': # cannot be handled by gs | |
| self.chunk_list.extend(v) | |
| # filter | |
| # st() | |
| # root = '/nas/shared/V2V/yslan/logs/nips23/Reconstruction/final/objav/vae/gs/infer-latents/768/8x8/animals-gs-latent/latent_dir' | |
| # root = '/nas/shared/V2V/yslan/logs/nips23/Reconstruction/final/objav/vae/gs/infer-latents/768/8x8/animals-gs-latent-dim=10-fullset/latent_dir' | |
| # filtered_chunk_list = [] | |
| # for v in self.chunk_list: | |
| # if os.path.exists(os.path.join(root, v[:-2], 'gaussians.npy') ): | |
| # continue | |
| # filtered_chunk_list.append(v) | |
| # self.chunk_list = filtered_chunk_list | |
| dataset_size = len(self.chunk_list) | |
| self.chunk_list = sorted(self.chunk_list) | |
| # self.chunk_list, self.eval_list = self.chunk_list[:int(dataset_size*0.95)], self.chunk_list[int(dataset_size*0.95):] | |
| # self.chunk_list = self.eval_list | |
| # self.wds_split_all = wds_split_all # for | |
| # self.wds_split_all = 1 | |
| # self.wds_split_all = 7 | |
| # self.wds_split_all = 4 | |
| self.wds_split_all = 1 | |
| # ! filter | |
| # st() | |
| if wds_split_all != 1: | |
| # ! retrieve the right wds split | |
| all_ins_size = len(self.chunk_list) | |
| ratio_size = all_ins_size // self.wds_split_all + 1 | |
| # ratio_size = int(all_ins_size / self.wds_split_all) + 1 | |
| print('ratio_size: ', ratio_size, 'all_ins_size: ', all_ins_size) | |
| self.chunk_list = self.chunk_list[ratio_size * | |
| (wds_split):ratio_size * | |
| (wds_split + 1)] | |
| # st() | |
| # load images from raw | |
| self.rgb_list = [] | |
| if self.load_instance_only: | |
| for ins in tqdm(self.chunk_list): | |
| ins_name = str(Path(ins).parent) | |
| # cur_all_fname = [f'{t:05d}' for t in range(40)] # load all instances for now | |
| self.rgb_list += ([ | |
| os.path.join(self.file_path, ins, fname + '.png') | |
| for fname in [f'{t}' for t in range(2)] | |
| # for fname in [f'{t:05d}' for t in range(2)] | |
| ]) # synthetic mv data | |
| # index mapping of mvi data to objv single-view data | |
| self.mvi_objv_mapping = { | |
| '0': '00000', | |
| '1': '00012', | |
| } | |
| # load gt mv data | |
| self.gt_chunk_list = [] | |
| self.gt_mv_file_path = '/cpfs01/user/lanyushi.p/data/chunk-jpeg-normal/bs_16_fixsave3/170K/512/' | |
| assert self.chunk_size in [16,18, 20] | |
| with open(f'{self.gt_mv_file_path}/dataset.json', 'r') as f: | |
| dataset_json = json.load(f) | |
| # dataset_json = {'Animals': dataset_json['Animals'] } # | |
| self.img_ext = 'jpg' # more views | |
| for k, v in dataset_json.items(): | |
| # if k != 'BuildingsOutdoor': # cannot be handled by gs | |
| self.gt_chunk_list.extend(v) | |
| elif self.load_raw: | |
| for ins in tqdm(self.chunk_list): | |
| # | |
| # st() | |
| # ins = ins[len('/nas/shared/V2V/yslan/aigc3d/unzip4/'):] | |
| # ins_name = str(Path(ins).relative_to(self.file_path).parent) | |
| ins_name = str(Path(ins).parent) | |
| # latent_path = os.path.join(self.mv_latent_dir, ins_name, 'latent.npz') | |
| # if not os.path.exists(latent_path): | |
| # continue | |
| cur_all_fname = [f'{t:05d}' for t in range(40)] # load all instances for now | |
| self.rgb_list += ([ | |
| os.path.join(self.file_path, ins, fname, fname + '.png') | |
| for fname in cur_all_fname | |
| ]) | |
| self.post_process = PostProcess( | |
| reso, | |
| reso_encoder, | |
| imgnet_normalize=imgnet_normalize, | |
| plucker_embedding=plucker_embedding, | |
| decode_encode_img_only=False, | |
| mv_input=mv_input, | |
| split_chunk_input=split_chunk_size, | |
| duplicate_sample=True, | |
| append_depth=append_depth, | |
| append_xyz=append_xyz, | |
| gs_cam_format=gs_cam_format, | |
| orthog_duplicate=False, | |
| frame_0_as_canonical=frame_0_as_canonical, | |
| pcd_path=pcd_path, | |
| load_pcd=load_pcd, | |
| split_chunk_size=split_chunk_size, | |
| ) | |
| self.kernel = torch.tensor([[0, 1, 0], [1, 1, 1], [0, 1, 0]]) | |
| # self.no_bottom = True # avoid loading bottom vew | |
| def fetch_chunk_list(self, file_path): | |
| if os.path.isdir(file_path): | |
| chunks = [ | |
| os.path.join(file_path, fname) | |
| for fname in os.listdir(file_path) if fname.isdigit() | |
| ] | |
| return chunks | |
| else: | |
| return [] | |
| def _pre_process_chunk(self): | |
| # e.g., remove bottom view | |
| pass | |
| def read_chunk(self, chunk_path): | |
| # equivalent to decode_zip() in wds | |
| # reshape chunk | |
| raw_img = imageio.imread( | |
| os.path.join(chunk_path, f'raw_img.{self.img_ext}')) | |
| h, bw, c = raw_img.shape | |
| raw_img = raw_img.reshape(h, self.chunk_size, -1, c).transpose( | |
| (1, 0, 2, 3)) | |
| c = np.load(os.path.join(chunk_path, 'c.npy')) | |
| with open(os.path.join(chunk_path, 'caption.txt'), | |
| 'r', | |
| encoding="utf-8") as f: | |
| caption = f.read() | |
| with open(os.path.join(chunk_path, 'ins.txt'), 'r', | |
| encoding="utf-8") as f: | |
| ins = f.read() | |
| bbox = np.load(os.path.join(chunk_path, 'bbox.npy')) | |
| if self.chunk_size > 16: | |
| depth_alpha = imageio.imread( | |
| os.path.join(chunk_path, 'depth_alpha.jpg')) # 2h 10w | |
| depth_alpha = depth_alpha.reshape(h * 2, self.chunk_size, | |
| -1).transpose((1, 0, 2)) | |
| depth, alpha = np.split(depth_alpha, 2, axis=1) | |
| d_near_far = np.load(os.path.join(chunk_path, 'd_near_far.npy')) | |
| d_near = d_near_far[0].reshape(self.chunk_size, 1, 1) | |
| d_far = d_near_far[1].reshape(self.chunk_size, 1, 1) | |
| # d = 1 / ( (d_normalized / 255) * (far-near) + near) | |
| depth = 1 / ((depth / 255) * (d_far - d_near) + d_near) | |
| depth[depth > 2.9] = 0.0 # background as 0, follow old tradition | |
| # ! filter anti-alias artifacts | |
| erode_mask = kornia.morphology.erosion( | |
| torch.from_numpy(alpha == 255).float().unsqueeze(1), | |
| self.kernel) # B 1 H W | |
| depth = (torch.from_numpy(depth).unsqueeze(1) * erode_mask).squeeze( | |
| 1) # shrink anti-alias bug | |
| else: | |
| # load separate alpha and depth map | |
| alpha = imageio.imread( | |
| os.path.join(chunk_path, f'alpha.{self.img_ext}')) | |
| alpha = alpha.reshape(h, self.chunk_size, h).transpose( | |
| (1, 0, 2)) | |
| depth = np.load(os.path.join(chunk_path, 'depth.npz'))['depth'] | |
| # depth = depth * (alpha==255) # mask out background | |
| # depth = np.stack([depth, alpha], -1) # rgba | |
| # if self.no_bottom: | |
| # raw_img | |
| # pass | |
| if self.read_normal: | |
| normal = imageio.imread(os.path.join( | |
| chunk_path, 'normal.png')).astype(np.float32) / 255.0 | |
| normal = (normal * 2 - 1).reshape(h, self.chunk_size, -1, | |
| 3).transpose((1, 0, 2, 3)) | |
| # fix g-buffer normal rendering coordinate | |
| # normal = unity2blender(normal) # ! still wrong | |
| normal = unity2blender_fix(normal) # ! | |
| depth = (depth, normal) # ? | |
| return raw_img, depth, c, alpha, bbox, caption, ins | |
| def __len__(self): | |
| return len(self.chunk_list) | |
| def __getitem__(self, index) -> Any: | |
| sample = self.read_chunk( | |
| os.path.join(self.file_path, self.chunk_list[index])) | |
| sample = self.post_process.paired_post_process_chunk(sample) | |
| sample = self.post_process.create_dict_nobatch(sample) | |
| # aug pcd | |
| # st() | |
| if self.perturb_pcd_scale > 0: | |
| if random.random() > 0.5: | |
| t = np.random.rand(sample['fps_pcd'].shape[0], 1, 1) * self.perturb_pcd_scale | |
| sample['fps_pcd'] = sample['fps_pcd'] + t * np.random.randn(*sample['fps_pcd'].shape) # type: ignore | |
| sample['fps_pcd'] = np.clip(sample['fps_pcd'], -0.45, 0.45) # truncate noisy augmentation | |
| return sample | |
| class ChunkObjaverseDatasetDDPM(ChunkObjaverseDataset): | |
| def __init__( | |
| self, | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=None, | |
| classes=False, | |
| load_depth=False, | |
| test=False, | |
| scene_scale=1, | |
| overfitting=False, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| overfitting_bs=-1, | |
| interval=1, | |
| plucker_embedding=False, | |
| shuffle_across_cls=False, | |
| wds_split=1, # 4 splits to accelerate preprocessing | |
| four_view_for_latent=False, | |
| single_view_for_i23d=False, | |
| load_extra_36_view=False, | |
| gs_cam_format=False, | |
| frame_0_as_canonical=True, | |
| split_chunk_size=10, | |
| mv_input=True, | |
| append_depth=False, | |
| append_xyz=False, | |
| pcd_path=None, | |
| load_pcd=False, | |
| read_normal=False, | |
| mv_latent_dir='', | |
| load_raw=False, | |
| # shards_folder_num=4, | |
| # eval=False, | |
| **kwargs): | |
| super().__init__( | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=None, | |
| classes=False, | |
| load_depth=False, | |
| test=False, | |
| scene_scale=1, | |
| overfitting=False, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| overfitting_bs=-1, | |
| interval=1, | |
| plucker_embedding=False, | |
| shuffle_across_cls=False, | |
| wds_split=1, # 4 splits to accelerate preprocessing | |
| four_view_for_latent=False, | |
| single_view_for_i23d=False, | |
| load_extra_36_view=False, | |
| gs_cam_format=False, | |
| frame_0_as_canonical=True, | |
| split_chunk_size=split_chunk_size, | |
| mv_input=True, | |
| append_depth=False, | |
| append_xyz=False, | |
| pcd_path=None, | |
| load_pcd=False, | |
| read_normal=False, | |
| load_raw=load_raw, | |
| mv_latent_dir=mv_latent_dir, | |
| # shards_folder_num=4, | |
| # eval=False, | |
| **kwargs) | |
| self.n_cond_frames = 6 | |
| self.perspective_transformer = v2.RandomPerspective(distortion_scale=0.4, p=0.15, fill=1, | |
| interpolation=torchvision.transforms.InterpolationMode.NEAREST) | |
| self.mv_resize_cls = torchvision.transforms.Resize(320, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, | |
| max_size=None, antialias=True) | |
| # ! read img c, caption. | |
| def get_plucker_ray(self, c): | |
| rays_plucker = [] | |
| for idx in range(c.shape[0]): | |
| rays_o, rays_d = self.gen_rays(c[idx]) | |
| rays_plucker.append( | |
| torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], | |
| dim=-1).permute(2, 0, 1)) # [h, w, 6] -> 6,h,w | |
| rays_plucker = torch.stack(rays_plucker, 0) | |
| return rays_plucker | |
| def read_chunk(self, chunk_path): | |
| # equivalent to decode_zip() in wds | |
| # reshape chunk | |
| raw_img = imageio.imread( | |
| os.path.join(chunk_path, f'raw_img.{self.img_ext}')).astype(np.float32) | |
| h, bw, c = raw_img.shape | |
| raw_img = raw_img.reshape(h, self.chunk_size, -1, c).transpose( | |
| (1, 0, 2, 3)) | |
| c = np.load(os.path.join(chunk_path, 'c.npy')).astype(np.float32) | |
| with open(os.path.join(chunk_path, 'caption.txt'), | |
| 'r', | |
| encoding="utf-8") as f: | |
| caption = f.read() | |
| with open(os.path.join(chunk_path, 'ins.txt'), 'r', | |
| encoding="utf-8") as f: | |
| ins = f.read() | |
| return raw_img, c, caption, ins | |
| def _load_latent(self, ins): | |
| # if 'adv' in self.mv_latent_dir: # new latent codes saved have 3 augmentations | |
| # idx = random.choice([0,1,2]) | |
| # latent = np.load(os.path.join(self.mv_latent_dir, ins, f'latent-{idx}.npy')) # pre-calculated VAE latent | |
| # else: | |
| latent = np.load(os.path.join(self.mv_latent_dir, ins, 'latent.npy')) # pre-calculated VAE latent | |
| latent = repeat(latent, 'C H W -> B C H W', B=2) | |
| # return {'latent': latent} | |
| return latent | |
| def normalize_camera(self, c, c_frame0): | |
| # assert c.shape[0] == self.chunk_size # 8 o r10 | |
| B = c.shape[0] | |
| camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4 | |
| canonical_camera_poses = c_frame0[:, :16].reshape(1, 4, 4) | |
| inverse_canonical_pose = np.linalg.inv(canonical_camera_poses) | |
| inverse_canonical_pose = np.repeat(inverse_canonical_pose, B, 0) | |
| cam_radius = np.linalg.norm( | |
| c_frame0[:, :16].reshape(1, 4, 4)[:, :3, 3], | |
| axis=-1, | |
| keepdims=False) # since g-buffer adopts dynamic radius here. | |
| frame1_fixed_pos = np.repeat(np.eye(4)[None], 1, axis=0) | |
| frame1_fixed_pos[:, 2, -1] = -cam_radius | |
| transform = frame1_fixed_pos @ inverse_canonical_pose | |
| new_camera_poses = np.repeat( | |
| transform, 1, axis=0 | |
| ) @ camera_poses # [V, 4, 4]. np.repeat() is th.repeat_interleave() | |
| c = np.concatenate([new_camera_poses.reshape(B, 16), c[:, 16:]], | |
| axis=-1) | |
| return c | |
| # @autocast | |
| # def plucker_embedding(self, c): | |
| # rays_o, rays_d = self.gen_rays(c) | |
| # rays_plucker = torch.cat( | |
| # [torch.cross(rays_o, rays_d, dim=-1), rays_d], | |
| # dim=-1).permute(2, 0, 1) # [h, w, 6] -> 6,h,w | |
| # return rays_plucker | |
| def __getitem__(self, index) -> Any: | |
| raw_img, c, caption, ins = self.read_chunk( | |
| os.path.join(self.file_path, self.chunk_list[index])) | |
| # sample = self.post_process.paired_post_process_chunk(sample) | |
| # ! random zoom in (scale augmentation) | |
| # for i in range(img.shape[0]): | |
| # for v in range(img.shape[1]): | |
| # if random.random() > 0.8: | |
| # rand_bg_scale = random.randint(60,99) / 100 | |
| # st() | |
| # img[i,v] = recenter(img[i,v], np.ones_like(img[i,v]), border_ratio=rand_bg_scale) | |
| # ! process | |
| raw_img = torch.from_numpy(raw_img).permute(0, 3, 1, 2) / 255.0 # [0,1] | |
| if raw_img.shape[-1] != self.reso: | |
| raw_img = torch.nn.functional.interpolate( | |
| input=raw_img, | |
| size=(self.reso, self.reso), | |
| mode='bilinear', | |
| align_corners=False, | |
| ) | |
| img = raw_img * 2 - 1 # as gt | |
| # ! load latent | |
| latent, _ = self._load_latent(ins) | |
| # ! shuffle | |
| indices = np.random.permutation(self.chunk_size) | |
| img = img[indices] | |
| c = c[indices] | |
| img = self.perspective_transformer(img) # create 3D inconsistency | |
| # ! split along V and repeat other stuffs accordingly | |
| img = rearrange(img, '(B V) ... -> B V ...', B=2)[:, :self.n_cond_frames] | |
| c = rearrange(c, '(B V) ... -> B V ...', B=2)[:, :self.n_cond_frames] # 2 6 25 | |
| # rand perspective aug | |
| caption = [caption, caption] | |
| ins = [ins, ins] | |
| # load plucker coord | |
| # st() | |
| # plucker_c = self.get_plucker_ray(rearrange(c[:, 1:1+self.n_cond_frames], "b t ... -> (b t) ...")) | |
| # plucker_c = rearrange(c, '(B V) ... -> B V ...', B=2) # 2 6 25 | |
| # use view-space camera tradition | |
| c[0] = self.normalize_camera(c[0], c[0,0:1]) | |
| c[1] = self.normalize_camera(c[1], c[1,0:1]) | |
| # https://github.com/TencentARC/InstantMesh/blob/7fe95627cf819748f7830b2b278f302a9d798d17/src/model.py#L70 | |
| # c = np.concatenate([c[..., :12], c[..., 16:17], c[..., 20:21], c[..., 18:19], c[..., 21:22]], axis=-1) | |
| # c = c + np.random.randn(*c.shape) * 0.04 - 0.02 | |
| # ! to dict | |
| # sample = self.post_process.create_dict_nobatch(sample) | |
| ret_dict = { | |
| 'caption': caption, | |
| 'ins': ins, | |
| 'c': c, | |
| 'img': img, # fix inp img range to [-1,1] | |
| 'latent': latent, | |
| # **latent | |
| } | |
| # st() | |
| return ret_dict | |
| class ChunkObjaverseDatasetDDPMgs(ChunkObjaverseDatasetDDPM): | |
| def __init__( | |
| self, | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=None, | |
| classes=False, | |
| load_depth=False, | |
| test=False, | |
| scene_scale=1, | |
| overfitting=False, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| overfitting_bs=-1, | |
| interval=1, | |
| plucker_embedding=False, | |
| shuffle_across_cls=False, | |
| wds_split=1, # 4 splits to accelerate preprocessing | |
| four_view_for_latent=False, | |
| single_view_for_i23d=False, | |
| load_extra_36_view=False, | |
| gs_cam_format=False, | |
| frame_0_as_canonical=True, | |
| split_chunk_size=10, | |
| mv_input=True, | |
| append_depth=False, | |
| append_xyz=False, | |
| pcd_path=None, | |
| load_pcd=False, | |
| read_normal=False, | |
| mv_latent_dir='', | |
| load_raw=False, | |
| # shards_folder_num=4, | |
| # eval=False, | |
| **kwargs): | |
| super().__init__( | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=preprocess, | |
| classes=classes, | |
| load_depth=load_depth, | |
| test=test, | |
| scene_scale=scene_scale, | |
| overfitting=overfitting, | |
| imgnet_normalize=imgnet_normalize, | |
| dataset_size=dataset_size, | |
| overfitting_bs=overfitting_bs, | |
| interval=interval, | |
| plucker_embedding=plucker_embedding, | |
| shuffle_across_cls=shuffle_across_cls, | |
| wds_split=wds_split, # 4 splits to accelerate preprocessing | |
| four_view_for_latent=four_view_for_latent, | |
| single_view_for_i23d=single_view_for_i23d, | |
| load_extra_36_view=load_extra_36_view, | |
| gs_cam_format=gs_cam_format, | |
| frame_0_as_canonical=frame_0_as_canonical, | |
| split_chunk_size=split_chunk_size, | |
| mv_input=mv_input, | |
| append_depth=append_depth, | |
| append_xyz=append_xyz, | |
| pcd_path=pcd_path, | |
| load_pcd=load_pcd, | |
| read_normal=read_normal, | |
| mv_latent_dir=mv_latent_dir, | |
| load_raw=load_raw, | |
| # shards_folder_num=4, | |
| # eval=False, | |
| **kwargs) | |
| self.avoid_loading_first = False | |
| # self.feat_scale_factor = torch.Tensor([0.99227685, 1.014337 , 0.20842505, 0.98727155, 0.3305389 , | |
| # 0.38729668, 1.0155401 , 0.9728264 , 1.0009694 , 0.97328585, | |
| # 0.2881106 , 0.1652732 , 0.3482468 , 0.9971449 , 0.99895126, | |
| # 0.18491288]).float().reshape(1,1,-1) | |
| # stat for normalization | |
| # self.xyz_mean = torch.Tensor([-0.00053714, 0.08095618, -0.01914407] ).reshape(1, 3).float() | |
| # self.xyz_std = np.array([0.14593576, 0.15753542, 0.18873914] ).reshape(1,3).astype(np.float32) | |
| # self.xyz_std = np.array([0.14593576, 0.15753542, 0.18873914] ).reshape(1,3).astype(np.float32) | |
| self.xyz_std = 0.164 # a global scaler | |
| self.kl_mean = np.array([ 0.0184, 0.0024, 0.0926, 0.0517, 0.1781, 0.7137, -0.0355, 0.0267, | |
| 0.0183, 0.0164, -0.5090, 0.2406, 0.2733, -0.0256, -0.0285, 0.0761]).reshape(1,16).astype(np.float32) | |
| self.kl_std = np.array([1.0018, 1.0309, 1.3001, 1.0160, 0.8182, 0.8023, 1.0591, 0.9789, 0.9966, | |
| 0.9448, 0.8908, 1.4595, 0.7957, 0.9871, 1.0236, 1.2923]).reshape(1,16).astype(np.float32) | |
| def normalize_pcd_act(self, x): | |
| return x / self.xyz_std | |
| def normalize_kl_feat(self, latent): | |
| # return latent / self.feat_scale_factor | |
| return (latent-self.kl_mean) / self.kl_std | |
| def _load_latent(self, ins, rand_pick_one=False, pick_both=False): | |
| if 'adv' in self.mv_latent_dir: # new latent codes saved have 3 augmentations | |
| idx = random.choice([0,1,2]) | |
| # idx = random.choice([0]) | |
| latent = np.load(os.path.join(self.mv_latent_dir, ins, f'latent-{idx}.npz')) # pre-calculated VAE latent | |
| else: | |
| latent = np.load(os.path.join(self.mv_latent_dir, ins, 'latent.npz')) # pre-calculated VAE latent | |
| latent, fps_xyz = latent['latent_normalized'], latent['query_pcd_xyz'] # 2,768,16; 2,768,3 | |
| if not pick_both: | |
| if rand_pick_one: | |
| rand_idx = random.randint(0,1) | |
| else: | |
| rand_idx = 0 | |
| latent, fps_xyz = latent[rand_idx:rand_idx+1], fps_xyz[rand_idx:rand_idx+1] | |
| # per-channel normalize to std=1 & concat | |
| # latent_pcd = np.concatenate([self.normalize_kl_feat(latent), self.normalize_pcd_act(fps_xyz)], -1) | |
| # latent_pcd = np.concatenate([latent, self.normalize_pcd_act(fps_xyz)], -1) | |
| # return latent_pcd, fps_xyz | |
| return latent, fps_xyz | |
| def __getitem__(self, index) -> Any: | |
| raw_img, c, caption, ins = self.read_chunk( | |
| os.path.join(self.file_path, self.chunk_list[index])) | |
| # sample = self.post_process.paired_post_process_chunk(sample) | |
| # ! random zoom in (scale augmentation) | |
| # for i in range(img.shape[0]): | |
| # for v in range(img.shape[1]): | |
| # if random.random() > 0.8: | |
| # rand_bg_scale = random.randint(60,99) / 100 | |
| # st() | |
| # img[i,v] = recenter(img[i,v], np.ones_like(img[i,v]), border_ratio=rand_bg_scale) | |
| # ! process | |
| raw_img = torch.from_numpy(raw_img).permute(0, 3, 1, 2) / 255.0 # [0,1] | |
| if raw_img.shape[-1] != self.reso: | |
| raw_img = torch.nn.functional.interpolate( | |
| input=raw_img, | |
| size=(self.reso, self.reso), | |
| mode='bilinear', | |
| align_corners=False, | |
| ) | |
| img = raw_img * 2 - 1 # as gt | |
| # ! load latent | |
| # latent, _ = self._load_latent(ins) | |
| latent, fps_xyz = self._load_latent(ins, pick_both=True) # analyzing xyz/latent disentangled diffusion | |
| # latent, fps_xyz = latent[0], fps_xyz[0] # remove batch dim here | |
| # fps_xyz = fps_xyz / self.scaling_factor # for xyz training | |
| normalized_fps_xyz = self.normalize_pcd_act(fps_xyz) | |
| if self.avoid_loading_first: # for training mv model | |
| index = list(range(1,6)) + list(range(7,12)) | |
| img = img[index] | |
| c = c[index] | |
| # ! shuffle | |
| indices = np.random.permutation(img.shape[0]) | |
| img = img[indices] | |
| c = c[indices] | |
| img = self.perspective_transformer(img) # create 3D inconsistency | |
| # ! split along V and repeat other stuffs accordingly | |
| img = rearrange(img, '(B V) ... -> B V ...', B=2)[:, :self.n_cond_frames] | |
| c = rearrange(c, '(B V) ... -> B V ...', B=2)[:, :self.n_cond_frames] # 2 6 25 | |
| # rand perspective aug | |
| caption = [caption, caption] | |
| ins = [ins, ins] | |
| # load plucker coord | |
| # st() | |
| # plucker_c = self.get_plucker_ray(rearrange(c[:, 1:1+self.n_cond_frames], "b t ... -> (b t) ...")) | |
| # plucker_c = rearrange(c, '(B V) ... -> B V ...', B=2) # 2 6 25 | |
| # use view-space camera tradition | |
| c[0] = self.normalize_camera(c[0], c[0,0:1]) | |
| c[1] = self.normalize_camera(c[1], c[1,0:1]) | |
| # ! to dict | |
| # sample = self.post_process.create_dict_nobatch(sample) | |
| ret_dict = { | |
| 'caption': caption, | |
| 'ins': ins, | |
| 'c': c, | |
| 'img': img, # fix inp img range to [-1,1] | |
| 'latent': latent, | |
| 'normalized-fps-xyz': normalized_fps_xyz | |
| # **latent | |
| } | |
| # st() | |
| return ret_dict | |
| class ChunkObjaverseDatasetDDPMgsT23D(ChunkObjaverseDatasetDDPMgs): | |
| def __init__( | |
| self, | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=None, | |
| classes=False, | |
| load_depth=False, | |
| test=False, | |
| scene_scale=1, | |
| overfitting=False, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| overfitting_bs=-1, | |
| interval=1, | |
| plucker_embedding=False, | |
| shuffle_across_cls=False, | |
| wds_split=1, # 4 splits to accelerate preprocessing | |
| four_view_for_latent=False, | |
| single_view_for_i23d=False, | |
| load_extra_36_view=False, | |
| gs_cam_format=False, | |
| frame_0_as_canonical=True, | |
| split_chunk_size=10, | |
| mv_input=True, | |
| append_depth=False, | |
| append_xyz=False, | |
| pcd_path=None, | |
| load_pcd=False, | |
| read_normal=False, | |
| mv_latent_dir='', | |
| # shards_folder_num=4, | |
| # eval=False, | |
| **kwargs): | |
| super().__init__( | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=preprocess, | |
| classes=classes, | |
| load_depth=load_depth, | |
| test=test, | |
| scene_scale=scene_scale, | |
| overfitting=overfitting, | |
| imgnet_normalize=imgnet_normalize, | |
| dataset_size=dataset_size, | |
| overfitting_bs=overfitting_bs, | |
| interval=interval, | |
| plucker_embedding=plucker_embedding, | |
| shuffle_across_cls=shuffle_across_cls, | |
| wds_split=wds_split, # 4 splits to accelerate preprocessing | |
| four_view_for_latent=four_view_for_latent, | |
| single_view_for_i23d=single_view_for_i23d, | |
| load_extra_36_view=load_extra_36_view, | |
| gs_cam_format=gs_cam_format, | |
| frame_0_as_canonical=frame_0_as_canonical, | |
| split_chunk_size=split_chunk_size, | |
| mv_input=mv_input, | |
| append_depth=append_depth, | |
| append_xyz=append_xyz, | |
| pcd_path=pcd_path, | |
| load_pcd=load_pcd, | |
| read_normal=read_normal, | |
| mv_latent_dir=mv_latent_dir, | |
| load_raw=True, | |
| # shards_folder_num=4, | |
| # eval=False, | |
| **kwargs) | |
| # def __len__(self): | |
| # return 40 | |
| def __len__(self): | |
| return len(self.rgb_list) | |
| def __getitem__(self, index) -> Any: | |
| rgb_path = self.rgb_list[index] | |
| ins = str(Path(rgb_path).relative_to(self.file_path).parent.parent.parent) | |
| # load caption | |
| caption = self.caption_data['/'.join(ins.split('/')[1:])] | |
| # chunk_path = os.path.join(self.file_path, self.chunk_list[index]) | |
| # # load caption | |
| # with open(os.path.join(chunk_path, 'caption.txt'), | |
| # 'r', | |
| # encoding="utf-8") as f: | |
| # caption = f.read() | |
| # # load latent | |
| # with open(os.path.join(chunk_path, 'ins.txt'), 'r', | |
| # encoding="utf-8") as f: | |
| # ins = f.read() | |
| latent, fps_xyz = self._load_latent(ins, True) # analyzing xyz/latent disentangled diffusion | |
| latent, fps_xyz = latent[0], fps_xyz[0] # remove batch dim here | |
| # fps_xyz = fps_xyz / self.scaling_factor # for xyz training | |
| normalized_fps_xyz = self.normalize_pcd_act(fps_xyz) | |
| # ! to dict | |
| ret_dict = { | |
| # 'caption': caption, | |
| 'latent': latent, | |
| # 'img': img, | |
| 'fps-xyz': fps_xyz, | |
| 'normalized-fps-xyz': normalized_fps_xyz, | |
| 'caption': caption | |
| } | |
| return ret_dict | |
| class ChunkObjaverseDatasetDDPMgsI23D(ChunkObjaverseDatasetDDPMgs): | |
| def __init__( | |
| self, | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=None, | |
| classes=False, | |
| load_depth=False, | |
| test=False, | |
| scene_scale=1, | |
| overfitting=False, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| overfitting_bs=-1, | |
| interval=1, | |
| plucker_embedding=False, | |
| shuffle_across_cls=False, | |
| wds_split=1, # 4 splits to accelerate preprocessing | |
| four_view_for_latent=False, | |
| single_view_for_i23d=False, | |
| load_extra_36_view=False, | |
| gs_cam_format=False, | |
| frame_0_as_canonical=True, | |
| split_chunk_size=10, | |
| mv_input=True, | |
| append_depth=False, | |
| append_xyz=False, | |
| pcd_path=None, | |
| load_pcd=False, | |
| read_normal=False, | |
| mv_latent_dir='', | |
| # shards_folder_num=4, | |
| # eval=False, | |
| **kwargs): | |
| super().__init__( | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=preprocess, | |
| classes=classes, | |
| load_depth=load_depth, | |
| test=test, | |
| scene_scale=scene_scale, | |
| overfitting=overfitting, | |
| imgnet_normalize=imgnet_normalize, | |
| dataset_size=dataset_size, | |
| overfitting_bs=overfitting_bs, | |
| interval=interval, | |
| plucker_embedding=plucker_embedding, | |
| shuffle_across_cls=shuffle_across_cls, | |
| wds_split=wds_split, # 4 splits to accelerate preprocessing | |
| four_view_for_latent=four_view_for_latent, | |
| single_view_for_i23d=single_view_for_i23d, | |
| load_extra_36_view=load_extra_36_view, | |
| gs_cam_format=gs_cam_format, | |
| frame_0_as_canonical=frame_0_as_canonical, | |
| split_chunk_size=split_chunk_size, | |
| mv_input=mv_input, | |
| append_depth=append_depth, | |
| append_xyz=append_xyz, | |
| pcd_path=pcd_path, | |
| load_pcd=load_pcd, | |
| read_normal=read_normal, | |
| mv_latent_dir=mv_latent_dir, | |
| load_raw=True, | |
| # shards_folder_num=4, | |
| # eval=False, | |
| **kwargs) | |
| assert self.load_raw | |
| self.scaling_factor = np.array([0.14593576, 0.15753542, 0.18873914]) | |
| def __len__(self): | |
| return len(self.rgb_list) | |
| # def __len__(self): | |
| # return 40 | |
| def __getitem__(self, index) -> Any: | |
| rgb_path = self.rgb_list[index] | |
| ins = str(Path(rgb_path).relative_to(self.file_path).parent.parent.parent) | |
| raw_img = imageio.imread(rgb_path).astype(np.float32) | |
| alpha_mask = raw_img[..., -1:] / 255 | |
| raw_img = alpha_mask * raw_img[..., :3] + ( | |
| 1 - alpha_mask) * np.ones_like(raw_img[..., :3]) * 255 | |
| raw_img = cv2.resize(raw_img, (self.reso, self.reso), interpolation=cv2.INTER_CUBIC) | |
| raw_img = torch.from_numpy(raw_img).permute(2,0,1).clip(0,255) # [0,1] | |
| img = raw_img / 127.5 - 1 | |
| # with open(os.path.join(chunk_path, 'caption.txt'), | |
| # 'r', | |
| # encoding="utf-8") as f: | |
| # caption = f.read() | |
| # latent = self._load_latent(ins, True)[0] | |
| latent, fps_xyz = self._load_latent(ins, True) # analyzing xyz/latent disentangled diffusion | |
| latent, fps_xyz = latent[0], fps_xyz[0] | |
| # fps_xyz = fps_xyz / self.scaling_factor # for xyz training | |
| normalized_fps_xyz = self.normalize_pcd_act(fps_xyz) | |
| # load caption | |
| caption = self.caption_data['/'.join(ins.split('/')[1:])] | |
| # ! to dict | |
| ret_dict = { | |
| # 'caption': caption, | |
| 'latent': latent, | |
| 'img': img.numpy(), # no idea whether loading Tensor leads to 'too many files opened' | |
| 'fps-xyz': fps_xyz, | |
| 'normalized-fps-xyz': normalized_fps_xyz, | |
| 'caption': caption | |
| } | |
| return ret_dict | |
| class ChunkObjaverseDatasetDDPMgsMV23D(ChunkObjaverseDatasetDDPMgs): | |
| def __init__( | |
| self, | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=None, | |
| classes=False, | |
| load_depth=False, | |
| test=False, | |
| scene_scale=1, | |
| overfitting=False, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| overfitting_bs=-1, | |
| interval=1, | |
| plucker_embedding=False, | |
| shuffle_across_cls=False, | |
| wds_split=1, # 4 splits to accelerate preprocessing | |
| four_view_for_latent=False, | |
| single_view_for_i23d=False, | |
| load_extra_36_view=False, | |
| gs_cam_format=False, | |
| frame_0_as_canonical=True, | |
| split_chunk_size=10, | |
| mv_input=True, | |
| append_depth=False, | |
| append_xyz=False, | |
| pcd_path=None, | |
| load_pcd=False, | |
| read_normal=False, | |
| mv_latent_dir='', | |
| # shards_folder_num=4, | |
| # eval=False, | |
| **kwargs): | |
| super().__init__( | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=preprocess, | |
| classes=classes, | |
| load_depth=load_depth, | |
| test=test, | |
| scene_scale=scene_scale, | |
| overfitting=overfitting, | |
| imgnet_normalize=imgnet_normalize, | |
| dataset_size=dataset_size, | |
| overfitting_bs=overfitting_bs, | |
| interval=interval, | |
| plucker_embedding=plucker_embedding, | |
| shuffle_across_cls=shuffle_across_cls, | |
| wds_split=wds_split, # 4 splits to accelerate preprocessing | |
| four_view_for_latent=four_view_for_latent, | |
| single_view_for_i23d=single_view_for_i23d, | |
| load_extra_36_view=load_extra_36_view, | |
| gs_cam_format=gs_cam_format, | |
| frame_0_as_canonical=frame_0_as_canonical, | |
| split_chunk_size=split_chunk_size, | |
| mv_input=mv_input, | |
| append_depth=append_depth, | |
| append_xyz=append_xyz, | |
| pcd_path=pcd_path, | |
| load_pcd=load_pcd, | |
| read_normal=read_normal, | |
| mv_latent_dir=mv_latent_dir, | |
| load_raw=False, | |
| # shards_folder_num=4, | |
| # eval=False, | |
| **kwargs) | |
| assert not self.load_raw | |
| # self.scaling_factor = np.array([0.14593576, 0.15753542, 0.18873914]) | |
| self.n_cond_frames = 4 # a easy version for now. | |
| self.avoid_loading_first = True | |
| def __getitem__(self, index) -> Any: | |
| raw_img, c, caption, ins = self.read_chunk( | |
| os.path.join(self.file_path, self.chunk_list[index])) | |
| # ! process | |
| raw_img = torch.from_numpy(raw_img).permute(0, 3, 1, 2) / 255.0 # [0,1] | |
| if raw_img.shape[-1] != self.reso: | |
| raw_img = torch.nn.functional.interpolate( | |
| input=raw_img, | |
| size=(self.reso, self.reso), | |
| mode='bilinear', | |
| align_corners=False, | |
| ) | |
| img = raw_img * 2 - 1 # as gt | |
| # ! load latent | |
| # latent, _ = self._load_latent(ins) | |
| latent, fps_xyz = self._load_latent(ins, pick_both=True) # analyzing xyz/latent disentangled diffusion | |
| # latent, fps_xyz = latent[0], fps_xyz[0] # remove batch dim here | |
| # fps_xyz = fps_xyz / self.scaling_factor # for xyz training | |
| normalized_fps_xyz = self.normalize_pcd_act(fps_xyz) | |
| if self.avoid_loading_first: # for training mv model | |
| index = list(range(1,self.chunk_size//2)) + list(range(self.chunk_size//2+1, self.chunk_size)) | |
| img = img[index] | |
| c = c[index] | |
| # ! shuffle | |
| indices = np.random.permutation(img.shape[0]) | |
| img = img[indices] | |
| c = c[indices] | |
| aug_img = self.perspective_transformer(img) # create 3D inconsistency | |
| # ! split along V and repeat other stuffs accordingly | |
| img = rearrange(img, '(B V) ... -> B V ...', B=2)[:, 0:1] # only return first view (randomly sampled) | |
| aug_img = rearrange(aug_img, '(B V) ... -> B V ...', B=2)[:, 1:self.n_cond_frames+1] | |
| c = rearrange(c, '(B V) ... -> B V ...', B=2)[:, 1:self.n_cond_frames+1] # 2 6 25 | |
| # use view-space camera tradition | |
| c[0] = self.normalize_camera(c[0], c[0,0:1]) | |
| c[1] = self.normalize_camera(c[1], c[1,0:1]) | |
| caption = [caption, caption] | |
| ins = [ins, ins] | |
| # ! to dict | |
| # sample = self.post_process.create_dict_nobatch(sample) | |
| ret_dict = { | |
| 'caption': caption, | |
| 'ins': ins, | |
| 'c': c, | |
| 'img': img, # fix inp img range to [-1,1] | |
| 'mv_img': aug_img, | |
| 'latent': latent, | |
| 'normalized-fps-xyz': normalized_fps_xyz | |
| # **latent | |
| } | |
| # st() | |
| return ret_dict | |
| class ChunkObjaverseDatasetDDPMgsMV23DSynthetic(ChunkObjaverseDatasetDDPMgs): | |
| def __init__( | |
| self, | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=None, | |
| classes=False, | |
| load_depth=False, | |
| test=False, | |
| scene_scale=1, | |
| overfitting=False, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| overfitting_bs=-1, | |
| interval=1, | |
| plucker_embedding=False, | |
| shuffle_across_cls=False, | |
| wds_split=1, # 4 splits to accelerate preprocessing | |
| four_view_for_latent=False, | |
| single_view_for_i23d=False, | |
| load_extra_36_view=False, | |
| gs_cam_format=False, | |
| frame_0_as_canonical=True, | |
| split_chunk_size=10, | |
| mv_input=True, | |
| append_depth=False, | |
| append_xyz=False, | |
| pcd_path=None, | |
| load_pcd=False, | |
| read_normal=False, | |
| mv_latent_dir='', | |
| # shards_folder_num=4, | |
| # eval=False, | |
| **kwargs): | |
| super().__init__( | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=preprocess, | |
| classes=classes, | |
| load_depth=load_depth, | |
| test=test, | |
| scene_scale=scene_scale, | |
| overfitting=overfitting, | |
| imgnet_normalize=imgnet_normalize, | |
| dataset_size=dataset_size, | |
| overfitting_bs=overfitting_bs, | |
| interval=interval, | |
| plucker_embedding=plucker_embedding, | |
| shuffle_across_cls=shuffle_across_cls, | |
| wds_split=wds_split, # 4 splits to accelerate preprocessing | |
| four_view_for_latent=four_view_for_latent, | |
| single_view_for_i23d=single_view_for_i23d, | |
| load_extra_36_view=load_extra_36_view, | |
| gs_cam_format=gs_cam_format, | |
| frame_0_as_canonical=frame_0_as_canonical, | |
| split_chunk_size=split_chunk_size, | |
| mv_input=mv_input, | |
| append_depth=append_depth, | |
| append_xyz=append_xyz, | |
| pcd_path=pcd_path, | |
| load_pcd=load_pcd, | |
| read_normal=read_normal, | |
| mv_latent_dir=mv_latent_dir, | |
| load_raw=True, | |
| load_instance_only=True, | |
| # shards_folder_num=4, | |
| # eval=False, | |
| **kwargs) | |
| # assert not self.load_raw | |
| # self.scaling_factor = np.array([0.14593576, 0.15753542, 0.18873914]) | |
| self.n_cond_frames = 6 # a easy version for now. | |
| self.avoid_loading_first = True | |
| self.indices = np.array([0,1,2,3,4,5]) | |
| self.img_root_dir = '/cpfs01/user/lanyushi.p/data/unzip4_img' | |
| azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float) | |
| elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float) | |
| zero123pp_pose, _ = generate_input_camera(1.8, [[elevations[i], azimuths[i]] for i in range(6)], fov=30) | |
| K = torch.Tensor([1.3889, 0.0000, 0.5000, 0.0000, 1.3889, 0.5000, 0.0000, 0.0000, 0.0039]).to(zero123pp_pose) # keeps the same | |
| zero123pp_pose = torch.cat([zero123pp_pose.reshape(6,-1), K.unsqueeze(0).repeat(6,1)], dim=-1) | |
| eval_camera = zero123pp_pose[self.indices].float().cpu().numpy() # for normalization | |
| self.eval_camera = self.normalize_camera(eval_camera, eval_camera[0:1]) # the first img is not used. | |
| # self.load_synthetic_only = False | |
| self.load_synthetic_only = True | |
| def __len__(self): | |
| return len(self.rgb_list) | |
| def _getitem_synthetic(self, index) -> Any: | |
| rgb_fname = Path(self.rgb_list[index]) | |
| # ins = self.mvi_objv_mapping(rgb_fname.parent.parent.stem) | |
| # ins = str(Path(rgb_fname).parent.parent.stem) | |
| ins = str((Path(rgb_fname).relative_to(self.file_path)).parent.parent) | |
| mv_img = imageio.imread(rgb_fname) | |
| # st() | |
| mv_img = rearrange(mv_img, '(n h) (m w) c -> (n m) h w c', n=3, m=2)[self.indices] # (6, 3, 320, 320) | |
| mv_img = np.stack([recenter(img, np.ones_like(img), border_ratio=0.1) for img in mv_img], axis=0) | |
| mv_img = rearrange(mv_img, 'b h w c -> b c h w') # to torch tradition | |
| mv_img = torch.from_numpy(mv_img) / 127.5 - 1 | |
| # ! load single-view image here | |
| img_idx = self.mvi_objv_mapping[rgb_fname.stem] | |
| img_path = os.path.join(self.img_root_dir, rgb_fname.parent.relative_to(self.file_path), img_idx, f'{img_idx}.png') | |
| raw_img = imageio.imread(img_path).astype(np.float32) | |
| alpha_mask = raw_img[..., -1:] / 255 | |
| raw_img = alpha_mask * raw_img[..., :3] + ( | |
| 1 - alpha_mask) * np.ones_like(raw_img[..., :3]) * 255 | |
| raw_img = cv2.resize(raw_img, (self.reso, self.reso), interpolation=cv2.INTER_CUBIC) | |
| raw_img = torch.from_numpy(raw_img).permute(2,0,1).clip(0,255) # [0,1] | |
| img = raw_img / 127.5 - 1 | |
| latent, fps_xyz = self._load_latent(ins, pick_both=False) # analyzing xyz/latent disentangled diffusion | |
| latent, fps_xyz = latent[0], fps_xyz[0] | |
| normalized_fps_xyz = self.normalize_pcd_act(fps_xyz) # for stage-1 | |
| # use view-space camera tradition | |
| # ins = [ins, ins] | |
| # st() | |
| caption = self.caption_data['/'.join(ins.split('/')[1:])] | |
| # ! to dict | |
| # sample = self.post_process.create_dict_nobatch(sample) | |
| ret_dict = { | |
| 'caption': caption, | |
| # 'ins': ins, | |
| 'c': self.eval_camera, | |
| 'img': img, # fix inp img range to [-1,1] | |
| 'mv_img': mv_img, | |
| 'latent': latent, | |
| 'normalized-fps-xyz': normalized_fps_xyz, | |
| 'fps-xyz': fps_xyz, | |
| } | |
| return ret_dict | |
| def _getitem_gt(self, index) -> Any: | |
| raw_img, c, caption, ins = self.read_chunk( | |
| os.path.join(self.gt_mv_file_path, self.gt_chunk_list[index])) | |
| # ! process | |
| raw_img = torch.from_numpy(raw_img).permute(0, 3, 1, 2) / 255.0 # [0,1] | |
| if raw_img.shape[-1] != self.reso: | |
| raw_img = torch.nn.functional.interpolate( | |
| input=raw_img, | |
| size=(self.reso, self.reso), | |
| mode='bilinear', | |
| align_corners=False, | |
| ) | |
| img = raw_img * 2 - 1 # as gt | |
| # ! load latent | |
| # latent, _ = self._load_latent(ins) | |
| latent, fps_xyz = self._load_latent(ins, pick_both=True) # analyzing xyz/latent disentangled diffusion | |
| # latent, fps_xyz = latent[0], fps_xyz[0] # remove batch dim here | |
| # fps_xyz = fps_xyz / self.scaling_factor # for xyz training | |
| normalized_fps_xyz = self.normalize_pcd_act(fps_xyz) | |
| if self.avoid_loading_first: # for training mv model | |
| index = list(range(1,self.chunk_size//2)) + list(range(self.chunk_size//2+1, self.chunk_size)) | |
| img = img[index] | |
| c = c[index] | |
| # ! shuffle | |
| indices = np.random.permutation(img.shape[0]) | |
| img = img[indices] | |
| c = c[indices] | |
| # st() | |
| aug_img = self.mv_resize_cls(img) | |
| aug_img = self.perspective_transformer(aug_img) # create 3D inconsistency | |
| # ! split along V and repeat other stuffs accordingly | |
| img = rearrange(img, '(B V) ... -> B V ...', B=2)[:, 0:1] # only return first view (randomly sampled) | |
| aug_img = rearrange(aug_img, '(B V) ... -> B V ...', B=2)[:, 1:self.n_cond_frames+1] | |
| c = rearrange(c, '(B V) ... -> B V ...', B=2)[:, 1:self.n_cond_frames+1] # 2 6 25 | |
| # use view-space camera tradition | |
| c[0] = self.normalize_camera(c[0], c[0,0:1]) | |
| c[1] = self.normalize_camera(c[1], c[1,0:1]) | |
| caption = [caption, caption] | |
| ins = [ins, ins] | |
| # ! to dict | |
| # sample = self.post_process.create_dict_nobatch(sample) | |
| ret_dict = { | |
| 'caption': caption, | |
| 'ins': ins, | |
| 'c': c, | |
| 'img': img, # fix inp img range to [-1,1] | |
| 'mv_img': aug_img, | |
| 'latent': latent, | |
| 'normalized-fps-xyz': normalized_fps_xyz, | |
| 'fps-xyz': fps_xyz, | |
| } | |
| return ret_dict | |
| def __getitem__(self, index) -> Any: | |
| # load synthetic version | |
| try: | |
| synthetic_mv = self._getitem_synthetic(index) | |
| except Exception as e: | |
| # logger.log(Path(self.rgb_list[index]), 'missing') | |
| synthetic_mv = self._getitem_synthetic(random.randint(0, len(self.rgb_list)//2)) | |
| if self.load_synthetic_only: | |
| return synthetic_mv | |
| else: | |
| # load gt mv chunk | |
| gt_chunk_index = random.randint(0, len(self.gt_chunk_list)-1) | |
| gt_mv = self._getitem_gt(gt_chunk_index) | |
| # merge them together along batch dim | |
| merged_mv = {} | |
| for k, v in synthetic_mv.items(): # merge, synthetic - gt order | |
| if k not in ['caption', 'ins']: | |
| if k == 'img': | |
| merged_mv[k] = np.concatenate([v[None], gt_mv[k][:, 0]], axis=0).astype(np.float32) | |
| else: | |
| merged_mv[k] = np.concatenate([v[None], gt_mv[k]], axis=0).astype(np.float32) | |
| else: | |
| merged_mv[k] = [v] + gt_mv[k] # list | |
| return merged_mv | |
| class ChunkObjaverseDatasetDDPMgsI23D_loadMV(ChunkObjaverseDatasetDDPMgs): | |
| def __init__( | |
| self, | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=None, | |
| classes=False, | |
| load_depth=False, | |
| test=False, | |
| scene_scale=1, | |
| overfitting=False, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| overfitting_bs=-1, | |
| interval=1, | |
| plucker_embedding=False, | |
| shuffle_across_cls=False, | |
| wds_split=1, # 4 splits to accelerate preprocessing | |
| four_view_for_latent=False, | |
| single_view_for_i23d=False, | |
| load_extra_36_view=False, | |
| gs_cam_format=False, | |
| frame_0_as_canonical=True, | |
| split_chunk_size=10, | |
| mv_input=True, | |
| append_depth=False, | |
| append_xyz=False, | |
| pcd_path=None, | |
| load_pcd=False, | |
| read_normal=False, | |
| mv_latent_dir='', | |
| canonicalize_pcd=False, | |
| # shards_folder_num=4, | |
| # eval=False, | |
| **kwargs): | |
| super().__init__( | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=preprocess, | |
| classes=classes, | |
| load_depth=load_depth, | |
| test=test, | |
| scene_scale=scene_scale, | |
| overfitting=overfitting, | |
| imgnet_normalize=imgnet_normalize, | |
| dataset_size=dataset_size, | |
| overfitting_bs=overfitting_bs, | |
| interval=interval, | |
| plucker_embedding=plucker_embedding, | |
| shuffle_across_cls=shuffle_across_cls, | |
| wds_split=wds_split, # 4 splits to accelerate preprocessing | |
| four_view_for_latent=four_view_for_latent, | |
| single_view_for_i23d=single_view_for_i23d, | |
| load_extra_36_view=load_extra_36_view, | |
| gs_cam_format=gs_cam_format, | |
| frame_0_as_canonical=frame_0_as_canonical, | |
| split_chunk_size=split_chunk_size, | |
| mv_input=mv_input, | |
| append_depth=append_depth, | |
| append_xyz=append_xyz, | |
| pcd_path=pcd_path, | |
| load_pcd=load_pcd, | |
| read_normal=read_normal, | |
| mv_latent_dir=mv_latent_dir, | |
| load_raw=False, | |
| # shards_folder_num=4, | |
| # eval=False, | |
| **kwargs) | |
| assert not self.load_raw | |
| # self.scaling_factor = np.array([0.14593576, 0.15753542, 0.18873914]) | |
| self.n_cond_frames = 5 # a easy version for now. | |
| self.avoid_loading_first = True | |
| # self.canonicalize_pcd = canonicalize_pcd | |
| # self.canonicalize_pcd = True | |
| self.canonicalize_pcd = False | |
| def canonicalize_xyz(self, c, pcd): | |
| B = c.shape[0] | |
| camera_poses_rot = c[:, :16].reshape(B, 4, 4)[:, :3, :3] | |
| R_inv = np.transpose(camera_poses_rot, (0,2,1)) # w2c rotation | |
| new_pcd = (R_inv @ np.transpose(pcd, (0,2,1))) # B 3 3 @ B 3 N | |
| new_pcd = np.transpose(new_pcd, (0,2,1)) | |
| return new_pcd | |
| def __getitem__(self, index) -> Any: | |
| raw_img, c, caption, ins = self.read_chunk( | |
| os.path.join(self.file_path, self.chunk_list[index])) | |
| # ! process | |
| raw_img = torch.from_numpy(raw_img).permute(0, 3, 1, 2) / 255.0 # [0,1] | |
| if raw_img.shape[-1] != self.reso: | |
| raw_img = torch.nn.functional.interpolate( | |
| input=raw_img, | |
| size=(self.reso, self.reso), | |
| mode='bilinear', | |
| align_corners=False, | |
| ) | |
| img = raw_img * 2 - 1 # as gt | |
| # ! load latent | |
| # latent, _ = self._load_latent(ins) | |
| if self.avoid_loading_first: # for training mv model | |
| index = list(range(1,self.chunk_size//2)) + list(range(self.chunk_size//2+1, self.chunk_size)) | |
| img = img[index] | |
| c = c[index] | |
| # ! shuffle | |
| indices = np.random.permutation(img.shape[0])[:self.n_cond_frames*2] | |
| img = img[indices] | |
| c = c[indices] | |
| latent, fps_xyz = self._load_latent(ins, pick_both=True) # analyzing xyz/latent disentangled diffusion | |
| # latent, fps_xyz = latent[0], fps_xyz[0] # remove batch dim here | |
| fps_xyz = np.repeat(fps_xyz, self.n_cond_frames, 0) | |
| latent = np.repeat(latent, self.n_cond_frames, 0) | |
| normalized_fps_xyz = self.normalize_pcd_act(fps_xyz) | |
| if self.canonicalize_pcd: | |
| normalized_fps_xyz = self.canonicalize_xyz(c, normalized_fps_xyz) | |
| # repeat | |
| caption = [caption] * self.n_cond_frames * 2 | |
| ins = [ins] * self.n_cond_frames * 2 | |
| ret_dict = { | |
| 'caption': caption, | |
| 'ins': ins, | |
| 'c': c, | |
| 'img': img, # fix inp img range to [-1,1] | |
| 'latent': latent, | |
| 'normalized-fps-xyz': normalized_fps_xyz, | |
| 'fps-xyz': fps_xyz, | |
| # **latent | |
| } | |
| return ret_dict | |
| class RealDataset(Dataset): | |
| def __init__( | |
| self, | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=None, | |
| classes=False, | |
| load_depth=False, | |
| test=False, | |
| scene_scale=1, | |
| overfitting=False, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| overfitting_bs=-1, | |
| interval=1, | |
| plucker_embedding=False, | |
| shuffle_across_cls=False, | |
| wds_split=1, # 4 splits to accelerate preprocessing | |
| ) -> None: | |
| super().__init__() | |
| self.file_path = file_path | |
| self.overfitting = overfitting | |
| self.scene_scale = scene_scale | |
| self.reso = reso | |
| self.reso_encoder = reso_encoder | |
| self.classes = False | |
| self.load_depth = load_depth | |
| self.preprocess = preprocess | |
| self.plucker_embedding = plucker_embedding | |
| self.rgb_list = [] | |
| all_fname = [ | |
| t for t in os.listdir(self.file_path) | |
| if t.split('.')[1] in ['png', 'jpg'] | |
| ] | |
| all_fname = [name for name in all_fname if '-input' in name ] | |
| self.rgb_list += ([ | |
| os.path.join(self.file_path, fname) for fname in all_fname | |
| ]) | |
| # st() | |
| # if len(self.rgb_list) == 1: | |
| # # placeholder | |
| # self.rgb_list = self.rgb_list * 40 | |
| # ! setup normalizataion | |
| transformations = [ | |
| transforms.ToTensor(), # [0,1] range | |
| ] | |
| assert imgnet_normalize | |
| if imgnet_normalize: | |
| transformations.append( | |
| transforms.Normalize((0.485, 0.456, 0.406), | |
| (0.229, 0.224, 0.225)) # type: ignore | |
| ) | |
| else: | |
| transformations.append( | |
| transforms.Normalize((0.5, 0.5, 0.5), | |
| (0.5, 0.5, 0.5))) # type: ignore | |
| self.normalize = transforms.Compose(transformations) | |
| # camera = torch.load('eval_pose.pt', map_location='cpu') | |
| # self.eval_camera = camera | |
| # pre-cache | |
| # self.calc_rays_plucker() | |
| def __len__(self): | |
| return len(self.rgb_list) | |
| def __getitem__(self, index) -> Any: | |
| # return super().__getitem__(index) | |
| rgb_fname = self.rgb_list[index] | |
| # ! preprocess, normalize | |
| raw_img = imageio.imread(rgb_fname) | |
| # interpolation=cv2.INTER_AREA) | |
| if raw_img.shape[-1] == 4: | |
| alpha_mask = raw_img[..., 3:4] / 255.0 | |
| bg_white = np.ones_like(alpha_mask) * 255.0 | |
| raw_img = raw_img[..., :3] * alpha_mask + ( | |
| 1 - alpha_mask) * bg_white #[3, reso_encoder, reso_encoder] | |
| raw_img = raw_img.astype(np.uint8) | |
| # raw_img = recenter(raw_img, np.ones_like(raw_img), border_ratio=0.2) | |
| # log gt | |
| img = cv2.resize(raw_img, (self.reso, self.reso), | |
| interpolation=cv2.INTER_LANCZOS4) | |
| img = torch.from_numpy(img)[..., :3].permute( | |
| 2, 0, 1 | |
| ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range | |
| ret_dict = { | |
| # 'rgb_fname': rgb_fname, | |
| # 'img_to_encoder': | |
| # img_to_encoder.unsqueeze(0).repeat_interleave(40, 0), | |
| 'img': img, | |
| # 'c': self.eval_camera, # TODO, get pre-calculated samples | |
| # 'ins': 'placeholder', | |
| # 'bbox': 'placeholder', | |
| # 'caption': 'placeholder', | |
| } | |
| # ! repeat as a intance | |
| return ret_dict | |
| class RealDataset_GSO(Dataset): | |
| def __init__( | |
| self, | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=None, | |
| classes=False, | |
| load_depth=False, | |
| test=False, | |
| scene_scale=1, | |
| overfitting=False, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| overfitting_bs=-1, | |
| interval=1, | |
| plucker_embedding=False, | |
| shuffle_across_cls=False, | |
| wds_split=1, # 4 splits to accelerate preprocessing | |
| ) -> None: | |
| super().__init__() | |
| self.file_path = file_path | |
| self.overfitting = overfitting | |
| self.scene_scale = scene_scale | |
| self.reso = reso | |
| self.reso_encoder = reso_encoder | |
| self.classes = False | |
| self.load_depth = load_depth | |
| self.preprocess = preprocess | |
| self.plucker_embedding = plucker_embedding | |
| self.rgb_list = [] | |
| # ! for gso-rendering | |
| all_objs = os.listdir(self.file_path) | |
| all_objs.sort() | |
| if True: # instant-mesh picked images | |
| # if False: | |
| all_instances = os.listdir(self.file_path) | |
| # all_fname = [ | |
| # t for t in all_instances | |
| # if t.split('.')[1] in ['png', 'jpg'] | |
| # ] | |
| # all_fname = [name for name in all_fname if '-input' in name ] | |
| # all_fname = ['house2-input.png', 'plant-input.png'] | |
| all_fname = ['house2-input.png'] | |
| self.rgb_list = [os.path.join(self.file_path, name) for name in all_fname] | |
| if False: | |
| for obj_folder in tqdm(all_objs[515:]): | |
| # for obj_folder in tqdm(all_objs[:515]): | |
| # for obj_folder in tqdm(all_objs[:]): | |
| # for obj_folder in tqdm(sorted(os.listdir(self.file_path))[515:]): | |
| # for idx in range(0,25,5): | |
| for idx in [0]: # only query frontal view is enough | |
| self.rgb_list.append(os.path.join(self.file_path, obj_folder, 'rgba', f'{idx:03d}.png')) | |
| # for free-3d rendering | |
| if False: | |
| # if True: | |
| # all_instances = sorted(os.listdir(self.file_path)) | |
| all_instances = ['BAGEL_WITH_CHEESE', | |
| 'BALANCING_CACTUS', | |
| 'Baby_Elements_Stacking_Cups', | |
| 'Breyer_Horse_Of_The_Year_2015', | |
| 'COAST_GUARD_BOAT', | |
| 'CONE_SORTING', | |
| 'CREATIVE_BLOCKS_35_MM', | |
| 'Cole_Hardware_Mini_Honey_Dipper', | |
| 'FAIRY_TALE_BLOCKS', | |
| 'FIRE_ENGINE', | |
| 'FOOD_BEVERAGE_SET', | |
| 'GEOMETRIC_PEG_BOARD', | |
| 'Great_Dinos_Triceratops_Toy', | |
| 'JUICER_SET', | |
| 'STACKING_BEAR', | |
| 'STACKING_RING', | |
| 'Schleich_African_Black_Rhino'] | |
| for instance in all_instances: | |
| self.rgb_list += ([ | |
| # os.path.join(self.file_path, instance, 'rgb', f'{fname:06d}.png') for fname in range(0,250,50) | |
| # os.path.join(self.file_path, instance, 'rgb', f'{fname:06d}.png') for fname in range(0,250,100) | |
| # os.path.join(self.file_path, instance, f'{fname:03d}.png') for fname in range(0,25,5) | |
| os.path.join(self.file_path, instance, 'render_mvs_25', 'model', f'{fname:03d}.png') for fname in range(0,25,4) | |
| ]) | |
| # if True: # g-objv animals images for i23d eval | |
| if False: | |
| # if True: | |
| objv_dataset = '/mnt/sfs-common/yslan/Dataset/Obajverse/chunk-jpeg-normal/bs_16_fixsave3/170K/512/' | |
| dataset_json = os.path.join(objv_dataset, 'dataset.json') | |
| with open(dataset_json, 'r') as f: | |
| dataset_json = json.load(f) | |
| # all_objs = dataset_json['Animals'][::3][:6250] | |
| all_objs = dataset_json['Animals'][::3][1100:2200][:600] | |
| for obj_folder in tqdm(all_objs[:]): | |
| for idx in [0]: # only query frontal view is enough | |
| self.rgb_list.append(os.path.join(self.file_path, obj_folder, f'{idx}.jpg')) | |
| # ! setup normalizataion | |
| transformations = [ | |
| transforms.ToTensor(), # [0,1] range | |
| ] | |
| assert imgnet_normalize | |
| if imgnet_normalize: | |
| transformations.append( | |
| transforms.Normalize((0.485, 0.456, 0.406), | |
| (0.229, 0.224, 0.225)) # type: ignore | |
| ) | |
| else: | |
| transformations.append( | |
| transforms.Normalize((0.5, 0.5, 0.5), | |
| (0.5, 0.5, 0.5))) # type: ignore | |
| self.normalize = transforms.Compose(transformations) | |
| # camera = torch.load('eval_pose.pt', map_location='cpu') | |
| # self.eval_camera = camera | |
| # pre-cache | |
| # self.calc_rays_plucker() | |
| def __len__(self): | |
| return len(self.rgb_list) | |
| def __getitem__(self, index) -> Any: | |
| # return super().__getitem__(index) | |
| rgb_fname = self.rgb_list[index] | |
| # ! preprocess, normalize | |
| raw_img = imageio.imread(rgb_fname) | |
| # interpolation=cv2.INTER_AREA) | |
| if raw_img.shape[-1] == 4: | |
| alpha_mask = raw_img[..., 3:4] / 255.0 | |
| bg_white = np.ones_like(alpha_mask) * 255.0 | |
| raw_img = raw_img[..., :3] * alpha_mask + ( | |
| 1 - alpha_mask) * bg_white #[3, reso_encoder, reso_encoder] | |
| raw_img = raw_img.astype(np.uint8) | |
| # raw_img = recenter(raw_img, np.ones_like(raw_img), border_ratio=0.2) | |
| # log gt | |
| img = cv2.resize(raw_img, (self.reso, self.reso), | |
| interpolation=cv2.INTER_LANCZOS4) | |
| img = torch.from_numpy(img)[..., :3].permute( | |
| 2, 0, 1 | |
| ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range | |
| ret_dict = { | |
| 'img': img, | |
| # 'ins': str(Path(rgb_fname).parent.parent.stem), # for gso-rendering | |
| 'ins': str(Path(rgb_fname).relative_to(self.file_path)), # for gso-rendering | |
| # 'ins': rgb_fname, | |
| } | |
| return ret_dict | |
| class RealMVDataset(Dataset): | |
| def __init__( | |
| self, | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=None, | |
| classes=False, | |
| load_depth=False, | |
| test=False, | |
| scene_scale=1, | |
| overfitting=False, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| overfitting_bs=-1, | |
| interval=1, | |
| plucker_embedding=False, | |
| shuffle_across_cls=False, | |
| wds_split=1, # 4 splits to accelerate preprocessing | |
| ) -> None: | |
| super().__init__() | |
| self.file_path = file_path | |
| self.overfitting = overfitting | |
| self.scene_scale = scene_scale | |
| self.reso = reso | |
| self.reso_encoder = reso_encoder | |
| self.classes = False | |
| self.load_depth = load_depth | |
| self.preprocess = preprocess | |
| self.plucker_embedding = plucker_embedding | |
| self.rgb_list = [] | |
| all_fname = [ | |
| t for t in os.listdir(self.file_path) | |
| if t.split('.')[1] in ['png', 'jpg'] | |
| ] | |
| all_fname = [name for name in all_fname if '-input' in name ] | |
| # all_fname = [name for name in all_fname if 'sorting_board-input' in name ] | |
| # all_fname = [name for name in all_fname if 'teasure_chest-input' in name ] | |
| # all_fname = [name for name in all_fname if 'bubble_mart_blue-input' in name ] | |
| # all_fname = [name for name in all_fname if 'chair_comfort-input' in name ] | |
| self.rgb_list += ([ | |
| os.path.join(self.file_path, fname) for fname in all_fname | |
| ]) | |
| # if len(self.rgb_list) == 1: | |
| # # placeholder | |
| # self.rgb_list = self.rgb_list * 40 | |
| # ! setup normalizataion | |
| transformations = [ | |
| transforms.ToTensor(), # [0,1] range | |
| ] | |
| azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float) | |
| elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float) | |
| # zero123pp_pose, _ = generate_input_camera(1.6, [[elevations[i], azimuths[i]] for i in range(6)], fov=30) | |
| zero123pp_pose, _ = generate_input_camera(1.8, [[elevations[i], azimuths[i]] for i in range(6)], fov=30) | |
| K = torch.Tensor([1.3889, 0.0000, 0.5000, 0.0000, 1.3889, 0.5000, 0.0000, 0.0000, 0.0039]).to(zero123pp_pose) # keeps the same | |
| # st() | |
| zero123pp_pose = torch.cat([zero123pp_pose.reshape(6,-1), K.unsqueeze(0).repeat(6,1)], dim=-1) | |
| # ! directly adopt gt input | |
| # self.indices = np.array([0,2,4,5]) | |
| # eval_camera = zero123pp_pose[self.indices] | |
| # self.eval_camera = torch.cat([torch.zeros_like(eval_camera[0:1]),eval_camera], 0) # first c not used as condition here, just placeholder | |
| # ! adopt mv-diffusion output as input. | |
| # self.indices = np.array([1,0,2,4,5]) | |
| self.indices = np.array([0,1,2,3,4,5]) | |
| eval_camera = zero123pp_pose[self.indices].float().cpu().numpy() # for normalization | |
| # eval_camera = zero123pp_pose[self.indices] | |
| # self.eval_camera = eval_camera | |
| # self.eval_camera = torch.cat([torch.zeros_like(eval_camera[0:1]),eval_camera], 0) # first c not used as condition here, just placeholder | |
| # # * normalize here | |
| self.eval_camera = self.normalize_camera(eval_camera, eval_camera[0:1]) # the first img is not used. | |
| # self.mv_resize_cls = torchvision.transforms.Resize(320, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, | |
| # max_size=None, antialias=True) | |
| def normalize_camera(self, c, c_frame0): | |
| # assert c.shape[0] == self.chunk_size # 8 o r10 | |
| B = c.shape[0] | |
| camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4 | |
| canonical_camera_poses = c_frame0[:, :16].reshape(1, 4, 4) | |
| inverse_canonical_pose = np.linalg.inv(canonical_camera_poses) | |
| inverse_canonical_pose = np.repeat(inverse_canonical_pose, B, 0) | |
| cam_radius = np.linalg.norm( | |
| c_frame0[:, :16].reshape(1, 4, 4)[:, :3, 3], | |
| axis=-1, | |
| keepdims=False) # since g-buffer adopts dynamic radius here. | |
| frame1_fixed_pos = np.repeat(np.eye(4)[None], 1, axis=0) | |
| frame1_fixed_pos[:, 2, -1] = -cam_radius | |
| transform = frame1_fixed_pos @ inverse_canonical_pose | |
| new_camera_poses = np.repeat( | |
| transform, 1, axis=0 | |
| ) @ camera_poses # [V, 4, 4]. np.repeat() is th.repeat_interleave() | |
| c = np.concatenate([new_camera_poses.reshape(B, 16), c[:, 16:]], | |
| axis=-1) | |
| return c | |
| def __len__(self): | |
| return len(self.rgb_list) | |
| def __getitem__(self, index) -> Any: | |
| # return super().__getitem__(index) | |
| rgb_fname = self.rgb_list[index] | |
| raw_img = imageio.imread(rgb_fname)[..., :3] | |
| raw_img = cv2.resize(raw_img, (self.reso, self.reso), interpolation=cv2.INTER_CUBIC) | |
| raw_img = torch.from_numpy(raw_img).permute(2,0,1).clip(0,255) # [0,1] | |
| img = raw_img / 127.5 - 1 | |
| # ! if loading mv-diff output views | |
| mv_img = imageio.imread(rgb_fname.replace('-input', '')) | |
| mv_img = rearrange(mv_img, '(n h) (m w) c -> (n m) h w c', n=3, m=2)[self.indices] # (6, 3, 320, 320) | |
| mv_img = np.stack([recenter(img, np.ones_like(img), border_ratio=0.1) for img in mv_img], axis=0) | |
| mv_img = rearrange(mv_img, 'b h w c -> b c h w') # to torch tradition | |
| mv_img = torch.from_numpy(mv_img) / 127.5 - 1 | |
| ret_dict = { | |
| 'img': img, | |
| 'mv_img': mv_img, | |
| 'c': self.eval_camera, | |
| 'caption': 'null', | |
| } | |
| return ret_dict | |
| class NovelViewObjverseDataset(MultiViewObjverseDataset): | |
| """novel view prediction version. | |
| """ | |
| def __init__(self, | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=None, | |
| classes=False, | |
| load_depth=False, | |
| test=False, | |
| scene_scale=1, | |
| overfitting=False, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| overfitting_bs=-1, | |
| **kwargs): | |
| super().__init__(file_path, reso, reso_encoder, preprocess, classes, | |
| load_depth, test, scene_scale, overfitting, | |
| imgnet_normalize, dataset_size, overfitting_bs, | |
| **kwargs) | |
| def __getitem__(self, idx): | |
| input_view = super().__getitem__( | |
| idx) # get previous input view results | |
| # get novel view of the same instance | |
| novel_view = super().__getitem__( | |
| (idx // self.instance_data_length) * self.instance_data_length + | |
| random.randint(0, self.instance_data_length - 1)) | |
| # assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance' | |
| input_view.update({f'nv_{k}': v for k, v in novel_view.items()}) | |
| return input_view | |
| class MultiViewObjverseDatasetforLMDB(MultiViewObjverseDataset): | |
| def __init__( | |
| self, | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=None, | |
| classes=False, | |
| load_depth=False, | |
| test=False, | |
| scene_scale=1, | |
| overfitting=False, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| overfitting_bs=-1, | |
| shuffle_across_cls=False, | |
| wds_split=1, | |
| four_view_for_latent=False, | |
| ): | |
| super().__init__(file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess, | |
| classes, | |
| load_depth, | |
| test, | |
| scene_scale, | |
| overfitting, | |
| imgnet_normalize, | |
| dataset_size, | |
| overfitting_bs, | |
| shuffle_across_cls=shuffle_across_cls, | |
| wds_split=wds_split, | |
| four_view_for_latent=four_view_for_latent) | |
| # assert self.reso == 256 | |
| self.load_caption = True | |
| with open( | |
| # '/cpfs01/shared/V2V/V2V_hdd/yslan/aigc3d/text_captions_cap3d.json' | |
| '/nas/shared/public/yslan/data/text_captions_cap3d.json') as f: | |
| # '/nas/shared/V2V/yslan/aigc3d/text_captions_cap3d.json') as f: | |
| self.caption_data = json.load(f) | |
| # lmdb_path = '/cpfs01/user/yangpeiqing.p/yslan/data/Furnitures_uncompressed/' | |
| # with open(os.path.join(lmdb_path, 'idx_to_ins_mapping.json')) as f: | |
| # self.idx_to_ins_mapping = json.load(f) | |
| def __len__(self): | |
| return super().__len__() | |
| # return 100 # for speed debug | |
| def quantize_depth(self, depth): | |
| # https://developers.google.com/depthmap-metadata/encoding | |
| # RangeInverse encoding | |
| bg = depth == 0 | |
| depth[bg] = 3 # no need to allocate capacity to it | |
| disparity = 1 / depth | |
| far = disparity.max().item() # np array here | |
| near = disparity.min().item() | |
| # d_normalized = (far * (depth-near) / (depth * far - near)) # [0,1] range | |
| d_normalized = (disparity - near) / (far - near) # [0,1] range | |
| # imageio.imwrite('depth_negative.jpeg', (((depth - near) / (far - near) * 255)<0).numpy().astype(np.uint8)) | |
| # imageio.imwrite('depth_negative.jpeg', ((depth <0)*255).numpy().astype(np.uint8)) | |
| d_normalized = np.nan_to_num(d_normalized.cpu().numpy()) | |
| d_normalized = (np.clip(d_normalized, 0, 1) * 255).astype(np.uint8) | |
| # imageio.imwrite('depth.png', d_normalized) | |
| # d = 1 / ( (d_normalized / 255) * (far-near) + near) | |
| # diff = (d[~bg.numpy()] - depth[~bg].numpy()).sum() | |
| return d_normalized, near, far # return disp | |
| def __getitem__(self, idx): | |
| # ret_dict = super().__getitem__(idx) | |
| rgb_fname = self.rgb_list[idx] | |
| pose_fname = self.pose_list[idx] | |
| raw_img = imageio.imread(rgb_fname) # [..., :3] | |
| assert raw_img.shape[-1] == 4 | |
| # st() # cv2.imwrite('img_CV2_90.jpg', a, [int(cv2.IMWRITE_JPEG_QUALITY), 90]) | |
| # if raw_img.shape[-1] == 4: # ! set bg to white | |
| alpha_mask = raw_img[..., -1:] / 255 # [0,1] | |
| raw_img = alpha_mask * raw_img[..., :3] + ( | |
| 1 - alpha_mask) * np.ones_like(raw_img[..., :3]) * 255 | |
| raw_img = np.concatenate([raw_img, alpha_mask * 255], -1) | |
| raw_img = raw_img.astype(np.uint8) | |
| raw_img = cv2.resize(raw_img, (self.reso, self.reso), | |
| interpolation=cv2.INTER_LANCZOS4) | |
| alpha_mask = raw_img[..., -1] / 255 | |
| raw_img = raw_img[..., :3] | |
| # alpha_mask = cv2.resize(alpha_mask, (self.reso, self.reso), | |
| # interpolation=cv2.INTER_LANCZOS4) | |
| c2w = read_camera_matrix_single(pose_fname) #[1, 4, 4] -> [1, 16] | |
| c = np.concatenate([c2w.reshape(16), self.intrinsics], | |
| axis=0).reshape(25).astype( | |
| np.float32) # 25, no '1' dim needed. | |
| c = torch.from_numpy(c) | |
| # c = np.concatenate([c2w, self.intrinsics], axis=0).reshape(25) # 25, no '1' dim needed. | |
| # if self.load_depth: | |
| # depth, depth_mask, depth_mask_sr = read_dnormal(self.depth_list[idx], | |
| # try: | |
| depth, normal = read_dnormal(self.depth_list[idx], c2w[:3, 3:], | |
| self.reso, self.reso) | |
| # ! quantize depth for fast decoding | |
| # d_normalized, d_near, d_far = self.quantize_depth(depth) | |
| # ! add frame_0 alignment | |
| # try: | |
| ins = str( | |
| (Path(self.data_ins_list[idx]).relative_to(self.file_path)).parent) | |
| # if self.shuffle_across_cls: | |
| if self.load_caption: | |
| caption = self.caption_data['/'.join(ins.split('/')[1:])] | |
| bbox = self.load_bbox(torch.from_numpy(alpha_mask) > 0) | |
| else: | |
| caption = '' # since in g-alignment-xl, some instances will fail. | |
| bbox = self.load_bbox(torch.from_numpy(np.ones_like(alpha_mask)) > 0) | |
| # else: | |
| # caption = self.caption_data[ins] | |
| ret_dict = { | |
| 'normal': normal, | |
| 'raw_img': raw_img, | |
| 'c': c, | |
| # 'depth_mask': depth_mask, # 64x64 here? | |
| 'bbox': bbox, | |
| 'ins': ins, | |
| 'caption': caption, | |
| 'alpha_mask': alpha_mask, | |
| 'depth': depth, # return for pcd creation | |
| # 'd_normalized': d_normalized, | |
| # 'd_near': d_near, | |
| # 'd_far': d_far, | |
| # 'fname': rgb_fname, | |
| } | |
| return ret_dict | |
| class MultiViewObjverseDatasetforLMDB_nocaption(MultiViewObjverseDatasetforLMDB): | |
| def __init__( | |
| self, | |
| file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess=None, | |
| classes=False, | |
| load_depth=False, | |
| test=False, | |
| scene_scale=1, | |
| overfitting=False, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| overfitting_bs=-1, | |
| shuffle_across_cls=False, | |
| wds_split=1, | |
| four_view_for_latent=False, | |
| ): | |
| super().__init__(file_path, | |
| reso, | |
| reso_encoder, | |
| preprocess, | |
| classes, | |
| load_depth, | |
| test, | |
| scene_scale, | |
| overfitting, | |
| imgnet_normalize, | |
| dataset_size, | |
| overfitting_bs, | |
| shuffle_across_cls=shuffle_across_cls, | |
| wds_split=wds_split, | |
| four_view_for_latent=four_view_for_latent) | |
| self.load_caption = False | |
| class Objv_LMDBDataset_MV_Compressed(LMDBDataset_MV_Compressed): | |
| def __init__(self, | |
| lmdb_path, | |
| reso, | |
| reso_encoder, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| test=False, | |
| **kwargs): | |
| super().__init__(lmdb_path, | |
| reso, | |
| reso_encoder, | |
| imgnet_normalize, | |
| dataset_size=dataset_size, | |
| **kwargs) | |
| self.instance_data_length = 40 # ! could save some key attributes in LMDB | |
| if test: | |
| self.length = self.instance_data_length | |
| elif dataset_size > 0: | |
| self.length = dataset_size * self.instance_data_length | |
| # load caption data, and idx-to-ins mapping | |
| with open( | |
| '/cpfs01/shared/V2V/V2V_hdd/yslan/aigc3d/text_captions_cap3d.json' | |
| ) as f: | |
| self.caption_data = json.load(f) | |
| with open(os.path.join(lmdb_path, 'idx_to_ins_mapping.json')) as f: | |
| self.idx_to_ins_mapping = json.load(f) | |
| def _load_data(self, idx): | |
| # ''' | |
| raw_img, depth, c, bbox = self._load_lmdb_data(idx) | |
| # raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx) | |
| # resize depth and bbox | |
| caption = self.caption_data[self.idx_to_ins_mapping[str(idx)]] | |
| return { | |
| **self._post_process_sample(raw_img, depth), | |
| 'c': c, | |
| 'bbox': (bbox * (self.reso / 512.0)).astype(np.uint8), | |
| # 'bbox': (bbox*(self.reso/256.0)).astype(np.uint8), # TODO, double check 512 in wds? | |
| 'caption': caption | |
| } | |
| # ''' | |
| # raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx) | |
| # st() | |
| # return {} | |
| def __getitem__(self, idx): | |
| return self._load_data(idx) | |
| class Objv_LMDBDataset_MV_NoCompressed(Objv_LMDBDataset_MV_Compressed): | |
| def __init__(self, | |
| lmdb_path, | |
| reso, | |
| reso_encoder, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| test=False, | |
| **kwargs): | |
| super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize, | |
| dataset_size, test, **kwargs) | |
| def _load_data(self, idx): | |
| # ''' | |
| raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx) | |
| # resize depth and bbox | |
| caption = self.caption_data[self.idx_to_ins_mapping[str(idx)]] | |
| return { | |
| **self._post_process_sample(raw_img, depth), 'c': c, | |
| 'bbox': (bbox * (self.reso / 512.0)).astype(np.uint8), | |
| 'caption': caption | |
| } | |
| return {} | |
| class Objv_LMDBDataset_NV_NoCompressed(Objv_LMDBDataset_MV_NoCompressed): | |
| def __init__(self, | |
| lmdb_path, | |
| reso, | |
| reso_encoder, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| test=False, | |
| **kwargs): | |
| super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize, | |
| dataset_size, test, **kwargs) | |
| def __getitem__(self, idx): | |
| input_view = self._load_data(idx) # get previous input view results | |
| # get novel view of the same instance | |
| try: | |
| novel_view = self._load_data( | |
| (idx // self.instance_data_length) * | |
| self.instance_data_length + | |
| random.randint(0, self.instance_data_length - 1)) | |
| except Exception as e: | |
| raise NotImplementedError(idx) | |
| # assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance' | |
| input_view.update({f'nv_{k}': v for k, v in novel_view.items()}) | |
| return input_view | |
| class Objv_LMDBDataset_MV_Compressed_for_lmdb(LMDBDataset_MV_Compressed): | |
| def __init__(self, | |
| lmdb_path, | |
| reso, | |
| reso_encoder, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| test=False, | |
| **kwargs): | |
| super().__init__(lmdb_path, | |
| reso, | |
| reso_encoder, | |
| imgnet_normalize, | |
| dataset_size=dataset_size, | |
| **kwargs) | |
| self.instance_data_length = 40 # ! could save some key attributes in LMDB | |
| if test: | |
| self.length = self.instance_data_length | |
| elif dataset_size > 0: | |
| self.length = dataset_size * self.instance_data_length | |
| # load caption data, and idx-to-ins mapping | |
| with open( | |
| '/cpfs01/shared/V2V/V2V_hdd/yslan/aigc3d/text_captions_cap3d.json' | |
| ) as f: | |
| self.caption_data = json.load(f) | |
| with open(os.path.join(lmdb_path, 'idx_to_ins_mapping.json')) as f: | |
| self.idx_to_ins_mapping = json.load(f) | |
| # def _load_data(self, idx): | |
| # # ''' | |
| # raw_img, depth, c, bbox = self._load_lmdb_data(idx) | |
| # # resize depth and bbox | |
| # caption = self.caption_data[self.idx_to_ins_mapping[str(idx)]] | |
| # # st() | |
| # return { | |
| # **self._post_process_sample(raw_img, depth), 'c': c, | |
| # 'bbox': (bbox*(self.reso/512.0)).astype(np.uint8), | |
| # 'caption': caption | |
| # } | |
| # # ''' | |
| # # raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx) | |
| # # st() | |
| # # return {} | |
| def load_bbox(self, mask): | |
| nonzero_value = torch.nonzero(mask) | |
| height, width = nonzero_value.max(dim=0)[0] | |
| top, left = nonzero_value.min(dim=0)[0] | |
| bbox = torch.tensor([top, left, height, width], dtype=torch.float32) | |
| return bbox | |
| def __getitem__(self, idx): | |
| raw_img, depth, c, bbox = self._load_lmdb_data(idx) | |
| return {'raw_img': raw_img, 'depth': depth, 'c': c, 'bbox': bbox} | |
| class Objv_LMDBDataset_NV_Compressed(Objv_LMDBDataset_MV_Compressed): | |
| def __init__(self, | |
| lmdb_path, | |
| reso, | |
| reso_encoder, | |
| imgnet_normalize=True, | |
| dataset_size=-1, | |
| **kwargs): | |
| super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize, | |
| dataset_size, **kwargs) | |
| def __getitem__(self, idx): | |
| input_view = self._load_data(idx) # get previous input view results | |
| # get novel view of the same instance | |
| try: | |
| novel_view = self._load_data( | |
| (idx // self.instance_data_length) * | |
| self.instance_data_length + | |
| random.randint(0, self.instance_data_length - 1)) | |
| except Exception as e: | |
| raise NotImplementedError(idx) | |
| # assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance' | |
| input_view.update({f'nv_{k}': v for k, v in novel_view.items()}) | |
| return input_view | |
| # | |
| # test tar loading | |
| def load_wds_ResampledShard(file_path, | |
| batch_size, | |
| num_workers, | |
| reso, | |
| reso_encoder, | |
| test=False, | |
| preprocess=None, | |
| imgnet_normalize=True, | |
| plucker_embedding=False, | |
| decode_encode_img_only=False, | |
| load_instance=False, | |
| mv_input=False, | |
| split_chunk_input=False, | |
| duplicate_sample=True, | |
| append_depth=False, | |
| append_normal=False, | |
| gs_cam_format=False, | |
| orthog_duplicate=False, | |
| **kwargs): | |
| # return raw_img, depth, c, bbox, sample_pyd['ins.pyd'], sample_pyd['fname.pyd'] | |
| post_process_cls = PostProcess( | |
| reso, | |
| reso_encoder, | |
| imgnet_normalize=imgnet_normalize, | |
| plucker_embedding=plucker_embedding, | |
| decode_encode_img_only=decode_encode_img_only, | |
| mv_input=mv_input, | |
| split_chunk_input=split_chunk_input, | |
| duplicate_sample=duplicate_sample, | |
| append_depth=append_depth, | |
| gs_cam_format=gs_cam_format, | |
| orthog_duplicate=orthog_duplicate, | |
| append_normal=append_normal, | |
| ) | |
| # ! add shuffling | |
| if isinstance(file_path, list): # lst of shard urls | |
| all_shards = [] | |
| for url_path in file_path: | |
| all_shards.extend(wds.shardlists.expand_source(url_path)) | |
| logger.log('all_shards', all_shards) | |
| else: | |
| all_shards = file_path # to be expanded | |
| if not load_instance: # during reconstruction training, load pair | |
| if not split_chunk_input: | |
| dataset = wds.DataPipeline( | |
| wds.ResampledShards(all_shards), # url_shard | |
| # at this point we have an iterator over all the shards | |
| wds.shuffle(50), | |
| wds.split_by_worker, # if multi-node | |
| wds.tarfile_to_samples(), | |
| # add wds.split_by_node here if you are using multiple nodes | |
| wds.shuffle( | |
| 1000 | |
| ), # shuffles in the memory, leverage large RAM for more efficient loading | |
| wds.decode(wds.autodecode.basichandlers), # TODO | |
| wds.to_tuple( | |
| "sample.pyd"), # extract the pyd from top level dict | |
| wds.map(post_process_cls.decode_zip), | |
| wds.map(post_process_cls.paired_post_process | |
| ), # create input-novelview paired samples | |
| # wds.map(post_process_cls._post_process_sample), | |
| # wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading | |
| wds.batched( | |
| 16, | |
| partial=True, | |
| # collation_fn=collate | |
| ) # streaming more data at once, and rebatch later | |
| ) | |
| elif load_gzip: # deprecated, no performance improve | |
| dataset = wds.DataPipeline( | |
| wds.ResampledShards(all_shards), # url_shard | |
| # at this point we have an iterator over all the shards | |
| wds.shuffle(10), | |
| wds.split_by_worker, # if multi-node | |
| wds.tarfile_to_samples(), | |
| # add wds.split_by_node here if you are using multiple nodes | |
| # wds.shuffle( | |
| # 100 | |
| # ), # shuffles in the memory, leverage large RAM for more efficient loading | |
| wds.decode('rgb8'), # TODO | |
| # wds.decode(wds.autodecode.basichandlers), # TODO | |
| # wds.to_tuple('raw_img.jpeg', 'depth.jpeg', 'alpha_mask.jpeg', | |
| # 'd_near.npy', 'd_far.npy', "c.npy", 'bbox.npy', | |
| # 'ins.txt', 'caption.txt'), | |
| wds.to_tuple('raw_img.png', 'depth_alpha.png'), | |
| # wds.to_tuple('raw_img.jpg', "c.npy", 'bbox.npy', 'depth.pyd', 'ins.txt', 'caption.txt'), | |
| # wds.to_tuple('raw_img.jpg', "c.npy", 'bbox.npy', 'ins.txt', 'caption.txt'), | |
| wds.map(post_process_cls.decode_gzip), | |
| # wds.map(post_process_cls.paired_post_process_chunk | |
| # ), # create input-novelview paired samples | |
| wds.batched( | |
| 20, | |
| partial=True, | |
| # collation_fn=collate | |
| ) # streaming more data at once, and rebatch later | |
| ) | |
| else: | |
| dataset = wds.DataPipeline( | |
| wds.ResampledShards(all_shards), # url_shard | |
| # at this point we have an iterator over all the shards | |
| wds.shuffle(100), | |
| wds.split_by_worker, # if multi-node | |
| wds.tarfile_to_samples(), | |
| # add wds.split_by_node here if you are using multiple nodes | |
| wds.shuffle( | |
| 4000 // split_chunk_size | |
| ), # shuffles in the memory, leverage large RAM for more efficient loading | |
| wds.decode(wds.autodecode.basichandlers), # TODO | |
| wds.to_tuple( | |
| "sample.pyd"), # extract the pyd from top level dict | |
| wds.map(post_process_cls.decode_zip), | |
| wds.map(post_process_cls.paired_post_process_chunk | |
| ), # create input-novelview paired samples | |
| # wds.map(post_process_cls._post_process_sample), | |
| # wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading | |
| wds.batched( | |
| 120 // split_chunk_size, | |
| partial=True, | |
| # collation_fn=collate | |
| ) # streaming more data at once, and rebatch later | |
| ) | |
| loader_shard = wds.WebLoader( | |
| dataset, | |
| num_workers=num_workers, | |
| drop_last=False, | |
| batch_size=None, | |
| shuffle=False, | |
| persistent_workers=num_workers > 0).unbatched().shuffle( | |
| 1000 // split_chunk_size).batched(batch_size).map( | |
| post_process_cls.create_dict) | |
| if mv_input: | |
| loader_shard = loader_shard.map(post_process_cls.prepare_mv_input) | |
| else: # load single instance during test/eval | |
| assert batch_size == 1 | |
| dataset = wds.DataPipeline( | |
| wds.ResampledShards(all_shards), # url_shard | |
| # at this point we have an iterator over all the shards | |
| wds.shuffle(50), | |
| wds.split_by_worker, # if multi-node | |
| wds.tarfile_to_samples(), | |
| # add wds.split_by_node here if you are using multiple nodes | |
| wds.detshuffle( | |
| 100 | |
| ), # shuffles in the memory, leverage large RAM for more efficient loading | |
| wds.decode(wds.autodecode.basichandlers), # TODO | |
| wds.to_tuple("sample.pyd"), # extract the pyd from top level dict | |
| wds.map(post_process_cls.decode_zip), | |
| # wds.map(post_process_cls.paired_post_process), # create input-novelview paired samples | |
| wds.map(post_process_cls._post_process_batch_sample), | |
| # wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading | |
| wds.batched( | |
| 2, | |
| partial=True, | |
| # collation_fn=collate | |
| ) # streaming more data at once, and rebatch later | |
| ) | |
| loader_shard = wds.WebLoader( | |
| dataset, | |
| num_workers=num_workers, | |
| drop_last=False, | |
| batch_size=None, | |
| shuffle=False, | |
| persistent_workers=num_workers | |
| > 0).unbatched().shuffle(200).batched(batch_size).map( | |
| post_process_cls.single_instance_sample_create_dict) | |
| # persistent_workers=num_workers > 0).unbatched().batched(batch_size).map(post_process_cls.create_dict) | |
| # 1000).batched(batch_size).map(post_process_cls.create_dict) | |
| # .map(collate) | |
| # .map(collate) | |
| # .batched(batch_size) | |
| # | |
| # .unbatched().shuffle(1000).batched(batch_size).map(post_process) | |
| # # https://github.com/webdataset/webdataset/issues/187 | |
| # return next(iter(loader_shard)) | |
| #return dataset | |
| return loader_shard | |
| class PostProcessForDiff: | |
| def __init__( | |
| self, | |
| reso, | |
| reso_encoder, | |
| imgnet_normalize, | |
| plucker_embedding, | |
| decode_encode_img_only, | |
| mv_latent_dir, | |
| ) -> None: | |
| self.plucker_embedding = plucker_embedding | |
| self.mv_latent_dir = mv_latent_dir | |
| self.decode_encode_img_only = decode_encode_img_only | |
| transformations = [ | |
| transforms.ToTensor(), # [0,1] range | |
| ] | |
| if imgnet_normalize: | |
| transformations.append( | |
| transforms.Normalize((0.485, 0.456, 0.406), | |
| (0.229, 0.224, 0.225)) # type: ignore | |
| ) | |
| else: | |
| transformations.append( | |
| transforms.Normalize((0.5, 0.5, 0.5), | |
| (0.5, 0.5, 0.5))) # type: ignore | |
| self.normalize = transforms.Compose(transformations) | |
| self.reso_encoder = reso_encoder | |
| self.reso = reso | |
| self.instance_data_length = 40 | |
| # self.pair_per_instance = 1 # compat | |
| self.pair_per_instance = 2 # check whether improves IO | |
| # self.pair_per_instance = 3 # check whether improves IO | |
| # self.pair_per_instance = 4 # check whether improves IO | |
| self.camera = torch.load('eval_pose.pt', map_location='cpu').numpy() | |
| self.canonical_frame = self.camera[25:26] # 1, 25 # inverse this | |
| self.canonical_frame_pos = self.canonical_frame[:, :16].reshape(4, 4) | |
| def get_rays_kiui(self, c, opengl=True): | |
| h, w = self.reso_encoder, self.reso_encoder | |
| intrinsics, pose = c[16:], c[:16].reshape(4, 4) | |
| # cx, cy, fx, fy = intrinsics[2], intrinsics[5] | |
| fx = fy = 525 # pixel space | |
| cx = cy = 256 # rendering default K | |
| factor = self.reso / (cx * 2) # 128 / 512 | |
| fx = fx * factor | |
| fy = fy * factor | |
| x, y = torch.meshgrid( | |
| torch.arange(w, device=pose.device), | |
| torch.arange(h, device=pose.device), | |
| indexing="xy", | |
| ) | |
| x = x.flatten() | |
| y = y.flatten() | |
| cx = w * 0.5 | |
| cy = h * 0.5 | |
| # focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy)) | |
| camera_dirs = F.pad( | |
| torch.stack( | |
| [ | |
| (x - cx + 0.5) / fx, | |
| (y - cy + 0.5) / fy * (-1.0 if opengl else 1.0), | |
| ], | |
| dim=-1, | |
| ), | |
| (0, 1), | |
| value=(-1.0 if opengl else 1.0), | |
| ) # [hw, 3] | |
| rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3] | |
| rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3] | |
| rays_o = rays_o.view(h, w, 3) | |
| rays_d = safe_normalize(rays_d).view(h, w, 3) | |
| return rays_o, rays_d | |
| def gen_rays(self, c): | |
| # Generate rays | |
| intrinsics, c2w = c[16:], c[:16].reshape(4, 4) | |
| self.h = self.reso_encoder | |
| self.w = self.reso_encoder | |
| yy, xx = torch.meshgrid( | |
| torch.arange(self.h, dtype=torch.float32) + 0.5, | |
| torch.arange(self.w, dtype=torch.float32) + 0.5, | |
| indexing='ij') | |
| # normalize to 0-1 pixel range | |
| yy = yy / self.h | |
| xx = xx / self.w | |
| # K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3) | |
| cx, cy, fx, fy = intrinsics[2], intrinsics[5], intrinsics[ | |
| 0], intrinsics[4] | |
| # cx *= self.w | |
| # cy *= self.h | |
| # f_x = f_y = fx * h / res_raw | |
| c2w = torch.from_numpy(c2w).float() | |
| xx = (xx - cx) / fx | |
| yy = (yy - cy) / fy | |
| zz = torch.ones_like(xx) | |
| dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention | |
| dirs /= torch.norm(dirs, dim=-1, keepdim=True) | |
| dirs = dirs.reshape(-1, 3, 1) | |
| del xx, yy, zz | |
| # st() | |
| dirs = (c2w[None, :3, :3] @ dirs)[..., 0] | |
| origins = c2w[None, :3, 3].expand(self.h * self.w, -1).contiguous() | |
| origins = origins.view(self.h, self.w, 3) | |
| dirs = dirs.view(self.h, self.w, 3) | |
| return origins, dirs | |
| def normalize_camera(self, c): | |
| # assert c.shape[0] == self.chunk_size # 8 o r10 | |
| c = c[None] # api compat | |
| B = c.shape[0] | |
| camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4 | |
| cam_radius = np.linalg.norm( | |
| self.canonical_frame_pos.reshape(4, 4)[:3, 3], | |
| axis=-1, | |
| keepdims=False) # since g-buffer adopts dynamic radius here. | |
| frame1_fixed_pos = np.eye(4) | |
| frame1_fixed_pos[2, -1] = -cam_radius | |
| transform = frame1_fixed_pos @ np.linalg.inv( | |
| self.canonical_frame_pos) # 4,4 | |
| # from LGM, https://github.com/3DTopia/LGM/blob/fe8d12cff8c827df7bb77a3c8e8b37408cb6fe4c/core/provider_objaverse.py#L127 | |
| # transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(c[[0,4]]) | |
| new_camera_poses = transform[None] @ camera_poses # [V, 4, 4] | |
| c = np.concatenate([new_camera_poses.reshape(B, 16), c[:, 16:]], | |
| axis=-1) | |
| return c[0] | |
| def _post_process_sample(self, data_sample): | |
| # raw_img, depth, c, bbox, caption, ins = data_sample | |
| raw_img, c, caption, ins = data_sample | |
| # c = self.normalize_camera(c) @ if relative pose. | |
| img = raw_img # 256x256 | |
| img = torch.from_numpy(img).permute(2, 0, 1) / 127.5 - 1 | |
| # load latent. | |
| # latent_path = Path(self.mv_latent_dir, ins, 'latent.npy') # ! a converged version, before adding augmentation | |
| # if random.random() > 0.5: | |
| # latent_path = Path(self.mv_latent_dir, ins, 'latent.npy') | |
| # else: # augmentation, double the dataset | |
| latent_path = Path( | |
| self.mv_latent_dir.replace('v=4-final', 'v=4-rotate'), ins, | |
| 'latent.npy') | |
| latent = np.load(latent_path) | |
| # return (img_to_encoder, img, c, caption, ins) | |
| return (latent, img, c, caption, ins) | |
| def rand_sample_idx(self): | |
| return random.randint(0, self.instance_data_length - 1) | |
| def rand_pair(self): | |
| return (self.rand_sample_idx() for _ in range(2)) | |
| def paired_post_process(self, sample): | |
| # repeat n times? | |
| all_inp_list = [] | |
| all_nv_list = [] | |
| caption, ins = sample[-2:] | |
| # expanded_return = [] | |
| for _ in range(self.pair_per_instance): | |
| cano_idx, nv_idx = self.rand_pair() | |
| cano_sample = self._post_process_sample(item[cano_idx] | |
| for item in sample[:-2]) | |
| nv_sample = self._post_process_sample(item[nv_idx] | |
| for item in sample[:-2]) | |
| all_inp_list.extend(cano_sample) | |
| all_nv_list.extend(nv_sample) | |
| return (*all_inp_list, *all_nv_list, caption, ins) | |
| # return [cano_sample, nv_sample, caption, ins] | |
| # return (*cano_sample, *nv_sample, caption, ins) | |
| # def single_sample_create_dict(self, sample, prefix=''): | |
| # # if len(sample) == 1: | |
| # # sample = sample[0] | |
| # # assert len(sample) == 6 | |
| # img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample | |
| # return { | |
| # # **sample, | |
| # f'{prefix}img_to_encoder': img_to_encoder, | |
| # f'{prefix}img': img, | |
| # f'{prefix}depth_mask': fg_mask_reso, | |
| # f'{prefix}depth': depth_reso, | |
| # f'{prefix}c': c, | |
| # f'{prefix}bbox': bbox, | |
| # } | |
| def single_sample_create_dict(self, sample, prefix=''): | |
| # if len(sample) == 1: | |
| # sample = sample[0] | |
| # assert len(sample) == 6 | |
| # img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample | |
| # img_to_encoder, img, c, caption, ins = sample | |
| # img, c, caption, ins = sample | |
| latent, img, c, caption, ins = sample | |
| # load latent | |
| return { | |
| # **sample, | |
| # 'img_to_encoder': img_to_encoder, | |
| 'latent': latent, | |
| 'img': img, | |
| 'c': c, | |
| 'caption': caption, | |
| 'ins': ins | |
| } | |
| def decode_zip(self, sample_pyd, shape=(256, 256)): | |
| if isinstance(sample_pyd, tuple): | |
| sample_pyd = sample_pyd[0] | |
| assert isinstance(sample_pyd, dict) | |
| raw_img = decompress_and_open_image_gzip( | |
| sample_pyd['raw_img'], | |
| is_img=True, | |
| decompress=True, | |
| decompress_fn=lz4.frame.decompress) | |
| caption = sample_pyd['caption'].decode('utf-8') | |
| ins = sample_pyd['ins'].decode('utf-8') | |
| c = decompress_array(sample_pyd['c'], (25, ), | |
| np.float32, | |
| decompress=True, | |
| decompress_fn=lz4.frame.decompress) | |
| # bbox = decompress_array( | |
| # sample_pyd['bbox'], | |
| # ( | |
| # 40, | |
| # 4, | |
| # ), | |
| # np.float32, | |
| # # decompress=False) | |
| # decompress=True, | |
| # decompress_fn=lz4.frame.decompress) | |
| # if self.decode_encode_img_only: | |
| # depth = np.zeros(shape=(40, *shape)) # save loading time | |
| # else: | |
| # depth = decompress_array(sample_pyd['depth'], (40, *shape), | |
| # np.float32, | |
| # decompress=True, | |
| # decompress_fn=lz4.frame.decompress) | |
| # return {'raw_img': raw_img, 'depth': depth, 'bbox': bbox, 'caption': caption, 'ins': ins, 'c': c} | |
| # return raw_img, depth, c, bbox, caption, ins | |
| # return raw_img, bbox, caption, ins | |
| # return bbox, caption, ins | |
| return raw_img, c, caption, ins | |
| # ! run single-instance pipeline first | |
| # return raw_img[0], depth[0], c[0], bbox[0], caption, ins | |
| def create_dict(self, sample): | |
| # sample = [item[0] for item in sample] # wds wrap items in [] | |
| # cano_sample_list = [[] for _ in range(6)] | |
| # nv_sample_list = [[] for _ in range(6)] | |
| # for idx in range(0, self.pair_per_instance): | |
| # cano_sample = sample[6*idx:6*(idx+1)] | |
| # nv_sample = sample[6*self.pair_per_instance+6*idx:6*self.pair_per_instance+6*(idx+1)] | |
| # for item_idx in range(6): | |
| # cano_sample_list[item_idx].append(cano_sample[item_idx]) | |
| # nv_sample_list[item_idx].append(nv_sample[item_idx]) | |
| # # ! cycle input/output view for more pairs | |
| # cano_sample_list[item_idx].append(nv_sample[item_idx]) | |
| # nv_sample_list[item_idx].append(cano_sample[item_idx]) | |
| cano_sample = self.single_sample_create_dict(sample, prefix='') | |
| # nv_sample = self.single_sample_create_dict((torch.cat(item_list) for item_list in nv_sample_list) , prefix='nv_') | |
| return cano_sample | |
| # return { | |
| # **cano_sample, | |
| # # **nv_sample, | |
| # 'caption': sample[-2], | |
| # 'ins': sample[-1] | |
| # } | |
| # test tar loading | |
| def load_wds_diff_ResampledShard(file_path, | |
| batch_size, | |
| num_workers, | |
| reso, | |
| reso_encoder, | |
| test=False, | |
| preprocess=None, | |
| imgnet_normalize=True, | |
| plucker_embedding=False, | |
| decode_encode_img_only=False, | |
| mv_latent_dir='', | |
| **kwargs): | |
| # return raw_img, depth, c, bbox, sample_pyd['ins.pyd'], sample_pyd['fname.pyd'] | |
| post_process_cls = PostProcessForDiff( | |
| reso, | |
| reso_encoder, | |
| imgnet_normalize=imgnet_normalize, | |
| plucker_embedding=plucker_embedding, | |
| decode_encode_img_only=decode_encode_img_only, | |
| mv_latent_dir=mv_latent_dir, | |
| ) | |
| if isinstance(file_path, list): # lst of shard urls | |
| all_shards = [] | |
| for url_path in file_path: | |
| all_shards.extend(wds.shardlists.expand_source(url_path)) | |
| logger.log('all_shards', all_shards) | |
| else: | |
| all_shards = file_path # to be expanded | |
| dataset = wds.DataPipeline( | |
| wds.ResampledShards(all_shards), # url_shard | |
| # at this point we have an iterator over all the shards | |
| wds.shuffle(100), | |
| wds.split_by_worker, # if multi-node | |
| wds.tarfile_to_samples(), | |
| # add wds.split_by_node here if you are using multiple nodes | |
| wds.shuffle( | |
| 20000 | |
| ), # shuffles in the memory, leverage large RAM for more efficient loading | |
| wds.decode(wds.autodecode.basichandlers), # TODO | |
| wds.to_tuple("sample.pyd"), # extract the pyd from top level dict | |
| wds.map(post_process_cls.decode_zip), | |
| # wds.map(post_process_cls.paired_post_process), # create input-novelview paired samples | |
| wds.map(post_process_cls._post_process_sample), | |
| # wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading | |
| wds.batched( | |
| 100, | |
| partial=True, | |
| # collation_fn=collate | |
| ) # streaming more data at once, and rebatch later | |
| ) | |
| loader_shard = wds.WebLoader( | |
| dataset, | |
| num_workers=num_workers, | |
| drop_last=False, | |
| batch_size=None, | |
| shuffle=False, | |
| persistent_workers=num_workers | |
| > 0).unbatched().shuffle(2500).batched(batch_size).map( | |
| post_process_cls.create_dict) | |
| # persistent_workers=num_workers > 0).unbatched().batched(batch_size).map(post_process_cls.create_dict) | |
| # 1000).batched(batch_size).map(post_process_cls.create_dict) | |
| # .map(collate) | |
| # .map(collate) | |
| # .batched(batch_size) | |
| # | |
| # .unbatched().shuffle(1000).batched(batch_size).map(post_process) | |
| # # https://github.com/webdataset/webdataset/issues/187 | |
| # return next(iter(loader_shard)) | |
| #return dataset | |
| return loader_shard | |
| def load_wds_data( | |
| file_path="", | |
| reso=64, | |
| reso_encoder=224, | |
| batch_size=1, | |
| num_workers=6, | |
| plucker_embedding=False, | |
| decode_encode_img_only=False, | |
| load_wds_diff=False, | |
| load_wds_latent=False, | |
| load_instance=False, # for evaluation | |
| mv_input=False, | |
| split_chunk_input=False, | |
| duplicate_sample=True, | |
| mv_latent_dir='', | |
| append_depth=False, | |
| gs_cam_format=False, | |
| orthog_duplicate=False, | |
| **args): | |
| if load_wds_diff: | |
| # assert num_workers == 0 # on aliyun, worker=0 performs much much faster | |
| wds_loader = load_wds_diff_ResampledShard( | |
| file_path, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| reso=reso, | |
| reso_encoder=reso_encoder, | |
| plucker_embedding=plucker_embedding, | |
| decode_encode_img_only=decode_encode_img_only, | |
| mv_input=mv_input, | |
| split_chunk_input=split_chunk_input, | |
| append_depth=append_depth, | |
| mv_latent_dir=mv_latent_dir, | |
| gs_cam_format=gs_cam_format, | |
| orthog_duplicate=orthog_duplicate, | |
| ) | |
| elif load_wds_latent: | |
| # for diffusion training, cache latent | |
| wds_loader = load_wds_latent_ResampledShard( | |
| file_path, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| reso=reso, | |
| reso_encoder=reso_encoder, | |
| plucker_embedding=plucker_embedding, | |
| decode_encode_img_only=decode_encode_img_only, | |
| mv_input=mv_input, | |
| split_chunk_input=split_chunk_input, | |
| ) | |
| # elif load_instance: | |
| # wds_loader = load_wds_instance_ResampledShard( | |
| # file_path, | |
| # batch_size=batch_size, | |
| # num_workers=num_workers, | |
| # reso=reso, | |
| # reso_encoder=reso_encoder, | |
| # plucker_embedding=plucker_embedding, | |
| # decode_encode_img_only=decode_encode_img_only | |
| # ) | |
| else: | |
| wds_loader = load_wds_ResampledShard( | |
| file_path, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| reso=reso, | |
| reso_encoder=reso_encoder, | |
| plucker_embedding=plucker_embedding, | |
| decode_encode_img_only=decode_encode_img_only, | |
| load_instance=load_instance, | |
| mv_input=mv_input, | |
| split_chunk_input=split_chunk_input, | |
| duplicate_sample=duplicate_sample, | |
| append_depth=append_depth, | |
| gs_cam_format=gs_cam_format, | |
| orthog_duplicate=orthog_duplicate, | |
| ) | |
| while True: | |
| yield from wds_loader | |
| # yield from wds_loader | |
| class PostProcess_forlatent: | |
| def __init__( | |
| self, | |
| reso, | |
| reso_encoder, | |
| imgnet_normalize, | |
| plucker_embedding, | |
| decode_encode_img_only, | |
| ) -> None: | |
| self.plucker_embedding = plucker_embedding | |
| self.decode_encode_img_only = decode_encode_img_only | |
| transformations = [ | |
| transforms.ToTensor(), # [0,1] range | |
| ] | |
| if imgnet_normalize: | |
| transformations.append( | |
| transforms.Normalize((0.485, 0.456, 0.406), | |
| (0.229, 0.224, 0.225)) # type: ignore | |
| ) | |
| else: | |
| transformations.append( | |
| transforms.Normalize((0.5, 0.5, 0.5), | |
| (0.5, 0.5, 0.5))) # type: ignore | |
| self.normalize = transforms.Compose(transformations) | |
| self.reso_encoder = reso_encoder | |
| self.reso = reso | |
| self.instance_data_length = 40 | |
| # self.pair_per_instance = 1 # compat | |
| self.pair_per_instance = 2 # check whether improves IO | |
| # self.pair_per_instance = 3 # check whether improves IO | |
| # self.pair_per_instance = 4 # check whether improves IO | |
| def _post_process_sample(self, data_sample): | |
| # raw_img, depth, c, bbox, caption, ins = data_sample | |
| raw_img, c, caption, ins = data_sample | |
| # bbox = (bbox*(self.reso/256)).astype(np.uint8) # normalize bbox to the reso range | |
| if raw_img.shape[-2] != self.reso_encoder: | |
| img_to_encoder = cv2.resize(raw_img, | |
| (self.reso_encoder, self.reso_encoder), | |
| interpolation=cv2.INTER_LANCZOS4) | |
| else: | |
| img_to_encoder = raw_img | |
| img_to_encoder = self.normalize(img_to_encoder) | |
| if self.plucker_embedding: | |
| rays_o, rays_d = self.gen_rays(c) | |
| rays_plucker = torch.cat( | |
| [torch.cross(rays_o, rays_d, dim=-1), rays_d], | |
| dim=-1).permute(2, 0, 1) # [h, w, 6] -> 6,h,w | |
| img_to_encoder = torch.cat([img_to_encoder, rays_plucker], 0) | |
| img = cv2.resize(raw_img, (self.reso, self.reso), | |
| interpolation=cv2.INTER_LANCZOS4) | |
| img = torch.from_numpy(img).permute(2, 0, 1) / 127.5 - 1 | |
| return (img_to_encoder, img, c, caption, ins) | |
| def rand_sample_idx(self): | |
| return random.randint(0, self.instance_data_length - 1) | |
| def rand_pair(self): | |
| return (self.rand_sample_idx() for _ in range(2)) | |
| def paired_post_process(self, sample): | |
| # repeat n times? | |
| all_inp_list = [] | |
| all_nv_list = [] | |
| caption, ins = sample[-2:] | |
| # expanded_return = [] | |
| for _ in range(self.pair_per_instance): | |
| cano_idx, nv_idx = self.rand_pair() | |
| cano_sample = self._post_process_sample(item[cano_idx] | |
| for item in sample[:-2]) | |
| nv_sample = self._post_process_sample(item[nv_idx] | |
| for item in sample[:-2]) | |
| all_inp_list.extend(cano_sample) | |
| all_nv_list.extend(nv_sample) | |
| return (*all_inp_list, *all_nv_list, caption, ins) | |
| # return [cano_sample, nv_sample, caption, ins] | |
| # return (*cano_sample, *nv_sample, caption, ins) | |
| def paired_post_process(self, sample): | |
| # repeat n times? | |
| all_inp_list = [] | |
| all_nv_list = [] | |
| caption, ins = sample[-2:] | |
| # expanded_return = [] | |
| for _ in range(self.pair_per_instance): | |
| cano_idx, nv_idx = self.rand_pair() | |
| cano_sample = self._post_process_sample(item[cano_idx] | |
| for item in sample[:-2]) | |
| nv_sample = self._post_process_sample(item[nv_idx] | |
| for item in sample[:-2]) | |
| all_inp_list.extend(cano_sample) | |
| all_nv_list.extend(nv_sample) | |
| return (*all_inp_list, *all_nv_list, caption, ins) | |
| # return [cano_sample, nv_sample, caption, ins] | |
| # return (*cano_sample, *nv_sample, caption, ins) | |
| # def single_sample_create_dict(self, sample, prefix=''): | |
| # # if len(sample) == 1: | |
| # # sample = sample[0] | |
| # # assert len(sample) == 6 | |
| # img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample | |
| # return { | |
| # # **sample, | |
| # f'{prefix}img_to_encoder': img_to_encoder, | |
| # f'{prefix}img': img, | |
| # f'{prefix}depth_mask': fg_mask_reso, | |
| # f'{prefix}depth': depth_reso, | |
| # f'{prefix}c': c, | |
| # f'{prefix}bbox': bbox, | |
| # } | |
| def single_sample_create_dict(self, sample, prefix=''): | |
| # if len(sample) == 1: | |
| # sample = sample[0] | |
| # assert len(sample) == 6 | |
| # img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample | |
| img_to_encoder, img, c, caption, ins = sample | |
| return { | |
| # **sample, | |
| 'img_to_encoder': img_to_encoder, | |
| 'img': img, | |
| 'c': c, | |
| 'caption': caption, | |
| 'ins': ins | |
| } | |
| def decode_zip(self, sample_pyd, shape=(256, 256)): | |
| if isinstance(sample_pyd, tuple): | |
| sample_pyd = sample_pyd[0] | |
| assert isinstance(sample_pyd, dict) | |
| latent = sample_pyd['latent'] | |
| caption = sample_pyd['caption'].decode('utf-8') | |
| c = sample_pyd['c'] | |
| # img = sample_pyd['img'] | |
| # st() | |
| return latent, caption, c | |
| def create_dict(self, sample): | |
| return { | |
| # **sample, | |
| 'latent': sample[0], | |
| 'caption': sample[1], | |
| 'c': sample[2], | |
| } | |
| # test tar loading | |
| def load_wds_latent_ResampledShard(file_path, | |
| batch_size, | |
| num_workers, | |
| reso, | |
| reso_encoder, | |
| test=False, | |
| preprocess=None, | |
| imgnet_normalize=True, | |
| plucker_embedding=False, | |
| decode_encode_img_only=False, | |
| **kwargs): | |
| # return raw_img, depth, c, bbox, sample_pyd['ins.pyd'], sample_pyd['fname.pyd'] | |
| post_process_cls = PostProcess_forlatent( | |
| reso, | |
| reso_encoder, | |
| imgnet_normalize=imgnet_normalize, | |
| plucker_embedding=plucker_embedding, | |
| decode_encode_img_only=decode_encode_img_only, | |
| ) | |
| if isinstance(file_path, list): # lst of shard urls | |
| all_shards = [] | |
| for url_path in file_path: | |
| all_shards.extend(wds.shardlists.expand_source(url_path)) | |
| logger.log('all_shards', all_shards) | |
| else: | |
| all_shards = file_path # to be expanded | |
| dataset = wds.DataPipeline( | |
| wds.ResampledShards(all_shards), # url_shard | |
| # at this point we have an iterator over all the shards | |
| wds.shuffle(50), | |
| wds.split_by_worker, # if multi-node | |
| wds.tarfile_to_samples(), | |
| # add wds.split_by_node here if you are using multiple nodes | |
| wds.detshuffle( | |
| 2500 | |
| ), # shuffles in the memory, leverage large RAM for more efficient loading | |
| wds.decode(wds.autodecode.basichandlers), # TODO | |
| wds.to_tuple("sample.pyd"), # extract the pyd from top level dict | |
| wds.map(post_process_cls.decode_zip), | |
| # wds.map(post_process_cls._post_process_sample), | |
| # wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading | |
| wds.batched( | |
| 150, | |
| partial=True, | |
| # collation_fn=collate | |
| ) # streaming more data at once, and rebatch later | |
| ) | |
| loader_shard = wds.WebLoader( | |
| dataset, | |
| num_workers=num_workers, | |
| drop_last=False, | |
| batch_size=None, | |
| shuffle=False, | |
| persistent_workers=num_workers | |
| > 0).unbatched().shuffle(1000).batched(batch_size).map( | |
| post_process_cls.create_dict) | |
| # persistent_workers=num_workers > 0).unbatched().batched(batch_size).map(post_process_cls.create_dict) | |
| # 1000).batched(batch_size).map(post_process_cls.create_dict) | |
| # .map(collate) | |
| # .map(collate) | |
| # .batched(batch_size) | |
| # | |
| # .unbatched().shuffle(1000).batched(batch_size).map(post_process) | |
| # # https://github.com/webdataset/webdataset/issues/187 | |
| # return next(iter(loader_shard)) | |
| #return dataset | |
| return loader_shard | |