|
|
from typing import Iterator, List, Optional, Union
|
|
|
from collections import Counter
|
|
|
import logging
|
|
|
from operator import itemgetter
|
|
|
import random
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
from torch.utils.data import DistributedSampler
|
|
|
from torch.utils.data.sampler import Sampler
|
|
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
from torch.utils.data import Dataset, Sampler
|
|
|
|
|
|
|
|
|
class DatasetFromSampler(Dataset):
|
|
|
"""Dataset to create indexes from `Sampler`.
|
|
|
Args:
|
|
|
sampler: PyTorch sampler
|
|
|
"""
|
|
|
|
|
|
def __init__(self, sampler: Sampler):
|
|
|
"""Initialisation for DatasetFromSampler."""
|
|
|
self.sampler = sampler
|
|
|
self.sampler_list = None
|
|
|
|
|
|
def __getitem__(self, index: int):
|
|
|
"""Gets element of the dataset.
|
|
|
Args:
|
|
|
index: index of the element in the dataset
|
|
|
Returns:
|
|
|
Single element by index
|
|
|
"""
|
|
|
if self.sampler_list is None:
|
|
|
self.sampler_list = list(self.sampler)
|
|
|
return self.sampler_list[index]
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
"""
|
|
|
Returns:
|
|
|
int: length of the dataset
|
|
|
"""
|
|
|
return len(self.sampler)
|
|
|
|
|
|
|
|
|
class BalanceClassSampler(Sampler):
|
|
|
"""Allows you to create stratified sample on unbalanced classes.
|
|
|
|
|
|
Args:
|
|
|
labels: list of class label for each elem in the dataset
|
|
|
mode: Strategy to balance classes.
|
|
|
Must be one of [downsampling, upsampling]
|
|
|
|
|
|
Python API examples:
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
import os
|
|
|
from torch import nn, optim
|
|
|
from torch.utils.data import DataLoader
|
|
|
from catalyst import dl
|
|
|
from catalyst.data import ToTensor, BalanceClassSampler
|
|
|
from catalyst.contrib.datasets import MNIST
|
|
|
|
|
|
train_data = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
|
|
|
train_labels = train_data.targets.cpu().numpy().tolist()
|
|
|
train_sampler = BalanceClassSampler(train_labels, mode=5000)
|
|
|
valid_data = MNIST(os.getcwd(), train=False)
|
|
|
|
|
|
loaders = {
|
|
|
"train": DataLoader(train_data, sampler=train_sampler, batch_size=32),
|
|
|
"valid": DataLoader(valid_data, batch_size=32),
|
|
|
}
|
|
|
|
|
|
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
optimizer = optim.Adam(model.parameters(), lr=0.02)
|
|
|
|
|
|
runner = dl.SupervisedRunner()
|
|
|
# model training
|
|
|
runner.train(
|
|
|
model=model,
|
|
|
criterion=criterion,
|
|
|
optimizer=optimizer,
|
|
|
loaders=loaders,
|
|
|
num_epochs=1,
|
|
|
logdir="./logs",
|
|
|
valid_loader="valid",
|
|
|
valid_metric="loss",
|
|
|
minimize_valid_metric=True,
|
|
|
verbose=True,
|
|
|
)
|
|
|
"""
|
|
|
|
|
|
def __init__(self, labels: List[int], mode: Union[str, int] = "downsampling"):
|
|
|
"""Sampler initialisation."""
|
|
|
super().__init__(labels)
|
|
|
|
|
|
labels = np.array(labels)
|
|
|
samples_per_class = {label: (labels == label).sum() for label in set(labels)}
|
|
|
|
|
|
self.lbl2idx = {
|
|
|
label: np.arange(len(labels))[labels == label].tolist()
|
|
|
for label in set(labels)
|
|
|
}
|
|
|
|
|
|
if isinstance(mode, str):
|
|
|
assert mode in ["downsampling", "upsampling"]
|
|
|
|
|
|
if isinstance(mode, int) or mode == "upsampling":
|
|
|
samples_per_class = (
|
|
|
mode if isinstance(mode, int) else max(samples_per_class.values())
|
|
|
)
|
|
|
else:
|
|
|
samples_per_class = min(samples_per_class.values())
|
|
|
|
|
|
self.labels = labels
|
|
|
self.samples_per_class = samples_per_class
|
|
|
self.length = self.samples_per_class * len(set(labels))
|
|
|
|
|
|
def __iter__(self) -> Iterator[int]:
|
|
|
"""
|
|
|
Returns:
|
|
|
iterator of indices of stratified sample
|
|
|
"""
|
|
|
indices = []
|
|
|
for key in sorted(self.lbl2idx):
|
|
|
replace_flag = self.samples_per_class > len(self.lbl2idx[key])
|
|
|
indices += np.random.choice(
|
|
|
self.lbl2idx[key], self.samples_per_class, replace=replace_flag
|
|
|
).tolist()
|
|
|
assert len(indices) == self.length
|
|
|
np.random.shuffle(indices)
|
|
|
|
|
|
return iter(indices)
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
"""
|
|
|
Returns:
|
|
|
length of result sample
|
|
|
"""
|
|
|
return self.length
|
|
|
|
|
|
|
|
|
class BatchBalanceClassSampler(Sampler):
|
|
|
"""
|
|
|
This kind of sampler can be used for both metric learning and classification task.
|
|
|
|
|
|
BatchSampler with the given strategy for the C unique classes dataset:
|
|
|
- Selection `num_classes` of C classes for each batch
|
|
|
- Selection `num_samples` instances for each class in the batch
|
|
|
The epoch ends after `num_batches`.
|
|
|
So, the batch sise is `num_classes` * `num_samples`.
|
|
|
|
|
|
One of the purposes of this sampler is to be used for
|
|
|
forming triplets and pos/neg pairs inside the batch.
|
|
|
To guarante existance of these pairs in the batch,
|
|
|
`num_classes` and `num_samples` should be > 1. (1)
|
|
|
|
|
|
This type of sampling can be found in the classical paper of Person Re-Id,
|
|
|
where P (`num_classes`) equals 32 and K (`num_samples`) equals 4:
|
|
|
`In Defense of the Triplet Loss for Person Re-Identification`_.
|
|
|
|
|
|
Args:
|
|
|
labels: list of classes labeles for each elem in the dataset
|
|
|
num_classes: number of classes in a batch, should be > 1
|
|
|
num_samples: number of instances of each class in a batch, should be > 1
|
|
|
num_batches: number of batches in epoch
|
|
|
(default = len(labels) // (num_classes * num_samples))
|
|
|
|
|
|
.. _In Defense of the Triplet Loss for Person Re-Identification:
|
|
|
https://arxiv.org/abs/1703.07737
|
|
|
|
|
|
Python API examples:
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
import os
|
|
|
from torch import nn, optim
|
|
|
from torch.utils.data import DataLoader
|
|
|
from catalyst import dl
|
|
|
from catalyst.data import ToTensor, BatchBalanceClassSampler
|
|
|
from catalyst.contrib.datasets import MNIST
|
|
|
|
|
|
train_data = MNIST(os.getcwd(), train=True, download=True)
|
|
|
train_labels = train_data.targets.cpu().numpy().tolist()
|
|
|
train_sampler = BatchBalanceClassSampler(
|
|
|
train_labels, num_classes=10, num_samples=4)
|
|
|
valid_data = MNIST(os.getcwd(), train=False)
|
|
|
|
|
|
loaders = {
|
|
|
"train": DataLoader(train_data, batch_sampler=train_sampler),
|
|
|
"valid": DataLoader(valid_data, batch_size=32),
|
|
|
}
|
|
|
|
|
|
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
optimizer = optim.Adam(model.parameters(), lr=0.02)
|
|
|
|
|
|
runner = dl.SupervisedRunner()
|
|
|
# model training
|
|
|
runner.train(
|
|
|
model=model,
|
|
|
criterion=criterion,
|
|
|
optimizer=optimizer,
|
|
|
loaders=loaders,
|
|
|
num_epochs=1,
|
|
|
logdir="./logs",
|
|
|
valid_loader="valid",
|
|
|
valid_metric="loss",
|
|
|
minimize_valid_metric=True,
|
|
|
verbose=True,
|
|
|
)
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
labels: Union[List[int], np.ndarray],
|
|
|
num_classes: int,
|
|
|
num_samples: int,
|
|
|
num_batches: int = None,
|
|
|
):
|
|
|
"""Sampler initialisation."""
|
|
|
super().__init__(labels)
|
|
|
classes = set(labels)
|
|
|
|
|
|
assert isinstance(num_classes, int) and isinstance(num_samples, int)
|
|
|
assert (1 < num_classes <= len(classes)) and (1 < num_samples)
|
|
|
assert all(
|
|
|
n > 1 for n in Counter(labels).values()
|
|
|
), "Each class shoud contain at least 2 instances to fit (1)"
|
|
|
|
|
|
labels = np.array(labels)
|
|
|
self._labels = list(set(labels.tolist()))
|
|
|
self._num_classes = num_classes
|
|
|
self._num_samples = num_samples
|
|
|
self._batch_size = self._num_classes * self._num_samples
|
|
|
self._num_batches = num_batches or len(labels) // self._batch_size
|
|
|
self.lbl2idx = {
|
|
|
label: np.arange(len(labels))[labels == label].tolist()
|
|
|
for label in set(labels)
|
|
|
}
|
|
|
|
|
|
@property
|
|
|
def batch_size(self) -> int:
|
|
|
"""
|
|
|
Returns:
|
|
|
this value should be used in DataLoader as batch size
|
|
|
"""
|
|
|
return self._batch_size
|
|
|
|
|
|
@property
|
|
|
def batches_in_epoch(self) -> int:
|
|
|
"""
|
|
|
Returns:
|
|
|
number of batches in an epoch
|
|
|
"""
|
|
|
return self._num_batches
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
"""
|
|
|
Returns:
|
|
|
number of samples in an epoch
|
|
|
"""
|
|
|
return self._num_batches
|
|
|
|
|
|
def __iter__(self) -> Iterator[int]:
|
|
|
"""
|
|
|
Returns:
|
|
|
indeces for sampling dataset elems during an epoch
|
|
|
"""
|
|
|
indices = []
|
|
|
for _ in range(self._num_batches):
|
|
|
batch_indices = []
|
|
|
classes_for_batch = random.sample(self._labels, self._num_classes)
|
|
|
while self._num_classes != len(set(classes_for_batch)):
|
|
|
classes_for_batch = random.sample(self._labels, self._num_classes)
|
|
|
for cls_id in classes_for_batch:
|
|
|
replace_flag = self._num_samples > len(self.lbl2idx[cls_id])
|
|
|
batch_indices += np.random.choice(
|
|
|
self.lbl2idx[cls_id], self._num_samples, replace=replace_flag
|
|
|
).tolist()
|
|
|
indices.append(batch_indices)
|
|
|
return iter(indices)
|
|
|
|
|
|
|
|
|
class DynamicBalanceClassSampler(Sampler):
|
|
|
"""
|
|
|
This kind of sampler can be used for classification tasks with significant
|
|
|
class imbalance.
|
|
|
|
|
|
The idea of this sampler that we start with the original class distribution
|
|
|
and gradually move to uniform class distribution like with downsampling.
|
|
|
|
|
|
Let's define :math: D_i = #C_i/ #C_min where :math: #C_i is a size of class
|
|
|
i and :math: #C_min is a size of the rarest class, so :math: D_i define
|
|
|
class distribution. Also define :math: g(n_epoch) is a exponential
|
|
|
scheduler. On each epoch current :math: D_i calculated as
|
|
|
:math: current D_i = D_i ^ g(n_epoch),
|
|
|
after this data samples according this distribution.
|
|
|
|
|
|
Notes:
|
|
|
In the end of the training, epochs will contain only
|
|
|
min_size_class * n_classes examples. So, possible it will not
|
|
|
necessary to do validation on each epoch. For this reason use
|
|
|
ControlFlowCallback.
|
|
|
|
|
|
Examples:
|
|
|
|
|
|
>>> import torch
|
|
|
>>> import numpy as np
|
|
|
|
|
|
>>> from catalyst.data import DynamicBalanceClassSampler
|
|
|
>>> from torch.utils import data
|
|
|
|
|
|
>>> features = torch.Tensor(np.random.random((200, 100)))
|
|
|
>>> labels = np.random.randint(0, 4, size=(200,))
|
|
|
>>> sampler = DynamicBalanceClassSampler(labels)
|
|
|
>>> labels = torch.LongTensor(labels)
|
|
|
>>> dataset = data.TensorDataset(features, labels)
|
|
|
>>> loader = data.dataloader.DataLoader(dataset, batch_size=8)
|
|
|
|
|
|
>>> for batch in loader:
|
|
|
>>> b_features, b_labels = batch
|
|
|
|
|
|
Sampler was inspired by https://arxiv.org/abs/1901.06783
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
labels: List[Union[int, str]],
|
|
|
exp_lambda: float = 0.9,
|
|
|
start_epoch: int = 0,
|
|
|
max_d: Optional[int] = None,
|
|
|
mode: Union[str, int] = "downsampling",
|
|
|
ignore_warning: bool = False,
|
|
|
):
|
|
|
"""
|
|
|
Args:
|
|
|
labels: list of labels for each elem in the dataset
|
|
|
exp_lambda: exponent figure for schedule
|
|
|
start_epoch: start epoch number, can be useful for multi-stage
|
|
|
experiments
|
|
|
max_d: if not None, limit on the difference between the most
|
|
|
frequent and the rarest classes, heuristic
|
|
|
mode: number of samples per class in the end of training. Must be
|
|
|
"downsampling" or number. Before change it, make sure that you
|
|
|
understand how does it work
|
|
|
ignore_warning: ignore warning about min class size
|
|
|
"""
|
|
|
assert isinstance(start_epoch, int)
|
|
|
assert 0 < exp_lambda < 1, "exp_lambda must be in (0, 1)"
|
|
|
super().__init__(labels)
|
|
|
self.exp_lambda = exp_lambda
|
|
|
if max_d is None:
|
|
|
max_d = np.inf
|
|
|
self.max_d = max_d
|
|
|
self.epoch = start_epoch
|
|
|
labels = np.array(labels)
|
|
|
samples_per_class = Counter(labels)
|
|
|
self.min_class_size = min(samples_per_class.values())
|
|
|
|
|
|
if self.min_class_size < 100 and not ignore_warning:
|
|
|
LOGGER.warning(
|
|
|
f"the smallest class contains only"
|
|
|
f" {self.min_class_size} examples. At the end of"
|
|
|
f" training, epochs will contain only"
|
|
|
f" {self.min_class_size * len(samples_per_class)}"
|
|
|
f" examples"
|
|
|
)
|
|
|
|
|
|
self.original_d = {
|
|
|
key: value / self.min_class_size for key, value in samples_per_class.items()
|
|
|
}
|
|
|
self.label2idxes = {
|
|
|
label: np.arange(len(labels))[labels == label].tolist()
|
|
|
for label in set(labels)
|
|
|
}
|
|
|
|
|
|
if isinstance(mode, int):
|
|
|
self.min_class_size = mode
|
|
|
else:
|
|
|
assert mode == "downsampling"
|
|
|
|
|
|
self.labels = labels
|
|
|
self._update()
|
|
|
|
|
|
def _update(self) -> None:
|
|
|
"""Update d coefficients."""
|
|
|
current_d = {
|
|
|
key: min(value ** self._exp_scheduler(), self.max_d)
|
|
|
for key, value in self.original_d.items()
|
|
|
}
|
|
|
samples_per_classes = {
|
|
|
key: int(value * self.min_class_size) for key, value in current_d.items()
|
|
|
}
|
|
|
self.samples_per_classes = samples_per_classes
|
|
|
self.length = np.sum(list(samples_per_classes.values()))
|
|
|
self.epoch += 1
|
|
|
|
|
|
def _exp_scheduler(self) -> float:
|
|
|
return self.exp_lambda**self.epoch
|
|
|
|
|
|
def __iter__(self) -> Iterator[int]:
|
|
|
"""
|
|
|
Returns:
|
|
|
iterator of indices of stratified sample
|
|
|
"""
|
|
|
indices = []
|
|
|
for key in sorted(self.label2idxes):
|
|
|
samples_per_class = self.samples_per_classes[key]
|
|
|
replace_flag = samples_per_class > len(self.label2idxes[key])
|
|
|
indices += np.random.choice(
|
|
|
self.label2idxes[key], samples_per_class, replace=replace_flag
|
|
|
).tolist()
|
|
|
assert len(indices) == self.length
|
|
|
np.random.shuffle(indices)
|
|
|
self._update()
|
|
|
return iter(indices)
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
"""
|
|
|
Returns:
|
|
|
length of result sample
|
|
|
"""
|
|
|
return self.length
|
|
|
|
|
|
|
|
|
class MiniEpochSampler(Sampler):
|
|
|
"""
|
|
|
Sampler iterates mini epochs from the dataset used by ``mini_epoch_len``.
|
|
|
|
|
|
Args:
|
|
|
data_len: Size of the dataset
|
|
|
mini_epoch_len: Num samples from the dataset used in one
|
|
|
mini epoch.
|
|
|
drop_last: If ``True``, sampler will drop the last batches
|
|
|
if its size would be less than ``batches_per_epoch``
|
|
|
shuffle: one of ``"always"``, ``"real_epoch"``, or `None``.
|
|
|
The sampler will shuffle indices
|
|
|
> "per_mini_epoch" - every mini epoch (every ``__iter__`` call)
|
|
|
> "per_epoch" -- every real epoch
|
|
|
> None -- don't shuffle
|
|
|
|
|
|
Example:
|
|
|
>>> MiniEpochSampler(len(dataset), mini_epoch_len=100)
|
|
|
>>> MiniEpochSampler(len(dataset), mini_epoch_len=100, drop_last=True)
|
|
|
>>> MiniEpochSampler(len(dataset), mini_epoch_len=100,
|
|
|
>>> shuffle="per_epoch")
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
data_len: int,
|
|
|
mini_epoch_len: int,
|
|
|
drop_last: bool = False,
|
|
|
shuffle: str = None,
|
|
|
):
|
|
|
"""Sampler initialisation."""
|
|
|
super().__init__(None)
|
|
|
|
|
|
self.data_len = int(data_len)
|
|
|
self.mini_epoch_len = int(mini_epoch_len)
|
|
|
|
|
|
self.steps = int(data_len / self.mini_epoch_len)
|
|
|
self.state_i = 0
|
|
|
|
|
|
has_reminder = data_len - self.steps * mini_epoch_len > 0
|
|
|
if self.steps == 0:
|
|
|
self.divider = 1
|
|
|
elif has_reminder and not drop_last:
|
|
|
self.divider = self.steps + 1
|
|
|
else:
|
|
|
self.divider = self.steps
|
|
|
|
|
|
self._indices = np.arange(self.data_len)
|
|
|
self.indices = self._indices
|
|
|
self.end_pointer = max(self.data_len, self.mini_epoch_len)
|
|
|
|
|
|
if not (shuffle is None or shuffle in ["per_mini_epoch", "per_epoch"]):
|
|
|
raise ValueError(
|
|
|
"Shuffle must be one of ['per_mini_epoch', 'per_epoch']. "
|
|
|
+ f"Got {shuffle}"
|
|
|
)
|
|
|
self.shuffle_type = shuffle
|
|
|
|
|
|
def shuffle(self) -> None:
|
|
|
"""Shuffle sampler indices."""
|
|
|
if self.shuffle_type == "per_mini_epoch" or (
|
|
|
self.shuffle_type == "per_epoch" and self.state_i == 0
|
|
|
):
|
|
|
if self.data_len >= self.mini_epoch_len:
|
|
|
self.indices = self._indices
|
|
|
np.random.shuffle(self.indices)
|
|
|
else:
|
|
|
self.indices = np.random.choice(
|
|
|
self._indices, self.mini_epoch_len, replace=True
|
|
|
)
|
|
|
|
|
|
def __iter__(self) -> Iterator[int]:
|
|
|
"""Iterate over sampler.
|
|
|
|
|
|
Returns:
|
|
|
python iterator
|
|
|
"""
|
|
|
self.state_i = self.state_i % self.divider
|
|
|
self.shuffle()
|
|
|
|
|
|
start = self.state_i * self.mini_epoch_len
|
|
|
stop = (
|
|
|
self.end_pointer
|
|
|
if (self.state_i == self.steps)
|
|
|
else (self.state_i + 1) * self.mini_epoch_len
|
|
|
)
|
|
|
indices = self.indices[start:stop].tolist()
|
|
|
|
|
|
self.state_i += 1
|
|
|
return iter(indices)
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
"""
|
|
|
Returns:
|
|
|
int: length of the mini-epoch
|
|
|
"""
|
|
|
return self.mini_epoch_len
|
|
|
|
|
|
|
|
|
class DistributedSamplerWrapper(DistributedSampler):
|
|
|
"""
|
|
|
Wrapper over `Sampler` for distributed training.
|
|
|
Allows you to use any sampler in distributed mode.
|
|
|
|
|
|
It is especially useful in conjunction with
|
|
|
`torch.nn.parallel.DistributedDataParallel`. In such case, each
|
|
|
process can pass a DistributedSamplerWrapper instance as a DataLoader
|
|
|
sampler, and load a subset of subsampled data of the original dataset
|
|
|
that is exclusive to it.
|
|
|
|
|
|
.. note::
|
|
|
Sampler is assumed to be of constant size.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
sampler,
|
|
|
num_replicas: Optional[int] = None,
|
|
|
rank: Optional[int] = None,
|
|
|
shuffle: bool = True,
|
|
|
):
|
|
|
"""
|
|
|
|
|
|
Args:
|
|
|
sampler: Sampler used for subsampling
|
|
|
num_replicas (int, optional): Number of processes participating in
|
|
|
distributed training
|
|
|
rank (int, optional): Rank of the current process
|
|
|
within ``num_replicas``
|
|
|
shuffle (bool, optional): If true (default),
|
|
|
sampler will shuffle the indices
|
|
|
"""
|
|
|
super(DistributedSamplerWrapper, self).__init__(
|
|
|
DatasetFromSampler(sampler),
|
|
|
num_replicas=num_replicas,
|
|
|
rank=rank,
|
|
|
shuffle=shuffle,
|
|
|
)
|
|
|
self.sampler = sampler
|
|
|
|
|
|
def __iter__(self) -> Iterator[int]:
|
|
|
"""Iterate over sampler.
|
|
|
|
|
|
Returns:
|
|
|
python iterator
|
|
|
"""
|
|
|
self.dataset = DatasetFromSampler(self.sampler)
|
|
|
indexes_of_indexes = super().__iter__()
|
|
|
subsampler_indexes = self.dataset
|
|
|
return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
"BalanceClassSampler",
|
|
|
"BatchBalanceClassSampler",
|
|
|
"DistributedSamplerWrapper",
|
|
|
"DynamicBalanceClassSampler",
|
|
|
"MiniEpochSampler",
|
|
|
]
|
|
|
|