Spaces:
Running
on
Zero
Running
on
Zero
| from PIL import Image, ImageDraw, ImageFont | |
| import os | |
| import torch | |
| import glob | |
| import matplotlib.pyplot as plt | |
| def read_images_in_path(path, size = (512,512)): | |
| image_paths = [] | |
| for filename in os.listdir(path): | |
| if filename.endswith(".png") or filename.endswith(".jpg") or filename.endswith(".jpeg"): | |
| image_path = os.path.join(path, filename) | |
| image_paths.append(image_path) | |
| image_paths = sorted(image_paths) | |
| return [Image.open(image_path).convert("RGB").resize(size) for image_path in image_paths] | |
| def concatenate_images(image_lists, return_list = False): | |
| num_rows = len(image_lists[0]) | |
| num_columns = len(image_lists) | |
| image_width = image_lists[0][0].width | |
| image_height = image_lists[0][0].height | |
| grid_width = num_columns * image_width | |
| grid_height = num_rows * image_height if not return_list else image_height | |
| if not return_list: | |
| grid_image = [Image.new('RGB', (grid_width, grid_height))] | |
| else: | |
| grid_image = [Image.new('RGB', (grid_width, grid_height)) for i in range(num_rows)] | |
| for i in range(num_rows): | |
| row_index = i if return_list else 0 | |
| for j in range(num_columns): | |
| image = image_lists[j][i] | |
| x_offset = j * image_width | |
| y_offset = i * image_height if not return_list else 0 | |
| grid_image[row_index].paste(image, (x_offset, y_offset)) | |
| return grid_image if return_list else grid_image[0] | |
| def concatenate_images_single(image_lists): | |
| num_columns = len(image_lists) | |
| image_width = image_lists[0].width | |
| image_height = image_lists[0].height | |
| grid_width = num_columns * image_width | |
| grid_height = image_height | |
| grid_image = Image.new('RGB', (grid_width, grid_height)) | |
| for j in range(num_columns): | |
| image = image_lists[j] | |
| x_offset = j * image_width | |
| y_offset = 0 | |
| grid_image.paste(image, (x_offset, y_offset)) | |
| return grid_image | |
| def get_captions_for_images(images, device): | |
| from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
| processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") | |
| model = Blip2ForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16 | |
| ) # doctest: +IGNORE_RESULT | |
| res = [] | |
| for image in images: | |
| inputs = processor(images=image, return_tensors="pt").to(device, torch.float16) | |
| generated_ids = model.generate(**inputs) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() | |
| res.append(generated_text) | |
| del processor | |
| del model | |
| return res | |
| def find_and_plot_images(directory, output_file, recursive=True, figsize=(15, 15), image_formats=("*.png", "*.jpg", "*.jpeg", "*.bmp", "*.tiff")): | |
| """ | |
| Finds all images in the specified directory (optionally recursively) | |
| and saves them in a single figure with their filenames. | |
| Parameters: | |
| directory (str): Path to the directory. | |
| output_file (str): Path to save the resulting figure (e.g., 'output.png'). | |
| recursive (bool): Whether to search directories recursively. | |
| figsize (tuple): Size of the resulting figure. | |
| image_formats (tuple): Image file formats to look for. | |
| Returns: | |
| None | |
| """ | |
| # Gather all image file paths | |
| pattern = "**/" if recursive else "" | |
| images = [] | |
| for fmt in image_formats: | |
| images.extend(glob.glob(os.path.join(directory, pattern + fmt), recursive=recursive)) | |
| images = [image for image in images if "noise.jpg" not in image and "results.jpg" not in image] # Filter out noise and result images | |
| # move "original" to the front, followed by "reconstruction" and then the rest | |
| images = sorted( | |
| images, | |
| key=lambda x: (not x.endswith("original.jpg"), not x.endswith("reconstruction.jpg"), x) | |
| ) | |
| if not images: | |
| print("No images found!") | |
| return | |
| # Create a figure | |
| num_images = len(images) | |
| cols = num_images # Max 5 images per row | |
| rows = (num_images + cols - 1) // cols # Calculate number of rows | |
| fig, axs = plt.subplots(rows, cols, figsize=figsize) | |
| axs = axs.flatten() if num_images > 1 else [axs] # Flatten axes for single image case | |
| for i, image_path in enumerate(images): | |
| # Open and plot image | |
| img = Image.open(image_path) | |
| axs[i].imshow(img) | |
| axs[i].axis('off') # Remove axes | |
| axs[i].set_title(os.path.basename(image_path), fontsize=8) # Add filename | |
| # Hide any remaining empty axes | |
| for j in range(i + 1, len(axs)): | |
| axs[j].axis('off') | |
| plt.tight_layout() | |
| plt.savefig(output_file, bbox_inches='tight', dpi=300) # Save the figure to the file | |
| plt.close(fig) # Close the figure to free up memory | |
| print(f"Figure saved to {output_file}") | |
| def add_label_to_image(image, label): | |
| """ | |
| Adds a label to the lower-right corner of an image. | |
| Args: | |
| image (PIL.Image): Image to add the label to. | |
| label (str): Text to add as a label. | |
| Returns: | |
| PIL.Image: Image with the added label. | |
| """ | |
| # Create a drawing context | |
| draw = ImageDraw.Draw(image) | |
| # Create a drawing context | |
| draw = ImageDraw.Draw(image) | |
| # Define font and size | |
| font_size = int(min(image.size) * 0.05) # Adjust font size based on image dimensions | |
| try: | |
| font = ImageFont.truetype("fonts/arial.ttf", font_size) # Replace with a font path if needed | |
| except IOError: | |
| font = ImageFont.load_default() # Fallback to default font if arial.ttf is not found | |
| # Measure text size using textbbox | |
| text_bbox = draw.textbbox((0, 0), label, font=font) # (left, top, right, bottom) | |
| text_width = text_bbox[2] - text_bbox[0] | |
| text_height = text_bbox[3] - text_bbox[1] | |
| # Position the text in the lower-right corner with some padding | |
| padding = 10 | |
| position = (image.width - text_width - padding, image.height - text_height - padding) | |
| # Add a semi-transparent background for the label | |
| draw.rectangle( | |
| [ | |
| (position[0] - padding, position[1] - padding), | |
| (position[0] + text_width + padding, position[1] + text_height + padding) | |
| ], | |
| fill=(0, 0, 0, 150) # Black with transparency | |
| ) | |
| # Draw the label | |
| draw.text(position, label, fill="white", font=font) | |
| return image | |
| def crop_center_square_and_resize(img, size, output_path=None): | |
| """ | |
| Crops the center of an image to make it square. | |
| Args: | |
| img (PIL.Image): Image to crop. | |
| output_path (str, optional): Path to save the cropped image. If None, the cropped image is not saved. | |
| Returns: | |
| Image: The cropped square image. | |
| """ | |
| width, height = img.size | |
| # Determine the shorter side | |
| side_length = min(width, height) | |
| # Calculate the cropping box | |
| left = (width - side_length) // 2 | |
| top = (height - side_length) // 2 | |
| right = left + side_length | |
| bottom = top + side_length | |
| # Crop the image | |
| cropped_img = img.crop((left, top, right, bottom)) | |
| # Resize the image | |
| cropped_img = cropped_img.resize(size) | |
| # Save the cropped image if output path is specified | |
| if output_path: | |
| cropped_img.save(output_path) | |
| return cropped_img | |