Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Union | |
| import torch | |
| from torch import Tensor | |
| from torch import nn | |
| class LayerScale(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| init_values: Union[float, Tensor] = 1e-5, | |
| inplace: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| self.inplace = inplace | |
| self.gamma = nn.Parameter(init_values * torch.ones(dim)) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return x.mul_(self.gamma) if self.inplace else x * self.gamma | |