| import torch | |
| import torch.nn as nn | |
| from enum import Enum | |
| import base64 | |
| import json | |
| from io import BytesIO | |
| from PIL import Image | |
| import requests | |
| import re | |
| class ImageType(Enum): | |
| REAL_UP_L = 0 | |
| REAL_UP_R = 1 | |
| REAL_DOWN_R = 2 | |
| REAL_DOWN_L = 3 | |
| FAKE = 4 | |
| def crop_image_part(image: torch.Tensor, | |
| part: ImageType) -> torch.Tensor: | |
| size = image.shape[2] // 2 | |
| if part == ImageType.REAL_UP_L: | |
| return image[:, :, :size, :size] | |
| elif part == ImageType.REAL_UP_R: | |
| return image[:, :, :size, size:] | |
| elif part == ImageType.REAL_DOWN_L: | |
| return image[:, :, size:, :size] | |
| elif part == ImageType.REAL_DOWN_R: | |
| return image[:, :, size:, size:] | |
| else: | |
| raise ValueError('invalid part') | |
| def init_weights(module: nn.Module): | |
| if isinstance(module, nn.Conv2d): | |
| torch.nn.init.normal_(module.weight, 0.0, 0.02) | |
| if isinstance(module, nn.BatchNorm2d): | |
| torch.nn.init.normal_(module.weight, 1.0, 0.02) | |
| module.bias.data.fill_(0) | |
| def load_image_from_local(image_path, image_resize=None): | |
| image = Image.open(image_path) | |
| if isinstance(image_resize, tuple): | |
| image = image.resize(image_resize) | |
| return image | |
| def load_image_from_url(image_url, rgba_mode=False, image_resize=None, default_image=None): | |
| try: | |
| image = Image.open(requests.get(image_url, stream=True).raw) | |
| if rgba_mode: | |
| image = image.convert("RGBA") | |
| if isinstance(image_resize, tuple): | |
| image = image.resize(image_resize) | |
| except Exception as e: | |
| image = None | |
| if default_image: | |
| image = load_image_from_local(default_image, image_resize=image_resize) | |
| return image | |
| def image_to_base64(image_array): | |
| buffered = BytesIO() | |
| image_array.save(buffered, format="PNG") | |
| image_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return f"data:image/png;base64, {image_b64}" | |