Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| # | |
| # -------------------------------------------------------- | |
| # modified from DUSt3R | |
| import numpy as np | |
| from dust3r.datasets.base.batched_sampler import ( | |
| BatchedRandomSampler, | |
| CustomRandomSampler, | |
| ) | |
| import torch | |
| class EasyDataset: | |
| """a dataset that you can easily resize and combine. | |
| Examples: | |
| --------- | |
| 2 * dataset ==> duplicate each element 2x | |
| 10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary) | |
| dataset1 + dataset2 ==> concatenate datasets | |
| """ | |
| def __add__(self, other): | |
| return CatDataset([self, other]) | |
| def __rmul__(self, factor): | |
| return MulDataset(factor, self) | |
| def __rmatmul__(self, factor): | |
| return ResizedDataset(factor, self) | |
| def set_epoch(self, epoch): | |
| pass # nothing to do by default | |
| def make_sampler( | |
| self, batch_size, shuffle=True, drop_last=True, world_size=1, rank=0, fixed_length=False | |
| ): | |
| if not (shuffle): | |
| raise NotImplementedError() # cannot deal yet | |
| num_of_aspect_ratios = len(self._resolutions) | |
| num_of_views = self.num_views | |
| sampler = CustomRandomSampler( | |
| self, | |
| batch_size, | |
| num_of_aspect_ratios, | |
| 4 if not fixed_length else num_of_views, | |
| num_of_views, | |
| world_size, | |
| warmup=1, | |
| drop_last=drop_last, | |
| ) | |
| return BatchedRandomSampler(sampler, batch_size, drop_last) | |
| class MulDataset(EasyDataset): | |
| """Artifically augmenting the size of a dataset.""" | |
| multiplicator: int | |
| def __init__(self, multiplicator, dataset): | |
| assert isinstance(multiplicator, int) and multiplicator > 0 | |
| self.multiplicator = multiplicator | |
| self.dataset = dataset | |
| def __len__(self): | |
| return self.multiplicator * len(self.dataset) | |
| def __repr__(self): | |
| return f"{self.multiplicator}*{repr(self.dataset)}" | |
| def __getitem__(self, idx): | |
| if isinstance(idx, tuple): | |
| idx, other, another = idx | |
| return self.dataset[idx // self.multiplicator, other, another] | |
| else: | |
| return self.dataset[idx // self.multiplicator] | |
| def _resolutions(self): | |
| return self.dataset._resolutions | |
| def num_views(self): | |
| return self.dataset.num_views | |
| class ResizedDataset(EasyDataset): | |
| """Artifically changing the size of a dataset.""" | |
| new_size: int | |
| def __init__(self, new_size, dataset): | |
| assert isinstance(new_size, int) and new_size > 0 | |
| self.new_size = new_size | |
| self.dataset = dataset | |
| def __len__(self): | |
| return self.new_size | |
| def __repr__(self): | |
| size_str = str(self.new_size) | |
| for i in range((len(size_str) - 1) // 3): | |
| sep = -4 * i - 3 | |
| size_str = size_str[:sep] + "_" + size_str[sep:] | |
| return f"{size_str} @ {repr(self.dataset)}" | |
| def set_epoch(self, epoch): | |
| # this random shuffle only depends on the epoch | |
| rng = np.random.default_rng(seed=epoch + 777) | |
| # shuffle all indices | |
| perm = rng.permutation(len(self.dataset)) | |
| # rotary extension until target size is met | |
| shuffled_idxs = np.concatenate( | |
| [perm] * (1 + (len(self) - 1) // len(self.dataset)) | |
| ) | |
| self._idxs_mapping = shuffled_idxs[: self.new_size] | |
| assert len(self._idxs_mapping) == self.new_size | |
| def __getitem__(self, idx): | |
| assert hasattr( | |
| self, "_idxs_mapping" | |
| ), "You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()" | |
| if isinstance(idx, tuple): | |
| idx, other, another = idx | |
| return self.dataset[self._idxs_mapping[idx], other, another] | |
| else: | |
| return self.dataset[self._idxs_mapping[idx]] | |
| def _resolutions(self): | |
| return self.dataset._resolutions | |
| def num_views(self): | |
| return self.dataset.num_views | |
| class CatDataset(EasyDataset): | |
| """Concatenation of several datasets""" | |
| def __init__(self, datasets): | |
| for dataset in datasets: | |
| assert isinstance(dataset, EasyDataset) | |
| self.datasets = datasets | |
| self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets]) | |
| def __len__(self): | |
| return self._cum_sizes[-1] | |
| def __repr__(self): | |
| # remove uselessly long transform | |
| return " + ".join( | |
| repr(dataset).replace( | |
| ",transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))", | |
| "", | |
| ) | |
| for dataset in self.datasets | |
| ) | |
| def set_epoch(self, epoch): | |
| for dataset in self.datasets: | |
| dataset.set_epoch(epoch) | |
| def __getitem__(self, idx): | |
| other = None | |
| if isinstance(idx, tuple): | |
| idx, other, another = idx | |
| if not (0 <= idx < len(self)): | |
| raise IndexError() | |
| db_idx = np.searchsorted(self._cum_sizes, idx, "right") | |
| dataset = self.datasets[db_idx] | |
| new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0) | |
| if other is not None and another is not None: | |
| new_idx = (new_idx, other, another) | |
| return dataset[new_idx] | |
| def _resolutions(self): | |
| resolutions = self.datasets[0]._resolutions | |
| for dataset in self.datasets[1:]: | |
| assert tuple(dataset._resolutions) == tuple(resolutions) | |
| return resolutions | |
| def num_views(self): | |
| num_views = self.datasets[0].num_views | |
| for dataset in self.datasets[1:]: | |
| assert dataset.num_views == num_views | |
| return num_views | |