Spaces:
Runtime error
Runtime error
| import math | |
| from typing import Callable, Optional | |
| import attr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| FilterFn = Callable[[torch.Tensor], torch.Tensor] | |
| class ZeroKeyBiasGrad(torch.autograd.Function): | |
| def forward(ctx, x): | |
| return x | |
| def backward(ctx, output_grad): | |
| output_grad = output_grad.clone() | |
| output_grad.chunk(3)[1].zero_() | |
| return output_grad | |
| def zero_key_bias_grad(x: torch.Tensor) -> torch.Tensor: | |
| return ZeroKeyBiasGrad.apply(x) | |
| class LayerNorm(nn.Module): | |
| n_state: int = attr.ib() | |
| eps: float = attr.ib(default=1e-6) | |
| device: torch.device = attr.ib(default=torch.device("cuda")) | |
| def __attrs_post_init__(self) -> None: | |
| super().__init__() | |
| self.g = nn.Parameter(torch.ones((self.n_state,), dtype=torch.float32, device=self.device)) | |
| self.b = nn.Parameter(torch.zeros((self.n_state,), dtype=torch.float32, device=self.device)) | |
| self.g.weight_decay_level = "disable" # type: ignore | |
| self.b.weight_decay_level = "disable" # type: ignore | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return F.layer_norm( | |
| x.type(torch.float32), torch.Size((self.n_state,)), self.g, self.b, self.eps | |
| ) | |
| class Affine(nn.Module): | |
| n_in: int = attr.ib() | |
| n_out: int = attr.ib() | |
| use_bias: bool = attr.ib(default=True) | |
| use_admnet_init: bool = attr.ib(default=False) | |
| std: Optional[float] = attr.ib(default=None) | |
| extra_init_scale: Optional[float] = attr.ib(default=None) | |
| bias_filter_fn: FilterFn = attr.ib(default=lambda x: x) | |
| device: torch.device = attr.ib(default=torch.device("cuda")) | |
| def __attrs_post_init__(self) -> None: | |
| super().__init__() | |
| if not self.use_admnet_init: | |
| self.std = self.std if self.std is not None else math.sqrt(2 / (self.n_in + self.n_out)) | |
| self.std = ( | |
| self.std if self.extra_init_scale is None else self.std * self.extra_init_scale | |
| ) | |
| w = torch.empty((self.n_out, self.n_in), dtype=torch.float32, device=self.device) | |
| self.w = nn.Parameter(w) | |
| if self.use_bias: | |
| self.b = nn.Parameter( | |
| torch.zeros((self.n_out,), dtype=torch.float32, device=self.device) | |
| ) | |
| self.b.weight_decay_level = "disable" # type: ignore | |
| else: | |
| if self.extra_init_scale is not None: | |
| raise ValueError("extra_init_scale incompatible with admnet init") | |
| w = torch.empty((self.n_out, self.n_in), dtype=torch.float32, device=self.device) | |
| if self.use_bias: | |
| b = torch.empty((self.n_out,), dtype=torch.float32, device=self.device) | |
| self.w = nn.Parameter(w) | |
| if self.use_bias: | |
| self.b = nn.Parameter(b) | |
| self.b.weight_decay_level = "disable" # type: ignore | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| w = self.w if self.w.dtype == x.dtype else self.w.to(x.dtype) | |
| b = ( | |
| self.bias_filter_fn(self.b if self.b.dtype == x.dtype else self.b.to(x.dtype)) | |
| if self.use_bias | |
| else None | |
| ) | |
| return F.linear(x, w, b) | |