Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import random | |
| import warnings | |
| from typing import Any, Mapping, Sequence | |
| import numpy as np | |
| import torch | |
| from torch.utils.data._utils.collate import \ | |
| default_collate as torch_default_collate | |
| from mmengine.registry import FUNCTIONS | |
| from mmengine.structures import BaseDataElement | |
| # FUNCTIONS is new in MMEngine v0.7.0. Reserve the `COLLATE_FUNCTIONS` to keep | |
| # the compatibility. | |
| COLLATE_FUNCTIONS = FUNCTIONS | |
| def worker_init_fn(worker_id: int, | |
| num_workers: int, | |
| rank: int, | |
| seed: int, | |
| disable_subprocess_warning: bool = False) -> None: | |
| """This function will be called on each worker subprocess after seeding and | |
| before data loading. | |
| Args: | |
| worker_id (int): Worker id in [0, num_workers - 1]. | |
| num_workers (int): How many subprocesses to use for data loading. | |
| rank (int): Rank of process in distributed environment. If in | |
| non-distributed environment, it is a constant number `0`. | |
| seed (int): Random seed. | |
| """ | |
| # The seed of each worker equals to | |
| # num_worker * rank + worker_id + user_seed | |
| worker_seed = num_workers * rank + worker_id + seed | |
| np.random.seed(worker_seed) | |
| random.seed(worker_seed) | |
| torch.manual_seed(worker_seed) | |
| if disable_subprocess_warning and worker_id != 0: | |
| warnings.simplefilter('ignore') | |
| def pseudo_collate(data_batch: Sequence) -> Any: | |
| """Convert list of data sampled from dataset into a batch of data, of which | |
| type consistent with the type of each data_itement in ``data_batch``. | |
| The default behavior of dataloader is to merge a list of samples to form | |
| a mini-batch of Tensor(s). However, in MMEngine, ``pseudo_collate`` | |
| will not stack tensors to batch tensors, and convert int, float, ndarray to | |
| tensors. | |
| This code is referenced from: | |
| `Pytorch default_collate <https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py>`_. | |
| Args: | |
| data_batch (Sequence): Batch of data from dataloader. | |
| Returns: | |
| Any: Transversed Data in the same format as the data_itement of | |
| ``data_batch``. | |
| """ # noqa: E501 | |
| data_item = data_batch[0] | |
| data_item_type = type(data_item) | |
| if isinstance(data_item, (str, bytes)): | |
| return data_batch | |
| elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'): | |
| # named tuple | |
| return data_item_type(*(pseudo_collate(samples) | |
| for samples in zip(*data_batch))) | |
| elif isinstance(data_item, Sequence): | |
| # check to make sure that the data_itements in batch have | |
| # consistent size | |
| it = iter(data_batch) | |
| data_item_size = len(next(it)) | |
| if not all(len(data_item) == data_item_size for data_item in it): | |
| raise RuntimeError( | |
| 'each data_itement in list of batch should be of equal size') | |
| transposed = list(zip(*data_batch)) | |
| if isinstance(data_item, tuple): | |
| return [pseudo_collate(samples) | |
| for samples in transposed] # Compat with Pytorch. | |
| else: | |
| try: | |
| return data_item_type( | |
| [pseudo_collate(samples) for samples in transposed]) | |
| except TypeError: | |
| # The sequence type may not support `__init__(iterable)` | |
| # (e.g., `range`). | |
| return [pseudo_collate(samples) for samples in transposed] | |
| elif isinstance(data_item, Mapping): | |
| return data_item_type({ | |
| key: pseudo_collate([d[key] for d in data_batch]) | |
| for key in data_item | |
| }) | |
| else: | |
| return data_batch | |
| def default_collate(data_batch: Sequence) -> Any: | |
| """Convert list of data sampled from dataset into a batch of data, of which | |
| type consistent with the type of each data_itement in ``data_batch``. | |
| Different from :func:`pseudo_collate`, ``default_collate`` will stack | |
| tensor contained in ``data_batch`` into a batched tensor with the | |
| first dimension batch size, and then move input tensor to the target | |
| device. | |
| Different from ``default_collate`` in pytorch, ``default_collate`` will | |
| not process ``BaseDataElement``. | |
| This code is referenced from: | |
| `Pytorch default_collate <https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py>`_. | |
| Note: | |
| ``default_collate`` only accept input tensor with the same shape. | |
| Args: | |
| data_batch (Sequence): Data sampled from dataset. | |
| Returns: | |
| Any: Data in the same format as the data_itement of ``data_batch``, of which | |
| tensors have been stacked, and ndarray, int, float have been | |
| converted to tensors. | |
| """ # noqa: E501 | |
| data_item = data_batch[0] | |
| data_item_type = type(data_item) | |
| if isinstance(data_item, (BaseDataElement, str, bytes)): | |
| return data_batch | |
| elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'): | |
| # named_tuple | |
| return data_item_type(*(default_collate(samples) | |
| for samples in zip(*data_batch))) | |
| elif isinstance(data_item, Sequence): | |
| # check to make sure that the data_itements in batch have | |
| # consistent size | |
| it = iter(data_batch) | |
| data_item_size = len(next(it)) | |
| if not all(len(data_item) == data_item_size for data_item in it): | |
| raise RuntimeError( | |
| 'each data_itement in list of batch should be of equal size') | |
| transposed = list(zip(*data_batch)) | |
| if isinstance(data_item, tuple): | |
| return [default_collate(samples) | |
| for samples in transposed] # Compat with Pytorch. | |
| else: | |
| try: | |
| return data_item_type( | |
| [default_collate(samples) for samples in transposed]) | |
| except TypeError: | |
| # The sequence type may not support `__init__(iterable)` | |
| # (e.g., `range`). | |
| return [default_collate(samples) for samples in transposed] | |
| elif isinstance(data_item, Mapping): | |
| return data_item_type({ | |
| key: default_collate([d[key] for d in data_batch]) | |
| for key in data_item | |
| }) | |
| else: | |
| return torch_default_collate(data_batch) | |