Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms as TF | |
| def load_and_preprocess_images(image_path_list, mode="crop"): | |
| """ | |
| A quick start function to load and preprocess images for model input. | |
| This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes. | |
| Args: | |
| image_path_list (list): List of paths to image files | |
| mode (str, optional): Preprocessing mode, either "crop" or "pad". | |
| - "crop" (default): Sets width to 518px and center crops height if needed. | |
| - "pad": Preserves all pixels by making the largest dimension 518px | |
| and padding the smaller dimension to reach a square shape. | |
| Returns: | |
| torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W) | |
| Raises: | |
| ValueError: If the input list is empty or if mode is invalid | |
| Notes: | |
| - Images with different dimensions will be padded with white (value=1.0) | |
| - A warning is printed when images have different shapes | |
| - When mode="crop": The function ensures width=518px while maintaining aspect ratio | |
| and height is center-cropped if larger than 518px | |
| - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio | |
| and the smaller dimension is padded to reach a square shape (518x518) | |
| - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements | |
| """ | |
| # Check for empty list | |
| if len(image_path_list) == 0: | |
| raise ValueError("At least 1 image is required") | |
| # Validate mode | |
| if mode not in ["crop", "pad"]: | |
| raise ValueError("Mode must be either 'crop' or 'pad'") | |
| images = [] | |
| shapes = set() | |
| to_tensor = TF.ToTensor() | |
| target_size = 224 | |
| # First process all images and collect their shapes | |
| for image_path in image_path_list: | |
| # Open image | |
| img = Image.open(image_path) | |
| # If there's an alpha channel, blend onto white background: | |
| if img.mode == "RGBA": | |
| # Create white background | |
| background = Image.new("RGBA", img.size, (255, 255, 255, 255)) | |
| # Alpha composite onto the white background | |
| img = Image.alpha_composite(background, img) | |
| # Now convert to "RGB" (this step assigns white for transparent areas) | |
| img = img.convert("RGB") | |
| width, height = img.size | |
| if mode == "pad": | |
| # Make the largest dimension 518px while maintaining aspect ratio | |
| if width >= height: | |
| new_width = target_size | |
| new_height = round(height * (new_width / width) / 14) * 14 # Make divisible by 14 | |
| else: | |
| new_height = target_size | |
| new_width = round(width * (new_height / height) / 14) * 14 # Make divisible by 14 | |
| else: # mode == "crop" | |
| # Original behavior: set width to 518px | |
| new_width = target_size | |
| # Calcu late height maintaining aspect ratio, divisible by 14 | |
| # new_height = round(height * (new_width / width) / 14) * 14 | |
| new_height = target_size | |
| # Resize with new dimensions (width, height) | |
| img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) | |
| img = to_tensor(img) # Convert to tensor (0, 1) | |
| # Center crop height if it's larger than 518 (only in crop mode) | |
| if mode == "crop" and new_height > target_size: | |
| start_y = (new_height - target_size) // 2 | |
| img = img[:, start_y: start_y + target_size, :] | |
| # For pad mode, pad to make a square of target_size x target_size | |
| if mode == "pad": | |
| h_padding = target_size - img.shape[1] | |
| w_padding = target_size - img.shape[2] | |
| if h_padding > 0 or w_padding > 0: | |
| pad_top = h_padding // 2 | |
| pad_bottom = h_padding - pad_top | |
| pad_left = w_padding // 2 | |
| pad_right = w_padding - pad_left | |
| # Pad with white (value=1.0) | |
| img = torch.nn.functional.pad( | |
| img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 | |
| ) | |
| shapes.add((img.shape[1], img.shape[2])) | |
| images.append(img) | |
| # Check if we have different shapes | |
| # In theory our model can also work well with different shapes | |
| if len(shapes) > 1: | |
| print(f"Warning: Found images with different shapes: {shapes}") | |
| # Find maximum dimensions | |
| max_height = max(shape[0] for shape in shapes) | |
| max_width = max(shape[1] for shape in shapes) | |
| # Pad images if necessary | |
| padded_images = [] | |
| for img in images: | |
| h_padding = max_height - img.shape[1] | |
| w_padding = max_width - img.shape[2] | |
| if h_padding > 0 or w_padding > 0: | |
| pad_top = h_padding // 2 | |
| pad_bottom = h_padding - pad_top | |
| pad_left = w_padding // 2 | |
| pad_right = w_padding - pad_left | |
| img = torch.nn.functional.pad( | |
| img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 | |
| ) | |
| padded_images.append(img) | |
| images = padded_images | |
| images = torch.stack(images) # concatenate images | |
| # Ensure correct shape when single image | |
| if len(image_path_list) == 1: | |
| # Verify shape is (1, C, H, W) | |
| if images.dim() == 3: | |
| images = images.unsqueeze(0) | |
| return images | |