File size: 4,756 Bytes
a96891a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
from typing import Tuple
from PIL import Image
import numpy as np

import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# Import MedIMeta
from medimeta import MedIMeta

from data.hoi_dataset import BongardDataset
try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC

from data.fewshot_datasets import *
import data.augmix_ops as augmentations

import medmnist
from medmnist import INFO, Evaluator

ID_to_DIRNAME={
    'I': 'ImageNet',
    'A': 'imagenet-a',
    'K': 'ImageNet-Sketch',
    'R': 'imagenet-r',
    'V': 'imagenetv2-matched-frequency-format-val',
    'flower102': 'Flower102',
    'dtd': 'DTD',
    'pets': 'OxfordPets',
    'cars': 'StanfordCars',
    'ucf101': 'UCF101',
    'caltech101': 'Caltech101',
    'food101': 'Food101',
    'sun397': 'SUN397',
    'aircraft': 'fgvc_aircraft',
    'eurosat': 'eurosat',
    'idrid':'IDRID',
    'isic2018':'ISIC2018',
    'pneumonia_guangzhou':'PneumoniaGuangzhou',
    'shenzhen_cxr':'ShenzhenCXR',
    "montgomery_cxr":'MontgomeryCXR',
    'covid':'Covid'
}

def build_dataset(set_id, transform, data_root, mode='test', n_shot=None, split="all", bongard_anno=False):

    testdir = os.path.join(os.path.join(data_root, set_id),ID_to_DIRNAME[set_id])
    # testdir = os.path.join(os.path.join(data_root, ID_to_DIRNAME[set_id]), 'test')
    testset = datasets.ImageFolder(testdir, transform=transform)

    # if set_id == 'I':
    #     # ImageNet validation set
    #     testdir = os.path.join(os.path.join(data_root, ID_to_DIRNAME[set_id]), 'val')
    #     testset = datasets.ImageFolder(testdir, transform=transform)
    # elif set_id in ['A', 'K', 'R', 'V']:
    #     testdir = os.path.join(data_root, ID_to_DIRNAME[set_id])
    #     testset = datasets.ImageFolder(testdir, transform=transform)
    # elif set_id in fewshot_datasets:
    #     if mode == 'train' and n_shot:
    #         testset = build_fewshot_dataset(set_id, os.path.join(data_root, ID_to_DIRNAME[set_id.lower()]), transform, mode=mode, n_shot=n_shot)
    #     else:
    #         testset = build_fewshot_dataset(set_id, os.path.join(data_root, ID_to_DIRNAME[set_id.lower()]), transform, mode=mode)
    # elif set_id == 'bongard':
    #     assert isinstance(transform, Tuple)
    #     base_transform, query_transform = transform
    #     testset = BongardDataset(data_root, split, mode, base_transform, query_transform, bongard_anno)
    # else:
    #     raise NotImplementedError
        
    return testset

def build_medimeta_dataset(data_root, task='bus', disease='Disease', transform=None):
    dataset = MedIMeta(data_root, task, disease, transform=transform)
    return dataset

def build_medmnist_dataset(data_root, set_id, transform, split='test', size=224, download=False):
    info = INFO[set_id]
    DataClass = getattr(medmnist, info['python_class'])
    dataset = DataClass(split=split, transform=transform, size=size, download=download, root=data_root)
    return dataset

medmnist_datasets = [
    'tissuemnist', 'pathmnist', 'chestmnist', 'dermamnist', 'octmnist', 
    'pneumoniamnist', 'retinamnist', 'breastmnist', 'bloodmnist', 
    'organamnist', 'organcmnist', 'organsmnist'
]

# AugMix Transforms
def get_preaugment():
    return transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
        ])

def augmix(image, preprocess, aug_list, severity=1):
    preaugment = get_preaugment()
    x_orig = preaugment(image)
    x_processed = preprocess(x_orig)
    if len(aug_list) == 0:
        return x_processed
    w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0]))
    m = np.float32(np.random.beta(1.0, 1.0))

    mix = torch.zeros_like(x_processed)
    for i in range(3):
        x_aug = x_orig.copy()
        for _ in range(np.random.randint(1, 4)):
            x_aug = np.random.choice(aug_list)(x_aug, severity)
        mix += w[i] * preprocess(x_aug)
    mix = m * x_processed + (1 - m) * mix
    return mix


class AugMixAugmenter(object):
    def __init__(self, base_transform, preprocess, n_views=2, augmix=False, 
                    severity=1):
        self.base_transform = base_transform
        self.preprocess = preprocess
        self.n_views = n_views
    
        if augmix:
            self.aug_list = augmentations.augmentations
        else:
            self.aug_list = []
        self.severity = severity
        
    def __call__(self, x):
        # breakpoint()
        image = self.preprocess(self.base_transform(x))
        views = [augmix(x, self.preprocess, self.aug_list, self.severity) for _ in range(self.n_views)]
        return [image] + views