aknapitsch user
initial commit of map anything demo
9507532
"""
MapAnything Datasets
"""
import torch
from mapanything.datasets.wai.ase import ASEWAI # noqa
from mapanything.datasets.wai.bedlam import BedlamWAI # noqa
from mapanything.datasets.wai.blendedmvs import BlendedMVSWAI # noqa
from mapanything.datasets.wai.dl3dv import DL3DVWAI # noqa
from mapanything.datasets.wai.dtu import DTUWAI # noqa
from mapanything.datasets.wai.dynamicreplica import DynamicReplicaWAI # noqa
from mapanything.datasets.wai.eth3d import ETH3DWAI # noqa
from mapanything.datasets.wai.gta_sfm import GTASfMWAI # noqa
from mapanything.datasets.wai.matrixcity import MatrixCityWAI # noqa
from mapanything.datasets.wai.megadepth import MegaDepthWAI # noqa
from mapanything.datasets.wai.mpsd import MPSDWAI # noqa
from mapanything.datasets.wai.mvs_synth import MVSSynthWAI # noqa
from mapanything.datasets.wai.paralleldomain4d import ParallelDomain4DWAI # noqa
from mapanything.datasets.wai.sailvos3d import SAILVOS3DWAI # noqa
from mapanything.datasets.wai.scannetpp import ScanNetPPWAI # noqa
from mapanything.datasets.wai.spring import SpringWAI # noqa
from mapanything.datasets.wai.structured3d import Structured3DWAI # noqa
from mapanything.datasets.wai.tav2_wb import TartanAirV2WBWAI # noqa
from mapanything.datasets.wai.unrealstereo4k import UnrealStereo4KWAI # noqa
from mapanything.datasets.wai.xrooms import XRoomsWAI # noqa
from mapanything.utils.train_tools import get_rank, get_world_size
def get_test_data_loader(
dataset, batch_size, num_workers=8, shuffle=False, drop_last=False, pin_mem=True
):
"Get simple PyTorch dataloader corresponding to the testing dataset"
# PyTorch dataset
if isinstance(dataset, str):
dataset = eval(dataset)
world_size = get_world_size()
rank = get_rank()
if torch.distributed.is_initialized():
sampler = torch.utils.data.DistributedSampler(
dataset,
num_replicas=world_size,
rank=rank,
shuffle=shuffle,
drop_last=drop_last,
)
elif shuffle:
sampler = torch.utils.data.RandomSampler(dataset)
else:
sampler = torch.utils.data.SequentialSampler(dataset)
data_loader = torch.utils.data.DataLoader(
dataset,
sampler=sampler,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_mem,
drop_last=drop_last,
)
return data_loader
def get_test_many_ar_data_loader(
dataset, batch_size, num_workers=8, drop_last=False, pin_mem=True
):
"Get PyTorch dataloader corresponding to the testing dataset that supports many aspect ratios"
# PyTorch dataset
if isinstance(dataset, str):
dataset = eval(dataset)
world_size = get_world_size()
rank = get_rank()
# Get BatchedMultiFeatureRandomSampler
sampler = dataset.make_sampler(
batch_size,
shuffle=True,
world_size=world_size,
rank=rank,
drop_last=drop_last,
use_dynamic_sampler=False,
)
# Init the data laoder
data_loader = torch.utils.data.DataLoader(
dataset,
sampler=sampler,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_mem,
drop_last=drop_last,
)
return data_loader
class DynamicBatchDatasetWrapper:
"""
Wrapper dataset that handles DynamicBatchedMultiFeatureRandomSampler output.
The dynamic sampler returns batches (lists of tuples) instead of individual samples.
This wrapper ensures that the underlying dataset's __getitem__ method gets called
with individual tuples as expected.
"""
def __init__(self, dataset):
self.dataset = dataset
def __getitem__(self, batch_indices):
"""
Handle batch of indices from DynamicBatchedMultiFeatureRandomSampler.
Args:
batch_indices: List of tuples like [(sample_idx, feat_idx_1, feat_idx_2, ...), ...]
Returns:
List of samples from the underlying dataset
"""
if isinstance(batch_indices, (list, tuple)) and len(batch_indices) > 0:
# If it's a batch (list of tuples), process each item
if isinstance(batch_indices[0], (list, tuple)):
return [self.dataset[idx] for idx in batch_indices]
else:
# Single tuple, call dataset directly
return self.dataset[batch_indices]
else:
# Fallback for single index
return self.dataset[batch_indices]
def __len__(self):
return len(self.dataset)
def __getattr__(self, name):
# Delegate all other attributes to the wrapped dataset
return getattr(self.dataset, name)
def get_train_data_loader(
dataset,
max_num_of_imgs_per_gpu,
num_workers=8,
shuffle=True,
drop_last=True,
pin_mem=True,
):
"Dynamic PyTorch dataloader corresponding to the training dataset"
# PyTorch dataset
if isinstance(dataset, str):
dataset = eval(dataset)
world_size = get_world_size()
rank = get_rank()
# Get DynamicBatchedMultiFeatureRandomSampler
batch_sampler = dataset.make_sampler(
shuffle=shuffle,
world_size=world_size,
rank=rank,
drop_last=drop_last,
max_num_of_images_per_gpu=max_num_of_imgs_per_gpu,
use_dynamic_sampler=True,
)
# Wrap the dataset to handle batch format from dynamic sampler
wrapped_dataset = DynamicBatchDatasetWrapper(dataset)
# Init the dynamic data loader
data_loader = torch.utils.data.DataLoader(
wrapped_dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=pin_mem,
)
return data_loader