Spaces:
Running
on
Zero
Running
on
Zero
| import math | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| def get_tiled_scale_steps(width: int, height: int, tile_x: int, tile_y: int, overlap: int) -> int: | |
| """#### Calculate the number of steps required for tiled scaling. | |
| #### Args: | |
| - `width` (int): The width of the image. | |
| - `height` (int): The height of the image. | |
| - `tile_x` (int): The width of each tile. | |
| - `tile_y` (int): The height of each tile. | |
| - `overlap` (int): The overlap between tiles. | |
| #### Returns: | |
| - `int`: The number of steps required for tiled scaling. | |
| """ | |
| return math.ceil((height / (tile_y - overlap))) * math.ceil( | |
| (width / (tile_x - overlap)) | |
| ) | |
| def tiled_scale( | |
| samples: torch.Tensor, | |
| function: callable, | |
| tile_x: int = 64, | |
| tile_y: int = 64, | |
| overlap: int = 8, | |
| upscale_amount: float = 4, | |
| out_channels: int = 3, | |
| pbar: any = None, | |
| ) -> torch.Tensor: | |
| """#### Perform tiled scaling on a batch of samples. | |
| #### Args: | |
| - `samples` (torch.Tensor): The input samples. | |
| - `function` (callable): The function to apply to each tile. | |
| - `tile_x` (int, optional): The width of each tile. Defaults to 64. | |
| - `tile_y` (int, optional): The height of each tile. Defaults to 64. | |
| - `overlap` (int, optional): The overlap between tiles. Defaults to 8. | |
| - `upscale_amount` (float, optional): The upscale amount. Defaults to 4. | |
| - `out_channels` (int, optional): The number of output channels. Defaults to 3. | |
| - `pbar` (any, optional): The progress bar. Defaults to None. | |
| #### Returns: | |
| - `torch.Tensor`: The scaled output tensor. | |
| """ | |
| output = torch.empty( | |
| ( | |
| samples.shape[0], | |
| out_channels, | |
| round(samples.shape[2] * upscale_amount), | |
| round(samples.shape[3] * upscale_amount), | |
| ), | |
| device="cpu", | |
| ) | |
| for b in range(samples.shape[0]): | |
| s = samples[b : b + 1] | |
| out = torch.zeros( | |
| ( | |
| s.shape[0], | |
| out_channels, | |
| round(s.shape[2] * upscale_amount), | |
| round(s.shape[3] * upscale_amount), | |
| ), | |
| device="cpu", | |
| ) | |
| out_div = torch.zeros( | |
| ( | |
| s.shape[0], | |
| out_channels, | |
| round(s.shape[2] * upscale_amount), | |
| round(s.shape[3] * upscale_amount), | |
| ), | |
| device="cpu", | |
| ) | |
| for y in range(0, s.shape[2], tile_y - overlap): | |
| for x in range(0, s.shape[3], tile_x - overlap): | |
| s_in = s[:, :, y : y + tile_y, x : x + tile_x] | |
| ps = function(s_in).cpu() | |
| mask = torch.ones_like(ps) | |
| feather = round(overlap * upscale_amount) | |
| for t in range(feather): | |
| mask[:, :, t : 1 + t, :] *= (1.0 / feather) * (t + 1) | |
| mask[:, :, mask.shape[2] - 1 - t : mask.shape[2] - t, :] *= ( | |
| 1.0 / feather | |
| ) * (t + 1) | |
| mask[:, :, :, t : 1 + t] *= (1.0 / feather) * (t + 1) | |
| mask[:, :, :, mask.shape[3] - 1 - t : mask.shape[3] - t] *= ( | |
| 1.0 / feather | |
| ) * (t + 1) | |
| out[ | |
| :, | |
| :, | |
| round(y * upscale_amount) : round((y + tile_y) * upscale_amount), | |
| round(x * upscale_amount) : round((x + tile_x) * upscale_amount), | |
| ] += ps * mask | |
| out_div[ | |
| :, | |
| :, | |
| round(y * upscale_amount) : round((y + tile_y) * upscale_amount), | |
| round(x * upscale_amount) : round((x + tile_x) * upscale_amount), | |
| ] += mask | |
| output[b : b + 1] = out / out_div | |
| return output | |
| def flatten(img: Image.Image, bgcolor: str) -> Image.Image: | |
| """#### Replace transparency with a background color. | |
| #### Args: | |
| - `img` (Image.Image): The input image. | |
| - `bgcolor` (str): The background color. | |
| #### Returns: | |
| - `Image.Image`: The image with transparency replaced by the background color. | |
| """ | |
| if img.mode in ("RGB"): | |
| return img | |
| return Image.alpha_composite(Image.new("RGBA", img.size, bgcolor), img).convert( | |
| "RGB" | |
| ) | |
| BLUR_KERNEL_SIZE = 15 | |
| def tensor_to_pil(img_tensor: torch.Tensor, batch_index: int = 0) -> Image.Image: | |
| """#### Convert a tensor to a PIL image. | |
| #### Args: | |
| - `img_tensor` (torch.Tensor): The input tensor. | |
| - `batch_index` (int, optional): The batch index. Defaults to 0. | |
| #### Returns: | |
| - `Image.Image`: The converted PIL image. | |
| """ | |
| # Get the tensor for the specified batch index | |
| tensor = img_tensor[batch_index] | |
| # Handle different tensor dimensions | |
| # The upscaler outputs in [H, W, C] format after movedim(-3, -1) | |
| if tensor.dim() == 3: # [H, W, C] - already in correct format | |
| pass | |
| elif tensor.dim() == 2: # [H, W] - grayscale | |
| pass | |
| else: | |
| raise ValueError(f"Unexpected tensor dimensions: {tensor.shape}") | |
| # Clamp values to valid range [0, 1] and convert to numpy | |
| tensor = torch.clamp(tensor, 0.0, 1.0) | |
| numpy_array = (tensor.cpu().numpy() * 255.0).astype(np.uint8) | |
| # Handle different channel configurations | |
| if numpy_array.ndim == 3: | |
| if numpy_array.shape[2] == 3: | |
| img = Image.fromarray(numpy_array, 'RGB') | |
| elif numpy_array.shape[2] == 1: | |
| img = Image.fromarray(numpy_array.squeeze(axis=2), 'L') | |
| elif numpy_array.shape[2] == 4: | |
| img = Image.fromarray(numpy_array, 'RGBA') | |
| else: | |
| # Fallback: take first 3 channels if more than 3, or convert single channel to grayscale | |
| if numpy_array.shape[2] >= 3: | |
| img = Image.fromarray(numpy_array[:, :, :3], 'RGB') | |
| else: | |
| img = Image.fromarray(numpy_array.squeeze(axis=2), 'L') | |
| elif numpy_array.ndim == 2: | |
| img = Image.fromarray(numpy_array, 'L') | |
| else: | |
| raise ValueError(f"Cannot convert array with shape {numpy_array.shape} to PIL image") | |
| return img | |
| def pil_to_tensor(image: Image.Image) -> torch.Tensor: | |
| """#### Convert a PIL image to a tensor. | |
| #### Args: | |
| - `image` (Image.Image): The input PIL image. | |
| #### Returns: | |
| - `torch.Tensor`: The converted tensor. | |
| """ | |
| # Convert RGBA to RGB if necessary (upscaler models expect 3 channels) | |
| if image.mode == 'RGBA': | |
| # Create a white background for transparency | |
| background = Image.new('RGB', image.size, (255, 255, 255)) | |
| background.paste(image, mask=image.split()[-1]) # Use alpha channel as mask | |
| image = background | |
| elif image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Convert to numpy array and normalize | |
| image_array = np.array(image).astype(np.float32) / 255.0 | |
| # Convert to tensor and add batch dimension: [H, W, C] -> [1, H, W, C] | |
| tensor = torch.from_numpy(image_array).unsqueeze(0) | |
| return tensor | |
| def get_crop_region(mask: Image.Image, pad: int = 0) -> tuple: | |
| """#### Get the coordinates of the white rectangular mask region. | |
| #### Args: | |
| - `mask` (Image.Image): The input mask image in 'L' mode. | |
| - `pad` (int, optional): The padding to apply. Defaults to 0. | |
| #### Returns: | |
| - `tuple`: The coordinates of the crop region. | |
| """ | |
| coordinates = mask.getbbox() | |
| if coordinates is not None: | |
| x1, y1, x2, y2 = coordinates | |
| else: | |
| x1, y1, x2, y2 = mask.width, mask.height, 0, 0 | |
| # Apply padding | |
| x1 = max(x1 - pad, 0) | |
| y1 = max(y1 - pad, 0) | |
| x2 = min(x2 + pad, mask.width) | |
| y2 = min(y2 + pad, mask.height) | |
| return fix_crop_region((x1, y1, x2, y2), (mask.width, mask.height)) | |
| def fix_crop_region(region: tuple, image_size: tuple) -> tuple: | |
| """#### Remove the extra pixel added by the get_crop_region function. | |
| #### Args: | |
| - `region` (tuple): The crop region coordinates. | |
| - `image_size` (tuple): The size of the image. | |
| #### Returns: | |
| - `tuple`: The fixed crop region coordinates. | |
| """ | |
| image_width, image_height = image_size | |
| x1, y1, x2, y2 = region | |
| if x2 < image_width: | |
| x2 -= 1 | |
| if y2 < image_height: | |
| y2 -= 1 | |
| return x1, y1, x2, y2 | |
| def expand_crop(region: tuple, width: int, height: int, target_width: int, target_height: int) -> tuple: | |
| """#### Expand a crop region to a specified target size. | |
| #### Args: | |
| - `region` (tuple): The crop region coordinates. | |
| - `width` (int): The width of the image. | |
| - `height` (int): The height of the image. | |
| - `target_width` (int): The desired width of the crop region. | |
| - `target_height` (int): The desired height of the crop region. | |
| #### Returns: | |
| - `tuple`: The expanded crop region coordinates and the target size. | |
| """ | |
| x1, y1, x2, y2 = region | |
| actual_width = x2 - x1 | |
| actual_height = y2 - y1 | |
| # Try to expand region to the right of half the difference | |
| width_diff = target_width - actual_width | |
| x2 = min(x2 + width_diff // 2, width) | |
| # Expand region to the left of the difference including the pixels that could not be expanded to the right | |
| width_diff = target_width - (x2 - x1) | |
| x1 = max(x1 - width_diff, 0) | |
| # Try the right again | |
| width_diff = target_width - (x2 - x1) | |
| x2 = min(x2 + width_diff, width) | |
| # Try to expand region to the bottom of half the difference | |
| height_diff = target_height - actual_height | |
| y2 = min(y2 + height_diff // 2, height) | |
| # Expand region to the top of the difference including the pixels that could not be expanded to the bottom | |
| height_diff = target_height - (y2 - y1) | |
| y1 = max(y1 - height_diff, 0) | |
| # Try the bottom again | |
| height_diff = target_height - (y2 - y1) | |
| y2 = min(y2 + height_diff, height) | |
| return (x1, y1, x2, y2), (target_width, target_height) | |
| def crop_cond(cond: list, region: tuple, init_size: tuple, canvas_size: tuple, tile_size: tuple, w_pad: int = 0, h_pad: int = 0) -> list: | |
| """#### Crop conditioning data to match a specific region. | |
| #### Args: | |
| - `cond` (list): The conditioning data. | |
| - `region` (tuple): The crop region coordinates. | |
| - `init_size` (tuple): The initial size of the image. | |
| - `canvas_size` (tuple): The size of the canvas. | |
| - `tile_size` (tuple): The size of the tile. | |
| - `w_pad` (int, optional): The width padding. Defaults to 0. | |
| - `h_pad` (int, optional): The height padding. Defaults to 0. | |
| #### Returns: | |
| - `list`: The cropped conditioning data. | |
| """ | |
| cropped = [] | |
| for emb, x in cond: | |
| cond_dict = x.copy() | |
| cond_entry = [emb, cond_dict] | |
| cropped.append(cond_entry) | |
| return cropped | |