Spaces:
Running
on
Zero
Running
on
Zero
| import PIL | |
| import numpy as np | |
| import torch | |
| import random | |
| import itertools | |
| from dust3r.datasets.base.easy_dataset import EasyDataset | |
| from dust3r.datasets.utils.transforms import ImgNorm, SeqColorJitter | |
| from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates | |
| import dust3r.datasets.utils.cropping as cropping | |
| from dust3r.datasets.utils.corr import extract_correspondences_from_pts3d | |
| def get_ray_map(c2w1, c2w2, intrinsics, h, w): | |
| c2w = np.linalg.inv(c2w1) @ c2w2 | |
| i, j = np.meshgrid(np.arange(w), np.arange(h), indexing="xy") | |
| grid = np.stack([i, j, np.ones_like(i)], axis=-1) | |
| ro = c2w[:3, 3] | |
| rd = np.linalg.inv(intrinsics) @ grid.reshape(-1, 3).T | |
| rd = (c2w @ np.vstack([rd, np.ones_like(rd[0])])).T[:, :3].reshape(h, w, 3) | |
| rd = rd / np.linalg.norm(rd, axis=-1, keepdims=True) | |
| ro = np.broadcast_to(ro, (h, w, 3)) | |
| ray_map = np.concatenate([ro, rd], axis=-1) | |
| return ray_map | |
| class BaseMultiViewDataset(EasyDataset): | |
| """Define all basic options. | |
| Usage: | |
| class MyDataset (BaseMultiViewDataset): | |
| def _get_views(self, idx, rng): | |
| # overload here | |
| views = [] | |
| views.append(dict(img=, ...)) | |
| return views | |
| """ | |
| def __init__( | |
| self, | |
| *, # only keyword arguments | |
| num_views=None, | |
| split=None, | |
| resolution=None, # square_size or (width, height) or list of [(width,height), ...] | |
| transform=ImgNorm, | |
| aug_crop=False, | |
| n_corres=0, | |
| nneg=0, | |
| seed=None, | |
| allow_repeat=False, | |
| seq_aug_crop=False, | |
| ): | |
| assert num_views is not None, "undefined num_views" | |
| self.num_views = num_views | |
| self.split = split | |
| self._set_resolutions(resolution) | |
| self.n_corres = n_corres | |
| self.nneg = nneg | |
| assert ( | |
| self.n_corres == "all" | |
| or isinstance(self.n_corres, int) | |
| or ( | |
| isinstance(self.n_corres, list) and len(self.n_corres) == self.num_views | |
| ) | |
| ), f"Error, n_corres should either be 'all', a single integer or a list of length {self.num_views}" | |
| assert ( | |
| self.nneg == 0 or self.n_corres != "all" | |
| ), "nneg should be 0 if n_corres is all" | |
| self.is_seq_color_jitter = False | |
| if isinstance(transform, str): | |
| transform = eval(transform) | |
| if transform == SeqColorJitter: | |
| transform = SeqColorJitter() | |
| self.is_seq_color_jitter = True | |
| self.transform = transform | |
| self.aug_crop = aug_crop | |
| self.seed = seed | |
| self.allow_repeat = allow_repeat | |
| self.seq_aug_crop = seq_aug_crop | |
| def __len__(self): | |
| return len(self.scenes) | |
| def efficient_random_intervals( | |
| start, | |
| num_elements, | |
| interval_range, | |
| fixed_interval_prob=0.8, | |
| weights=None, | |
| seed=42, | |
| ): | |
| if random.random() < fixed_interval_prob: | |
| intervals = random.choices(interval_range, weights=weights) * ( | |
| num_elements - 1 | |
| ) | |
| else: | |
| intervals = [ | |
| random.choices(interval_range, weights=weights)[0] | |
| for _ in range(num_elements - 1) | |
| ] | |
| return list(itertools.accumulate([start] + intervals)) | |
| def sample_based_on_timestamps(self, i, timestamps, num_views, interval=1): | |
| time_diffs = np.abs(timestamps - timestamps[i]) | |
| ids_candidate = np.where(time_diffs < interval)[0] | |
| ids_candidate = np.sort(ids_candidate) | |
| if (self.allow_repeat and len(ids_candidate) < num_views // 3) or ( | |
| len(ids_candidate) < num_views | |
| ): | |
| return [] | |
| ids_sel_list = [] | |
| ids_candidate_left = ids_candidate.copy() | |
| while len(ids_candidate_left) >= num_views: | |
| ids_sel = np.random.choice(ids_candidate_left, num_views, replace=False) | |
| ids_sel_list.append(sorted(ids_sel)) | |
| ids_candidate_left = np.setdiff1d(ids_candidate_left, ids_sel) | |
| if len(ids_candidate_left) > 0 and len(ids_candidate) >= num_views: | |
| ids_sel = np.concatenate( | |
| [ | |
| ids_candidate_left, | |
| np.random.choice( | |
| np.setdiff1d(ids_candidate, ids_candidate_left), | |
| num_views - len(ids_candidate_left), | |
| replace=False, | |
| ), | |
| ] | |
| ) | |
| ids_sel_list.append(sorted(ids_sel)) | |
| if self.allow_repeat: | |
| ids_sel_list.append( | |
| sorted(np.random.choice(ids_candidate, num_views, replace=True)) | |
| ) | |
| # add sequences with fixed intervals (all possible intervals) | |
| pos_i = np.where(ids_candidate == i)[0][0] | |
| curr_interval = 1 | |
| stop = len(ids_candidate) < num_views | |
| while not stop: | |
| pos_sel = [pos_i] | |
| count = 0 | |
| while len(pos_sel) < num_views: | |
| if count % 2 == 0: | |
| curr_pos_i = pos_sel[-1] + curr_interval | |
| if curr_pos_i >= len(ids_candidate): | |
| stop = True | |
| break | |
| pos_sel.append(curr_pos_i) | |
| else: | |
| curr_pos_i = pos_sel[0] - curr_interval | |
| if curr_pos_i < 0: | |
| stop = True | |
| break | |
| pos_sel.insert(0, curr_pos_i) | |
| count += 1 | |
| if not stop and len(pos_sel) == num_views: | |
| ids_sel = sorted([ids_candidate[pos] for pos in pos_sel]) | |
| if ids_sel not in ids_sel_list: | |
| ids_sel_list.append(ids_sel) | |
| curr_interval += 1 | |
| return ids_sel_list | |
| def blockwise_shuffle(x, rng, block_shuffle): | |
| if block_shuffle is None: | |
| return rng.permutation(x).tolist() | |
| else: | |
| assert block_shuffle > 0 | |
| blocks = [x[i : i + block_shuffle] for i in range(0, len(x), block_shuffle)] | |
| shuffled_blocks = [rng.permutation(block).tolist() for block in blocks] | |
| shuffled_list = [item for block in shuffled_blocks for item in block] | |
| return shuffled_list | |
| def get_seq_from_start_id( | |
| self, | |
| num_views, | |
| id_ref, | |
| ids_all, | |
| rng, | |
| min_interval=1, | |
| max_interval=25, | |
| video_prob=0.5, | |
| fix_interval_prob=0.5, | |
| block_shuffle=None, | |
| ): | |
| """ | |
| args: | |
| num_views: number of views to return | |
| id_ref: the reference id (first id) | |
| ids_all: all the ids | |
| rng: random number generator | |
| max_interval: maximum interval between two views | |
| returns: | |
| pos: list of positions of the views in ids_all, i.e., index for ids_all | |
| is_video: True if the views are consecutive | |
| """ | |
| assert min_interval > 0, f"min_interval should be > 0, got {min_interval}" | |
| assert ( | |
| min_interval <= max_interval | |
| ), f"min_interval should be <= max_interval, got {min_interval} and {max_interval}" | |
| assert id_ref in ids_all | |
| pos_ref = ids_all.index(id_ref) | |
| all_possible_pos = np.arange(pos_ref, len(ids_all)) | |
| remaining_sum = len(ids_all) - 1 - pos_ref | |
| if remaining_sum >= num_views - 1: | |
| if remaining_sum == num_views - 1: | |
| assert ids_all[-num_views] == id_ref | |
| return [pos_ref + i for i in range(num_views)], True | |
| max_interval = min(max_interval, 2 * remaining_sum // (num_views - 1)) | |
| intervals = [ | |
| rng.choice(range(min_interval, max_interval + 1)) | |
| for _ in range(num_views - 1) | |
| ] | |
| # if video or collection | |
| if rng.random() < video_prob: | |
| # if fixed interval or random | |
| if rng.random() < fix_interval_prob: | |
| # regular interval | |
| fixed_interval = rng.choice( | |
| range( | |
| 1, | |
| min(remaining_sum // (num_views - 1) + 1, max_interval + 1), | |
| ) | |
| ) | |
| intervals = [fixed_interval for _ in range(num_views - 1)] | |
| is_video = True | |
| else: | |
| is_video = False | |
| pos = list(itertools.accumulate([pos_ref] + intervals)) | |
| pos = [p for p in pos if p < len(ids_all)] | |
| pos_candidates = [p for p in all_possible_pos if p not in pos] | |
| pos = ( | |
| pos | |
| + rng.choice( | |
| pos_candidates, num_views - len(pos), replace=False | |
| ).tolist() | |
| ) | |
| pos = ( | |
| sorted(pos) | |
| if is_video | |
| else self.blockwise_shuffle(pos, rng, block_shuffle) | |
| ) | |
| else: | |
| # assert self.allow_repeat | |
| uniq_num = remaining_sum | |
| new_pos_ref = rng.choice(np.arange(pos_ref + 1)) | |
| new_remaining_sum = len(ids_all) - 1 - new_pos_ref | |
| new_max_interval = min(max_interval, new_remaining_sum // (uniq_num - 1)) | |
| new_intervals = [ | |
| rng.choice(range(1, new_max_interval + 1)) for _ in range(uniq_num - 1) | |
| ] | |
| revisit_random = rng.random() | |
| video_random = rng.random() | |
| if rng.random() < fix_interval_prob and video_random < video_prob: | |
| # regular interval | |
| fixed_interval = rng.choice(range(1, new_max_interval + 1)) | |
| new_intervals = [fixed_interval for _ in range(uniq_num - 1)] | |
| pos = list(itertools.accumulate([new_pos_ref] + new_intervals)) | |
| is_video = False | |
| if revisit_random < 0.5 or video_prob == 1.0: # revisit, video / collection | |
| is_video = video_random < video_prob | |
| pos = ( | |
| self.blockwise_shuffle(pos, rng, block_shuffle) | |
| if not is_video | |
| else pos | |
| ) | |
| num_full_repeat = num_views // uniq_num | |
| pos = ( | |
| pos * num_full_repeat | |
| + pos[: num_views - len(pos) * num_full_repeat] | |
| ) | |
| elif revisit_random < 0.9: # random | |
| pos = rng.choice(pos, num_views, replace=True) | |
| else: # ordered | |
| pos = sorted(rng.choice(pos, num_views, replace=True)) | |
| assert len(pos) == num_views | |
| return pos, is_video | |
| def get_img_and_ray_masks(self, is_metric, v, rng, p=[0.8, 0.15, 0.05]): | |
| # generate img mask and raymap mask | |
| if v == 0 or (not is_metric): | |
| img_mask = True | |
| raymap_mask = False | |
| else: | |
| rand_val = rng.random() | |
| if rand_val < p[0]: | |
| img_mask = True | |
| raymap_mask = False | |
| elif rand_val < p[0] + p[1]: | |
| img_mask = False | |
| raymap_mask = True | |
| else: | |
| img_mask = True | |
| raymap_mask = True | |
| return img_mask, raymap_mask | |
| def get_stats(self): | |
| return f"{len(self)} groups of views" | |
| def __repr__(self): | |
| resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]" | |
| return ( | |
| f"""{type(self).__name__}({self.get_stats()}, | |
| {self.num_views=}, | |
| {self.split=}, | |
| {self.seed=}, | |
| resolutions={resolutions_str}, | |
| {self.transform=})""".replace( | |
| "self.", "" | |
| ) | |
| .replace("\n", "") | |
| .replace(" ", "") | |
| ) | |
| def _get_views(self, idx, resolution, rng, num_views): | |
| raise NotImplementedError() | |
| def __getitem__(self, idx): | |
| # print("Receiving:" , idx) | |
| if isinstance(idx, (tuple, list, np.ndarray)): | |
| # the idx is specifying the aspect-ratio | |
| idx, ar_idx, nview = idx | |
| else: | |
| assert len(self._resolutions) == 1 | |
| ar_idx = 0 | |
| nview = self.num_views | |
| assert nview >= 1 and nview <= self.num_views | |
| # set-up the rng | |
| if self.seed: # reseed for each __getitem__ | |
| self._rng = np.random.default_rng(seed=self.seed + idx) | |
| elif not hasattr(self, "_rng"): | |
| seed = torch.randint(0, 2**32, (1,)).item() | |
| self._rng = np.random.default_rng(seed=seed) | |
| if self.aug_crop > 1 and self.seq_aug_crop: | |
| self.delta_target_resolution = self._rng.integers(0, self.aug_crop) | |
| # over-loaded code | |
| resolution = self._resolutions[ | |
| ar_idx | |
| ] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler) | |
| views = self._get_views(idx, resolution, self._rng, nview) | |
| assert len(views) == nview | |
| if "camera_pose" not in views[0]: | |
| views[0]["camera_pose"] = np.ones((4, 4), dtype=np.float32) | |
| first_view_camera_pose = views[0]["camera_pose"] | |
| transform = SeqColorJitter() if self.is_seq_color_jitter else self.transform | |
| for v, view in enumerate(views): | |
| assert ( | |
| "pts3d" not in view | |
| ), f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}" | |
| view["idx"] = (idx, ar_idx, v) | |
| # encode the image | |
| width, height = view["img"].size | |
| view["true_shape"] = np.int32((height, width)) | |
| view["img"] = transform(view["img"]) | |
| view["sky_mask"] = view["depthmap"] < 0 | |
| assert "camera_intrinsics" in view | |
| if "camera_pose" not in view: | |
| view["camera_pose"] = np.full((4, 4), np.nan, dtype=np.float32) | |
| else: | |
| assert np.isfinite( | |
| view["camera_pose"] | |
| ).all(), f"NaN in camera pose for view {view_name(view)}" | |
| ray_map = get_ray_map( | |
| first_view_camera_pose, | |
| view["camera_pose"], | |
| view["camera_intrinsics"], | |
| height, | |
| width, | |
| ) | |
| view["ray_map"] = ray_map.astype(np.float32) | |
| assert "pts3d" not in view | |
| assert "valid_mask" not in view | |
| assert np.isfinite( | |
| view["depthmap"] | |
| ).all(), f"NaN in depthmap for view {view_name(view)}" | |
| pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view) | |
| view["pts3d"] = pts3d | |
| view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1) | |
| # check all datatypes | |
| for key, val in view.items(): | |
| res, err_msg = is_good_type(key, val) | |
| assert res, f"{err_msg} with {key}={val} for view {view_name(view)}" | |
| K = view["camera_intrinsics"] | |
| if self.n_corres > 0: | |
| ref_view = views[0] | |
| for view in views: | |
| corres1, corres2, valid = extract_correspondences_from_pts3d( | |
| ref_view, view, self.n_corres, self._rng, nneg=self.nneg | |
| ) | |
| view["corres"] = (corres1, corres2) | |
| view["valid_corres"] = valid | |
| # last thing done! | |
| for view in views: | |
| view["rng"] = int.from_bytes(self._rng.bytes(4), "big") | |
| return views | |
| def _set_resolutions(self, resolutions): | |
| assert resolutions is not None, "undefined resolution" | |
| if not isinstance(resolutions, list): | |
| resolutions = [resolutions] | |
| self._resolutions = [] | |
| for resolution in resolutions: | |
| if isinstance(resolution, int): | |
| width = height = resolution | |
| else: | |
| width, height = resolution | |
| assert isinstance( | |
| width, int | |
| ), f"Bad type for {width=} {type(width)=}, should be int" | |
| assert isinstance( | |
| height, int | |
| ), f"Bad type for {height=} {type(height)=}, should be int" | |
| self._resolutions.append((width, height)) | |
| def _crop_resize_if_necessary( | |
| self, image, depthmap, intrinsics, resolution, rng=None, info=None | |
| ): | |
| """This function: | |
| - first downsizes the image with LANCZOS inteprolation, | |
| which is better than bilinear interpolation in | |
| """ | |
| if not isinstance(image, PIL.Image.Image): | |
| image = PIL.Image.fromarray(image) | |
| # downscale with lanczos interpolation so that image.size == resolution | |
| # cropping centered on the principal point | |
| W, H = image.size | |
| cx, cy = intrinsics[:2, 2].round().astype(int) | |
| min_margin_x = min(cx, W - cx) | |
| min_margin_y = min(cy, H - cy) | |
| assert min_margin_x > W / 5, f"Bad principal point in view={info}" | |
| assert min_margin_y > H / 5, f"Bad principal point in view={info}" | |
| # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) | |
| l, t = cx - min_margin_x, cy - min_margin_y | |
| r, b = cx + min_margin_x, cy + min_margin_y | |
| crop_bbox = (l, t, r, b) | |
| image, depthmap, intrinsics = cropping.crop_image_depthmap( | |
| image, depthmap, intrinsics, crop_bbox | |
| ) | |
| # transpose the resolution if necessary | |
| W, H = image.size # new size | |
| # high-quality Lanczos down-scaling | |
| target_resolution = np.array(resolution) | |
| if self.aug_crop > 1: | |
| target_resolution += ( | |
| rng.integers(0, self.aug_crop) | |
| if not self.seq_aug_crop | |
| else self.delta_target_resolution | |
| ) | |
| image, depthmap, intrinsics = cropping.rescale_image_depthmap( | |
| image, depthmap, intrinsics, target_resolution | |
| ) | |
| # actual cropping (if necessary) with bilinear interpolation | |
| intrinsics2 = cropping.camera_matrix_of_crop( | |
| intrinsics, image.size, resolution, offset_factor=0.5 | |
| ) | |
| crop_bbox = cropping.bbox_from_intrinsics_in_out( | |
| intrinsics, intrinsics2, resolution | |
| ) | |
| image, depthmap, intrinsics2 = cropping.crop_image_depthmap( | |
| image, depthmap, intrinsics, crop_bbox | |
| ) | |
| return image, depthmap, intrinsics2 | |
| def is_good_type(key, v): | |
| """returns (is_good, err_msg)""" | |
| if isinstance(v, (str, int, tuple)): | |
| return True, None | |
| if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8): | |
| return False, f"bad {v.dtype=}" | |
| return True, None | |
| def view_name(view, batch_index=None): | |
| def sel(x): | |
| return x[batch_index] if batch_index not in (None, slice(None)) else x | |
| db = sel(view["dataset"]) | |
| label = sel(view["label"]) | |
| instance = sel(view["instance"]) | |
| return f"{db}/{label}/{instance}" | |
| def transpose_to_landscape(view): | |
| height, width = view["true_shape"] | |
| if width < height: | |
| # rectify portrait to landscape | |
| assert view["img"].shape == (3, height, width) | |
| view["img"] = view["img"].swapaxes(1, 2) | |
| assert view["valid_mask"].shape == (height, width) | |
| view["valid_mask"] = view["valid_mask"].swapaxes(0, 1) | |
| assert view["depthmap"].shape == (height, width) | |
| view["depthmap"] = view["depthmap"].swapaxes(0, 1) | |
| assert view["pts3d"].shape == (height, width, 3) | |
| view["pts3d"] = view["pts3d"].swapaxes(0, 1) | |
| # transpose x and y pixels | |
| view["camera_intrinsics"] = view["camera_intrinsics"][[1, 0, 2]] | |
| assert view["ray_map"].shape == (height, width, 6) | |
| view["ray_map"] = view["ray_map"].swapaxes(0, 1) | |
| assert view["sky_mask"].shape == (height, width) | |
| view["sky_mask"] = view["sky_mask"].swapaxes(0, 1) | |
| if "corres" in view: | |
| # transpose correspondences x and y | |
| view["corres"][0] = view["corres"][0][:, [1, 0]] | |
| view["corres"][1] = view["corres"][1][:, [1, 0]] | |