Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import json | |
| import torch | |
| import gc | |
| import numpy as np | |
| from PIL import Image | |
| from diffusers import StableDiffusionXLPipeline | |
| import open_clip | |
| from huggingface_hub import hf_hub_download | |
| from IP_Adapter.ip_adapter import IPAdapterXL | |
| from perform_swap import compute_dataset_embeds_svd, get_modified_images_embeds_composition | |
| from create_grids import create_grids | |
| import argparse | |
| def save_images(output_dir, image_list): | |
| os.makedirs(output_dir, exist_ok=True) | |
| for i, img in enumerate(image_list): | |
| img.save(os.path.join(output_dir, f"sample_{i + 1}.png")) | |
| def get_image_embeds(pil_image, model, preprocess, device): | |
| image = preprocess(pil_image)[np.newaxis, :, :, :] | |
| with torch.no_grad(): | |
| embeds = model.encode_image(image.to(device)) | |
| return embeds.cpu().detach().numpy() | |
| def process_combo( | |
| image_embeds_base, | |
| image_names_base, | |
| concept_embeds, | |
| concept_names, | |
| projection_matrices, | |
| ip_model, | |
| output_base_dir, | |
| num_samples=4, | |
| seed=420, | |
| prompt=None, | |
| scale=1.0 | |
| ): | |
| for base_embed, base_name in zip(image_embeds_base, image_names_base): | |
| # Generate all combinations of concept embeddings | |
| for combo_indices in np.ndindex(*(len(embeds) for embeds in concept_embeds)): | |
| concept_combo_names = [concept_names[c][idx] for c, idx in enumerate(combo_indices)] | |
| combo_dir = os.path.join( | |
| output_base_dir, | |
| f"{base_name}_to_" + "_".join(concept_combo_names) | |
| ) | |
| if os.path.exists(combo_dir): | |
| print(f"Directory {combo_dir} already exists. Skipping...") | |
| continue | |
| projections_data = [ | |
| { | |
| "embed": concept_embeds[c][idx], | |
| "projection_matrix": projection_matrices[c] | |
| } | |
| for c, idx in enumerate(combo_indices) | |
| ] | |
| modified_images = get_modified_images_embeds_composition( | |
| base_embed, projections_data, ip_model, prompt=prompt, scale=scale, num_samples=num_samples, seed=seed | |
| ) | |
| save_images(combo_dir, modified_images) | |
| del modified_images | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| def main(config_path, should_create_grids): | |
| with open(config_path, 'r') as f: | |
| config = json.load(f) | |
| if "prompt" not in config: | |
| config["prompt"] = None | |
| if "scale" not in config: | |
| config["scale"] = 1.0 if config["prompt"] is None else 0.6 | |
| if "seed" not in config: | |
| config["seed"] = 420 | |
| if "num_samples" not in config: | |
| config["num_samples"] = 4 | |
| base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| base_model_path, | |
| torch_dtype=torch.float16, | |
| add_watermarker=False, | |
| ) | |
| image_encoder_repo = 'h94/IP-Adapter' | |
| image_encoder_subfolder = 'models/image_encoder' | |
| ip_ckpt = hf_hub_download('h94/IP-Adapter', subfolder="sdxl_models", filename='ip-adapter_sdxl_vit-h.bin') | |
| device = "cuda" | |
| ip_model = IPAdapterXL(pipe, image_encoder_repo, image_encoder_subfolder, ip_ckpt, device) | |
| device = 'cuda:0' | |
| model, _, preprocess = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K') | |
| model.to(device) | |
| # Get base image embeddings | |
| image_files_base = [os.path.join(config["input_dir_base"], f) for f in os.listdir(config["input_dir_base"]) if f.lower().endswith(('png', 'jpg', 'jpeg'))] | |
| image_embeds_base = [] | |
| image_names_base = [] | |
| for path in image_files_base: | |
| img_name = os.path.basename(path) | |
| image_names_base.append(img_name) | |
| image_embeds_base.append(get_image_embeds(Image.open(path).convert("RGB"), model, preprocess, device)) | |
| # Handle n concepts | |
| concept_dirs = config["input_dirs_concepts"] | |
| concept_embeds = [] | |
| concept_names = [] | |
| projection_matrices = [] | |
| for concept_dir, embeds_path, rank in zip(concept_dirs, config["all_embeds_paths"], config["ranks"]): | |
| image_files = [os.path.join(concept_dir, f) for f in os.listdir(concept_dir) if f.lower().endswith(('png', 'jpg', 'jpeg'))] | |
| embeds = [] | |
| names = [] | |
| for path in image_files: | |
| img_name = os.path.basename(path) | |
| names.append(img_name) | |
| embeds.append(get_image_embeds(Image.open(path).convert("RGB"), model, preprocess, device)) | |
| concept_embeds.append(embeds) | |
| concept_names.append(names) | |
| with open(embeds_path, "rb") as f: | |
| all_embeds_in = np.load(f) | |
| projection_matrix = compute_dataset_embeds_svd(all_embeds_in, rank) | |
| projection_matrices.append(projection_matrix) | |
| # Process combinations | |
| process_combo( | |
| image_embeds_base, | |
| image_names_base, | |
| concept_embeds, | |
| concept_names, | |
| projection_matrices, | |
| ip_model, | |
| config["output_base_dir"], | |
| config["num_samples"], | |
| config["seed"], | |
| config["prompt"], | |
| config["scale"] | |
| ) | |
| # generate grids | |
| if should_create_grids: | |
| create_grids(config) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Process images using embeddings and configurations.") | |
| parser.add_argument("--config", type=str, required=True, help="Path to the configuration JSON file.") | |
| parser.add_argument("--create_grids", action="store_true", help="Enable grid creation") | |
| args = parser.parse_args() | |
| main(args.config, args.create_grids) | |