Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.utils import data | |
| import torch.nn.functional as F | |
| import numpy as np | |
| class TorchDataset(data.Dataset): | |
| def __init__(self, datasamples, is_inference: bool): | |
| self.x = datasamples["image"] | |
| self.y = datasamples["label"] | |
| self.is_inference = is_inference | |
| def __getitem__(self, idx): | |
| if self.is_inference: | |
| x = torch.FloatTensor(np.array(self.x[idx])[None, ...]) / 255 | |
| return x | |
| else: | |
| x = torch.FloatTensor(np.array(self.x[idx])[None, ...]) / 255 | |
| y = torch.tensor(self.y[idx]).type(torch.int64) | |
| y = F.one_hot(y, 10) | |
| y = y.type(torch.float32) | |
| return x, y | |
| def __len__(self): | |
| return len(self.x) | |