Spaces:
Paused
Paused
| import numpy as np | |
| from svgpathtools import ( | |
| Path, Arc, CubicBezier, QuadraticBezier, | |
| svgstr2paths) | |
| import os | |
| from noise import pnoise1 | |
| import re | |
| import matplotlib.colors as mcolors | |
| from bs4 import BeautifulSoup | |
| from starvector.data.util import rasterize_svg | |
| class SVGTransforms: | |
| def __init__(self, transformations): | |
| self.transformations = transformations | |
| self.noise_std = self.transformations.get('noise_std', False) | |
| self.noise_type = self.transformations.get('noise_type', False) | |
| self.rotate = self.transformations.get('rotate', False) | |
| self.shift_re = self.transformations.get('shift_re', False) | |
| self.shift_im = self.transformations.get('shift_im', False) | |
| self.scale = self.transformations.get('scale', False) | |
| self.color_noise = self.transformations.get('color_noise', False) | |
| self.p = self.transformations.get('p', 0.5) | |
| self.color_change = self.transformations.get('color_change', False) | |
| self.colors = self.transformations.get('colors', ['#ff0000', '#0000ff', '#000000']) | |
| def sample_transformations(self): | |
| if self.rotate: | |
| a, b = self.rotate['from'], self.rotate['to'] | |
| rotation_angle = np.random.uniform(a, b) | |
| self.rotation_angle = rotation_angle | |
| if self.shift_re or self.shift_im: | |
| self.shift_real = np.random.uniform(self.shift_re['from'], self.shift_re['to']) | |
| self.shift_imag = np.random.uniform(self.shift_im['from'], self.shift_im['to']) | |
| if self.scale: | |
| self.scale = np.random.uniform(self.scale['from'], self.scale['to']) | |
| if self.color_noise: | |
| self.color_noise_std = np.random.uniform(self.color_noise['from'], self.color_noise['to']) | |
| def paths2str(self, groupped_paths, svg_opening_tag='<svg xmlns="http://www.w3.org/2000/svg" version="1.1">'): | |
| keys_to_exclude = ['d', 'cx', 'cy', 'rx', 'ry'] | |
| all_groups_srt = '' | |
| for group, elements in groupped_paths.items(): | |
| group_attributes, paths_and_attributes = elements.get('attrs', {}), elements.get('paths', []) | |
| group_attr_str = ' '.join(f'{key}="{value}"' for key, value in group_attributes.items()) | |
| path_strings = [] | |
| path_str = '' | |
| for path, attributes in paths_and_attributes: | |
| path_attr_str = '' | |
| d_str = path.d() | |
| for key, value in attributes.items(): | |
| if key not in keys_to_exclude: | |
| path_attr_str += f' {key}="{value}"' | |
| path_strings.append(f'<path d="{d_str}"{path_attr_str} />') | |
| path_str = "\n".join(path_strings) | |
| if 'no_group'in group: | |
| group_str = path_str | |
| else: | |
| group_str = f'<g {group_attr_str}>\n{path_str}\n</g>\n' | |
| all_groups_srt += group_str | |
| svg = f'{svg_opening_tag}\n{all_groups_srt}</svg>' | |
| return svg | |
| def add_noise(self, seg): | |
| noise_scale = np.random.uniform(self.noise_std['from'], self.noise_std['to']) | |
| if self.noise_type == 'gaussian': | |
| noise_sample = np.random.normal(loc=0.0, scale=noise_scale) + \ | |
| 1j * np.random.normal(loc=0.0, scale=noise_scale) | |
| elif self.noise_type == 'perlin': | |
| noise_sample = complex(pnoise1(np.random.random(), octaves=2), pnoise1(np.random.random(), octaves=2))*noise_scale | |
| if isinstance(seg, CubicBezier): | |
| seg.control1 = seg.control1 + noise_sample | |
| seg.control2 = seg.control2 + noise_sample | |
| elif isinstance(seg, QuadraticBezier): | |
| seg.control = seg.control + noise_sample | |
| elif isinstance(seg, Arc): | |
| seg.radius = seg.radius + noise_sample | |
| return seg | |
| def do_rotate(self, path, viewbox_width, viewbox_height): | |
| if self.rotate: | |
| new_path = path.rotated(self.rotation_angle, complex(viewbox_width/2, viewbox_height/2)) | |
| return new_path | |
| else: | |
| return path | |
| def do_shift(self, path): | |
| if self.shift_re or self.shift_im: | |
| return path.translated(complex(self.shift_real, self.shift_imag)) | |
| else: | |
| return path | |
| def do_scale(self, path): | |
| if self.scale: | |
| return path.scaled(self.scale) | |
| else: | |
| return path | |
| def add_color_noise(self, source_color): | |
| # Convert color to RGB | |
| if source_color.startswith("#"): | |
| base_color = mcolors.hex2color(source_color) | |
| else: | |
| base_color = mcolors.hex2color(mcolors.CSS4_COLORS.get(source_color, '#FFFFFF')) | |
| # Add noise to each RGB component | |
| noise = np.random.normal(0, self.color_noise_std, 3) | |
| noisy_color = np.clip(np.array(base_color) + noise, 0, 1) | |
| # Convert the RGB color back to hex | |
| hex_color = mcolors.rgb2hex(noisy_color) | |
| return hex_color | |
| def do_color_change(self, attr): | |
| if 'fill' in attr: | |
| if self.color_noise or self.color_change: | |
| fill_value = attr['fill'] | |
| if fill_value == 'none': | |
| new_fill_value = 'none' | |
| else: | |
| if self.color_noise: | |
| new_fill_value = self.add_color_noise(fill_value) | |
| elif self.color_change: | |
| new_fill_value = np.random.choice(self.colors) | |
| attr['fill'] = new_fill_value | |
| return attr | |
| def clean_attributes(self, attr): | |
| attr_out = {} | |
| if 'fill' in attr: | |
| attr_out = attr | |
| elif 'style' in attr: | |
| fill_values = re.findall('fill:[^;]+', attr['style']) | |
| if fill_values: | |
| fill_value = fill_values[0].replace('fill:', '').strip() | |
| attr_out['fill'] = fill_value | |
| else: | |
| attr_out = attr | |
| else: | |
| attr_out = attr | |
| return attr_out | |
| def get_viewbox_size(self, svg): | |
| # Try to extract viewBox attribute | |
| match = re.search(r'viewBox="([^"]+)"', svg) | |
| if match: | |
| viewbox = match.group(1) | |
| else: | |
| # If viewBox is not found, try to extract width and height attributes | |
| match = re.search(r'width="([^"]+)px" height="([^"]+)px"', svg) | |
| if match: | |
| width, height = match.groups() | |
| viewbox = f"0 0 {width} {height}" | |
| else: | |
| viewbox = "0 0 256 256" # Default if neither viewBox nor width/height are found | |
| viewbox = [float(x) for x in viewbox.split()] | |
| viewbox_width, viewbox_height = viewbox[2], viewbox[3] | |
| return viewbox_width, viewbox_height | |
| def augment(self, svg): | |
| if os.path.isfile(svg): | |
| # open svg file | |
| with open(svg, 'r') as f: | |
| svg = f.read() | |
| # Sample transformations for this sample | |
| self.sample_transformations() | |
| # Parse the SVG content | |
| soup = BeautifulSoup(svg, 'xml') | |
| # Get opening tag | |
| svg_opening_tag = re.findall('<svg[^>]+>', svg)[0] | |
| viewbox_width, viewbox_height = self.get_viewbox_size(svg) | |
| # Get all svg parents | |
| groups = soup.findAll() | |
| # Create the groups of paths based on their original <g> tag | |
| grouped_paths = {} | |
| for i, g in enumerate(groups): | |
| if g.name == 'g': | |
| group_id = group_id = g.get('id') if g.get('id') else f'none_{i}' | |
| group_attrs = g.attrs | |
| elif g.name == 'svg' or g.name == 'metadata' or g.name == 'defs': | |
| continue | |
| else: | |
| group_id = f'no_group_{i}' | |
| group_attrs = {} | |
| group_svg_string = f'{svg_opening_tag}{str(g)}</svg>' | |
| try: | |
| paths, attributes = svgstr2paths(group_svg_string) | |
| except: | |
| return svg, rasterize_svg(svg) | |
| if not paths: | |
| continue | |
| paths_and_attributes = [] | |
| # Rotation, shift, scale, noise addition | |
| new_paths = [] | |
| new_attributes = [] | |
| for path, attribute in zip(paths, attributes): | |
| attr = self.clean_attributes(attribute) | |
| new_path = self.do_rotate(path, viewbox_width, viewbox_height) | |
| new_path = self.do_shift(new_path) | |
| new_path = self.do_scale(new_path) | |
| if self.noise_std: | |
| # Add noise to path to deform svg | |
| noisy_path = [] | |
| for seg in new_path: | |
| noisy_seg = self.add_noise(seg) | |
| noisy_path.append(noisy_seg) | |
| new_paths.append(Path(*noisy_path)) | |
| else: | |
| new_paths.append(new_path) | |
| # Color change | |
| attr = self.do_color_change(attr) | |
| paths_and_attributes.append((new_path, attr)) | |
| grouped_paths[group_id] = { | |
| 'paths': paths_and_attributes, | |
| 'attrs': group_attrs | |
| } | |
| svg = self.paths2str(grouped_paths, svg_opening_tag) | |
| image = rasterize_svg(svg) | |
| return svg, image | |