Spaces:
Running
Running
| import cv2 | |
| import inspect | |
| import numpy as np | |
| import albumentations as A | |
| import gradio as gr | |
| from typing import get_type_hints | |
| from PIL import Image, ImageDraw | |
| import base64 | |
| import io | |
| from PIL import Image | |
| from functools import wraps | |
| from copy import deepcopy | |
| DEFAULT_TRANSFORM = "CoarseDropout" | |
| DEFAULT_IMAGE = "images/doctor.webp" | |
| DEFAULT_IMAGE_HEIGHT = 400 | |
| DEFAULT_IMAGE_WIDTH = 600 | |
| DEFAULT_BOXES = [[265, 121, 326, 177], [192, 169, 401, 395]] | |
| DEFAULT_KEYPOINTS = [ | |
| [(x_min + x_max) // 2, (y_min + y_max) // 2] | |
| for x_min, y_min, x_max, y_max in DEFAULT_BOXES | |
| ] | |
| CORENERS = [[[x_min, y_min], [x_max, y_max], [x_min, y_max], [x_max, y_min]] for x_min, y_min, x_max, y_max in DEFAULT_BOXES] | |
| for bbox_corners in CORENERS: | |
| DEFAULT_KEYPOINTS += bbox_corners | |
| BASE64_DEFAULT_MASKS = [ | |
| { | |
| "label": "Coverall", | |
| # light green color | |
| "color": (144, 238, 144), | |
| "mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAF+ElEQVR4nO3dwXLjNhBFUSg1///LziLj1Iwt26KkFhuvz9kkWVhFAJcAqXGStQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACe5+3sCyCUsq745+wLSKCsz4T1DMr6RFiUENZT2LI+EhYlhPWnt7t3nvt/MtSvsy+gkcfaeFuXJ11HBJPx7r+s7piPP3o0m/9zFP729tdfjv/gnT8dy1G41npeEc7Dd3astR7q6uOP2rT+4wb7mMLBGfkckildyyw8HMa1HWr8pK7pz1hF55YnrdE315dVHZmTb9IcPLVr7Oi/36oOTMpPe97Q+R16FL7wzW3sqThw2DdkdfOs3JTowDmeN+gbN6tbp+XJHxdk1pAPnIE3TczNnzdrmteaNeJjj1Y3zMyRD5w00WuNGu/RR/afpubZn5dlymjvexH8Znbu+cApk73WlLE+8v3C1Rm69wNnTPdaM0ba6hcOJkz4WhPG2SqrtSZM+Vr5o2yX1Vr5k75W+hhbZrVW+rSvlT3Ctlmt7Hlfa0X/anLnrnpf3DPkhtV86Zpf3sNyw+ouvKzYsPqvW/8rfERsWJxLWOeJ3rKERQlhnSh5y0oNK3nNtpAa1h6C8w8NK3jFNpEZlq5OlxnWNnLvgMiwcpdrH5FhbST2HkgMK3axdpIY1lZS7wJhUSIwrM32gM0u91aBYdFBXlihO8Bu4sLSVQ9xYe0n81ZICytzlTaUFtaOIm8GYVEiLKzIm39LYWHRhbAoIawGEg9wYVEiK6zEW39TWWHtKvCGEBYlhEUJYbWQdxYKixLCokRUWHkHyr6iwqIPYVEiKSwnYSNJYdGIsCghLEoIixLCokRQWF4KOwkKa2txd4WwmkgrS1iUEFYXYVuWsCiRE1bYHb+7nLBoRViUEBYlhEUJYVFCWG1kvdYKixIxYWXd7/tLCUtXzaSEFeBy9gU8VUhYNqxuMsLSVTsZYdFORFgZG1bGKN5FhEU/wqJEQlhZZ0iIhLBoSFiUEBYlhEUJYVEiIaysP70NkRAWDQmLEsLqI+qLXmFRQliUEBYlhEWJX2dfwK4ua4U9bj+XsA66/P0P0vqCo/CQy8dv+X3r/wVhHXElI2VdJ6wDiiOKalRYlBDWo6L2mecRVhtZhQrrUb5wuEpYlBBWF1knYUZYYWsSISKsM8vyiHVdRlivYWM8ICQsa95NSFinleUk/EJKWDQjrCbSDvOYsM5ZGCfhV2LCohdhPcKG9SVh3e5TRk/sKu0RKyis1y+N/eobOWG9nK6+ExTWa7esN119y79XeBdV/URYd5DVz4R1wNtlqepGUa+5+6551DKstaIe3ulEWJQQFiWERQlhUSIqrLx3q31FhUUfwqJEVljOwjaywqINYXUQuNOGhRW4QpsKC4su0sKyZTWRFtaWZe14zT+JCytylTYUuQyb/cJf5Brk7Vhrt5Xa62pvFRnWVmu107UekBlW6mptJDQsZZ0tNaxtpN4BqeN6fzW8/PH3LaUuQOq4PuqaVuz8OwopEXvHfNRyywqe/eCh/aVjV9Fz7z8KcpborIR1jvCo1hLWGQZk5a3wBCO6EtbLzehKWNQQFiWERQlhUUJYlBAWJYRFCWG9Wsc/Di8gLEoMCWvINtHIkLB4NWFRQliUENbLzXjeExYlhEUJYVFiRlgzHmtamREWLycsSgjr9UYczMKihLAoISxKCIsSwqKEsCghLEoI6wQTvsgSFiWERQlhUUJYZxjwkCUsSgiLEiPCGnDytDMirH7yU58QVv4qNjQhLE4gLEoIixLCooSwzhH/QjEhrMuQ/31NKxPCktYJBs14s9MnfOZn7FhrLdvWaw0KS1qvNG+qu5yI4TMfPryrmqSVPfWjjsLfnIgvMHWOO+xa0XM/ccdaK3xRO5gaFsWERQlhUWJqWB0e3qNNDauD6LiFRYmhYUVvFi3MDEtX5UaGpat6I8Oi3r8KSpCuwVpGmQAAAABJRU5ErkJggg==", | |
| }, | |
| { | |
| "label": "Mask", | |
| # light blue color | |
| "color": (173, 216, 230), | |
| "mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAB4ElEQVR4nO3csQ6CMBSG0avv/864OFhoobW9UeM5i4ML+fOFkmCMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIbcPn0BX257ftppkMEqtuY35uplqYN2VhFhsU5m2rnIKsJmXYy00xFWRBjuin1KvV2F6c7dP30Bv2ugwT8krPcp64SwJmzSahJWYbQUZbUIqzD8QK6sBmEVdLKKsEghLFIIa5LDs05YpBAWKYQ1y1lYJSxSCIsUwiKFsF55XlpGWLP83q9KWKQQ1qt37j6OzyphkUJYpBBWwZP4KsIqKWsRYZFCWDtuWWsIixTC2nPLWkJYB8paQVhHY2XpsMosdV0vaozXZpsG/+s3x0BN9bQM1sdOJ45pmauXpU4VadlqgLEubRF2AgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGCZBwKLGEVAl/J/AAAAAElFTkSuQmCC", | |
| }, | |
| ] | |
| # Get all the transforms from the albumentations library | |
| transforms_map = { | |
| name: cls | |
| for name, cls in vars(A).items() | |
| if inspect.isclass(cls) and issubclass(cls, (A.DualTransform, A.ImageOnlyTransform)) | |
| } | |
| transforms_map.pop("DualTransform", None) | |
| transforms_map.pop("ImageOnlyTransform", None) | |
| transforms_keys = list(sorted(transforms_map.keys())) | |
| # Decode the masks | |
| for mask in BASE64_DEFAULT_MASKS: | |
| mask["mask"] = np.array(Image.open(io.BytesIO(base64.b64decode(mask["mask"]))).convert("L")) | |
| def run_with_retry(compose): | |
| def wrapper(*args, **kwargs): | |
| processors = deepcopy(compose.processors) | |
| for _ in range(4): | |
| try: | |
| result = compose(*args, **kwargs) | |
| break | |
| except NotImplementedError as e: | |
| print(f"Caught NotImplementedError: {e}") | |
| if "bbox" in str(e): | |
| kwargs.pop("bboxes", None) | |
| kwargs.pop("category_id", None) | |
| compose.processors.pop("bboxes") | |
| if "keypoint" in str(e): | |
| kwargs.pop("keypoints", None) | |
| compose.processors.pop("keypoints") | |
| if "mask" in str(e): | |
| kwargs.pop("mask", None) | |
| except Exception as e: | |
| compose.processors = processors | |
| raise e | |
| compose.processors = processors | |
| return result | |
| return wrapper | |
| def draw_boxes(image, boxes, color=(255, 0, 0), thickness=2) -> np.ndarray: | |
| """Draw boxes with PIL.""" | |
| pil_image = Image.fromarray(image) | |
| draw = ImageDraw.Draw(pil_image) | |
| for box in boxes: | |
| x_min, y_min, x_max, y_max = box | |
| draw.rectangle([x_min, y_min, x_max, y_max], outline=color, width=thickness) | |
| return np.array(pil_image) | |
| def draw_keypoints(image, keypoints, color=(255, 0, 0), radius=2): | |
| """Draw keypoints with PIL.""" | |
| pil_image = Image.fromarray(image) | |
| draw = ImageDraw.Draw(pil_image) | |
| for keypoint in keypoints: | |
| x, y = keypoint | |
| draw.ellipse([x - radius, y - radius, x + radius, y + radius], fill=color) | |
| return np.array(pil_image) | |
| def get_rgb_mask(masks): | |
| """Get the RGB mask from the binary mask.""" | |
| rgb_mask = np.zeros((DEFAULT_IMAGE_HEIGHT, DEFAULT_IMAGE_WIDTH, 3), dtype=np.uint8) | |
| for data in masks: | |
| mask = data["mask"] | |
| rgb_mask[mask > 0] = np.array(data["color"]) | |
| return rgb_mask | |
| def draw_mask(image, mask): | |
| """Draw the mask on the image.""" | |
| image_with_mask = cv2.addWeighted(image, 0.5, mask, 0.5, 0) | |
| return image_with_mask | |
| def draw_not_implemented_image(image): | |
| """Draw the image with a text. In the middle.""" | |
| pil_image = Image.fromarray(image) | |
| draw = ImageDraw.Draw(pil_image) | |
| # align in the centerm, and make bigger font | |
| text = "NOT IMPLEMETED FOR THIS TYPE OF ANNOTATIONS" | |
| length = draw.textlength(text) | |
| draw.text( | |
| (DEFAULT_IMAGE_WIDTH // 2 - length // 2, DEFAULT_IMAGE_HEIGHT // 2), | |
| text, | |
| fill=(255, 0, 0), | |
| align="center", | |
| ) | |
| return np.array(pil_image) | |
| def get_formatted_signature(function_or_class, indentation=4): | |
| signature = inspect.signature(function_or_class) | |
| type_hints = get_type_hints(function_or_class) | |
| args = [] | |
| for param in signature.parameters.values(): | |
| if param.name == "p": | |
| str_param = "p=1.0," | |
| elif param.default == inspect.Parameter.empty: | |
| str_param = f"{param.name}=," | |
| else: | |
| if isinstance(param.default, str): | |
| str_param = f'{param.name}="{param.default}",' | |
| else: | |
| str_param = f"{param.name}={param.default}," | |
| annotation = type_hints.get(param.name, param.annotation) | |
| if isinstance(param.annotation, type): | |
| str_param += f" # {param.annotation.__name__}" | |
| else: | |
| str_annotation = str(annotation).replace("typing.", "") | |
| str_param += f" # {str_annotation}" | |
| str_param = "\n" + " " * indentation + str_param | |
| args.append(str_param) | |
| result = "(" + "".join(args) + "\n" + " " * (indentation - 4) + ")" | |
| return result | |
| def update(image, code): | |
| try: | |
| augmentation = eval(code) | |
| compose = A.Compose( | |
| [augmentation], | |
| bbox_params=A.BboxParams(format="pascal_voc", label_fields=["category_id"]), | |
| keypoint_params=A.KeypointParams(format="xy"), | |
| additional_targets={"not_implemented_image": "image"} | |
| ) | |
| compose = run_with_retry(compose) # to prevent NotImplementedError | |
| keypoints = DEFAULT_KEYPOINTS | |
| bboxes = DEFAULT_BOXES | |
| mask = get_rgb_mask(BASE64_DEFAULT_MASKS) | |
| augmented = compose( | |
| image=image, | |
| not_implemented_image=draw_not_implemented_image(image), | |
| mask=mask, | |
| keypoints=keypoints, | |
| bboxes=bboxes, | |
| category_id=range(len(bboxes)), | |
| ) | |
| image = augmented["image"] | |
| not_implemented_image = augmented["not_implemented_image"] | |
| mask = augmented.get("mask", None) | |
| bboxes = augmented.get("bboxes", None) | |
| keypoints = augmented.get("keypoints", None) | |
| image_with_mask = draw_mask(image.copy(), mask) if mask is not None else not_implemented_image | |
| image_with_bboxes = draw_boxes(image.copy(), bboxes) if bboxes is not None else not_implemented_image | |
| image_with_keypoints = draw_keypoints(image.copy(), keypoints) if keypoints is not None else not_implemented_image | |
| return [ | |
| (image_with_mask, "Mask"), | |
| (image_with_bboxes, "Boxes"), | |
| (image_with_keypoints, "Keypoints"), | |
| ] | |
| except Exception as e: | |
| raise e | |
| def update_image_info(image): | |
| h, w = image.shape[:2] | |
| dtype = image.dtype | |
| max_, min_ = image.max(), image.min() | |
| return f"Image info:\n\t - shape: {h}x{w}\n\t - dtype: {dtype}\n\t - min/max: {min_}/{max_}" | |
| def get_formatted_transform(transform_number): | |
| transform_name = transforms_keys[transform_number] | |
| transform = transforms_map[transform_name] | |
| return f"A.{transform.__name__}{get_formatted_signature(transform)}" | |
| def get_formatted_transform_docs(transform_number): | |
| transform_name = transforms_keys[transform_number] | |
| transform = transforms_map[transform_name] | |
| return transform.__doc__.strip("\n") | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Group(): | |
| select = gr.Dropdown( | |
| label="Select a transformation", | |
| choices=transforms_keys, | |
| value=DEFAULT_TRANSFORM, | |
| type="index", | |
| interactive=True, | |
| ) | |
| with gr.Accordion("Documentation", open=False): | |
| docs = gr.TextArea( | |
| get_formatted_transform_docs( | |
| transforms_keys.index(DEFAULT_TRANSFORM) | |
| ), | |
| show_label=False, | |
| interactive=False, | |
| ) | |
| code = gr.Code( | |
| language="python", | |
| value=get_formatted_transform(transforms_keys.index(DEFAULT_TRANSFORM)), | |
| interactive=True, | |
| lines=5, | |
| ) | |
| button = gr.Button("Run") | |
| #info = gr.Text(interactive=False, label="Image info", value="") | |
| image = gr.Image( | |
| value=DEFAULT_IMAGE, | |
| type="numpy", | |
| height=500, | |
| width=300, | |
| sources=[], | |
| ) | |
| with gr.Row(): | |
| augmented_image = gr.Gallery(rows=1, columns=3) | |
| # augmented_image = gr.Image(type="numpy", height=300, width=300) | |
| #image.upload(fn=update_image_info, inputs=[image], outputs=[info]) | |
| select.change(fn=get_formatted_transform, inputs=[select], outputs=[code]) | |
| button.click(fn=update, inputs=[image, code], outputs=[augmented_image]) | |
| if __name__ == "__main__": | |
| demo.launch() | |