Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # author: adefossez | |
| import functools | |
| import logging | |
| from contextlib import contextmanager | |
| import inspect | |
| import time | |
| logger = logging.getLogger(__name__) | |
| EPS = 1e-8 | |
| def capture_init(init): | |
| """capture_init. | |
| Decorate `__init__` with this, and you can then | |
| recover the *args and **kwargs passed to it in `self._init_args_kwargs` | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| self._init_args_kwargs = (args, kwargs) | |
| init(self, *args, **kwargs) | |
| return __init__ | |
| def deserialize_model(package, strict=False): | |
| """deserialize_model. | |
| """ | |
| klass = package['class'] | |
| if strict: | |
| model = klass(*package['args'], **package['kwargs']) | |
| else: | |
| sig = inspect.signature(klass) | |
| kw = package['kwargs'] | |
| for key in list(kw): | |
| if key not in sig.parameters: | |
| logger.warning("Dropping inexistant parameter %s", key) | |
| del kw[key] | |
| model = klass(*package['args'], **kw) | |
| model.load_state_dict(package['state']) | |
| return model | |
| def copy_state(state): | |
| return {k: v.cpu().clone() for k, v in state.items()} | |
| def serialize_model(model): | |
| args, kwargs = model._init_args_kwargs | |
| state = copy_state(model.state_dict()) | |
| return {"class": model.__class__, "args": args, "kwargs": kwargs, "state": state} | |
| def swap_state(model, state): | |
| """ | |
| Context manager that swaps the state of a model, e.g: | |
| # model is in old state | |
| with swap_state(model, new_state): | |
| # model in new state | |
| # model back to old state | |
| """ | |
| old_state = copy_state(model.state_dict()) | |
| model.load_state_dict(state) | |
| try: | |
| yield | |
| finally: | |
| model.load_state_dict(old_state) | |
| def pull_metric(history, name): | |
| out = [] | |
| for metrics in history: | |
| if name in metrics: | |
| out.append(metrics[name]) | |
| return out | |
| class LogProgress: | |
| """ | |
| Sort of like tqdm but using log lines and not as real time. | |
| Args: | |
| - logger: logger obtained from `logging.getLogger`, | |
| - iterable: iterable object to wrap | |
| - updates (int): number of lines that will be printed, e.g. | |
| if `updates=5`, log every 1/5th of the total length. | |
| - total (int): length of the iterable, in case it does not support | |
| `len`. | |
| - name (str): prefix to use in the log. | |
| - level: logging level (like `logging.INFO`). | |
| """ | |
| def __init__(self, | |
| logger, | |
| iterable, | |
| updates=5, | |
| total=None, | |
| name="LogProgress", | |
| level=logging.INFO): | |
| self.iterable = iterable | |
| self.total = total or len(iterable) | |
| self.updates = updates | |
| self.name = name | |
| self.logger = logger | |
| self.level = level | |
| def update(self, **infos): | |
| self._infos = infos | |
| def __iter__(self): | |
| self._iterator = iter(self.iterable) | |
| self._index = -1 | |
| self._infos = {} | |
| self._begin = time.time() | |
| return self | |
| def __next__(self): | |
| self._index += 1 | |
| try: | |
| value = next(self._iterator) | |
| except StopIteration: | |
| raise | |
| else: | |
| return value | |
| finally: | |
| log_every = max(1, self.total // self.updates) | |
| # logging is delayed by 1 it, in order to have the metrics from update | |
| if self._index >= 1 and self._index % log_every == 0: | |
| self._log() | |
| def _log(self): | |
| self._speed = (1 + self._index) / (time.time() - self._begin) | |
| infos = " | ".join(f"{k.capitalize()} {v}" for k, v in self._infos.items()) | |
| if self._speed < 1e-4: | |
| speed = "oo sec/it" | |
| elif self._speed < 0.1: | |
| speed = f"{1/self._speed:.1f} sec/it" | |
| else: | |
| speed = f"{self._speed:.1f} it/sec" | |
| out = f"{self.name} | {self._index}/{self.total} | {speed}" | |
| if infos: | |
| out += " | " + infos | |
| self.logger.log(self.level, out) | |
| def colorize(text, color): | |
| """ | |
| Display text with some ANSI color in the terminal. | |
| """ | |
| code = f"\033[{color}m" | |
| restore = "\033[0m" | |
| return "".join([code, text, restore]) | |
| def bold(text): | |
| """ | |
| Display text in bold in the terminal. | |
| """ | |
| return colorize(text, "1") | |
| def cal_snr(lbl, est): | |
| import torch | |
| y = 10.0 * torch.log10( | |
| torch.sum(lbl**2, dim=-1) / (torch.sum((est-lbl)**2, dim=-1) + EPS) + | |
| EPS | |
| ) | |
| return y | |