Spaces:
Runtime error
Runtime error
| import os | |
| import torch.nn.functional as F | |
| import torch | |
| from utils.filters_tensor import GaussianSmoothing, bgr2gray | |
| from utils import pytorch_ssim | |
| from torch import nn | |
| from .hourglass import HourGlass | |
| from torchvision.models.vgg import vgg19 | |
| def l2_loss(y_input, y_target): | |
| return F.mse_loss(y_input, y_target) | |
| def l1_loss(y_input, y_target): | |
| return F.l1_loss(y_input, y_target) | |
| def gaussianL2(yInput, yTarget): | |
| # data range [-1,1] | |
| smoother = GaussianSmoothing(channels=1, kernel_size=11, sigma=2.0) | |
| gaussianInput = smoother(yInput) | |
| gaussianTarget = smoother(bgr2gray(yTarget)) | |
| return F.mse_loss(gaussianInput, gaussianTarget) | |
| def binL1(yInput): | |
| # data range is [-1,1] | |
| return (yInput.abs() - 1.0).abs().mean() | |
| def ssimLoss(yInput, yTarget): | |
| # data range is [-1,1] | |
| ssim = pytorch_ssim.ssim(yInput / 2. + 0.5, bgr2gray(yTarget / 2. + 0.5), window_size=11) | |
| return 1. - ssim | |
| class InverseHalf(nn.Module): | |
| def __init__(self): | |
| super(InverseHalf, self).__init__() | |
| self.net = HourGlass(inChannel=1, outChannel=1) | |
| def forward(self, x): | |
| grayscale = self.net(x) | |
| return grayscale | |
| class FeatureLoss: | |
| def __init__(self, pretrainedPath, requireGrad=False, multiGpu=True): | |
| self.featureExactor = InverseHalf() | |
| if multiGpu: | |
| self.featureExactor = torch.nn.DataParallel(self.featureExactor).cuda() | |
| print("-loading feature extractor: {} ...".format(pretrainedPath)) | |
| checkpoint = torch.load(pretrainedPath) | |
| self.featureExactor.load_state_dict(checkpoint['state_dict']) | |
| print("-feature network loaded") | |
| if not requireGrad: | |
| for param in self.featureExactor.parameters(): | |
| param.requires_grad = False | |
| def __call__(self, yInput, yTarget): | |
| inFeature = self.featureExactor(yInput) | |
| return l2_loss(inFeature, yTarget) | |
| class Vgg19Loss: | |
| def __init__(self, multiGpu=True): | |
| os.environ['TORCH_HOME']='~/bigdata/0ProgramS/checkpoints' | |
| # data in BGR format, [0,1] range | |
| self.mean = [0.485, 0.456, 0.406] | |
| self.mean.reverse() | |
| self.std = [0.229, 0.224, 0.225] | |
| self.std.reverse() | |
| vgg = vgg19(pretrained=True) | |
| # maxpoll after conv4_4 | |
| self.featureExactor = nn.Sequential(*list(vgg.features)[:28]).eval() | |
| for param in self.featureExactor.parameters(): | |
| param.requires_grad = False | |
| if multiGpu: | |
| self.featureExactor = torch.nn.DataParallel(self.featureExactor).cuda() | |
| print('[*] Vgg19Loss init!') | |
| def normalize(self, tensor): | |
| tensor = tensor.clone() | |
| mean = torch.as_tensor(self.mean, dtype=torch.float32, device=tensor.device) | |
| std = torch.as_tensor(self.std, dtype=torch.float32, device=tensor.device) | |
| tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) | |
| return tensor | |
| def __call__(self, yInput, yTarget): | |
| inFeature = self.featureExactor(self.normalize(yInput).flip(1)) | |
| targetFeature = self.featureExactor(self.normalize(yTarget).flip(1)) | |
| return l2_loss(inFeature, targetFeature) | |