Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from collections import defaultdict | |
| import logging | |
| import typing as tp | |
| import flashy | |
| import torch | |
| from ..optim import ModuleDictEMA | |
| from .utils import copy_state | |
| logger = logging.getLogger(__name__) | |
| class BestStateDictManager(flashy.state.StateDictSource): | |
| """BestStateDictManager maintains a copy of best state_dict() for registered sources. | |
| BestStateDictManager has two main attributes: | |
| states (dict): State dict of the registered StateDictSource. | |
| param_ids (dict): Dict of parameter ids for registered states from ModuleDictEMA and other sources. | |
| When registering new sources, the BestStateDictManager will ensure two conflicting sources between | |
| ModuleDictEMA and original modules are not both registered as it would otherwise create ambiguity about | |
| what to consider for best state. | |
| Args: | |
| device (torch.device or str): Device on which we keep the copy. | |
| dtype (torch.dtype): Data type for the state parameters. | |
| """ | |
| def __init__(self, device: tp.Union[torch.device, str] = 'cpu', | |
| dtype: tp.Optional[torch.dtype] = None): | |
| self.device = device | |
| self.states: dict = {} | |
| self.param_ids: dict = defaultdict(dict) | |
| self.dtype = dtype | |
| def _get_parameter_ids(self, state_dict): | |
| return {id(p): name for name, p in state_dict.items() if isinstance(p, torch.Tensor)} | |
| def _validate_no_parameter_ids_overlap(self, name: str, param_ids: dict): | |
| for registered_name, registered_param_ids in self.param_ids.items(): | |
| if registered_name != name: | |
| overlap = set.intersection(registered_param_ids.keys(), param_ids.keys()) | |
| assert len(overlap) == 0, f"Found {len(overlap)} / {len(param_ids.keys())} overlapping parameters" | |
| f" in {name} and already registered {registered_name}: {' '.join(overlap)}" | |
| def update(self, name: str, source: flashy.state.StateDictSource): | |
| if name not in self.states: | |
| raise ValueError(f"{name} missing from registered states.") | |
| self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype) | |
| def register(self, name: str, source: flashy.state.StateDictSource): | |
| if name in self.states: | |
| raise ValueError(f"{name} already present in states.") | |
| # Registering parameter ids for EMA and non-EMA states allows us to check that | |
| # there is no overlap that would create ambiguity about how to handle the best state | |
| param_ids = self._get_parameter_ids(source.state_dict()) | |
| if isinstance(source, ModuleDictEMA): | |
| logger.debug(f"Registering to best state: ModuleDictEMA '{name}' with {len(param_ids)} params") | |
| self._validate_no_parameter_ids_overlap(name, param_ids) | |
| self.param_ids[name] = param_ids | |
| else: | |
| logger.debug(f"Registering to best state: StateDictSource '{name}' with {len(param_ids)} params") | |
| self._validate_no_parameter_ids_overlap('base', param_ids) | |
| self.param_ids['base'].update(param_ids) | |
| # Register state | |
| self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype) | |
| def state_dict(self) -> flashy.state.StateDict: | |
| return self.states | |
| def load_state_dict(self, state: flashy.state.StateDict): | |
| for name, sub_state in state.items(): | |
| for k, v in sub_state.items(): | |
| self.states[name][k].copy_(v) | |