Spaces:
Runtime error
Runtime error
| import json | |
| import cv2 | |
| import numpy as np | |
| import os | |
| from torch.utils.data import Dataset | |
| import pycocotools.mask as maskUtils | |
| from torchvision import transforms | |
| import utils.transforms as custom_transforms | |
| from PIL import Image | |
| class SAMDataset(Dataset): | |
| def __init__(self, data_path='../data/files', txt_path='../data/data_85616.txt'): | |
| self.data = [] | |
| with open(txt_path, 'rt') as f: | |
| for line in f: | |
| self.data.append(eval(line)) | |
| self.data_path = data_path | |
| randomresizedcrop = custom_transforms.RandomResizedCrop( | |
| 512, | |
| scale=(0.9, 1), | |
| ) | |
| self.transform = custom_transforms.Compose([ | |
| randomresizedcrop, | |
| custom_transforms.RandomHorizontalFlip(p=0.5), | |
| custom_transforms.ToTensor(), | |
| transforms.Normalize(mean=0.5, std=0.5) | |
| ]) | |
| def __len__(self): | |
| return len(self.data) | |
| def load_rle_annotations_from_json(self, json_file_path, return_pil=True): | |
| with open(json_file_path, 'r', encoding='utf-8') as f: | |
| anno_data = json.load(f) | |
| annotations = anno_data['annotations'] | |
| height = int(anno_data['image']['height']) | |
| width = int(anno_data['image']['width']) | |
| map = np.zeros((height,width), dtype=np.uint16) | |
| for i in range(len(annotations)): | |
| ann = annotations[i] | |
| mask = maskUtils.decode(ann['segmentation']) | |
| map[mask != 0] = i + 1 | |
| if return_pil: | |
| res = np.zeros((map.shape[0], map.shape[1], 3)) | |
| res[:, :, 0] = map % 256 | |
| res[:, :, 1] = map // 256 | |
| res = Image.fromarray(res.astype(np.uint8)) | |
| return res | |
| return map | |
| def __getitem__(self, idx): | |
| item = self.data[idx] | |
| source_filename = item['source'] | |
| target_filename = item['target'] | |
| prompt = item['prompt'] | |
| source = self.load_rle_annotations_from_json(os.path.join(self.data_path, source_filename)) | |
| target = Image.open(os.path.join(self.data_path, target_filename)) | |
| target, source = self.transform(target, source) | |
| print(source.max(), source.min()) | |
| target = target.permute(1,2,0) | |
| return dict(jpg=target, txt=prompt, hint=source) | |