Spaces:
Paused
Paused
| from torch.utils.data import Dataset | |
| from starvector.data.util import ImageTrainProcessor, use_placeholder, rasterize_svg | |
| from starvector.util import instantiate_from_config | |
| import numpy as np | |
| from datasets import load_dataset | |
| class SVGDatasetBase(Dataset): | |
| def __init__(self, dataset_name, split, im_size, num_samples=-1, **kwargs): | |
| self.split = split | |
| self.im_size = im_size | |
| transforms = kwargs.get('transforms', False) | |
| if transforms: | |
| self.transforms = instantiate_from_config(transforms) | |
| self.p = self.transforms.p | |
| else: | |
| self.transforms = None | |
| self.p = 0.0 | |
| normalization = kwargs.get('normalize', False) | |
| if normalization: | |
| mean = tuple(normalization.get('mean', None)) | |
| std = tuple(normalization.get('std', None)) | |
| else: | |
| mean = None | |
| std = None | |
| self.processor = ImageTrainProcessor(size=self.im_size, mean=mean, std=std) | |
| self.data = load_dataset(dataset_name, split=split) | |
| print(f"Loaded {len(self.data)} samples from {dataset_name} {split} split") | |
| def __len__(self): | |
| return len(self.data_json) | |
| def get_svg_and_image(self, svg_str, sample_id): | |
| do_augment = np.random.choice([True, False], p=[self.p, 1 - self.p]) | |
| svg, image = None, None | |
| # Try to augment the image if conditions are met | |
| if self.transforms is not None and do_augment: | |
| try: | |
| svg, image = self.transforms.augment(svg_str) | |
| except Exception as e: | |
| print(f"Error augmenting {sample_id} due to {str(e)}, trying to rasterize SVG") | |
| # If augmentation failed or wasn't attempted, try to rasterize the SVG | |
| if svg is None or image is None: | |
| try: | |
| svg, image = svg_str, rasterize_svg(svg_str, self.im_size) | |
| except Exception as e: | |
| print(f"Error rasterizing {sample_id} due to {str(e)}, using placeholder image") | |
| svg = use_placeholder() | |
| image = rasterize_svg(svg, self.im_size) | |
| # If the image is completely white, use a placeholder image | |
| if np.array(image).mean() == 255.0: | |
| print(f"Image is full white, using placeholder image for {sample_id}") | |
| svg = use_placeholder() | |
| image = rasterize_svg(svg) | |
| # Process the image | |
| if 'siglip' in self.image_processor: | |
| image = self.processor(image).pixel_values[0] | |
| else: | |
| image = self.processor(image) | |
| return svg, image | |
| def __getitem__(self, idx): | |
| raise NotImplementedError("This method should be implemented by subclasses") | |