Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| from typing import Callable, Optional | |
| import torch.nn as nn | |
| from torch.nn.parallel import DistributedDataParallel | |
| from mmengine.device import get_device | |
| from mmengine.dist import init_dist, is_distributed, master_only | |
| from mmengine.model import convert_sync_batchnorm, is_model_wrapper | |
| from mmengine.registry import MODEL_WRAPPERS, STRATEGIES | |
| from .single_device import SingleDeviceStrategy | |
| class DDPStrategy(SingleDeviceStrategy): | |
| """Distribution strategy for distributed data parallel training. | |
| Args: | |
| model_wrapper (dict): Dict for model wrapper. Defaults to None. | |
| sync_bn (str): Type of sync batch norm. Defaults to None. | |
| Options are 'torch' and 'mmcv'. | |
| **kwargs: Other arguments for :class:`BaseStrategy`. | |
| """ | |
| def __init__( | |
| self, | |
| *, | |
| model_wrapper: Optional[dict] = None, | |
| sync_bn: Optional[str] = None, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.model_wrapper = model_wrapper | |
| self.sync_bn = sync_bn | |
| def _setup_distributed( # type: ignore | |
| self, | |
| launcher: str = 'pytorch', | |
| backend: str = 'nccl', | |
| **kwargs, | |
| ): | |
| """Setup distributed environment. | |
| Args: | |
| launcher (str): Way to launcher multi processes. Supported | |
| launchers are 'pytorch', 'mpi' and 'slurm'. | |
| backend (str): Communication Backends. Supported backends are | |
| 'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'. | |
| **kwargs: Other arguments for :func:`init_dist`. | |
| """ | |
| if not is_distributed(): | |
| init_dist(launcher, backend, **kwargs) | |
| def convert_model(self, model: nn.Module) -> nn.Module: | |
| """convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm`` | |
| (SyncBN) or ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers. | |
| Args: | |
| model (nn.Module): Model to be converted. | |
| Returns: | |
| nn.Module: Converted model. | |
| """ | |
| if self.sync_bn is not None: | |
| try: | |
| model = convert_sync_batchnorm(model, self.sync_bn) | |
| except ValueError as e: | |
| self.logger.error('cfg.sync_bn should be "torch" or ' | |
| f'"mmcv", but got {self.sync_bn}') | |
| raise e | |
| return model | |
| def _wrap_model(self, model: nn.Module) -> DistributedDataParallel: | |
| """Wrap the model to :obj:``MMDistributedDataParallel`` or other custom | |
| distributed data-parallel module wrappers. | |
| Args: | |
| model (nn.Module): Model to be wrapped. | |
| Returns: | |
| nn.Module or DistributedDataParallel: nn.Module or subclass of | |
| ``DistributedDataParallel``. | |
| """ | |
| if is_model_wrapper(model): | |
| return model | |
| model = model.to(get_device()) | |
| model = self.convert_model(model) | |
| if self.model_wrapper is None: | |
| # set broadcast_buffers as False to keep compatibility with | |
| # OpenMMLab repos | |
| self.model_wrapper = dict( | |
| type='MMDistributedDataParallel', broadcast_buffers=False) | |
| default_args = dict( | |
| type='MMDistributedDataParallel', | |
| module=model, | |
| device_ids=[int(os.environ['LOCAL_RANK'])]) | |
| model = MODEL_WRAPPERS.build( | |
| self.model_wrapper, default_args=default_args) | |
| return model | |
| def save_checkpoint( | |
| self, | |
| filename: str, | |
| *, | |
| save_optimizer: bool = True, | |
| save_param_scheduler: bool = True, | |
| extra_ckpt: Optional[dict] = None, | |
| callback: Optional[Callable] = None, | |
| ) -> None: | |
| super().save_checkpoint( | |
| filename=filename, | |
| save_optimizer=save_optimizer, | |
| save_param_scheduler=save_param_scheduler, | |
| extra_ckpt=extra_ckpt, | |
| callback=callback) | |