Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import os | |
| import json | |
| import itertools | |
| from PIL import Image, ImageDraw, ImageFont | |
| def wrap_text(text, max_width, draw, font): | |
| """ | |
| Wrap the text to fit within the given width by breaking it into lines. | |
| """ | |
| lines = [] | |
| words = text.split(' ') | |
| current_line = [] | |
| for word in words: | |
| current_line.append(word) | |
| line_width = draw.textbbox((0, 0), ' '.join(current_line), font=font)[2] | |
| if line_width > max_width: | |
| current_line.pop() | |
| lines.append(' '.join(current_line)) | |
| current_line = [word] | |
| if current_line: | |
| lines.append(' '.join(current_line)) | |
| return lines | |
| def image_grid_with_titles(imgs, rows, cols, top_titles, left_titles, margin=20): | |
| assert len(imgs) == rows * cols | |
| assert len(top_titles) == cols | |
| assert len(left_titles) == rows | |
| imgs = [img.resize((256, 256)) for img in imgs] | |
| w, h = imgs[0].size | |
| title_height = 50 | |
| title_width = 120 | |
| grid_width = cols * (w + margin) + title_width + margin | |
| grid_height = rows * (h + margin) + title_height + margin | |
| grid = Image.new('RGB', size=(grid_width, grid_height), color='white') | |
| draw = ImageDraw.Draw(grid) | |
| try: | |
| font = ImageFont.truetype("arial.ttf", 20) | |
| except IOError: | |
| font = ImageFont.load_default() | |
| for i, title in enumerate(top_titles): | |
| wrapped_title = wrap_text(title, w, draw, font) | |
| total_text_height = sum([draw.textbbox((0, 0), line, font=font)[3] for line in wrapped_title]) | |
| y_offset = (title_height - total_text_height) // 2 | |
| for line in wrapped_title: | |
| text_width = draw.textbbox((0, 0), line, font=font)[2] | |
| x_offset = ((i * (w + margin)) + title_width + margin + (w - text_width) // 2) | |
| draw.text((x_offset, y_offset), line, fill="black", font=font) | |
| y_offset += draw.textbbox((0, 0), line, font=font)[3] | |
| for i, title in enumerate(left_titles): | |
| wrapped_title = wrap_text(title, title_width - 10, draw, font) | |
| total_text_height = sum([draw.textbbox((0, 0), line, font=font)[3] for line in wrapped_title]) | |
| y_offset = (i * (h + margin)) + title_height + (h - total_text_height) // 2 + margin | |
| for line in wrapped_title: | |
| text_width = draw.textbbox((0, 0), line, font=font)[2] | |
| x_offset = (title_width - text_width) // 2 | |
| draw.text((x_offset, y_offset), line, fill="black", font=font) | |
| y_offset += draw.textbbox((0, 0), line, font=font)[3] | |
| for i, img in enumerate(imgs): | |
| x_pos = (i % cols) * (w + margin) + title_width + margin | |
| y_pos = (i // cols) * (h + margin) + title_height + margin | |
| grid.paste(img, box=(x_pos, y_pos)) | |
| return grid | |
| def create_grids(config): | |
| num_samples = config["num_samples"] | |
| concept_dirs = config["input_dirs_concepts"] | |
| output_base_dir = config["output_base_dir"] | |
| output_grid_dir = os.path.join(output_base_dir, "grids") | |
| os.makedirs(output_grid_dir, exist_ok=True) | |
| base_images = os.listdir(config["input_dir_base"]) | |
| if len(concept_dirs) == 1: | |
| # Special case: Single concept | |
| last_concept_dir = concept_dirs[0] | |
| last_concept_images = os.listdir(last_concept_dir) | |
| top_titles = ["Base Image", "Concept 1"] + ["Samples"] + [""] * (num_samples - 1) | |
| left_titles = ["" for i in range(len(last_concept_images))] | |
| def load_image(path): | |
| return Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), color="white") | |
| for base_image in base_images: | |
| base_image_path = os.path.join(config["input_dir_base"], base_image) | |
| images = [] | |
| for last_image in last_concept_images: | |
| last_image_path = os.path.join(last_concept_dir, last_image) | |
| row_images = [load_image(base_image_path), load_image(last_image_path)] | |
| # Add generated samples for the current row | |
| sample_dir = os.path.join(output_base_dir, f"{base_image}_to_{last_image}") | |
| if os.path.exists(sample_dir): | |
| sample_images = sorted(os.listdir(sample_dir)) | |
| row_images.extend([load_image(os.path.join(sample_dir, sample_image)) for sample_image in sample_images]) | |
| images.extend(row_images) | |
| # Fill empty spaces to match the grid dimensions | |
| total_required = len(left_titles) * len(top_titles) | |
| if len(images) < total_required: | |
| images.extend([Image.new("RGB", (256, 256), color="white")] * (total_required - len(images))) | |
| # Create the grid | |
| grid = image_grid_with_titles( | |
| imgs=images, | |
| rows=len(left_titles), | |
| cols=len(top_titles), | |
| top_titles=top_titles, | |
| left_titles=left_titles | |
| ) | |
| # Save the grid | |
| grid_save_path = os.path.join(output_grid_dir, f"grid_base_{base_image}_concept1.png") | |
| grid.save(grid_save_path) | |
| print(f"Grid saved at {grid_save_path}") | |
| else: | |
| # General case: Multiple concepts | |
| fixed_concepts = concept_dirs[:-1] | |
| last_concept_dir = concept_dirs[-1] | |
| last_concept_images = os.listdir(last_concept_dir) | |
| top_titles = ["Base Image"] + [f"Concept {i+1}" for i in range(len(fixed_concepts))] + ["Last Concept"] + ["Samples"] + [""] * (num_samples - 1) | |
| left_titles = ["" for i in range(len(last_concept_images))] | |
| def load_image(path): | |
| return Image.open(path) if os.path.exists(path) else Image.new("RGB", (256, 256), color="white") | |
| fixed_concept_images = [os.listdir(concept_dir) for concept_dir in fixed_concepts] | |
| for base_image in base_images: | |
| base_image_path = os.path.join(config["input_dir_base"], base_image) | |
| fixed_combinations = itertools.product(*fixed_concept_images) | |
| for fixed_combination in fixed_combinations: | |
| images = [] | |
| # Build fixed combination row | |
| fixed_images = [load_image(base_image_path)] | |
| for concept_dir, concept_image in zip(fixed_concepts, fixed_combination): | |
| concept_image_path = os.path.join(concept_dir, concept_image) | |
| fixed_images.append(load_image(concept_image_path)) | |
| # Iterate over last concept for rows | |
| for last_image in last_concept_images: | |
| last_image_path = os.path.join(last_concept_dir, last_image) | |
| row_images = fixed_images + [load_image(last_image_path)] | |
| # Add generated samples for the current row | |
| sample_dir = os.path.join(output_base_dir, f"{base_image}_to_" + "_".join([f"{concept_image}" for concept_image in fixed_combination]) + f"_{last_image}") | |
| if os.path.exists(sample_dir): | |
| sample_images = sorted(os.listdir(sample_dir)) | |
| row_images.extend([load_image(os.path.join(sample_dir, sample_image)) for sample_image in sample_images]) | |
| images.extend(row_images) | |
| # Fill empty spaces to match the grid dimensions | |
| total_required = len(left_titles) * len(top_titles) | |
| if len(images) < total_required: | |
| images.extend([Image.new("RGB", (256, 256), color="white")] * (total_required - len(images))) | |
| # Create the grid | |
| grid = image_grid_with_titles( | |
| imgs=images, | |
| rows=len(left_titles), | |
| cols=len(top_titles), | |
| top_titles=top_titles, | |
| left_titles=left_titles | |
| ) | |
| # Save the grid | |
| grid_save_path = os.path.join(output_grid_dir, f"grid_base_{base_image}_combo_{'_'.join(map(str, fixed_combination))}.png") | |
| grid.save(grid_save_path) | |
| print(f"Grid saved at {grid_save_path}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Create image grids based on a configuration file.") | |
| parser.add_argument("config_path", type=str, help="Path to the configuration JSON file.") | |
| args = parser.parse_args() | |
| # Load the configuration | |
| with open(args.config_path, 'r') as f: | |
| config = json.load(f) | |
| if "num_samples" not in config: | |
| config["num_samples"] = 4 | |
| create_grids(config) | |