| import os | |
| from PIL import Image | |
| from torchvision import transforms | |
| from torch.utils.data import Dataset | |
| def find_mask_file(image_path, mask_dir, mask_extensions=['.png', '.jpg', '.jpeg']): | |
| base_name = os.path.splitext(os.path.basename(image_path))[0] | |
| for ext in mask_extensions: | |
| mask_path = os.path.join(mask_dir, base_name + ext) | |
| if os.path.exists(mask_path): | |
| return mask_path | |
| return None | |
| class SegmentationDataset(Dataset): | |
| def __init__(self, image_dir, mask_dir, transform=None): | |
| self.image_dir = image_dir | |
| self.mask_dir = mask_dir | |
| self.transform = transform | |
| self.image_filenames = os.listdir(image_dir) | |
| def __len__(self): | |
| return len(self.image_filenames) | |
| def __getitem__(self, idx): | |
| img_path = os.path.join(self.image_dir, self.image_filenames[idx]) | |
| mask_path = find_mask_file(img_path, self.mask_dir) | |
| image = Image.open(img_path).convert("RGB") | |
| mask = Image.open(mask_path).convert("L") | |
| if self.transform: | |
| image = self.transform(image) | |
| mask = self.transform(mask) | |
| return image, mask | |
| def transform_img(): | |
| transform = transforms.Compose([ | |
| transforms.Resize((128, 128)), | |
| transforms.ToTensor() | |
| ]) | |
| return transform | |
| if __name__ == "__main__": | |
| print("Dataset class") |