Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| import albumentations as A | |
| import base64 | |
| import cv2 | |
| import gradio as gr | |
| import inspect | |
| import io | |
| import numpy as np | |
| import os | |
| from dataclasses import dataclass | |
| from loguru import logger | |
| from copy import deepcopy | |
| from functools import wraps | |
| from PIL import Image, ImageDraw | |
| from typing import get_type_hints, Optional | |
| from pydantic_core._pydantic_core import ValidationError | |
| # from mixpanel import Mixpanel | |
| from utils import is_not_supported_transform | |
| # MIXPANEL_TOKEN = os.getenv("MIXPANEL_TOKEN") | |
| # mp = Mixpanel(MIXPANEL_TOKEN) | |
| HEADER = f""" | |
| <div align="center"> | |
| <p> | |
| <img src="https://avatars.githubusercontent.com/u/57894582?s=200&v=4" alt="A" width="50" height="50" style="display:inline;"> | |
| <span style="font-size: 30px; vertical-align: bottom;"> lbumentations Demo ({A.__version__})</span> | |
| </p> | |
| <p style="margin-top: -15px;"> | |
| <a href="https://albumentations.ai/docs/" target="_blank" style="color: grey;">Documentation</a> | |
|   | |
| <a href="https://github.com/albumentations-team/albumentations" target="_blank" style="color: grey;">GitHub Repository</a> | |
| </p> | |
| </div> | |
| """ | |
| DEFAULT_TRANSFORM = "Rotate" | |
| NO_OPERATION_TRANFORM = "NoOp" | |
| DEFAULT_IMAGE_PATH = "images/doctor.webp" | |
| DEFAULT_IMAGE = np.array(Image.open(DEFAULT_IMAGE_PATH)) | |
| DEFAULT_IMAGE_HEIGHT = DEFAULT_IMAGE.shape[0] | |
| DEFAULT_IMAGE_WIDTH = DEFAULT_IMAGE.shape[1] | |
| DEFAULT_BOXES = [ | |
| [265, 121, 326, 177], # Mask | |
| [192, 169, 401, 395], # Coverall | |
| ] | |
| mask_keypoints = [[270, 123], [320, 130], [270, 151], [321, 158]] | |
| pocket_keypoints = [[226, 379], [272, 386], [307, 388], [364, 380]] | |
| arm_keypoints = [[215, 194], [372, 192], [214, 322], [378, 330]] | |
| DEFAULT_KEYPOINTS = mask_keypoints + pocket_keypoints + arm_keypoints | |
| BASE64_DEFAULT_MASKS = [ | |
| { | |
| "label": "Coverall", | |
| # light green color | |
| "color": (144, 238, 144), | |
| "mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAF+ElEQVR4nO3dwXLjNhBFUSg1///LziLj1Iwt26KkFhuvz9kkWVhFAJcAqXGStQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACe5+3sCyCUsq745+wLSKCsz4T1DMr6RFiUENZT2LI+EhYlhPWnt7t3nvt/MtSvsy+gkcfaeFuXJ11HBJPx7r+s7piPP3o0m/9zFP729tdfjv/gnT8dy1G41npeEc7Dd3astR7q6uOP2rT+4wb7mMLBGfkckildyyw8HMa1HWr8pK7pz1hF55YnrdE315dVHZmTb9IcPLVr7Oi/36oOTMpPe97Q+R16FL7wzW3sqThw2DdkdfOs3JTowDmeN+gbN6tbp+XJHxdk1pAPnIE3TczNnzdrmteaNeJjj1Y3zMyRD5w00WuNGu/RR/afpubZn5dlymjvexH8Znbu+cApk73WlLE+8v3C1Rm69wNnTPdaM0ba6hcOJkz4WhPG2SqrtSZM+Vr5o2yX1Vr5k75W+hhbZrVW+rSvlT3Ctlmt7Hlfa0X/anLnrnpf3DPkhtV86Zpf3sNyw+ouvKzYsPqvW/8rfERsWJxLWOeJ3rKERQlhnSh5y0oNK3nNtpAa1h6C8w8NK3jFNpEZlq5OlxnWNnLvgMiwcpdrH5FhbST2HkgMK3axdpIY1lZS7wJhUSIwrM32gM0u91aBYdFBXlihO8Bu4sLSVQ9xYe0n81ZICytzlTaUFtaOIm8GYVEiLKzIm39LYWHRhbAoIawGEg9wYVEiK6zEW39TWWHtKvCGEBYlhEUJYbWQdxYKixLCokRUWHkHyr6iwqIPYVEiKSwnYSNJYdGIsCghLEoIixLCokRQWF4KOwkKa2txd4WwmkgrS1iUEFYXYVuWsCiRE1bYHb+7nLBoRViUEBYlhEUJYVFCWG1kvdYKixIxYWXd7/tLCUtXzaSEFeBy9gU8VUhYNqxuMsLSVTsZYdFORFgZG1bGKN5FhEU/wqJEQlhZZ0iIhLBoSFiUEBYlhEUJYVEiIaysP70NkRAWDQmLEsLqI+qLXmFRQliUEBYlhEWJX2dfwK4ua4U9bj+XsA66/P0P0vqCo/CQy8dv+X3r/wVhHXElI2VdJ6wDiiOKalRYlBDWo6L2mecRVhtZhQrrUb5wuEpYlBBWF1knYUZYYWsSISKsM8vyiHVdRlivYWM8ICQsa95NSFinleUk/EJKWDQjrCbSDvOYsM5ZGCfhV2LCohdhPcKG9SVh3e5TRk/sKu0RKyis1y+N/eobOWG9nK6+ExTWa7esN119y79XeBdV/URYd5DVz4R1wNtlqepGUa+5+6551DKstaIe3ulEWJQQFiWERQlhUSIqrLx3q31FhUUfwqJEVljOwjaywqINYXUQuNOGhRW4QpsKC4su0sKyZTWRFtaWZe14zT+JCytylTYUuQyb/cJf5Brk7Vhrt5Xa62pvFRnWVmu107UekBlW6mptJDQsZZ0tNaxtpN4BqeN6fzW8/PH3LaUuQOq4PuqaVuz8OwopEXvHfNRyywqe/eCh/aVjV9Fz7z8KcpborIR1jvCo1hLWGQZk5a3wBCO6EtbLzehKWNQQFiWERQlhUUJYlBAWJYRFCWG9Wsc/Di8gLEoMCWvINtHIkLB4NWFRQliUENbLzXjeExYlhEUJYVFiRlgzHmtamREWLycsSgjr9UYczMKihLAoISxKCIsSwqKEsCghLEoI6wQTvsgSFiWERQlhUUJYZxjwkCUsSgiLEiPCGnDytDMirH7yU58QVv4qNjQhLE4gLEoIixLCooSwzhH/QjEhrMuQ/31NKxPCktYJBs14s9MnfOZn7FhrLdvWaw0KS1qvNG+qu5yI4TMfPryrmqSVPfWjjsLfnIgvMHWOO+xa0XM/ccdaK3xRO5gaFsWERQlhUWJqWB0e3qNNDauD6LiFRYmhYUVvFi3MDEtX5UaGpat6I8Oi3r8KSpCuwVpGmQAAAABJRU5ErkJggg==", | |
| }, | |
| { | |
| "label": "Mask", | |
| # light blue color | |
| "color": (173, 216, 230), | |
| "mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAB4ElEQVR4nO3csQ6CMBSG0avv/864OFhoobW9UeM5i4ML+fOFkmCMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIbcPn0BX257ftppkMEqtuY35uplqYN2VhFhsU5m2rnIKsJmXYy00xFWRBjuin1KvV2F6c7dP30Bv2ugwT8krPcp64SwJmzSahJWYbQUZbUIqzD8QK6sBmEVdLKKsEghLFIIa5LDs05YpBAWKYQ1y1lYJSxSCIsUwiKFsF55XlpGWLP83q9KWKQQ1qt37j6OzyphkUJYpBBWwZP4KsIqKWsRYZFCWDtuWWsIixTC2nPLWkJYB8paQVhHY2XpsMosdV0vaozXZpsG/+s3x0BN9bQM1sdOJ45pmauXpU4VadlqgLEubRF2AgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGCZBwKLGEVAl/J/AAAAAElFTkSuQmCC", | |
| }, | |
| ] | |
| # Get all the transforms from the albumentations library | |
| transforms_map = { | |
| name: cls | |
| for name, cls in vars(A).items() | |
| if ( | |
| inspect.isclass(cls) | |
| and issubclass(cls, (A.DualTransform, A.ImageOnlyTransform)) | |
| and not is_not_supported_transform(cls) | |
| and not name.endswith("3D") | |
| ) | |
| } | |
| transforms_map.pop("DualTransform", None) | |
| transforms_map.pop("ImageOnlyTransform", None) | |
| transforms_map.pop("ReferenceBasedTransform", None) | |
| transforms_map.pop("ToFloat", None) | |
| transforms_map.pop("Normalize", None) | |
| transforms_keys = list(sorted(transforms_map.keys())) | |
| # Decode the masks | |
| for mask in BASE64_DEFAULT_MASKS: | |
| mask["mask"] = np.array( | |
| Image.open(io.BytesIO(base64.b64decode(mask["mask"]))).convert("L") | |
| ) | |
| class RequestParams: | |
| user_ip: str | |
| transform_name: Optional[str] | |
| def track_event(event_name, user_id="unknown", properties=None): | |
| if properties is None: | |
| properties = {} | |
| #mp.track(user_id, event_name, properties) | |
| logger.info(f"Event tracked: {event_name} - {properties}") | |
| def get_params(request: gr.Request) -> RequestParams: | |
| """Parse input request parameters.""" | |
| ip = request.client.host | |
| transform_name = request.query_params.get("transform", None) | |
| params = RequestParams(user_ip=ip, transform_name=transform_name) | |
| track_event("app_opened", user_id=params.user_ip, properties={"transform_name": params.transform_name}) | |
| return params | |
| def run_with_retry(compose): | |
| def wrapper(*args, **kwargs): | |
| processors = deepcopy(compose.processors) | |
| for _ in range(4): | |
| try: | |
| result = compose(*args, **kwargs) | |
| break | |
| except NotImplementedError as e: | |
| print(f"Caught NotImplementedError: {e}") | |
| if "bbox" in str(e): | |
| kwargs.pop("bboxes", None) | |
| kwargs.pop("category_id", None) | |
| compose.processors.pop("bboxes") | |
| if "keypoint" in str(e): | |
| kwargs.pop("keypoints", None) | |
| compose.processors.pop("keypoints") | |
| if "mask" in str(e): | |
| kwargs.pop("mask", None) | |
| except (ValueError, ValidationError) as e: | |
| raise gr.Error(str(e)) | |
| except Exception as e: | |
| compose.processors = processors | |
| raise e | |
| compose.processors = processors | |
| return result | |
| return wrapper | |
| def draw_boxes(image, boxes, color=(255, 0, 0), thickness=1) -> np.ndarray: | |
| """Draw boxes with PIL.""" | |
| pil_image = Image.fromarray(image) | |
| draw = ImageDraw.Draw(pil_image) | |
| for box in boxes: | |
| x_min, y_min, x_max, y_max = box | |
| draw.rectangle([x_min, y_min, x_max, y_max], outline=color, width=thickness) | |
| return np.array(pil_image) | |
| def draw_keypoints(image, keypoints, color=(255, 0, 0), radius=2): | |
| """Draw keypoints with PIL.""" | |
| pil_image = Image.fromarray(image) | |
| draw = ImageDraw.Draw(pil_image) | |
| for keypoint in keypoints: | |
| x, y = keypoint | |
| draw.ellipse([x - radius, y - radius, x + radius, y + radius], fill=color) | |
| return np.array(pil_image) | |
| def get_rgb_mask(masks): | |
| """Get the RGB mask from the binary mask.""" | |
| rgb_mask = np.zeros((DEFAULT_IMAGE_HEIGHT, DEFAULT_IMAGE_WIDTH, 3), dtype=np.uint8) | |
| for data in masks: | |
| mask = data["mask"] | |
| rgb_mask[mask > 0] = np.array(data["color"]) | |
| return rgb_mask | |
| def draw_mask(image, mask): | |
| """Draw the mask on the image.""" | |
| image_with_mask = cv2.addWeighted(image, 0.5, mask, 0.5, 0) | |
| return image_with_mask | |
| def draw_not_implemented_image(image: np.ndarray, annotation_type: str): | |
| """Draw the image with a text. In the middle.""" | |
| pil_image = Image.fromarray(image) | |
| draw = ImageDraw.Draw(pil_image) | |
| # align in the centerm, and make bigger font | |
| text = f'Transform NOT working with "{annotation_type.upper()}" annotations.' | |
| length = draw.textlength(text) | |
| draw.text( | |
| (DEFAULT_IMAGE_WIDTH // 2 - length // 2, DEFAULT_IMAGE_HEIGHT // 2), | |
| text, | |
| fill=(255, 0, 0), | |
| align="center", | |
| ) | |
| return np.array(pil_image) | |
| def get_formatted_signature(function_or_class, indentation=4): | |
| signature = inspect.signature(function_or_class) | |
| type_hints = get_type_hints(function_or_class) | |
| args = [] | |
| for param in signature.parameters.values(): | |
| if param.name == "p": | |
| str_param = "p=1.0," | |
| elif param.default == inspect.Parameter.empty: | |
| if "height" in param.name or "width" in param.name: | |
| str_param = f"{param.name}=300," | |
| else: | |
| str_param = f"{param.name}=," | |
| else: | |
| if isinstance(param.default, str): | |
| str_param = f'{param.name}="{param.default}",' | |
| else: | |
| str_param = f"{param.name}={param.default}," | |
| annotation = type_hints.get(param.name, param.annotation) | |
| if isinstance(param.annotation, type): | |
| str_param += f" # {param.annotation.__name__}" | |
| else: | |
| str_annotation = str(annotation).replace("typing.", "") | |
| str_param += f" # {str_annotation}" | |
| str_param = "\n" + " " * indentation + str_param | |
| args.append(str_param) | |
| result = "(" + "".join(args) + "\n" + " " * (indentation - 4) + ")" | |
| return result | |
| def get_formatted_transform(transform_name): | |
| track_event("transform_selected", properties={"transform_name": transform_name}) | |
| transform = transforms_map[transform_name] | |
| return f"A.{transform.__name__}{get_formatted_signature(transform)}" | |
| def get_formatted_transform_docs(transform_name): | |
| transform = transforms_map[transform_name] | |
| return transform.__doc__.strip("\n") | |
| def update_augmented_images(image, code): | |
| if "=," in code: | |
| raise gr.Error("You have to fill in parameters to apply transform! See 'Code' section!") | |
| try: | |
| augmentation = eval(code) | |
| except ValidationError as e: | |
| raise gr.Error(str(e)) | |
| except Exception as e: | |
| logger.info(code) | |
| logger.error(e) | |
| raise e | |
| track_event("transform_applied", properties={"transform_name": augmentation.__class__.__name__, "code": code}) | |
| compose = A.Compose( | |
| [augmentation], | |
| bbox_params=A.BboxParams(format="pascal_voc", label_fields=["category_id"]), | |
| keypoint_params=A.KeypointParams(format="xy"), | |
| ) | |
| compose = run_with_retry(compose) # to prevent NotImplementedError | |
| keypoints = DEFAULT_KEYPOINTS | |
| bboxes = DEFAULT_BOXES | |
| mask = get_rgb_mask(BASE64_DEFAULT_MASKS) | |
| augmented = compose( | |
| image=image, | |
| mask=mask, | |
| keypoints=keypoints, | |
| bboxes=bboxes, | |
| category_id=range(len(bboxes)), | |
| ) | |
| image = augmented["image"] | |
| mask = augmented.get("mask", None) | |
| bboxes = augmented.get("bboxes", None) | |
| keypoints = augmented.get("keypoints", None) | |
| # Draw the augmented images (or replace by placeholder if not implemented) | |
| if mask is not None: | |
| image_with_mask = draw_mask(image.copy(), mask) | |
| else: | |
| image_with_mask = draw_not_implemented_image(image.copy(), "mask") | |
| if bboxes is not None: | |
| image_with_bboxes = draw_boxes(image.copy(), bboxes) | |
| else: | |
| image_with_bboxes = draw_not_implemented_image(image.copy(), "boxes") | |
| if keypoints is not None: | |
| image_with_keypoints = draw_keypoints(image.copy(), keypoints) | |
| else: | |
| image_with_keypoints = draw_not_implemented_image(image.copy(), "keypoints") | |
| return [ | |
| (image_with_mask, "Mask"), | |
| (image_with_bboxes, "Boxes"), | |
| (image_with_keypoints, "Keypoints"), | |
| ] | |
| def update_image_info(image): | |
| h, w = image.shape[:2] | |
| dtype = image.dtype | |
| max_, min_ = image.max(), image.min() | |
| return f"Image info:\n\t - shape: {h}x{w}\n\t - dtype: {dtype}\n\t - min/max: {min_}/{max_}" | |
| def update_code_and_docs(select): | |
| code = get_formatted_transform(select) | |
| docs = get_formatted_transform_docs(select) | |
| return code, docs | |
| def update_code_and_docs_on_start(url_params: gr.Request): | |
| params = get_params(url_params) | |
| if params.transform_name is not None and params.transform_name not in transforms_map: | |
| gr.Warning(f"Sorry, `{params.transform_name}` transform is not supported at the moment :(") | |
| transform_name = NO_OPERATION_TRANFORM | |
| elif params.transform_name in transforms_map: | |
| transform_name = params.transform_name | |
| else: | |
| transform_name = DEFAULT_TRANSFORM | |
| return gr.update(value=transform_name) | |
| with gr.Blocks() as demo: | |
| gr.Markdown(HEADER) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Group(): | |
| # gr.Markdown( | |
| # (" " * 4) + \ | |
| # "If a component is loading on start, please, try to refresh the page a few times. [Working on fix...]" | |
| # ) | |
| select = gr.Dropdown( | |
| label="Select a transformation", | |
| choices=transforms_keys, | |
| value=DEFAULT_TRANSFORM, | |
| type="value", | |
| interactive=True, | |
| ) | |
| with gr.Accordion("Documentation (click to expand)", open=False): | |
| docs = gr.TextArea( | |
| get_formatted_transform_docs(DEFAULT_TRANSFORM), | |
| show_label=False, | |
| interactive=False, | |
| ) | |
| code = gr.Code( | |
| label="Code", | |
| language="python", | |
| value=get_formatted_transform(DEFAULT_TRANSFORM), | |
| interactive=True, | |
| lines=5, | |
| ) | |
| info = gr.TextArea( | |
| value=f"Image size: {DEFAULT_IMAGE_HEIGHT} x {DEFAULT_IMAGE_WIDTH} (height x width)", | |
| show_label=False, | |
| lines=1, | |
| max_lines=1, | |
| ) | |
| button = gr.Button("Apply!") | |
| image = gr.Image( | |
| value=DEFAULT_IMAGE_PATH, | |
| type="numpy", | |
| height=500, | |
| width=300, | |
| sources=[], | |
| ) | |
| with gr.Row(): | |
| augmented_image = gr.Gallery( | |
| value=update_augmented_images(DEFAULT_IMAGE, "A.NoOp()"), | |
| rows=1, | |
| columns=3, | |
| show_label=False, | |
| ) | |
| select.change(fn=update_code_and_docs, inputs=[select], outputs=[code, docs]) | |
| button.click( | |
| fn=update_augmented_images, inputs=[image, code], outputs=[augmented_image] | |
| ) | |
| demo.load( | |
| update_code_and_docs_on_start, inputs=None, outputs=[select], queue=False | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |
