Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from itertools import chain | |
| from torchvision import models | |
| from typing import Sequence | |
| from collections import OrderedDict | |
| def get_network(net_type: str = 'vgg'): | |
| if net_type == 'alex': | |
| return AlexNet() | |
| elif net_type == 'squeeze': | |
| return SqueezeNet() | |
| elif net_type == 'vgg': | |
| return VGG16() | |
| else: | |
| raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') | |
| def normalize_activation(x, eps=1e-10): | |
| norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) | |
| return x / (norm_factor + eps) | |
| class BaseNet(nn.Module): | |
| def __init__(self): | |
| super(BaseNet, self).__init__() | |
| # register buffer | |
| self.register_buffer( | |
| 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) | |
| self.register_buffer( | |
| 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) | |
| def set_requires_grad(self, state: bool): | |
| for param in chain(self.parameters(), self.buffers()): | |
| param.requires_grad = state | |
| def z_score(self, x: torch.Tensor): | |
| return (x - self.mean) / self.std | |
| def forward(self, x: torch.Tensor): | |
| x = self.z_score(x) | |
| output = [] | |
| for i, (_, layer) in enumerate(self.layers._modules.items(), 1): | |
| x = layer(x) | |
| if i in self.target_layers: | |
| output.append(normalize_activation(x)) | |
| if len(output) == len(self.target_layers): | |
| break | |
| return output | |
| class SqueezeNet(BaseNet): | |
| def __init__(self): | |
| super(SqueezeNet, self).__init__() | |
| self.layers = models.squeezenet1_1(True).features | |
| self.target_layers = [2, 5, 8, 10, 11, 12, 13] | |
| self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] | |
| self.set_requires_grad(False) | |
| class AlexNet(BaseNet): | |
| def __init__(self): | |
| super(AlexNet, self).__init__() | |
| self.layers = models.alexnet(True).features | |
| self.target_layers = [2, 5, 8, 10, 12] | |
| self.n_channels_list = [64, 192, 384, 256, 256] | |
| self.set_requires_grad(False) | |
| class VGG16(BaseNet): | |
| def __init__(self): | |
| super(VGG16, self).__init__() | |
| self.layers = models.vgg16(True).features | |
| self.target_layers = [4, 9, 16, 23, 30] | |
| self.n_channels_list = [64, 128, 256, 512, 512] | |
| self.set_requires_grad(False) | |
| class LinLayers(nn.ModuleList): | |
| def __init__(self, n_channels_list: Sequence[int]): | |
| super(LinLayers, self).__init__([ | |
| nn.Sequential( | |
| nn.Identity(), | |
| nn.Conv2d(nc, 1, 1, 1, 0, bias=False) | |
| ) for nc in n_channels_list | |
| ]) | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def get_state_dict(net_type: str = 'alex', version: str = '0.1'): | |
| # build url | |
| url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ | |
| + f'master/lpips/weights/v{version}/{net_type}.pth' | |
| # download | |
| old_state_dict = torch.hub.load_state_dict_from_url( | |
| url, progress=True, | |
| map_location=None if torch.cuda.is_available() else torch.device('cpu') | |
| ) | |
| # rename keys | |
| new_state_dict = OrderedDict() | |
| for key, val in old_state_dict.items(): | |
| new_key = key | |
| new_key = new_key.replace('lin', '') | |
| new_key = new_key.replace('model.', '') | |
| new_state_dict[new_key] = val | |
| return new_state_dict | |
| class LPIPS(nn.Module): | |
| r"""Creates a criterion that measures | |
| Learned Perceptual Image Patch Similarity (LPIPS). | |
| Arguments: | |
| net_type (str): the network type to compare the features: | |
| 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. | |
| version (str): the version of LPIPS. Default: 0.1. | |
| """ | |
| def __init__(self, net_type: str = 'vgg', version: str = '0.1'): | |
| assert version in ['0.1'], 'v0.1 is only supported now' | |
| super(LPIPS, self).__init__() | |
| # pretrained network | |
| self.net = get_network(net_type).to("cuda") | |
| # linear layers | |
| self.lin = LinLayers(self.net.n_channels_list).to("cuda") | |
| self.lin.load_state_dict(get_state_dict(net_type, version)) | |
| def forward(self, x: torch.Tensor, y: torch.Tensor): | |
| feat_x, feat_y = self.net(x), self.net(y) | |
| diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] | |
| res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] | |
| return torch.sum(torch.cat(res, 0)) / x.shape[0] | |