Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (C) 2022-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| import torch | |
| import torchvision.transforms | |
| import torchvision.transforms.functional as F | |
| # "Pair": apply a transform on a pair | |
| # "Both": apply the exact same transform to both images | |
| class ComposePair(torchvision.transforms.Compose): | |
| def __call__(self, img1, img2): | |
| for t in self.transforms: | |
| img1, img2 = t(img1, img2) | |
| return img1, img2 | |
| class NormalizeBoth(torchvision.transforms.Normalize): | |
| def forward(self, img1, img2): | |
| img1 = super().forward(img1) | |
| img2 = super().forward(img2) | |
| return img1, img2 | |
| class ToTensorBoth(torchvision.transforms.ToTensor): | |
| def __call__(self, img1, img2): | |
| img1 = super().__call__(img1) | |
| img2 = super().__call__(img2) | |
| return img1, img2 | |
| class RandomCropPair(torchvision.transforms.RandomCrop): | |
| # the crop will be intentionally different for the two images with this class | |
| def forward(self, img1, img2): | |
| img1 = super().forward(img1) | |
| img2 = super().forward(img2) | |
| return img1, img2 | |
| class ColorJitterPair(torchvision.transforms.ColorJitter): | |
| # can be symmetric (same for both images) or assymetric (different jitter params for each image) depending on assymetric_prob | |
| def __init__(self, assymetric_prob, **kwargs): | |
| super().__init__(**kwargs) | |
| self.assymetric_prob = assymetric_prob | |
| def jitter_one( | |
| self, | |
| img, | |
| fn_idx, | |
| brightness_factor, | |
| contrast_factor, | |
| saturation_factor, | |
| hue_factor, | |
| ): | |
| for fn_id in fn_idx: | |
| if fn_id == 0 and brightness_factor is not None: | |
| img = F.adjust_brightness(img, brightness_factor) | |
| elif fn_id == 1 and contrast_factor is not None: | |
| img = F.adjust_contrast(img, contrast_factor) | |
| elif fn_id == 2 and saturation_factor is not None: | |
| img = F.adjust_saturation(img, saturation_factor) | |
| elif fn_id == 3 and hue_factor is not None: | |
| img = F.adjust_hue(img, hue_factor) | |
| return img | |
| def forward(self, img1, img2): | |
| fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = ( | |
| self.get_params(self.brightness, self.contrast, self.saturation, self.hue) | |
| ) | |
| img1 = self.jitter_one( | |
| img1, | |
| fn_idx, | |
| brightness_factor, | |
| contrast_factor, | |
| saturation_factor, | |
| hue_factor, | |
| ) | |
| if torch.rand(1) < self.assymetric_prob: # assymetric: | |
| ( | |
| fn_idx, | |
| brightness_factor, | |
| contrast_factor, | |
| saturation_factor, | |
| hue_factor, | |
| ) = self.get_params( | |
| self.brightness, self.contrast, self.saturation, self.hue | |
| ) | |
| img2 = self.jitter_one( | |
| img2, | |
| fn_idx, | |
| brightness_factor, | |
| contrast_factor, | |
| saturation_factor, | |
| hue_factor, | |
| ) | |
| return img1, img2 | |
| def get_pair_transforms(transform_str, totensor=True, normalize=True): | |
| # transform_str is eg crop224+color | |
| trfs = [] | |
| for s in transform_str.split("+"): | |
| if s.startswith("crop"): | |
| size = int(s[len("crop") :]) | |
| trfs.append(RandomCropPair(size)) | |
| elif s == "acolor": | |
| trfs.append( | |
| ColorJitterPair( | |
| assymetric_prob=1.0, | |
| brightness=(0.6, 1.4), | |
| contrast=(0.6, 1.4), | |
| saturation=(0.6, 1.4), | |
| hue=0.0, | |
| ) | |
| ) | |
| elif s == "": # if transform_str was "" | |
| pass | |
| else: | |
| raise NotImplementedError("Unknown augmentation: " + s) | |
| if totensor: | |
| trfs.append(ToTensorBoth()) | |
| if normalize: | |
| trfs.append( | |
| NormalizeBoth(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ) | |
| if len(trfs) == 0: | |
| return None | |
| elif len(trfs) == 1: | |
| return trfs | |
| else: | |
| return ComposePair(trfs) | |