Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,740 Bytes
9507532 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
"""
MapAnything Datasets
"""
import torch
from mapanything.datasets.wai.ase import ASEWAI # noqa
from mapanything.datasets.wai.bedlam import BedlamWAI # noqa
from mapanything.datasets.wai.blendedmvs import BlendedMVSWAI # noqa
from mapanything.datasets.wai.dl3dv import DL3DVWAI # noqa
from mapanything.datasets.wai.dtu import DTUWAI # noqa
from mapanything.datasets.wai.dynamicreplica import DynamicReplicaWAI # noqa
from mapanything.datasets.wai.eth3d import ETH3DWAI # noqa
from mapanything.datasets.wai.gta_sfm import GTASfMWAI # noqa
from mapanything.datasets.wai.matrixcity import MatrixCityWAI # noqa
from mapanything.datasets.wai.megadepth import MegaDepthWAI # noqa
from mapanything.datasets.wai.mpsd import MPSDWAI # noqa
from mapanything.datasets.wai.mvs_synth import MVSSynthWAI # noqa
from mapanything.datasets.wai.paralleldomain4d import ParallelDomain4DWAI # noqa
from mapanything.datasets.wai.sailvos3d import SAILVOS3DWAI # noqa
from mapanything.datasets.wai.scannetpp import ScanNetPPWAI # noqa
from mapanything.datasets.wai.spring import SpringWAI # noqa
from mapanything.datasets.wai.structured3d import Structured3DWAI # noqa
from mapanything.datasets.wai.tav2_wb import TartanAirV2WBWAI # noqa
from mapanything.datasets.wai.unrealstereo4k import UnrealStereo4KWAI # noqa
from mapanything.datasets.wai.xrooms import XRoomsWAI # noqa
from mapanything.utils.train_tools import get_rank, get_world_size
def get_test_data_loader(
dataset, batch_size, num_workers=8, shuffle=False, drop_last=False, pin_mem=True
):
"Get simple PyTorch dataloader corresponding to the testing dataset"
# PyTorch dataset
if isinstance(dataset, str):
dataset = eval(dataset)
world_size = get_world_size()
rank = get_rank()
if torch.distributed.is_initialized():
sampler = torch.utils.data.DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
shuffle=shuffle,
drop_last=drop_last,
)
elif shuffle:
sampler = torch.utils.data.RandomSampler(dataset)
else:
sampler = torch.utils.data.SequentialSampler(dataset)
data_loader = torch.utils.data.DataLoader(
dataset,
sampler=sampler,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_mem,
drop_last=drop_last,
)
return data_loader
def get_test_many_ar_data_loader(
dataset, batch_size, num_workers=8, drop_last=False, pin_mem=True
):
"Get PyTorch dataloader corresponding to the testing dataset that supports many aspect ratios"
# PyTorch dataset
if isinstance(dataset, str):
dataset = eval(dataset)
world_size = get_world_size()
rank = get_rank()
# Get BatchedMultiFeatureRandomSampler
sampler = dataset.make_sampler(
batch_size,
shuffle=True,
world_size=world_size,
rank=rank,
drop_last=drop_last,
use_dynamic_sampler=False,
)
# Init the data laoder
data_loader = torch.utils.data.DataLoader(
dataset,
sampler=sampler,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_mem,
drop_last=drop_last,
)
return data_loader
class DynamicBatchDatasetWrapper:
"""
Wrapper dataset that handles DynamicBatchedMultiFeatureRandomSampler output.
The dynamic sampler returns batches (lists of tuples) instead of individual samples.
This wrapper ensures that the underlying dataset's __getitem__ method gets called
with individual tuples as expected.
"""
def __init__(self, dataset):
self.dataset = dataset
def __getitem__(self, batch_indices):
"""
Handle batch of indices from DynamicBatchedMultiFeatureRandomSampler.
Args:
batch_indices: List of tuples like [(sample_idx, feat_idx_1, feat_idx_2, ...), ...]
Returns:
List of samples from the underlying dataset
"""
if isinstance(batch_indices, (list, tuple)) and len(batch_indices) > 0:
# If it's a batch (list of tuples), process each item
if isinstance(batch_indices[0], (list, tuple)):
return [self.dataset[idx] for idx in batch_indices]
else:
# Single tuple, call dataset directly
return self.dataset[batch_indices]
else:
# Fallback for single index
return self.dataset[batch_indices]
def __len__(self):
return len(self.dataset)
def __getattr__(self, name):
# Delegate all other attributes to the wrapped dataset
return getattr(self.dataset, name)
def get_train_data_loader(
dataset,
max_num_of_imgs_per_gpu,
num_workers=8,
shuffle=True,
drop_last=True,
pin_mem=True,
):
"Dynamic PyTorch dataloader corresponding to the training dataset"
# PyTorch dataset
if isinstance(dataset, str):
dataset = eval(dataset)
world_size = get_world_size()
rank = get_rank()
# Get DynamicBatchedMultiFeatureRandomSampler
batch_sampler = dataset.make_sampler(
shuffle=shuffle,
world_size=world_size,
rank=rank,
drop_last=drop_last,
max_num_of_images_per_gpu=max_num_of_imgs_per_gpu,
use_dynamic_sampler=True,
)
# Wrap the dataset to handle batch format from dynamic sampler
wrapped_dataset = DynamicBatchDatasetWrapper(dataset)
# Init the dynamic data loader
data_loader = torch.utils.data.DataLoader(
wrapped_dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=pin_mem,
)
return data_loader
|