Spaces:
Runtime error
Runtime error
| import math | |
| import random | |
| import time | |
| from functools import wraps | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import distributions as pyd | |
| from torch.distributions.utils import _standard_normal | |
| from collections.abc import MutableMapping | |
| class eval_mode: | |
| def __init__(self, *models): | |
| self.models = models | |
| def __enter__(self): | |
| self.prev_states = [] | |
| for model in self.models: | |
| self.prev_states.append(model.training) | |
| model.train(False) | |
| def __exit__(self, *args): | |
| for model, state in zip(self.models, self.prev_states): | |
| model.train(state) | |
| return False | |
| def set_seed_everywhere(seed): | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| def soft_update_params(net, target_net, tau): | |
| for param, target_param in zip(net.parameters(), target_net.parameters()): | |
| target_param.data.copy_(tau * param.data + | |
| (1 - tau) * target_param.data) | |
| def hard_update_params(net, target_net): | |
| for param, target_param in zip(net.parameters(), target_net.parameters()): | |
| target_param.data.copy_(param.data) | |
| def weight_init(m): | |
| """Custom weight init for Conv2D and Linear layers.""" | |
| if isinstance(m, nn.Linear): | |
| nn.init.orthogonal_(m.weight.data) | |
| if hasattr(m.bias, 'data'): | |
| m.bias.data.fill_(0.0) | |
| elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): | |
| gain = nn.init.calculate_gain('relu') | |
| nn.init.orthogonal_(m.weight.data, gain) | |
| if hasattr(m.bias, 'data'): | |
| m.bias.data.fill_(0.0) | |
| class Until: | |
| def __init__(self, until, action_repeat=1): | |
| self._until = until | |
| self._action_repeat = action_repeat | |
| def __call__(self, step): | |
| if self._until is None: | |
| return True | |
| until = self._until // self._action_repeat | |
| return step < until | |
| class Every: | |
| def __init__(self, every, action_repeat=1): | |
| self._every = every | |
| self._action_repeat = action_repeat | |
| def __call__(self, step): | |
| if self._every is None: | |
| return False | |
| every = self._every // self._action_repeat | |
| if step % every == 0: | |
| return True | |
| return False | |
| class Timer: | |
| def __init__(self): | |
| self._start_time = time.time() | |
| self._last_time = time.time() | |
| def reset(self): | |
| elapsed_time = time.time() - self._last_time | |
| self._last_time = time.time() | |
| total_time = time.time() - self._start_time | |
| return elapsed_time, total_time | |
| def total_time(self): | |
| return time.time() - self._start_time | |
| class TruncatedNormal(pyd.Normal): | |
| def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): | |
| super().__init__(loc, scale, validate_args=False) | |
| self.low = low | |
| self.high = high | |
| self.eps = eps | |
| def _clamp(self, x): | |
| clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) | |
| x = x - x.detach() + clamped_x.detach() | |
| return x | |
| def sample(self, sample_shape=torch.Size(), stddev_clip=None): | |
| shape = self._extended_shape(sample_shape) | |
| eps = _standard_normal(shape, | |
| dtype=self.loc.dtype, | |
| device=self.loc.device) | |
| eps *= self.scale | |
| if stddev_clip is not None: | |
| eps = torch.clamp(eps, -stddev_clip, stddev_clip) | |
| x = self.loc + eps | |
| return self._clamp(x) | |
| class TanhTransform(pyd.transforms.Transform): | |
| domain = pyd.constraints.real | |
| codomain = pyd.constraints.interval(-1.0, 1.0) | |
| bijective = True | |
| sign = +1 | |
| def __init__(self, cache_size=1): | |
| super().__init__(cache_size=cache_size) | |
| def atanh(x): | |
| return 0.5 * (x.log1p() - (-x).log1p()) | |
| def __eq__(self, other): | |
| return isinstance(other, TanhTransform) | |
| def _call(self, x): | |
| return x.tanh() | |
| def _inverse(self, y): | |
| # We do not clamp to the boundary here as it may degrade the performance of certain algorithms. | |
| # one should use `cache_size=1` instead | |
| return self.atanh(y) | |
| def log_abs_det_jacobian(self, x, y): | |
| # We use a formula that is more numerically stable, see details in the following link | |
| # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7 | |
| return 2. * (math.log(2.) - x - F.softplus(-2. * x)) | |
| class SquashedNormal(pyd.transformed_distribution.TransformedDistribution): | |
| def __init__(self, loc, scale): | |
| self.loc = loc | |
| self.scale = scale | |
| self.base_dist = pyd.Normal(loc, scale) | |
| transforms = [TanhTransform()] | |
| super().__init__(self.base_dist, transforms) | |
| def mean(self): | |
| mu = self.loc | |
| for tr in self.transforms: | |
| mu = tr(mu) | |
| return mu | |
| def retry(func): | |
| """ | |
| A Decorator to retry a function for a certain amount of attempts | |
| """ | |
| def wrapper(*args, **kwargs): | |
| attempts = 0 | |
| max_attempts = 1000 | |
| while attempts < max_attempts: | |
| try: | |
| return func(*args, **kwargs) | |
| except (OSError, PermissionError): | |
| attempts += 1 | |
| time.sleep(0.1) | |
| raise OSError("Retry failed") | |
| return wrapper | |
| def flatten_dict(dictionary, parent_key='', separator='_'): | |
| items = [] | |
| for key in dictionary.keys(): | |
| try: | |
| value = dictionary[key] | |
| except: | |
| value = '??? <MISSING>' | |
| new_key = parent_key + separator + key if parent_key else key | |
| if isinstance(value, MutableMapping): | |
| items.extend(flatten_dict(value, new_key, separator=separator).items()) | |
| else: | |
| items.append((new_key, value)) | |
| return dict(items) | |
| def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): | |
| ''' | |
| Spherical linear interpolation | |
| Args: | |
| t (float/np.ndarray): Float value between 0.0 and 1.0 | |
| v0 (np.ndarray): Starting vector | |
| v1 (np.ndarray): Final vector | |
| DOT_THRESHOLD (float): Threshold for considering the two vectors as | |
| colineal. Not recommended to alter this. | |
| Returns: | |
| v2 (np.ndarray): Interpolation vector between v0 and v1 | |
| ''' | |
| c = False | |
| if not isinstance(v0,np.ndarray): | |
| c = True | |
| v0 = v0.detach().cpu().numpy() | |
| if not isinstance(v1,np.ndarray): | |
| c = True | |
| v1 = v1.detach().cpu().numpy() | |
| if len(v0.shape) == 1: | |
| v0 = v0.reshape(1, -1) | |
| if len(v1.shape) == 1: | |
| v1 = v1.reshape(1, -1) | |
| # Copy the vectors to reuse them later | |
| v0_copy = np.copy(v0) | |
| v1_copy = np.copy(v1) | |
| # Normalize the vectors to get the directions and angles | |
| v0 = v0 / np.linalg.norm(v0, axis=-1, keepdims=True) | |
| v1 = v1 / np.linalg.norm(v1, axis=-1, keepdims=True) | |
| # Dot product with the normalized vectors (can't use np.dot in W) | |
| dot = np.sum(v0 * v1, axis=-1) | |
| # If absolute value of dot product is almost 1, vectors are ~colineal, so use lerp | |
| if (np.abs(dot) > DOT_THRESHOLD).any(): | |
| raise NotImplementedError('lerp not implemented') # return lerp(t, v0_copy, v1_copy) | |
| # Calculate initial angle between v0 and v1 | |
| theta_0 = np.arccos(dot) | |
| sin_theta_0 = np.sin(theta_0) | |
| # Angle at timestep t | |
| theta_t = theta_0 * t | |
| sin_theta_t = np.sin(theta_t) | |
| # Finish the slerp algorithm | |
| s0 = np.sin(theta_0 - theta_t) / sin_theta_0 | |
| s1 = sin_theta_t / sin_theta_0 | |
| v2 = s0.reshape(-1, 1) * v0_copy + s1.reshape(-1, 1) * v1_copy | |
| if c: | |
| res = torch.from_numpy(v2).to("cuda") | |
| else: | |
| res = v2 | |
| return res |