Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import bisect | |
| import logging | |
| import time | |
| from typing import Dict, List, Optional, Sequence, Tuple, Union | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from mmengine.evaluator import Evaluator | |
| from mmengine.logging import print_log | |
| from mmengine.registry import LOOPS | |
| from .amp import autocast | |
| from .base_loop import BaseLoop | |
| from .utils import calc_dynamic_intervals | |
| import socket | |
| class EpochBasedTrainLoop(BaseLoop): | |
| """Loop for epoch-based training. | |
| Args: | |
| runner (Runner): A reference of runner. | |
| dataloader (Dataloader or dict): A dataloader object or a dict to | |
| build a dataloader. | |
| max_epochs (int): Total training epochs. | |
| val_begin (int): The epoch that begins validating. | |
| Defaults to 1. | |
| val_interval (int): Validation interval. Defaults to 1. | |
| dynamic_intervals (List[Tuple[int, int]], optional): The | |
| first element in the tuple is a milestone and the second | |
| element is a interval. The interval is used after the | |
| corresponding milestone. Defaults to None. | |
| """ | |
| def __init__( | |
| self, | |
| runner, | |
| dataloader: Union[DataLoader, Dict], | |
| max_epochs: int, | |
| val_begin: int = 1, | |
| val_interval: int = 1, | |
| dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: | |
| super().__init__(runner, dataloader) | |
| self._max_epochs = int(max_epochs) | |
| assert self._max_epochs == max_epochs, \ | |
| f'`max_epochs` should be a integer number, but get {max_epochs}.' | |
| self._max_iters = self._max_epochs * len(self.dataloader) | |
| self._epoch = 0 | |
| self._iter = 0 | |
| self.val_begin = val_begin | |
| self.val_interval = val_interval | |
| # This attribute will be updated by `EarlyStoppingHook` | |
| # when it is enabled. | |
| self.stop_training = False | |
| if hasattr(self.dataloader.dataset, 'metainfo'): | |
| self.runner.visualizer.dataset_meta = \ | |
| self.dataloader.dataset.metainfo | |
| else: | |
| print_log( | |
| f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' | |
| 'metainfo. ``dataset_meta`` in visualizer will be ' | |
| 'None.', | |
| logger='current', | |
| level=logging.WARNING) | |
| self.dynamic_milestones, self.dynamic_intervals = \ | |
| calc_dynamic_intervals( | |
| self.val_interval, dynamic_intervals) | |
| def max_epochs(self): | |
| """int: Total epochs to train model.""" | |
| return self._max_epochs | |
| def max_iters(self): | |
| """int: Total iterations to train model.""" | |
| return self._max_iters | |
| def epoch(self): | |
| """int: Current epoch.""" | |
| return self._epoch | |
| def iter(self): | |
| """int: Current iteration.""" | |
| return self._iter | |
| def run(self) -> torch.nn.Module: | |
| """Launch training.""" | |
| self.runner.call_hook('before_train') | |
| while self._epoch < self._max_epochs and not self.stop_training: | |
| self.run_epoch() | |
| self._decide_current_val_interval() | |
| if (self.runner.val_loop is not None | |
| and self._epoch >= self.val_begin | |
| and self._epoch % self.val_interval == 0): | |
| self.runner.val_loop.run() | |
| self.runner.call_hook('after_train') | |
| return self.runner.model | |
| def run_epoch(self) -> None: | |
| """Iterate one epoch.""" | |
| self.runner.call_hook('before_train_epoch') | |
| self.runner.model.train() | |
| for idx, data_batch in enumerate(self.dataloader): | |
| self.run_iter(idx, data_batch) | |
| self.runner.call_hook('after_train_epoch') | |
| self._epoch += 1 | |
| def run_iter(self, idx, data_batch: Sequence[dict]) -> None: | |
| """Iterate one min-batch. | |
| Args: | |
| data_batch (Sequence[dict]): Batch of data from dataloader. | |
| """ | |
| self.runner.call_hook( | |
| 'before_train_iter', batch_idx=idx, data_batch=data_batch) | |
| # Enable gradient accumulation mode and avoid unnecessary gradient | |
| # synchronization during gradient accumulation process. | |
| # outputs should be a dict of loss. | |
| outputs = self.runner.model.train_step( | |
| data_batch, optim_wrapper=self.runner.optim_wrapper) | |
| self.runner.call_hook( | |
| 'after_train_iter', | |
| batch_idx=idx, | |
| data_batch=data_batch, | |
| outputs=outputs) | |
| self._iter += 1 | |
| def _decide_current_val_interval(self) -> None: | |
| """Dynamically modify the ``val_interval``.""" | |
| step = bisect.bisect(self.dynamic_milestones, (self.epoch + 1)) | |
| self.val_interval = self.dynamic_intervals[step - 1] | |
| class _InfiniteDataloaderIterator: | |
| """An infinite dataloader iterator wrapper for IterBasedTrainLoop. | |
| It resets the dataloader to continue iterating when the iterator has | |
| iterated over all the data. However, this approach is not efficient, as the | |
| workers need to be restarted every time the dataloader is reset. It is | |
| recommended to use `mmengine.dataset.InfiniteSampler` to enable the | |
| dataloader to iterate infinitely. | |
| """ | |
| def __init__(self, dataloader: DataLoader) -> None: | |
| self._dataloader = dataloader | |
| self._iterator = iter(self._dataloader) | |
| self._epoch = 0 | |
| def __iter__(self): | |
| return self | |
| def __next__(self) -> Sequence[dict]: | |
| try: | |
| data = next(self._iterator) | |
| except StopIteration: | |
| print_log( | |
| 'Reach the end of the dataloader, it will be ' | |
| 'restarted and continue to iterate. It is ' | |
| 'recommended to use ' | |
| '`mmengine.dataset.InfiniteSampler` to enable the ' | |
| 'dataloader to iterate infinitely.', | |
| logger='current', | |
| level=logging.WARNING) | |
| self._epoch += 1 | |
| if hasattr(self._dataloader, 'sampler') and hasattr( | |
| self._dataloader.sampler, 'set_epoch'): | |
| # In case the` _SingleProcessDataLoaderIter` has no sampler, | |
| # or data loader uses `SequentialSampler` in Pytorch. | |
| self._dataloader.sampler.set_epoch(self._epoch) | |
| elif hasattr(self._dataloader, 'batch_sampler') and hasattr( | |
| self._dataloader.batch_sampler.sampler, 'set_epoch'): | |
| # In case the` _SingleProcessDataLoaderIter` has no batch | |
| # sampler. batch sampler in pytorch warps the sampler as its | |
| # attributes. | |
| self._dataloader.batch_sampler.sampler.set_epoch(self._epoch) | |
| time.sleep(2) # Prevent possible deadlock during epoch transition | |
| self._iterator = iter(self._dataloader) | |
| data = next(self._iterator) | |
| return data | |
| class IterBasedTrainLoop(BaseLoop): | |
| """Loop for iter-based training. | |
| Args: | |
| runner (Runner): A reference of runner. | |
| dataloader (Dataloader or dict): A dataloader object or a dict to | |
| build a dataloader. | |
| max_iters (int): Total training iterations. | |
| val_begin (int): The iteration that begins validating. | |
| Defaults to 1. | |
| val_interval (int): Validation interval. Defaults to 1000. | |
| dynamic_intervals (List[Tuple[int, int]], optional): The | |
| first element in the tuple is a milestone and the second | |
| element is a interval. The interval is used after the | |
| corresponding milestone. Defaults to None. | |
| """ | |
| def __init__( | |
| self, | |
| runner, | |
| dataloader: Union[DataLoader, Dict], | |
| max_iters: int, | |
| val_begin: int = 1, | |
| val_interval: int = 1000, | |
| dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: | |
| super().__init__(runner, dataloader) | |
| self._max_iters = int(max_iters) | |
| assert self._max_iters == max_iters, \ | |
| f'`max_iters` should be a integer number, but get {max_iters}' | |
| self._max_epochs = 1 # for compatibility with EpochBasedTrainLoop | |
| self._epoch = 0 | |
| self._iter = 0 | |
| self.val_begin = val_begin | |
| self.val_interval = val_interval | |
| # This attribute will be updated by `EarlyStoppingHook` | |
| # when it is enabled. | |
| self.stop_training = False | |
| if hasattr(self.dataloader.dataset, 'metainfo'): | |
| self.runner.visualizer.dataset_meta = \ | |
| self.dataloader.dataset.metainfo | |
| else: | |
| print_log( | |
| f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' | |
| 'metainfo. ``dataset_meta`` in visualizer will be ' | |
| 'None.', | |
| logger='current', | |
| level=logging.WARNING) | |
| # get the iterator of the dataloader | |
| self.dataloader_iterator = _InfiniteDataloaderIterator(self.dataloader) | |
| self.dynamic_milestones, self.dynamic_intervals = \ | |
| calc_dynamic_intervals( | |
| self.val_interval, dynamic_intervals) | |
| def max_epochs(self): | |
| """int: Total epochs to train model.""" | |
| return self._max_epochs | |
| def max_iters(self): | |
| """int: Total iterations to train model.""" | |
| return self._max_iters | |
| def epoch(self): | |
| """int: Current epoch.""" | |
| return self._epoch | |
| def iter(self): | |
| """int: Current iteration.""" | |
| return self._iter | |
| def run(self) -> None: | |
| """Launch training.""" | |
| self.runner.call_hook('before_train') | |
| # In iteration-based training loop, we treat the whole training process | |
| # as a big epoch and execute the corresponding hook. | |
| self.runner.call_hook('before_train_epoch') | |
| while self._iter < self._max_iters and not self.stop_training: | |
| self.runner.model.train() | |
| data_batch = next(self.dataloader_iterator) | |
| self.run_iter(data_batch) | |
| self._decide_current_val_interval() | |
| if (self.runner.val_loop is not None | |
| and self._iter >= self.val_begin | |
| and self._iter % self.val_interval == 0): | |
| self.runner.val_loop.run() | |
| self.runner.call_hook('after_train_epoch') | |
| self.runner.call_hook('after_train') | |
| return self.runner.model | |
| def run_iter(self, data_batch: Sequence[dict]) -> None: | |
| """Iterate one mini-batch. | |
| Args: | |
| data_batch (Sequence[dict]): Batch of data from dataloader. | |
| """ | |
| self.runner.call_hook( | |
| 'before_train_iter', batch_idx=self._iter, data_batch=data_batch) | |
| # Enable gradient accumulation mode and avoid unnecessary gradient | |
| # synchronization during gradient accumulation process. | |
| # outputs should be a dict of loss. | |
| outputs = self.runner.model.train_step( | |
| data_batch, optim_wrapper=self.runner.optim_wrapper) | |
| self.runner.call_hook( | |
| 'after_train_iter', | |
| batch_idx=self._iter, | |
| data_batch=data_batch, | |
| outputs=outputs) | |
| self._iter += 1 | |
| def _decide_current_val_interval(self) -> None: | |
| """Dynamically modify the ``val_interval``.""" | |
| step = bisect.bisect(self.dynamic_milestones, (self._iter + 1)) | |
| self.val_interval = self.dynamic_intervals[step - 1] | |
| class ValLoop(BaseLoop): | |
| """Loop for validation. | |
| Args: | |
| runner (Runner): A reference of runner. | |
| dataloader (Dataloader or dict): A dataloader object or a dict to | |
| build a dataloader. | |
| evaluator (Evaluator or dict or list): Used for computing metrics. | |
| fp16 (bool): Whether to enable fp16 validation. Defaults to | |
| False. | |
| """ | |
| def __init__(self, | |
| runner, | |
| dataloader: Union[DataLoader, Dict], | |
| evaluator: Union[Evaluator, Dict, List], | |
| fp16: bool = False) -> None: | |
| super().__init__(runner, dataloader) | |
| if isinstance(evaluator, (dict, list)): | |
| self.evaluator = runner.build_evaluator(evaluator) # type: ignore | |
| else: | |
| assert isinstance(evaluator, Evaluator), ( | |
| 'evaluator must be one of dict, list or Evaluator instance, ' | |
| f'but got {type(evaluator)}.') | |
| self.evaluator = evaluator # type: ignore | |
| if hasattr(self.dataloader.dataset, 'metainfo'): | |
| self.evaluator.dataset_meta = self.dataloader.dataset.metainfo | |
| self.runner.visualizer.dataset_meta = \ | |
| self.dataloader.dataset.metainfo | |
| else: | |
| print_log( | |
| f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' | |
| 'metainfo. ``dataset_meta`` in evaluator, metric and ' | |
| 'visualizer will be None.', | |
| logger='current', | |
| level=logging.WARNING) | |
| self.fp16 = fp16 | |
| def run(self) -> dict: | |
| """Launch validation.""" | |
| self.runner.call_hook('before_val') | |
| self.runner.call_hook('before_val_epoch') | |
| self.runner.model.eval() | |
| for idx, data_batch in enumerate(self.dataloader): | |
| self.run_iter(idx, data_batch) | |
| # compute metrics | |
| metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) | |
| self.runner.call_hook('after_val_epoch', metrics=metrics) | |
| self.runner.call_hook('after_val') | |
| return metrics | |
| def run_iter(self, idx, data_batch: Sequence[dict]): | |
| """Iterate one mini-batch. | |
| Args: | |
| data_batch (Sequence[dict]): Batch of data | |
| from dataloader. | |
| """ | |
| self.runner.call_hook( | |
| 'before_val_iter', batch_idx=idx, data_batch=data_batch) | |
| # outputs should be sequence of BaseDataElement | |
| with autocast(enabled=self.fp16): | |
| outputs = self.runner.model.val_step(data_batch) | |
| self.evaluator.process(data_samples=outputs, data_batch=data_batch) | |
| self.runner.call_hook( | |
| 'after_val_iter', | |
| batch_idx=idx, | |
| data_batch=data_batch, | |
| outputs=outputs) | |
| class TestLoop(BaseLoop): | |
| """Loop for test. | |
| Args: | |
| runner (Runner): A reference of runner. | |
| dataloader (Dataloader or dict): A dataloader object or a dict to | |
| build a dataloader. | |
| evaluator (Evaluator or dict or list): Used for computing metrics. | |
| fp16 (bool): Whether to enable fp16 testing. Defaults to | |
| False. | |
| """ | |
| def __init__(self, | |
| runner, | |
| dataloader: Union[DataLoader, Dict], | |
| evaluator: Union[Evaluator, Dict, List], | |
| fp16: bool = False): | |
| super().__init__(runner, dataloader) | |
| if isinstance(evaluator, dict) or isinstance(evaluator, list): | |
| self.evaluator = runner.build_evaluator(evaluator) # type: ignore | |
| else: | |
| self.evaluator = evaluator # type: ignore | |
| if hasattr(self.dataloader.dataset, 'metainfo'): | |
| self.evaluator.dataset_meta = self.dataloader.dataset.metainfo | |
| self.runner.visualizer.dataset_meta = \ | |
| self.dataloader.dataset.metainfo | |
| else: | |
| print_log( | |
| f'Dataset {self.dataloader.dataset.__class__.__name__} has no ' | |
| 'metainfo. ``dataset_meta`` in evaluator, metric and ' | |
| 'visualizer will be None.', | |
| logger='current', | |
| level=logging.WARNING) | |
| self.fp16 = fp16 | |
| def run(self) -> dict: | |
| """Launch test.""" | |
| self.runner.call_hook('before_test') | |
| self.runner.call_hook('before_test_epoch') | |
| self.runner.model.eval() | |
| for idx, data_batch in enumerate(self.dataloader): | |
| self.run_iter(idx, data_batch) | |
| # compute metrics | |
| metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) | |
| self.runner.call_hook('after_test_epoch', metrics=metrics) | |
| self.runner.call_hook('after_test') | |
| return metrics | |
| def run_iter(self, idx, data_batch: Sequence[dict]) -> None: | |
| """Iterate one mini-batch. | |
| Args: | |
| data_batch (Sequence[dict]): Batch of data from dataloader. | |
| """ | |
| self.runner.call_hook( | |
| 'before_test_iter', batch_idx=idx, data_batch=data_batch) | |
| # predictions should be sequence of BaseDataElement | |
| with autocast(enabled=self.fp16): | |
| outputs = self.runner.model.test_step(data_batch) | |
| self.evaluator.process(data_samples=outputs, data_batch=data_batch) | |
| self.runner.call_hook( | |
| 'after_test_iter', | |
| batch_idx=idx, | |
| data_batch=data_batch, | |
| outputs=outputs) | |