Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import math | |
| import random | |
| import sys | |
| from argparse import ArgumentParser | |
| import einops | |
| import k_diffusion as K | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from tqdm.auto import tqdm | |
| from einops import rearrange | |
| from omegaconf import OmegaConf | |
| from PIL import Image, ImageOps | |
| from torch import autocast | |
| import json | |
| import matplotlib.pyplot as plt | |
| import seaborn | |
| from pathlib import Path | |
| sys.path.append("./") | |
| from clip_similarity import ClipSimilarity | |
| from edit_dataset import EditDatasetEval | |
| sys.path.append("./stable_diffusion") | |
| from ldm.util import instantiate_from_config | |
| class CFGDenoiser(nn.Module): | |
| def __init__(self, model): | |
| super().__init__() | |
| self.inner_model = model | |
| def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale): | |
| cfg_z = einops.repeat(z, "1 ... -> n ...", n=3) | |
| cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3) | |
| cfg_cond = { | |
| "c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])], | |
| "c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])], | |
| } | |
| out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3) | |
| return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond) | |
| def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False): | |
| print(f"Loading model from {ckpt}") | |
| pl_sd = torch.load(ckpt, map_location="cpu") | |
| if "global_step" in pl_sd: | |
| print(f"Global Step: {pl_sd['global_step']}") | |
| sd = pl_sd["state_dict"] | |
| if vae_ckpt is not None: | |
| print(f"Loading VAE from {vae_ckpt}") | |
| vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"] | |
| sd = { | |
| k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v | |
| for k, v in sd.items() | |
| } | |
| model = instantiate_from_config(config.model) | |
| m, u = model.load_state_dict(sd, strict=False) | |
| if len(m) > 0 and verbose: | |
| print("missing keys:") | |
| print(m) | |
| if len(u) > 0 and verbose: | |
| print("unexpected keys:") | |
| print(u) | |
| return model | |
| class ImageEditor(nn.Module): | |
| def __init__(self, config, ckpt, vae_ckpt=None): | |
| super().__init__() | |
| config = OmegaConf.load(config) | |
| self.model = load_model_from_config(config, ckpt, vae_ckpt) | |
| self.model.eval().cuda() | |
| self.model_wrap = K.external.CompVisDenoiser(self.model) | |
| self.model_wrap_cfg = CFGDenoiser(self.model_wrap) | |
| self.null_token = self.model.get_learned_conditioning([""]) | |
| def forward( | |
| self, | |
| image: torch.Tensor, | |
| edit: str, | |
| scale_txt: float = 7.5, | |
| scale_img: float = 1.0, | |
| steps: int = 100, | |
| ) -> torch.Tensor: | |
| assert image.dim() == 3 | |
| assert image.size(1) % 64 == 0 | |
| assert image.size(2) % 64 == 0 | |
| with torch.no_grad(), autocast("cuda"), self.model.ema_scope(): | |
| cond = { | |
| "c_crossattn": [self.model.get_learned_conditioning([edit])], | |
| "c_concat": [self.model.encode_first_stage(image[None]).mode()], | |
| } | |
| uncond = { | |
| "c_crossattn": [self.model.get_learned_conditioning([""])], | |
| "c_concat": [torch.zeros_like(cond["c_concat"][0])], | |
| } | |
| extra_args = { | |
| "uncond": uncond, | |
| "cond": cond, | |
| "image_cfg_scale": scale_img, | |
| "text_cfg_scale": scale_txt, | |
| } | |
| sigmas = self.model_wrap.get_sigmas(steps) | |
| x = torch.randn_like(cond["c_concat"][0]) * sigmas[0] | |
| x = K.sampling.sample_euler_ancestral(self.model_wrap_cfg, x, sigmas, extra_args=extra_args) | |
| x = self.model.decode_first_stage(x)[0] | |
| return x | |
| def compute_metrics(config, | |
| model_path, | |
| vae_ckpt, | |
| data_path, | |
| output_path, | |
| scales_img, | |
| scales_txt, | |
| num_samples = 5000, | |
| split = "test", | |
| steps = 50, | |
| res = 512, | |
| seed = 0): | |
| editor = ImageEditor(config, model_path, vae_ckpt).cuda() | |
| clip_similarity = ClipSimilarity().cuda() | |
| outpath = Path(output_path, f"n={num_samples}_p={split}_s={steps}_r={res}_e={seed}.jsonl") | |
| Path(output_path).mkdir(parents=True, exist_ok=True) | |
| for scale_txt in scales_txt: | |
| for scale_img in scales_img: | |
| dataset = EditDatasetEval( | |
| path=data_path, | |
| split=split, | |
| res=res | |
| ) | |
| assert num_samples <= len(dataset) | |
| print(f'Processing t={scale_txt}, i={scale_img}') | |
| torch.manual_seed(seed) | |
| perm = torch.randperm(len(dataset)) | |
| count = 0 | |
| i = 0 | |
| sim_0_avg = 0 | |
| sim_1_avg = 0 | |
| sim_direction_avg = 0 | |
| sim_image_avg = 0 | |
| count = 0 | |
| pbar = tqdm(total=num_samples) | |
| while count < num_samples: | |
| idx = perm[i].item() | |
| sample = dataset[idx] | |
| i += 1 | |
| gen = editor(sample["image_0"].cuda(), sample["edit"], scale_txt=scale_txt, scale_img=scale_img, steps=steps) | |
| sim_0, sim_1, sim_direction, sim_image = clip_similarity( | |
| sample["image_0"][None].cuda(), gen[None].cuda(), [sample["input_prompt"]], [sample["output_prompt"]] | |
| ) | |
| sim_0_avg += sim_0.item() | |
| sim_1_avg += sim_1.item() | |
| sim_direction_avg += sim_direction.item() | |
| sim_image_avg += sim_image.item() | |
| count += 1 | |
| pbar.update(count) | |
| pbar.close() | |
| sim_0_avg /= count | |
| sim_1_avg /= count | |
| sim_direction_avg /= count | |
| sim_image_avg /= count | |
| with open(outpath, "a") as f: | |
| f.write(f"{json.dumps(dict(sim_0=sim_0_avg, sim_1=sim_1_avg, sim_direction=sim_direction_avg, sim_image=sim_image_avg, num_samples=num_samples, split=split, scale_txt=scale_txt, scale_img=scale_img, steps=steps, res=res, seed=seed))}\n") | |
| return outpath | |
| def plot_metrics(metrics_file, output_path): | |
| with open(metrics_file, 'r') as f: | |
| data = [json.loads(line) for line in f] | |
| plt.rcParams.update({'font.size': 11.5}) | |
| seaborn.set_style("darkgrid") | |
| plt.figure(figsize=(20.5* 0.7, 10.8* 0.7), dpi=200) | |
| x = [d["sim_direction"] for d in data] | |
| y = [d["sim_image"] for d in data] | |
| plt.plot(x, y, marker='o', linewidth=2, markersize=4) | |
| plt.xlabel("CLIP Text-Image Direction Similarity", labelpad=10) | |
| plt.ylabel("CLIP Image Similarity", labelpad=10) | |
| plt.savefig(Path(output_path) / Path("plot.pdf"), bbox_inches="tight") | |
| def main(): | |
| parser = ArgumentParser() | |
| parser.add_argument("--resolution", default=512, type=int) | |
| parser.add_argument("--steps", default=100, type=int) | |
| parser.add_argument("--config", default="configs/generate.yaml", type=str) | |
| parser.add_argument("--output_path", default="analysis/", type=str) | |
| parser.add_argument("--ckpt", default="checkpoints/instruct-pix2pix-00-22000.ckpt", type=str) | |
| parser.add_argument("--dataset", default="data/clip-filtered-dataset/", type=str) | |
| parser.add_argument("--vae-ckpt", default=None, type=str) | |
| args = parser.parse_args() | |
| scales_img = [1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.2] | |
| scales_txt = [7.5] | |
| metrics_file = compute_metrics( | |
| args.config, | |
| args.ckpt, | |
| args.vae_ckpt, | |
| args.dataset, | |
| args.output_path, | |
| scales_img, | |
| scales_txt | |
| steps = args.steps | |
| ) | |
| plot_metrics(metrics_file, args.output_path) | |
| if __name__ == "__main__": | |
| main() | |