Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| import random | |
| import subprocess | |
| from datetime import datetime | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| from torch.nn.parallel import DistributedDataParallel | |
| from torch.nn.parallel.distributed import _find_tensors | |
| import torch.optim | |
| import torch.utils.data | |
| from packaging import version | |
| from omegaconf import OmegaConf | |
| def set_random_seed(seed): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| def is_logging_process(): | |
| return not dist.is_initialized() or dist.get_rank() == 0 | |
| def get_logger(cfg, name=None): | |
| # log_file_path is used when unit testing | |
| if is_logging_process(): | |
| logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_config, resolve=True)) | |
| return logging.getLogger(name) | |
| # from https://github.com/Lightning-AI/lightning-bolts/blob/5d61197cd2f491f69e238137a5edabe80ae14ad9/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20 | |
| class SyncFunction(torch.autograd.Function): | |
| # @torch.no_grad() | |
| def forward(ctx, tensor): | |
| ctx.batch_size = tensor.shape[0] | |
| gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] | |
| torch.distributed.all_gather(gathered_tensor, tensor) | |
| gathered_tensor = torch.cat(gathered_tensor, 0) | |
| return gathered_tensor | |
| def backward(ctx, grad_output): | |
| grad_input = grad_output.clone() | |
| torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False) | |
| idx_from = torch.distributed.get_rank() * ctx.batch_size | |
| idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size | |
| return grad_input[idx_from:idx_to] | |
| def get_timestamp(): | |
| return datetime.now().strftime("%y%m%d-%H%M%S") | |
| def get_commit_hash(): | |
| message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) | |
| return message.strip().decode("utf-8") | |
| class DDP(DistributedDataParallel): | |
| """ | |
| Override the forward call in lightning so it goes to training and validation step respectively | |
| """ | |
| def forward(self, *inputs, **kwargs): # pragma: no cover | |
| if version.parse(torch.__version__[:6]) < version.parse("1.11"): | |
| self._sync_params() | |
| inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) | |
| assert len(self.device_ids) == 1 | |
| if self.module.training: | |
| output = self.module.training_step(*inputs[0], **kwargs[0]) | |
| elif self.module.testing: | |
| output = self.module.test_step(*inputs[0], **kwargs[0]) | |
| else: | |
| output = self.module.validation_step(*inputs[0], **kwargs[0]) | |
| if torch.is_grad_enabled(): | |
| # We'll return the output object verbatim since it is a freeform | |
| # object. We need to find any tensors in this object, though, | |
| # because we need to figure out which parameters were used during | |
| # this forward pass, to ensure we short circuit reduction for any | |
| # unused parameters. Only if `find_unused_parameters` is set. | |
| if self.find_unused_parameters: | |
| self.reducer.prepare_for_backward(list(_find_tensors(output))) | |
| else: | |
| self.reducer.prepare_for_backward([]) | |
| else: | |
| from torch.nn.parallel.distributed import ( | |
| logging, | |
| Join, | |
| _DDPSink, | |
| _tree_flatten_with_rref, | |
| _tree_unflatten_with_rref, | |
| ) | |
| with torch.autograd.profiler.record_function("DistributedDataParallel.forward"): | |
| if torch.is_grad_enabled() and self.require_backward_grad_sync: | |
| self.logger.set_runtime_stats_and_log() | |
| self.num_iterations += 1 | |
| self.reducer.prepare_for_forward() | |
| # Notify the join context that this process has not joined, if | |
| # needed | |
| work = Join.notify_join_context(self) | |
| if work: | |
| self.reducer._set_forward_pass_work_handle(work, self._divide_by_initial_world_size) | |
| # Calling _rebuild_buckets before forward compuation, | |
| # It may allocate new buckets before deallocating old buckets | |
| # inside _rebuild_buckets. To save peak memory usage, | |
| # call _rebuild_buckets before the peak memory usage increases | |
| # during forward computation. | |
| # This should be called only once during whole training period. | |
| if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): | |
| logging.info("Reducer buckets have been rebuilt in this iteration.") | |
| self._has_rebuilt_buckets = True | |
| # sync params according to location (before/after forward) user | |
| # specified as part of hook, if hook was specified. | |
| buffer_hook_registered = hasattr(self, "buffer_hook") | |
| if self._check_sync_bufs_pre_fwd(): | |
| self._sync_buffers() | |
| if self._join_config.enable: | |
| # Notify joined ranks whether they should sync in backwards pass or not. | |
| self._check_global_requires_backward_grad_sync(is_joined_rank=False) | |
| inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) | |
| if self.module.training: | |
| output = self.module.training_step(*inputs[0], **kwargs[0]) | |
| elif self.module.testing: | |
| output = self.module.test_step(*inputs[0], **kwargs[0]) | |
| else: | |
| output = self.module.validation_step(*inputs[0], **kwargs[0]) | |
| # sync params according to location (before/after forward) user | |
| # specified as part of hook, if hook was specified. | |
| if self._check_sync_bufs_post_fwd(): | |
| self._sync_buffers() | |
| if torch.is_grad_enabled() and self.require_backward_grad_sync: | |
| self.require_forward_param_sync = True | |
| # We'll return the output object verbatim since it is a freeform | |
| # object. We need to find any tensors in this object, though, | |
| # because we need to figure out which parameters were used during | |
| # this forward pass, to ensure we short circuit reduction for any | |
| # unused parameters. Only if `find_unused_parameters` is set. | |
| if self.find_unused_parameters and not self.static_graph: | |
| # Do not need to populate this for static graph. | |
| self.reducer.prepare_for_backward(list(_find_tensors(output))) | |
| else: | |
| self.reducer.prepare_for_backward([]) | |
| else: | |
| self.require_forward_param_sync = False | |
| # TODO: DDPSink is currently enabled for unused parameter detection and | |
| # static graph training for first iteration. | |
| if (self.find_unused_parameters and not self.static_graph) or ( | |
| self.static_graph and self.num_iterations == 1 | |
| ): | |
| state_dict = { | |
| "static_graph": self.static_graph, | |
| "num_iterations": self.num_iterations, | |
| } | |
| output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(output) | |
| output_placeholders = [None for _ in range(len(output_tensor_list))] | |
| # Do not touch tensors that have no grad_fn, which can cause issues | |
| # such as https://github.com/pytorch/pytorch/issues/60733 | |
| for i, output in enumerate(output_tensor_list): | |
| if torch.is_tensor(output) and output.grad_fn is None: | |
| output_placeholders[i] = output | |
| # When find_unused_parameters=True, makes tensors which require grad | |
| # run through the DDPSink backward pass. When not all outputs are | |
| # used in loss, this makes those corresponding tensors receive | |
| # undefined gradient which the reducer then handles to ensure | |
| # param.grad field is not touched and we don't error out. | |
| passthrough_tensor_list = _DDPSink.apply( | |
| self.reducer, | |
| state_dict, | |
| *output_tensor_list, | |
| ) | |
| for i in range(len(output_placeholders)): | |
| if output_placeholders[i] is None: | |
| output_placeholders[i] = passthrough_tensor_list[i] | |
| # Reconstruct output data structure. | |
| output = _tree_unflatten_with_rref(output_placeholders, treespec, output_is_rref) | |
| return output | |