Spaces:
Running
on
Zero
Running
on
Zero
| import os.path as osp | |
| import numpy as np | |
| import os | |
| import sys | |
| sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) | |
| from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset | |
| from dust3r.utils.image import imread_cv2 | |
| import h5py | |
| from tqdm import tqdm | |
| class BlendedMVS_Multi(BaseMultiViewDataset): | |
| """Dataset of outdoor street scenes, 5 images each time""" | |
| def __init__(self, *args, ROOT, split=None, **kwargs): | |
| self.ROOT = ROOT | |
| self.video = False | |
| self.is_metric = False | |
| super().__init__(*args, **kwargs) | |
| # assert split is None | |
| self._load_data() | |
| def _load_data(self): | |
| self.data_dict = self.read_h5_file(os.path.join(self.ROOT, "new_overlap.h5")) | |
| self.num_imgs = sum( | |
| [len(self.data_dict[s]["basenames"]) for s in self.data_dict.keys()] | |
| ) | |
| self.num_scenes = len(self.data_dict.keys()) | |
| self.invalid_scenes = [] | |
| self.is_reachable_cache = {scene: {} for scene in self.data_dict.keys()} | |
| def read_h5_file(self, h5_file_path): | |
| data_dict = {} | |
| self.all_ref_imgs = [] | |
| with h5py.File(h5_file_path, "r") as f: | |
| for scene_dir in tqdm(f.keys()): | |
| group = f[scene_dir] | |
| basenames = group["basenames"][:] | |
| indices = group["indices"][:] | |
| values = group["values"][:] | |
| shape = group.attrs["shape"] | |
| # Reconstruct the sparse matrix | |
| score_matrix = np.zeros(shape, dtype=np.float32) | |
| score_matrix[indices[0], indices[1]] = values | |
| data_dict[scene_dir] = { | |
| "basenames": basenames, | |
| "score_matrix": self.build_adjacency_list(score_matrix), | |
| } | |
| self.all_ref_imgs.extend( | |
| [(scene_dir, b) for b in range(len(basenames))] | |
| ) | |
| return data_dict | |
| def build_adjacency_list(S, thresh=0.2): | |
| adjacency_list = [[] for _ in range(len(S))] | |
| S = S - thresh | |
| S[S < 0] = 0 | |
| rows, cols = np.nonzero(S) | |
| for i, j in zip(rows, cols): | |
| adjacency_list[i].append((j, S[i][j])) | |
| return adjacency_list | |
| def is_reachable(adjacency_list, start_index, k): | |
| visited = set() | |
| stack = [start_index] | |
| while stack and len(visited) < k: | |
| node = stack.pop() | |
| if node not in visited: | |
| visited.add(node) | |
| for neighbor in adjacency_list[node]: | |
| if neighbor[0] not in visited: | |
| stack.append(neighbor[0]) | |
| return len(visited) >= k | |
| def random_sequence_no_revisit_with_backtracking( | |
| adjacency_list, k, start_index, rng: np.random.Generator | |
| ): | |
| path = [start_index] | |
| visited = set([start_index]) | |
| neighbor_iterators = [] | |
| # Initialize the iterator for the start index | |
| neighbors = adjacency_list[start_index] | |
| neighbor_idxs = [n[0] for n in neighbors] | |
| neighbor_weights = [n[1] for n in neighbors] | |
| neighbor_idxs = rng.choice( | |
| neighbor_idxs, | |
| size=len(neighbor_idxs), | |
| replace=False, | |
| p=np.array(neighbor_weights) / np.sum(neighbor_weights), | |
| ).tolist() | |
| neighbor_iterators.append(iter(neighbor_idxs)) | |
| while len(path) < k: | |
| if not neighbor_iterators: | |
| # No possible sequence | |
| return None | |
| current_iterator = neighbor_iterators[-1] | |
| try: | |
| next_index = next(current_iterator) | |
| if next_index not in visited: | |
| path.append(next_index) | |
| visited.add(next_index) | |
| # Prepare iterator for the next node | |
| neighbors = adjacency_list[next_index] | |
| neighbor_idxs = [n[0] for n in neighbors] | |
| neighbor_weights = [n[1] for n in neighbors] | |
| neighbor_idxs = rng.choice( | |
| neighbor_idxs, | |
| size=len(neighbor_idxs), | |
| replace=False, | |
| p=np.array(neighbor_weights) / np.sum(neighbor_weights), | |
| ).tolist() | |
| neighbor_iterators.append(iter(neighbor_idxs)) | |
| except StopIteration: | |
| # No more neighbors to try at this node, backtrack | |
| neighbor_iterators.pop() | |
| visited.remove(path.pop()) | |
| return path | |
| def random_sequence_with_optional_repeats( | |
| adjacency_list, | |
| k, | |
| start_index, | |
| rng: np.random.Generator, | |
| max_k=None, | |
| max_attempts=100, | |
| ): | |
| if max_k is None: | |
| max_k = k | |
| path = [start_index] | |
| visited = set([start_index]) | |
| current_index = start_index | |
| attempts = 0 | |
| while len(path) < max_k and attempts < max_attempts: | |
| attempts += 1 | |
| neighbors = adjacency_list[current_index] | |
| neighbor_idxs = [n[0] for n in neighbors] | |
| neighbor_weights = [n[1] for n in neighbors] | |
| if not neighbor_idxs: | |
| # No neighbors, cannot proceed further | |
| break | |
| # Try to find unvisited neighbors | |
| unvisited_neighbors = [ | |
| (idx, wgt) | |
| for idx, wgt in zip(neighbor_idxs, neighbor_weights) | |
| if idx not in visited | |
| ] | |
| if unvisited_neighbors: | |
| # Select among unvisited neighbors | |
| unvisited_idxs = [idx for idx, _ in unvisited_neighbors] | |
| unvisited_weights = [wgt for _, wgt in unvisited_neighbors] | |
| probabilities = np.array(unvisited_weights) / np.sum(unvisited_weights) | |
| next_index = rng.choice(unvisited_idxs, p=probabilities) | |
| visited.add(next_index) | |
| else: | |
| # All neighbors visited, but we need to reach length max_k | |
| # So we can revisit nodes | |
| probabilities = np.array(neighbor_weights) / np.sum(neighbor_weights) | |
| next_index = rng.choice(neighbor_idxs, p=probabilities) | |
| path.append(next_index) | |
| current_index = next_index | |
| if len(set(path)) >= k: | |
| # If path is shorter than max_k, extend it by repeating existing elements | |
| while len(path) < max_k: | |
| # Randomly select nodes from the existing path to repeat | |
| next_index = rng.choice(path) | |
| path.append(next_index) | |
| return path | |
| else: | |
| # Could not reach k unique nodes | |
| return None | |
| def __len__(self): | |
| return len(self.all_ref_imgs) | |
| def get_image_num(self): | |
| return self.num_imgs | |
| def get_stats(self): | |
| return f"{len(self)} imgs from {self.num_scenes} scenes" | |
| def generate_sequence( | |
| self, scene, adj_list, num_views, start_index, rng, allow_repeat=False | |
| ): | |
| cutoff = num_views if not allow_repeat else max(num_views // 5, 3) | |
| if start_index in self.is_reachable_cache[scene]: | |
| if not self.is_reachable_cache[scene][start_index]: | |
| print( | |
| f"Cannot reach {num_views} unique elements from index {start_index}." | |
| ) | |
| return None | |
| else: | |
| self.is_reachable_cache[scene][start_index] = self.is_reachable( | |
| adj_list, start_index, cutoff | |
| ) | |
| if not self.is_reachable_cache[scene][start_index]: | |
| print( | |
| f"Cannot reach {num_views} unique elements from index {start_index}." | |
| ) | |
| return None | |
| if not allow_repeat: | |
| sequence = self.random_sequence_no_revisit_with_backtracking( | |
| adj_list, cutoff, start_index, rng | |
| ) | |
| else: | |
| sequence = self.random_sequence_with_optional_repeats( | |
| adj_list, cutoff, start_index, rng, max_k=num_views | |
| ) | |
| if not sequence: | |
| self.is_reachable_cache[scene][start_index] = False | |
| print("Failed to generate a sequence without revisiting.") | |
| return sequence | |
| def _get_views(self, idx, resolution, rng: np.random.Generator, num_views): | |
| scene_info, ref_img_idx = self.all_ref_imgs[idx] | |
| invalid_seq = True | |
| ordered_video = False | |
| while invalid_seq: | |
| basenames = self.data_dict[scene_info]["basenames"] | |
| if ( | |
| sum( | |
| [ | |
| (1 - int(x)) | |
| for x in list(self.is_reachable_cache[scene_info].values()) | |
| ] | |
| ) | |
| > len(basenames) - self.num_views | |
| ): | |
| self.invalid_scenes.append(scene_info) | |
| while scene_info in self.invalid_scenes: | |
| idx = rng.integers(low=0, high=len(self.all_ref_imgs)) | |
| scene_info, ref_img_idx = self.all_ref_imgs[idx] | |
| basenames = self.data_dict[scene_info]["basenames"] | |
| score_matrix = self.data_dict[scene_info]["score_matrix"] | |
| imgs_idxs = self.generate_sequence( | |
| scene_info, score_matrix, num_views, ref_img_idx, rng, self.allow_repeat | |
| ) | |
| if imgs_idxs is None: | |
| random_direction = 2 * rng.choice(2) - 1 | |
| for offset in range(1, len(basenames)): | |
| tentative_im_idx = ( | |
| ref_img_idx + (random_direction * offset) | |
| ) % len(basenames) | |
| if ( | |
| tentative_im_idx not in self.is_reachable_cache[scene_info] | |
| or self.is_reachable_cache[scene_info][tentative_im_idx] | |
| ): | |
| ref_img_idx = tentative_im_idx | |
| break | |
| else: | |
| invalid_seq = False | |
| views = [] | |
| for view_idx in imgs_idxs: | |
| scene_dir = osp.join(self.ROOT, scene_info) | |
| impath = basenames[view_idx].decode("utf-8") | |
| image = imread_cv2(osp.join(scene_dir, impath + ".jpg")) | |
| depthmap = imread_cv2(osp.join(scene_dir, impath + ".exr")) | |
| camera_params = np.load(osp.join(scene_dir, impath + ".npz")) | |
| intrinsics = np.float32(camera_params["intrinsics"]) | |
| camera_pose = np.eye(4, dtype=np.float32) | |
| camera_pose[:3, :3] = camera_params["R_cam2world"] | |
| camera_pose[:3, 3] = camera_params["t_cam2world"] | |
| image, depthmap, intrinsics = self._crop_resize_if_necessary( | |
| image, depthmap, intrinsics, resolution, rng, info=(scene_dir, impath) | |
| ) | |
| views.append( | |
| dict( | |
| img=image, | |
| depthmap=depthmap, | |
| camera_pose=camera_pose, # cam2world | |
| camera_intrinsics=intrinsics, | |
| dataset="BlendedMVS", | |
| label=osp.relpath(scene_dir, self.ROOT), | |
| is_metric=self.is_metric, | |
| is_video=ordered_video, | |
| instance=osp.join(scene_dir, impath + ".jpg"), | |
| quantile=np.array(0.97, dtype=np.float32), | |
| img_mask=True, | |
| ray_mask=False, | |
| camera_only=False, | |
| depth_only=False, | |
| single_view=False, | |
| reset=False, | |
| ) | |
| ) | |
| assert len(views) == num_views | |
| return views | |