| import numpy as np | |
| import pytorch_lightning as pl | |
| import torch | |
| from datasets import load_dataset | |
| from torch.utils.data import DataLoader, Dataset, random_split, Subset | |
| from transformers import SegformerFeatureExtractor, BatchFeature | |
| from typing import Optional | |
| class SegmentationDataset(Dataset): | |
| """Image Segmentation Dataset""" | |
| def __init__(self, pixel_values: torch.Tensor, labels: torch.Tensor): | |
| """ | |
| Dataset for image segmentation. | |
| Parameters | |
| ---------- | |
| pixel_values : torch.Tensor | |
| Tensor of shape (N, H, W) containing the pixel values of the images. | |
| labels : torch.Tensor | |
| Tensor of shape (H, W) containing the labels of the images. | |
| """ | |
| self.pixel_values = pixel_values | |
| self.labels = labels | |
| assert pixel_values.shape[0] == labels.shape[0] | |
| self.length = pixel_values.shape[0] | |
| print(f"Created dataset with {self.length} samples") | |
| def __len__(self): | |
| return self.length | |
| def __getitem__(self, index): | |
| image = self.pixel_values[index] | |
| label = self.labels[index] | |
| encoded_inputs = BatchFeature({"pixel_values": image, "labels": label}) | |
| return encoded_inputs | |
| class SidewalkSegmentationDataLoader(pl.LightningDataModule): | |
| def __init__( | |
| self, hub_dir: str, batch_size: int, split: Optional[str] = None, | |
| ): | |
| super().__init__() | |
| self.hub_dir = hub_dir | |
| self.batch_size = batch_size | |
| self.tokenizer = SegformerFeatureExtractor(reduce_labels=True) | |
| self.dataset = load_dataset(self.hub_dir, split=split) | |
| self.len = len(self.dataset) | |
| def tokenize_data(self, *args, **kwargs): | |
| return self.tokenizer(*args, **kwargs) | |
| def setup(self, stage: str = None): | |
| encoded_dataset = self.tokenize_data( | |
| images=self.dataset["pixel_values"], segmentation_maps=self.dataset["label"], return_tensors="pt" | |
| ) | |
| dataset = SegmentationDataset(encoded_dataset["pixel_values"], encoded_dataset["labels"]) | |
| indices = np.arange(self.len) | |
| train_indices, val_indices = random_split(indices, [int(self.len * 0.8), int(self.len * 0.2)]) | |
| self.train_dataset = Subset(dataset, train_indices) | |
| self.val_dataset = Subset(dataset, val_indices) | |
| def train_dataloader(self): | |
| return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=12) | |
| def val_dataloader(self): | |
| return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=12) | |