| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn.modules.batchnorm import BatchNorm2d | |
| from torch.nn.utils import spectral_norm | |
| class SpectralConv2d(nn.Module): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__() | |
| self._conv = spectral_norm( | |
| nn.Conv2d(*args, **kwargs) | |
| ) | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| return self._conv(input) | |
| class SpectralConvTranspose2d(nn.Module): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__() | |
| self._conv = spectral_norm( | |
| nn.ConvTranspose2d(*args, **kwargs) | |
| ) | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| return self._conv(input) | |
| class Noise(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self._weight = nn.Parameter( | |
| torch.zeros(1), | |
| requires_grad=True, | |
| ) | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| batch_size, _, height, width = input.shape | |
| noise = torch.randn(batch_size, 1, height, width, device=input.device) | |
| return self._weight * noise + input | |
| class InitLayer(nn.Module): | |
| def __init__(self, in_channels: int, | |
| out_channels: int): | |
| super().__init__() | |
| self._layers = nn.Sequential( | |
| SpectralConvTranspose2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels * 2, | |
| kernel_size=4, | |
| stride=1, | |
| padding=0, | |
| bias=False, | |
| ), | |
| nn.BatchNorm2d(num_features=out_channels * 2), | |
| nn.GLU(dim=1), | |
| ) | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| return self._layers(input) | |
| class SLEBlock(nn.Module): | |
| def __init__(self, in_channels: int, | |
| out_channels: int): | |
| super().__init__() | |
| self._layers = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(output_size=4), | |
| SpectralConv2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=4, | |
| stride=1, | |
| padding=0, | |
| bias=False, | |
| ), | |
| nn.SiLU(), | |
| SpectralConv2d( | |
| in_channels=out_channels, | |
| out_channels=out_channels, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| bias=False, | |
| ), | |
| nn.Sigmoid(), | |
| ) | |
| def forward(self, low_dim: torch.Tensor, | |
| high_dim: torch.Tensor) -> torch.Tensor: | |
| return high_dim * self._layers(low_dim) | |
| class UpsampleBlockT1(nn.Module): | |
| def __init__(self, in_channels: int, | |
| out_channels: int): | |
| super().__init__() | |
| self._layers = nn.Sequential( | |
| nn.Upsample(scale_factor=2, mode='nearest'), | |
| SpectralConv2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels * 2, | |
| kernel_size=3, | |
| stride=1, | |
| padding='same', | |
| bias=False, | |
| ), | |
| nn.BatchNorm2d(num_features=out_channels * 2), | |
| nn.GLU(dim=1), | |
| ) | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| return self._layers(input) | |
| class UpsampleBlockT2(nn.Module): | |
| def __init__(self, in_channels: int, | |
| out_channels: int): | |
| super().__init__() | |
| self._layers = nn.Sequential( | |
| nn.Upsample(scale_factor=2, mode='nearest'), | |
| SpectralConv2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels * 2, | |
| kernel_size=3, | |
| stride=1, | |
| padding='same', | |
| bias=False, | |
| ), | |
| Noise(), | |
| BatchNorm2d(num_features=out_channels * 2), | |
| nn.GLU(dim=1), | |
| SpectralConv2d( | |
| in_channels=out_channels, | |
| out_channels=out_channels * 2, | |
| kernel_size=3, | |
| stride=1, | |
| padding='same', | |
| bias=False, | |
| ), | |
| Noise(), | |
| nn.BatchNorm2d(num_features=out_channels * 2), | |
| nn.GLU(dim=1), | |
| ) | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| return self._layers(input) | |
| class DownsampleBlockT1(nn.Module): | |
| def __init__(self, in_channels: int, | |
| out_channels: int): | |
| super().__init__() | |
| self._layers = nn.Sequential( | |
| SpectralConv2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=4, | |
| stride=2, | |
| padding=1, | |
| bias=False, | |
| ), | |
| nn.BatchNorm2d(num_features=out_channels), | |
| nn.LeakyReLU(negative_slope=0.2), | |
| ) | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| return self._layers(input) | |
| class DownsampleBlockT2(nn.Module): | |
| def __init__(self, in_channels: int, | |
| out_channels: int): | |
| super().__init__() | |
| self._layers_1 = nn.Sequential( | |
| SpectralConv2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=4, | |
| stride=2, | |
| padding=1, | |
| bias=False, | |
| ), | |
| nn.BatchNorm2d(num_features=out_channels), | |
| nn.LeakyReLU(negative_slope=0.2), | |
| SpectralConv2d( | |
| in_channels=out_channels, | |
| out_channels=out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding='same', | |
| bias=False, | |
| ), | |
| nn.BatchNorm2d(num_features=out_channels), | |
| nn.LeakyReLU(negative_slope=0.2), | |
| ) | |
| self._layers_2 = nn.Sequential( | |
| nn.AvgPool2d( | |
| kernel_size=2, | |
| stride=2, | |
| ), | |
| SpectralConv2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| bias=False, | |
| ), | |
| nn.BatchNorm2d(num_features=out_channels), | |
| nn.LeakyReLU(negative_slope=0.2), | |
| ) | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| t1 = self._layers_1(input) | |
| t2 = self._layers_2(input) | |
| return (t1 + t2) / 2 | |
| class Decoder(nn.Module): | |
| def __init__(self, in_channels: int, | |
| out_channels: int): | |
| super().__init__() | |
| self._channels = { | |
| 16: 128, | |
| 32: 64, | |
| 64: 64, | |
| 128: 32, | |
| 256: 16, | |
| 512: 8, | |
| 1024: 4, | |
| } | |
| self._layers = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(output_size=8), | |
| UpsampleBlockT1(in_channels=in_channels, out_channels=self._channels[16]), | |
| UpsampleBlockT1(in_channels=self._channels[16], out_channels=self._channels[32]), | |
| UpsampleBlockT1(in_channels=self._channels[32], out_channels=self._channels[64]), | |
| UpsampleBlockT1(in_channels=self._channels[64], out_channels=self._channels[128]), | |
| SpectralConv2d( | |
| in_channels=self._channels[128], | |
| out_channels=out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding='same', | |
| bias=False, | |
| ), | |
| nn.Tanh(), | |
| ) | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| return self._layers(input) | |