Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| MapAnything Datasets | |
| """ | |
| import torch | |
| from mapanything.datasets.wai.ase import ASEWAI # noqa | |
| from mapanything.datasets.wai.bedlam import BedlamWAI # noqa | |
| from mapanything.datasets.wai.blendedmvs import BlendedMVSWAI # noqa | |
| from mapanything.datasets.wai.dl3dv import DL3DVWAI # noqa | |
| from mapanything.datasets.wai.dtu import DTUWAI # noqa | |
| from mapanything.datasets.wai.dynamicreplica import DynamicReplicaWAI # noqa | |
| from mapanything.datasets.wai.eth3d import ETH3DWAI # noqa | |
| from mapanything.datasets.wai.gta_sfm import GTASfMWAI # noqa | |
| from mapanything.datasets.wai.matrixcity import MatrixCityWAI # noqa | |
| from mapanything.datasets.wai.megadepth import MegaDepthWAI # noqa | |
| from mapanything.datasets.wai.mpsd import MPSDWAI # noqa | |
| from mapanything.datasets.wai.mvs_synth import MVSSynthWAI # noqa | |
| from mapanything.datasets.wai.paralleldomain4d import ParallelDomain4DWAI # noqa | |
| from mapanything.datasets.wai.sailvos3d import SAILVOS3DWAI # noqa | |
| from mapanything.datasets.wai.scannetpp import ScanNetPPWAI # noqa | |
| from mapanything.datasets.wai.spring import SpringWAI # noqa | |
| from mapanything.datasets.wai.structured3d import Structured3DWAI # noqa | |
| from mapanything.datasets.wai.tav2_wb import TartanAirV2WBWAI # noqa | |
| from mapanything.datasets.wai.unrealstereo4k import UnrealStereo4KWAI # noqa | |
| from mapanything.datasets.wai.xrooms import XRoomsWAI # noqa | |
| from mapanything.utils.train_tools import get_rank, get_world_size | |
| def get_test_data_loader( | |
| dataset, batch_size, num_workers=8, shuffle=False, drop_last=False, pin_mem=True | |
| ): | |
| "Get simple PyTorch dataloader corresponding to the testing dataset" | |
| # PyTorch dataset | |
| if isinstance(dataset, str): | |
| dataset = eval(dataset) | |
| world_size = get_world_size() | |
| rank = get_rank() | |
| if torch.distributed.is_initialized(): | |
| sampler = torch.utils.data.DistributedSampler( | |
| dataset, | |
| num_replicas=world_size, | |
| rank=rank, | |
| shuffle=shuffle, | |
| drop_last=drop_last, | |
| ) | |
| elif shuffle: | |
| sampler = torch.utils.data.RandomSampler(dataset) | |
| else: | |
| sampler = torch.utils.data.SequentialSampler(dataset) | |
| data_loader = torch.utils.data.DataLoader( | |
| dataset, | |
| sampler=sampler, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| pin_memory=pin_mem, | |
| drop_last=drop_last, | |
| ) | |
| return data_loader | |
| def get_test_many_ar_data_loader( | |
| dataset, batch_size, num_workers=8, drop_last=False, pin_mem=True | |
| ): | |
| "Get PyTorch dataloader corresponding to the testing dataset that supports many aspect ratios" | |
| # PyTorch dataset | |
| if isinstance(dataset, str): | |
| dataset = eval(dataset) | |
| world_size = get_world_size() | |
| rank = get_rank() | |
| # Get BatchedMultiFeatureRandomSampler | |
| sampler = dataset.make_sampler( | |
| batch_size, | |
| shuffle=True, | |
| world_size=world_size, | |
| rank=rank, | |
| drop_last=drop_last, | |
| use_dynamic_sampler=False, | |
| ) | |
| # Init the data laoder | |
| data_loader = torch.utils.data.DataLoader( | |
| dataset, | |
| sampler=sampler, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| pin_memory=pin_mem, | |
| drop_last=drop_last, | |
| ) | |
| return data_loader | |
| class DynamicBatchDatasetWrapper: | |
| """ | |
| Wrapper dataset that handles DynamicBatchedMultiFeatureRandomSampler output. | |
| The dynamic sampler returns batches (lists of tuples) instead of individual samples. | |
| This wrapper ensures that the underlying dataset's __getitem__ method gets called | |
| with individual tuples as expected. | |
| """ | |
| def __init__(self, dataset): | |
| self.dataset = dataset | |
| def __getitem__(self, batch_indices): | |
| """ | |
| Handle batch of indices from DynamicBatchedMultiFeatureRandomSampler. | |
| Args: | |
| batch_indices: List of tuples like [(sample_idx, feat_idx_1, feat_idx_2, ...), ...] | |
| Returns: | |
| List of samples from the underlying dataset | |
| """ | |
| if isinstance(batch_indices, (list, tuple)) and len(batch_indices) > 0: | |
| # If it's a batch (list of tuples), process each item | |
| if isinstance(batch_indices[0], (list, tuple)): | |
| return [self.dataset[idx] for idx in batch_indices] | |
| else: | |
| # Single tuple, call dataset directly | |
| return self.dataset[batch_indices] | |
| else: | |
| # Fallback for single index | |
| return self.dataset[batch_indices] | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getattr__(self, name): | |
| # Delegate all other attributes to the wrapped dataset | |
| return getattr(self.dataset, name) | |
| def get_train_data_loader( | |
| dataset, | |
| max_num_of_imgs_per_gpu, | |
| num_workers=8, | |
| shuffle=True, | |
| drop_last=True, | |
| pin_mem=True, | |
| ): | |
| "Dynamic PyTorch dataloader corresponding to the training dataset" | |
| # PyTorch dataset | |
| if isinstance(dataset, str): | |
| dataset = eval(dataset) | |
| world_size = get_world_size() | |
| rank = get_rank() | |
| # Get DynamicBatchedMultiFeatureRandomSampler | |
| batch_sampler = dataset.make_sampler( | |
| shuffle=shuffle, | |
| world_size=world_size, | |
| rank=rank, | |
| drop_last=drop_last, | |
| max_num_of_images_per_gpu=max_num_of_imgs_per_gpu, | |
| use_dynamic_sampler=True, | |
| ) | |
| # Wrap the dataset to handle batch format from dynamic sampler | |
| wrapped_dataset = DynamicBatchDatasetWrapper(dataset) | |
| # Init the dynamic data loader | |
| data_loader = torch.utils.data.DataLoader( | |
| wrapped_dataset, | |
| batch_sampler=batch_sampler, | |
| num_workers=num_workers, | |
| pin_memory=pin_mem, | |
| ) | |
| return data_loader | |