Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Wrapper around FSDP for more convenient use in the training loops. | |
| """ | |
| from contextlib import contextmanager | |
| import typing as tp | |
| import dora | |
| import torch | |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | |
| from torch.distributed.fsdp import ( | |
| MixedPrecision, ShardingStrategy, FullStateDictConfig, StateDictType) | |
| from torch.distributed._shard.sharded_tensor.api import ShardedTensor | |
| def is_fsdp_used() -> bool: | |
| """Return whether we are using FSDP.""" | |
| # A bit of a hack but should work from anywhere. | |
| if dora.is_xp(): | |
| cfg = dora.get_xp().cfg | |
| if hasattr(cfg, 'fsdp'): | |
| return cfg.fsdp.use | |
| return False | |
| def is_sharded_tensor(x: tp.Any) -> bool: | |
| return isinstance(x, ShardedTensor) | |
| def switch_to_full_state_dict(models: tp.List[FSDP]): | |
| # Another bug in FSDP makes it that we cannot use the `state_dict_type` API, | |
| # so let's do thing manually. | |
| for model in models: | |
| FSDP.set_state_dict_type( # type: ignore | |
| model, StateDictType.FULL_STATE_DICT, | |
| FullStateDictConfig(offload_to_cpu=True, rank0_only=True)) | |
| try: | |
| yield | |
| finally: | |
| for model in models: | |
| FSDP.set_state_dict_type(model, StateDictType.LOCAL_STATE_DICT) # type: ignore | |
| def wrap_with_fsdp(cfg, model: torch.nn.Module, | |
| block_classes: tp.Optional[tp.Set[tp.Type]] = None) -> FSDP: | |
| """Wraps a model with FSDP.""" | |
| # Some of the typing is disabled until this gets integrated | |
| # into the stable version of PyTorch. | |
| from torch.distributed.fsdp.wrap import ModuleWrapPolicy # type: ignore | |
| # we import this here to prevent circular import. | |
| from ..modules.transformer import StreamingTransformerLayer | |
| from ..modules.conditioners import ConditioningProvider | |
| _fix_post_backward_hook() | |
| assert cfg.use | |
| sharding_strategy_dict = { | |
| "no_shard": ShardingStrategy.NO_SHARD, | |
| "shard_grad_op": ShardingStrategy.SHARD_GRAD_OP, | |
| "full_shard": ShardingStrategy.FULL_SHARD, | |
| } | |
| dtype_dict = { | |
| "float32": torch.float32, | |
| "float16": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| } | |
| mixed_precision_config = MixedPrecision( | |
| param_dtype=dtype_dict[cfg.param_dtype], | |
| reduce_dtype=dtype_dict[cfg.reduce_dtype], | |
| buffer_dtype=dtype_dict[cfg.buffer_dtype], | |
| ) | |
| sharding_strategy_config = sharding_strategy_dict[cfg.sharding_strategy] | |
| # The following is going to require being a bit smart | |
| # when doing LM, because this would flush the weights for every time step | |
| # during generation. One possiblity is to use hybrid sharding: | |
| # See: https://pytorch.org/docs/master/fsdp.html#torch.distributed.fsdp.ShardingStrategy | |
| assert sharding_strategy_config != ShardingStrategy.FULL_SHARD, \ | |
| "Not supported at the moment, requires a bit more work." | |
| local_rank = dora.distrib.get_distrib_spec().local_rank | |
| assert local_rank < torch.cuda.device_count(), "Please upgrade Dora!" | |
| auto_wrap_policy = None | |
| if block_classes is None: | |
| block_classes = {StreamingTransformerLayer, ConditioningProvider} | |
| if cfg.per_block: | |
| auto_wrap_policy = ModuleWrapPolicy(block_classes) | |
| wrapped = _FSDPFixStateDict( | |
| model, | |
| sharding_strategy=sharding_strategy_config, | |
| mixed_precision=mixed_precision_config, | |
| device_id=local_rank, | |
| sync_module_states=True, | |
| use_orig_params=True, | |
| auto_wrap_policy=auto_wrap_policy, | |
| ) # type: ignore | |
| FSDP.set_state_dict_type(wrapped, StateDictType.LOCAL_STATE_DICT) # type: ignore | |
| # Let the wrapped model know about the wrapping! | |
| # We use __dict__ to avoid it going into the state dict. | |
| # This is a bit dirty, but needed during generation, as otherwise | |
| # the wrapped model would call itself and bypass FSDP. | |
| for module in FSDP.fsdp_modules(wrapped): | |
| original = module._fsdp_wrapped_module | |
| original.__dict__['_fsdp'] = module | |
| return wrapped | |
| def purge_fsdp(model: FSDP): | |
| """Purge the FSDP cached shard inside the model. This should | |
| allow setting the best state or switching to the EMA. | |
| """ | |
| from torch.distributed.fsdp._runtime_utils import _reshard # type: ignore | |
| for module in FSDP.fsdp_modules(model): | |
| if hasattr(module, "_handles"): | |
| # support for FSDP with torch<2.1.0 | |
| handles = module._handles | |
| if not handles: | |
| continue | |
| handle = handles[0] | |
| unsharded_flat_param = handle._get_padded_unsharded_flat_param() | |
| storage_size: int = unsharded_flat_param._typed_storage()._size() # type: ignore | |
| if storage_size == 0: | |
| continue | |
| true_list = [True for h in handles] | |
| _reshard(module, handles, true_list) | |
| else: | |
| handle = module._handle | |
| if not handle: | |
| continue | |
| unsharded_flat_param = handle._get_padded_unsharded_flat_param() | |
| storage_size: int = unsharded_flat_param._typed_storage()._size() # type: ignore | |
| if storage_size == 0: | |
| continue | |
| _reshard(module, handle, True) | |
| class _FSDPFixStateDict(FSDP): | |
| def _name_without_fsdp_prefix(name: str) -> str: | |
| from torch.distributed.fsdp._common_utils import FSDP_WRAPPED_MODULE # type: ignore | |
| parts = name.split('.') | |
| new_parts = [part for part in parts if part != FSDP_WRAPPED_MODULE] | |
| return '.'.join(new_parts) | |
| def state_dict(self, *args, **kwargs) -> tp.Dict[str, tp.Any]: # type: ignore | |
| state = dict(super().state_dict(*args, **kwargs)) | |
| for key, value in list(state.items()): | |
| if is_sharded_tensor(value): | |
| del state[key] | |
| return state | |
| def load_state_dict(self, state: tp.Dict[str, tp.Any]): # type: ignore | |
| if self._state_dict_type is StateDictType.FULL_STATE_DICT: | |
| super().load_state_dict(state) | |
| purge_fsdp(self) | |
| return | |
| # Fix FSDP load state dict in all situation. | |
| # Use this only with LOCAL_STATE_DICT !!! | |
| current_state = dict(super().state_dict()) | |
| for key, value in state.items(): | |
| key = _FSDPFixStateDict._name_without_fsdp_prefix(key) | |
| if key not in current_state: | |
| # Emulate strict loading manually. | |
| raise RuntimeError(f"Unknown state key {key}") | |
| current_state[key].copy_(value) | |
| # Purging cached weights from previous forward. | |
| purge_fsdp(self) | |
| _hook_fixed = False | |
| def _fix_post_backward_hook(): | |
| global _hook_fixed | |
| if _hook_fixed: | |
| return | |
| _hook_fixed = True | |
| from torch.distributed.fsdp import _runtime_utils | |
| from torch.distributed.fsdp._common_utils import TrainingState, HandleTrainingState | |
| old_hook = _runtime_utils._post_backward_hook | |
| def _post_backward_hook(state, handle, *args, **kwargs): | |
| checkpointed = getattr(state._fsdp_wrapped_module, '_audiocraft_checkpointed', False) | |
| if checkpointed: | |
| # there will be one more forward in the backward with checkpointing and that will | |
| # massively confuse FSDP, so we have to make it think everything | |
| # is going according to the plan. | |
| state.training_state = TrainingState.FORWARD_BACKWARD | |
| handle._training_state = HandleTrainingState.BACKWARD_PRE | |
| old_hook(state, handle, *args, **kwargs) | |
| _runtime_utils._post_backward_hook = _post_backward_hook | |