File size: 4,756 Bytes
a96891a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
from typing import Tuple
from PIL import Image
import numpy as np
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
# Import MedIMeta
from medimeta import MedIMeta
from data.hoi_dataset import BongardDataset
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
from data.fewshot_datasets import *
import data.augmix_ops as augmentations
import medmnist
from medmnist import INFO, Evaluator
ID_to_DIRNAME={
'I': 'ImageNet',
'A': 'imagenet-a',
'K': 'ImageNet-Sketch',
'R': 'imagenet-r',
'V': 'imagenetv2-matched-frequency-format-val',
'flower102': 'Flower102',
'dtd': 'DTD',
'pets': 'OxfordPets',
'cars': 'StanfordCars',
'ucf101': 'UCF101',
'caltech101': 'Caltech101',
'food101': 'Food101',
'sun397': 'SUN397',
'aircraft': 'fgvc_aircraft',
'eurosat': 'eurosat',
'idrid':'IDRID',
'isic2018':'ISIC2018',
'pneumonia_guangzhou':'PneumoniaGuangzhou',
'shenzhen_cxr':'ShenzhenCXR',
"montgomery_cxr":'MontgomeryCXR',
'covid':'Covid'
}
def build_dataset(set_id, transform, data_root, mode='test', n_shot=None, split="all", bongard_anno=False):
testdir = os.path.join(os.path.join(data_root, set_id),ID_to_DIRNAME[set_id])
# testdir = os.path.join(os.path.join(data_root, ID_to_DIRNAME[set_id]), 'test')
testset = datasets.ImageFolder(testdir, transform=transform)
# if set_id == 'I':
# # ImageNet validation set
# testdir = os.path.join(os.path.join(data_root, ID_to_DIRNAME[set_id]), 'val')
# testset = datasets.ImageFolder(testdir, transform=transform)
# elif set_id in ['A', 'K', 'R', 'V']:
# testdir = os.path.join(data_root, ID_to_DIRNAME[set_id])
# testset = datasets.ImageFolder(testdir, transform=transform)
# elif set_id in fewshot_datasets:
# if mode == 'train' and n_shot:
# testset = build_fewshot_dataset(set_id, os.path.join(data_root, ID_to_DIRNAME[set_id.lower()]), transform, mode=mode, n_shot=n_shot)
# else:
# testset = build_fewshot_dataset(set_id, os.path.join(data_root, ID_to_DIRNAME[set_id.lower()]), transform, mode=mode)
# elif set_id == 'bongard':
# assert isinstance(transform, Tuple)
# base_transform, query_transform = transform
# testset = BongardDataset(data_root, split, mode, base_transform, query_transform, bongard_anno)
# else:
# raise NotImplementedError
return testset
def build_medimeta_dataset(data_root, task='bus', disease='Disease', transform=None):
dataset = MedIMeta(data_root, task, disease, transform=transform)
return dataset
def build_medmnist_dataset(data_root, set_id, transform, split='test', size=224, download=False):
info = INFO[set_id]
DataClass = getattr(medmnist, info['python_class'])
dataset = DataClass(split=split, transform=transform, size=size, download=download, root=data_root)
return dataset
medmnist_datasets = [
'tissuemnist', 'pathmnist', 'chestmnist', 'dermamnist', 'octmnist',
'pneumoniamnist', 'retinamnist', 'breastmnist', 'bloodmnist',
'organamnist', 'organcmnist', 'organsmnist'
]
# AugMix Transforms
def get_preaugment():
return transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
])
def augmix(image, preprocess, aug_list, severity=1):
preaugment = get_preaugment()
x_orig = preaugment(image)
x_processed = preprocess(x_orig)
if len(aug_list) == 0:
return x_processed
w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0]))
m = np.float32(np.random.beta(1.0, 1.0))
mix = torch.zeros_like(x_processed)
for i in range(3):
x_aug = x_orig.copy()
for _ in range(np.random.randint(1, 4)):
x_aug = np.random.choice(aug_list)(x_aug, severity)
mix += w[i] * preprocess(x_aug)
mix = m * x_processed + (1 - m) * mix
return mix
class AugMixAugmenter(object):
def __init__(self, base_transform, preprocess, n_views=2, augmix=False,
severity=1):
self.base_transform = base_transform
self.preprocess = preprocess
self.n_views = n_views
if augmix:
self.aug_list = augmentations.augmentations
else:
self.aug_list = []
self.severity = severity
def __call__(self, x):
# breakpoint()
image = self.preprocess(self.base_transform(x))
views = [augmix(x, self.preprocess, self.aug_list, self.severity) for _ in range(self.n_views)]
return [image] + views
|