File size: 5,740 Bytes
9507532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""
MapAnything Datasets
"""

import torch

from mapanything.datasets.wai.ase import ASEWAI  # noqa
from mapanything.datasets.wai.bedlam import BedlamWAI  # noqa
from mapanything.datasets.wai.blendedmvs import BlendedMVSWAI  # noqa
from mapanything.datasets.wai.dl3dv import DL3DVWAI  # noqa
from mapanything.datasets.wai.dtu import DTUWAI  # noqa
from mapanything.datasets.wai.dynamicreplica import DynamicReplicaWAI  # noqa
from mapanything.datasets.wai.eth3d import ETH3DWAI  # noqa
from mapanything.datasets.wai.gta_sfm import GTASfMWAI  # noqa
from mapanything.datasets.wai.matrixcity import MatrixCityWAI  # noqa
from mapanything.datasets.wai.megadepth import MegaDepthWAI  # noqa
from mapanything.datasets.wai.mpsd import MPSDWAI  # noqa
from mapanything.datasets.wai.mvs_synth import MVSSynthWAI  # noqa
from mapanything.datasets.wai.paralleldomain4d import ParallelDomain4DWAI  # noqa
from mapanything.datasets.wai.sailvos3d import SAILVOS3DWAI  # noqa
from mapanything.datasets.wai.scannetpp import ScanNetPPWAI  # noqa
from mapanything.datasets.wai.spring import SpringWAI  # noqa
from mapanything.datasets.wai.structured3d import Structured3DWAI  # noqa
from mapanything.datasets.wai.tav2_wb import TartanAirV2WBWAI  # noqa
from mapanything.datasets.wai.unrealstereo4k import UnrealStereo4KWAI  # noqa
from mapanything.datasets.wai.xrooms import XRoomsWAI  # noqa
from mapanything.utils.train_tools import get_rank, get_world_size


def get_test_data_loader(
    dataset, batch_size, num_workers=8, shuffle=False, drop_last=False, pin_mem=True
):
    "Get simple PyTorch dataloader corresponding to the testing dataset"
    # PyTorch dataset
    if isinstance(dataset, str):
        dataset = eval(dataset)

    world_size = get_world_size()
    rank = get_rank()

    if torch.distributed.is_initialized():
        sampler = torch.utils.data.DistributedSampler(
            dataset,
            num_replicas=world_size,
            rank=rank,
            shuffle=shuffle,
            drop_last=drop_last,
        )
    elif shuffle:
        sampler = torch.utils.data.RandomSampler(dataset)
    else:
        sampler = torch.utils.data.SequentialSampler(dataset)

    data_loader = torch.utils.data.DataLoader(
        dataset,
        sampler=sampler,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_mem,
        drop_last=drop_last,
    )

    return data_loader


def get_test_many_ar_data_loader(
    dataset, batch_size, num_workers=8, drop_last=False, pin_mem=True
):
    "Get PyTorch dataloader corresponding to the testing dataset that supports many aspect ratios"
    # PyTorch dataset
    if isinstance(dataset, str):
        dataset = eval(dataset)

    world_size = get_world_size()
    rank = get_rank()

    # Get BatchedMultiFeatureRandomSampler
    sampler = dataset.make_sampler(
        batch_size,
        shuffle=True,
        world_size=world_size,
        rank=rank,
        drop_last=drop_last,
        use_dynamic_sampler=False,
    )

    # Init the data laoder
    data_loader = torch.utils.data.DataLoader(
        dataset,
        sampler=sampler,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_mem,
        drop_last=drop_last,
    )

    return data_loader


class DynamicBatchDatasetWrapper:
    """
    Wrapper dataset that handles DynamicBatchedMultiFeatureRandomSampler output.

    The dynamic sampler returns batches (lists of tuples) instead of individual samples.
    This wrapper ensures that the underlying dataset's __getitem__ method gets called
    with individual tuples as expected.
    """

    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, batch_indices):
        """
        Handle batch of indices from DynamicBatchedMultiFeatureRandomSampler.

        Args:
            batch_indices: List of tuples like [(sample_idx, feat_idx_1, feat_idx_2, ...), ...]

        Returns:
            List of samples from the underlying dataset
        """
        if isinstance(batch_indices, (list, tuple)) and len(batch_indices) > 0:
            # If it's a batch (list of tuples), process each item
            if isinstance(batch_indices[0], (list, tuple)):
                return [self.dataset[idx] for idx in batch_indices]
            else:
                # Single tuple, call dataset directly
                return self.dataset[batch_indices]
        else:
            # Fallback for single index
            return self.dataset[batch_indices]

    def __len__(self):
        return len(self.dataset)

    def __getattr__(self, name):
        # Delegate all other attributes to the wrapped dataset
        return getattr(self.dataset, name)


def get_train_data_loader(
    dataset,
    max_num_of_imgs_per_gpu,
    num_workers=8,
    shuffle=True,
    drop_last=True,
    pin_mem=True,
):
    "Dynamic PyTorch dataloader corresponding to the training dataset"
    # PyTorch dataset
    if isinstance(dataset, str):
        dataset = eval(dataset)

    world_size = get_world_size()
    rank = get_rank()

    # Get DynamicBatchedMultiFeatureRandomSampler
    batch_sampler = dataset.make_sampler(
        shuffle=shuffle,
        world_size=world_size,
        rank=rank,
        drop_last=drop_last,
        max_num_of_images_per_gpu=max_num_of_imgs_per_gpu,
        use_dynamic_sampler=True,
    )

    # Wrap the dataset to handle batch format from dynamic sampler
    wrapped_dataset = DynamicBatchDatasetWrapper(dataset)

    # Init the dynamic data loader
    data_loader = torch.utils.data.DataLoader(
        wrapped_dataset,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        pin_memory=pin_mem,
    )

    return data_loader