Spaces:
Build error
Build error
| ################################################################################ | |
| # Copyright (C) 2023 Jiayi Guo, Xingqian Xu, Manushree Vasu - All Rights Reserved # | |
| ################################################################################ | |
| import gradio as gr | |
| import os | |
| import os.path as osp | |
| import PIL | |
| from PIL import Image | |
| import numpy as np | |
| from collections import OrderedDict | |
| from easydict import EasyDict as edict | |
| from functools import partial | |
| import torch | |
| import torchvision.transforms as tvtrans | |
| import time | |
| import argparse | |
| import json | |
| import hashlib | |
| import copy | |
| from tqdm import tqdm | |
| from diffusers import StableDiffusionPipeline | |
| from diffusers import DDIMScheduler | |
| from app_utils import auto_dropdown | |
| from huggingface_hub import hf_hub_download | |
| import spaces | |
| version = "Smooth Diffusion Demo v1.0" | |
| refresh_symbol = "\U0001f504" # π | |
| recycle_symbol = '\U0000267b' # | |
| ############## | |
| # model_book # | |
| ############## | |
| choices = edict() | |
| choices.diffuser = OrderedDict([ | |
| ['SD-v1-5' , "runwayml/stable-diffusion-v1-5"], | |
| ['OJ-v4' , "prompthero/openjourney-v4"], | |
| ['RR-v2', "SG161222/Realistic_Vision_V2.0"], | |
| ]) | |
| choices.lora = OrderedDict([ | |
| ['empty', ""], | |
| ['Smooth-LoRA-v1', hf_hub_download('shi-labs/smooth-diffusion-lora', 'smooth_lora.safetensors')], | |
| ]) | |
| choices.scheduler = OrderedDict([ | |
| ['DDIM', DDIMScheduler], | |
| ]) | |
| choices.inversion = OrderedDict([ | |
| ['NTI', 'NTI'], | |
| ['DDIM w/o text', 'DDIM w/o text'], | |
| ['DDIM', 'DDIM'], | |
| ]) | |
| default = edict() | |
| default.diffuser = 'RR-v2' | |
| default.scheduler = 'DDIM' | |
| default.lora = 'Smooth-LoRA-v1' | |
| default.inversion = 'NTI' | |
| default.step = 50 | |
| default.cfg_scale = 7.5 | |
| default.framen = 24 | |
| default.fps = 16 | |
| default.nullinv_inner_step = 10 | |
| default.threshold = 0.8 | |
| default.variation = 0.8 | |
| ########## | |
| # helper # | |
| ########## | |
| def lerp(t, v0, v1): | |
| if isinstance(t, float): | |
| return v0*(1-t) + v1*t | |
| elif isinstance(t, (list, np.ndarray)): | |
| return [v0*(1-ti) + v1*ti for ti in t] | |
| def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): | |
| # mostly copied from | |
| # https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c | |
| v0_unit = v0 / np.linalg.norm(v0) | |
| v1_unit = v1 / np.linalg.norm(v1) | |
| dot = np.sum(v0_unit * v1_unit) | |
| if np.abs(dot) > DOT_THRESHOLD: | |
| return lerp(t, v0, v1) | |
| # Calculate initial angle between v0 and v1 | |
| theta_0 = np.arccos(dot) | |
| sin_theta_0 = np.sin(theta_0) | |
| # Angle at timestep t | |
| if isinstance(t, float): | |
| tlist = [t] | |
| elif isinstance(t, (list, np.ndarray)): | |
| tlist = t | |
| v2_list = [] | |
| for ti in tlist: | |
| theta_t = theta_0 * ti | |
| sin_theta_t = np.sin(theta_t) | |
| # Finish the slerp algorithm | |
| s0 = np.sin(theta_0 - theta_t) / sin_theta_0 | |
| s1 = sin_theta_t / sin_theta_0 | |
| v2 = s0 * v0 + s1 * v1 | |
| v2_list.append(v2) | |
| if isinstance(t, float): | |
| return v2_list[0] | |
| else: | |
| return v2_list | |
| def offset_resize(image, width=512, height=512, left=0, right=0, top=0, bottom=0): | |
| image = np.array(image)[:, :, :3] | |
| h, w, c = image.shape | |
| left = min(left, w-1) | |
| right = min(right, w - left - 1) | |
| top = min(top, h - left - 1) | |
| bottom = min(bottom, h - top - 1) | |
| image = image[top:h-bottom, left:w-right] | |
| h, w, c = image.shape | |
| if h < w: | |
| offset = (w - h) // 2 | |
| image = image[:, offset:offset + h] | |
| elif w < h: | |
| offset = (h - w) // 2 | |
| image = image[offset:offset + w] | |
| image = Image.fromarray(image).resize((width, height)) | |
| return image | |
| def auto_dtype_device_shape(tlist, v0, v1, func,): | |
| vshape = v0.shape | |
| assert v0.shape == v1.shape | |
| assert isinstance(tlist, (list, np.ndarray)) | |
| if isinstance(v0, torch.Tensor): | |
| is_torch = True | |
| dtype, device = v0.dtype, v0.device | |
| v0 = v0.to('cpu').numpy().astype(float).flatten() | |
| v1 = v1.to('cpu').numpy().astype(float).flatten() | |
| else: | |
| is_torch = False | |
| dtype = v0.dtype | |
| assert isinstance(v0, np.ndarray) | |
| assert isinstance(v1, np.ndarray) | |
| v0 = v0.astype(float).flatten() | |
| v1 = v1.astype(float).flatten() | |
| r = func(tlist, v0, v1) | |
| if is_torch: | |
| r = [torch.Tensor(ri).view(*vshape).to(dtype).to(device) for ri in r] | |
| else: | |
| r = [ri.astype(dtype) for ri in r] | |
| return r | |
| auto_lerp = partial(auto_dtype_device_shape, func=lerp) | |
| auto_slerp = partial(auto_dtype_device_shape, func=slerp) | |
| def frames2mp4(vpath, frames, fps): | |
| import moviepy.editor as mpy | |
| frames = [np.array(framei) for framei in frames] | |
| clip = mpy.ImageSequenceClip(frames, fps=fps) | |
| clip.write_videofile(vpath, fps=fps) | |
| def negseed_to_rndseed(seed): | |
| if seed < 0: | |
| seed = np.random.randint(0, np.iinfo(np.uint32).max-100) | |
| return seed | |
| def regulate_image(pilim): | |
| w, h = pilim.size | |
| w = int(round(w/64)) * 64 | |
| h = int(round(h/64)) * 64 | |
| return pilim.resize([w, h], resample=PIL.Image.BILINEAR) | |
| def txt_to_emb(model, prompt): | |
| text_input = model.tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=model.tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt",) | |
| text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] | |
| return text_embeddings | |
| def hash_pilim(pilim): | |
| hasha = hashlib.md5(pilim.tobytes()).hexdigest() | |
| return hasha | |
| def hash_cfgdict(cfgdict): | |
| hashb = hashlib.md5(json.dumps(cfgdict, sort_keys=True).encode('utf-8')).hexdigest() | |
| return hashb | |
| def remove_earliest_file(path, max_allowance=500, remove_ratio=0.1, ext=None): | |
| if len(os.listdir(path)) <= max_allowance: | |
| return | |
| def get_mtime(fname): | |
| return osp.getmtime(osp.join(path, fname)) | |
| if ext is None: | |
| flist = sorted(os.listdir(path), key=get_mtime) | |
| else: | |
| flist = [fi for fi in os.listdir(path) if fi.endswith(ext)] | |
| flist = sorted(flist, key=get_mtime) | |
| exceedn = max(len(flist)-max_allowance, 0) | |
| removen = int(max_allowance*remove_ratio) | |
| removen = max(1, removen) + exceedn | |
| for fi in flist[0:removen]: | |
| os.remove(osp.join(path, fi)) | |
| def remove_decoupled_file(path, exta='.mp4', extb='.json'): | |
| tag_a = [osp.splitext(fi)[0] for fi in os.listdir(path) if fi.endswith(exta)] | |
| tag_b = [osp.splitext(fi)[0] for fi in os.listdir(path) if fi.endswith(extb)] | |
| tag_a_extra = set(tag_a) - set(tag_b) | |
| tag_b_extra = set(tag_b) - set(tag_a) | |
| [os.remove(osp.join(path, tagi+exta)) for tagi in tag_a_extra] | |
| [os.remove(osp.join(path, tagi+extb)) for tagi in tag_b_extra] | |
| def t2i_core(model, xt, emb, nemb, step=30, cfg_scale=7.5, return_list=False): | |
| from nulltxtinv_wrapper import diffusion_step, latent2image | |
| model.scheduler.set_timesteps(step) | |
| xi = xt | |
| emb = txt_to_emb(model, "") if emb is None else emb | |
| nemb = txt_to_emb(model, "") if nemb is None else nemb | |
| if return_list: | |
| xi_list = [xi.clone()] | |
| for i, t in enumerate(tqdm(model.scheduler.timesteps)): | |
| embi = emb[i] if isinstance(emb, list) else emb | |
| nembi = nemb[i] if isinstance(nemb, list) else nemb | |
| context = torch.cat([nembi, embi]) | |
| xi = diffusion_step(model, xi, context, t, cfg_scale, low_resource=False) | |
| if return_list: | |
| xi_list.append(xi.clone()) | |
| x0 = xi | |
| im = latent2image(model.vae, x0, return_type='pil') | |
| if return_list: | |
| return im, xi_list | |
| else: | |
| return im | |
| ######## | |
| # main # | |
| ######## | |
| class wrapper(object): | |
| def __init__(self, | |
| fp16=False, | |
| tag_diffuser=None, | |
| tag_lora=None, | |
| tag_scheduler=None,): | |
| self.device = "cuda" #if torch.cuda.is_available() else "cpu" | |
| if fp16: | |
| self.torch_dtype = torch.float16 | |
| else: | |
| self.torch_dtype = torch.float32 | |
| self.load_all(tag_diffuser, tag_lora, tag_scheduler) | |
| self.image_latent_dim = 4 | |
| self.batchsize = 8 | |
| self.seed = {} | |
| self.cache_video_folder = "temp/video" | |
| self.cache_video_maxn = 500 | |
| self.cache_image_folder = "temp/image" | |
| self.cache_image_maxn = 500 | |
| self.cache_inverse_folder = "temp/inverse" | |
| self.cache_inverse_maxn = 500 | |
| def load_all(self, tag_diffuser, tag_lora, tag_scheduler): | |
| self.load_diffuser_lora(tag_diffuser, tag_lora) | |
| self.load_scheduler(tag_scheduler) | |
| return tag_diffuser, tag_lora, tag_scheduler | |
| def load_diffuser_lora(self, tag_diffuser, tag_lora): | |
| self.net = StableDiffusionPipeline.from_pretrained( | |
| choices.diffuser[tag_diffuser], torch_dtype=self.torch_dtype).to(self.device) | |
| self.net.safety_checker = None | |
| if tag_lora != 'empty': | |
| self.net.unet.load_attn_procs( | |
| choices.lora[tag_lora], use_safetensors=True,) | |
| self.tag_diffuser = tag_diffuser | |
| self.tag_lora = tag_lora | |
| return tag_diffuser, tag_lora | |
| def load_scheduler(self, tag_scheduler): | |
| self.net.scheduler = choices.scheduler[tag_scheduler].from_config(self.net.scheduler.config) | |
| self.tag_scheduler = tag_scheduler | |
| return tag_scheduler | |
| def reset_seed(self, which='ltintp'): | |
| return -1 | |
| def recycle_seed(self, which='ltintp'): | |
| if which not in self.seed: | |
| return self.reset_seed(which=which) | |
| else: | |
| return self.seed[which] | |
| ########## | |
| # helper # | |
| ########## | |
| def precheck_model(self, tag_diffuser, tag_lora, tag_scheduler): | |
| if (tag_diffuser != self.tag_diffuser) or (tag_lora != self.tag_lora): | |
| self.load_all(tag_diffuser, tag_lora, tag_scheduler) | |
| if tag_scheduler != self.tag_scheduler: | |
| self.load_scheduler(tag_scheduler) | |
| ######## | |
| # main # | |
| ######## | |
| def ddiminv(self, img, cfgdict): | |
| txt, step, cfg_scale = cfgdict['txt'], cfgdict['step'], cfgdict['cfg_scale'] | |
| from nulltxtinv_wrapper import NullInversion | |
| null_inversion_model = NullInversion(self.net, step, cfg_scale) | |
| with torch.no_grad(): | |
| emb = txt_to_emb(self.net, txt) | |
| nemb = txt_to_emb(self.net, "") | |
| xt = null_inversion_model.ddim_invert(img, txt) | |
| data = { | |
| 'step' : step, 'cfg_scale' : cfg_scale, 'txt' : txt, | |
| 'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora, | |
| 'xt': xt, 'emb': emb, 'nemb': nemb,} | |
| return data | |
| def nullinv_or_loadcache(self, img, cfgdict, force_reinvert=False): | |
| hash = hash_pilim(img) + "--" + hash_cfgdict(cfgdict) | |
| cdir = self.cache_inverse_folder | |
| cfname = osp.join(cdir, hash+'.pth') | |
| if osp.isfile(cfname) and (not force_reinvert): | |
| cache_data = torch.load(cfname) | |
| dtype = next(self.net.unet.parameters()).dtype | |
| device = next(self.net.unet.parameters()).device | |
| cache_data['xt'] = cache_data['xt'].to(device=device, dtype=dtype) | |
| cache_data['emb'] = cache_data['emb'].to(device=device, dtype=dtype) | |
| cache_data['nemb'] = [ | |
| nembi.to(device=device, dtype=dtype) | |
| for nembi in cache_data['nemb']] | |
| return cache_data | |
| else: | |
| txt, step, cfg_scale = cfgdict['txt'], cfgdict['step'], cfgdict['cfg_scale'] | |
| inner_step = cfgdict['inner_step'] | |
| from nulltxtinv_wrapper import NullInversion | |
| null_inversion_model = NullInversion(self.net, step, cfg_scale) | |
| with torch.no_grad(): | |
| emb = txt_to_emb(self.net, txt) | |
| xt, nemb = null_inversion_model.null_invert(img, txt, num_inner_steps=inner_step) | |
| cache_data = { | |
| 'step' : step, 'cfg_scale' : cfg_scale, 'txt' : txt, | |
| 'inner_step' : inner_step, | |
| 'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora, | |
| 'xt' : xt.to('cpu'), | |
| 'emb' : emb.to('cpu'), | |
| 'nemb' : [nembi.to('cpu') for nembi in nemb],} | |
| os.makedirs(cdir, exist_ok=True) | |
| remove_earliest_file(cdir, max_allowance=self.cache_inverse_maxn) | |
| torch.save(cache_data, cfname) | |
| data = { | |
| 'step' : step, 'cfg_scale' : cfg_scale, 'txt' : txt, | |
| 'inner_step' : inner_step, | |
| 'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora, | |
| 'xt' : xt, 'emb' : emb, 'nemb' : nemb,} | |
| return data | |
| def nullinvdual_or_loadcachedual(self, img0, img1, cfgdict, force_reinvert=False): | |
| hash = hash_pilim(img0) + "--" + hash_pilim(img1) + "--" + hash_cfgdict(cfgdict) | |
| cdir = self.cache_inverse_folder | |
| cfname = osp.join(cdir, hash+'.pth') | |
| if osp.isfile(cfname) and (not force_reinvert): | |
| cache_data = torch.load(cfname) | |
| dtype = next(self.net.unet.parameters()).dtype | |
| device = next(self.net.unet.parameters()).device | |
| cache_data['xt0'] = cache_data['xt0'].to(device=device, dtype=dtype) | |
| cache_data['xt1'] = cache_data['xt1'].to(device=device, dtype=dtype) | |
| cache_data['emb0'] = cache_data['emb0'].to(device=device, dtype=dtype) | |
| cache_data['emb1'] = cache_data['emb1'].to(device=device, dtype=dtype) | |
| cache_data['nemb'] = [ | |
| nembi.to(device=device, dtype=dtype) | |
| for nembi in cache_data['nemb']] | |
| cache_data_a = copy.deepcopy(cache_data) | |
| cache_data_a['xt'] = cache_data_a.pop('xt0') | |
| cache_data_a['emb'] = cache_data_a.pop('emb0') | |
| cache_data_a.pop('xt1'); cache_data_a.pop('emb1') | |
| cache_data_b = cache_data | |
| cache_data_b['xt'] = cache_data_b.pop('xt1') | |
| cache_data_b['emb'] = cache_data_b.pop('emb1') | |
| cache_data_b.pop('xt0'); cache_data_b.pop('emb0') | |
| return cache_data_a, cache_data_b | |
| else: | |
| txt0, txt1, step, cfg_scale, inner_step = \ | |
| cfgdict['txt0'], cfgdict['txt1'], cfgdict['step'], \ | |
| cfgdict['cfg_scale'], cfgdict['inner_step'] | |
| from nulltxtinv_wrapper import NullInversion | |
| null_inversion_model = NullInversion(self.net, step, cfg_scale) | |
| with torch.no_grad(): | |
| emb0 = txt_to_emb(self.net, txt0) | |
| emb1 = txt_to_emb(self.net, txt1) | |
| xt0, xt1, nemb = null_inversion_model.null_invert_dual( | |
| img0, img1, txt0, txt1, num_inner_steps=inner_step) | |
| cache_data = { | |
| 'step' : step, 'cfg_scale' : cfg_scale, | |
| 'txt0' : txt0, 'txt1' : txt1, | |
| 'inner_step' : inner_step, | |
| 'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora, | |
| 'xt0' : xt0.to('cpu'), 'xt1' : xt1.to('cpu'), | |
| 'emb0' : emb0.to('cpu'), 'emb1' : emb1.to('cpu'), | |
| 'nemb' : [nembi.to('cpu') for nembi in nemb],} | |
| os.makedirs(cdir, exist_ok=True) | |
| remove_earliest_file(cdir, max_allowance=self.cache_inverse_maxn) | |
| torch.save(cache_data, cfname) | |
| data0 = { | |
| 'step' : step, 'cfg_scale' : cfg_scale, 'txt' : txt0, | |
| 'inner_step' : inner_step, | |
| 'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora, | |
| 'xt' : xt0, 'emb' : emb0, 'nemb' : nemb,} | |
| data1 = { | |
| 'step' : step, 'cfg_scale' : cfg_scale, 'txt' : txt1, | |
| 'inner_step' : inner_step, | |
| 'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora, | |
| 'xt' : xt1, 'emb' : emb1, 'nemb' : nemb,} | |
| return data0, data1 | |
| def image_inversion( | |
| self, img, txt, | |
| cfg_scale, step, | |
| inversion, inner_step, force_reinvert): | |
| from nulltxtinv_wrapper import text2image_ldm | |
| if inversion == 'DDIM w/o text': | |
| txt = '' | |
| if not inversion == 'NTI': | |
| data = self.ddiminv(img, {'txt':txt, 'step':step, 'cfg_scale':cfg_scale,}) | |
| else: | |
| data = self.nullinv_or_loadcache( | |
| img, {'txt':txt, 'step':step, | |
| 'cfg_scale':cfg_scale, 'inner_step':inner_step, | |
| 'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,}, force_reinvert) | |
| if inversion == 'NTI': | |
| img_inv, _ = text2image_ldm( | |
| self.net, [txt], step, cfg_scale, | |
| latent=data['xt'], uncond_embeddings=data['nemb']) | |
| else: | |
| img_inv, _ = text2image_ldm( | |
| self.net, [txt], step, cfg_scale, | |
| latent=data['xt'], uncond_embeddings=None) | |
| return img_inv | |
| def image_editing( | |
| self, img, txt_0, txt_1, | |
| cfg_scale, step, thresh, | |
| inversion, inner_step, force_reinvert): | |
| from nulltxtinv_wrapper import text2image_ldm_imedit | |
| if inversion == 'DDIM w/o text': | |
| txt_0 = '' | |
| if not inversion == 'NTI': | |
| data = self.ddiminv(img, {'txt':txt_0, 'step':step, 'cfg_scale':cfg_scale,}) | |
| img_edited, _ = text2image_ldm_imedit( | |
| self.net, thresh, [txt_0], [txt_1], step, cfg_scale, | |
| latent=data['xt'], uncond_embeddings=None) | |
| else: | |
| data = self.nullinv_or_loadcache( | |
| img, {'txt':txt_0, 'step':step, | |
| 'cfg_scale':cfg_scale, 'inner_step':inner_step, | |
| 'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,}, force_reinvert) | |
| img_edited, _ = text2image_ldm_imedit( | |
| self.net, thresh, [txt_0], [txt_1], step, cfg_scale, | |
| latent=data['xt'], uncond_embeddings=data['nemb']) | |
| return img_edited | |
| def general_interpolation( | |
| self, xset0, xset1, | |
| cfg_scale, step, tlist,): | |
| xt0, emb0, nemb0 = xset0['xt'], xset0['emb'], xset0['nemb'] | |
| xt1, emb1, nemb1 = xset1['xt'], xset1['emb'], xset1['nemb'] | |
| framen = len(tlist) | |
| xt_list = auto_slerp(tlist, xt0, xt1) | |
| emb_list = auto_lerp(tlist, emb0, emb1) | |
| if isinstance(nemb0, list) and isinstance(nemb1, list): | |
| assert len(nemb0) == len(nemb1) | |
| nemb_list = [auto_lerp(tlist, e0, e1) for e0, e1 in zip(nemb0, nemb1)] | |
| nemb_islist = True | |
| else: | |
| nemb_list = auto_lerp(tlist, nemb0, nemb1) | |
| nemb_islist = False | |
| im_list = [] | |
| for frameidx in range(0, len(xt_list), self.batchsize): | |
| xt_batch = [xt_list[idx] for idx in range(frameidx, min(frameidx+self.batchsize, framen))] | |
| xt_batch = torch.cat(xt_batch, dim=0) | |
| emb_batch = [emb_list[idx] for idx in range(frameidx, min(frameidx+self.batchsize, framen))] | |
| emb_batch = torch.cat(emb_batch, dim=0) | |
| if nemb_islist: | |
| nemb_batch = [] | |
| for nembi in nemb_list: | |
| nembi_batch = [nembi[idx] for idx in range(frameidx, min(frameidx+self.batchsize, framen))] | |
| nembi_batch = torch.cat(nembi_batch, dim=0) | |
| nemb_batch.append(nembi_batch) | |
| else: | |
| nemb_batch = [nemb_list[idx] for idx in range(frameidx, min(frameidx+self.batchsize, framen))] | |
| nemb_batch = torch.cat(nemb_batch, dim=0) | |
| im = t2i_core( | |
| self.net, xt_batch, emb_batch, nemb_batch, step, cfg_scale) | |
| im_list += im if isinstance(im, list) else [im] | |
| return im_list | |
| def run_iminvs( | |
| self, img, text, | |
| cfg_scale, step, | |
| force_resize, width, height, | |
| inversion, inner_step, force_reinvert, | |
| tag_diffuser, tag_lora, tag_scheduler, ): | |
| self.precheck_model(tag_diffuser, tag_lora, tag_scheduler) | |
| if force_resize: | |
| img = offset_resize(img, width, height) | |
| else: | |
| img = regulate_image(img) | |
| recon_output = self.image_inversion( | |
| img, text, cfg_scale, step, | |
| inversion, inner_step, force_reinvert) | |
| idir = self.cache_image_folder | |
| os.makedirs(idir, exist_ok=True) | |
| remove_earliest_file(idir, max_allowance=self.cache_image_maxn) | |
| sname = "time{}_iminvs_{}_{}".format( | |
| int(time.time()), self.tag_diffuser, self.tag_lora,) | |
| ipath = osp.join(idir, sname+'.png') | |
| recon_output.save(ipath) | |
| return [recon_output] | |
| def run_imedit( | |
| self, img, txt_0,txt_1, | |
| threshold, cfg_scale, step, | |
| force_resize, width, height, | |
| inversion, inner_step, force_reinvert, | |
| tag_diffuser, tag_lora, tag_scheduler, ): | |
| self.precheck_model(tag_diffuser, tag_lora, tag_scheduler) | |
| if force_resize: | |
| img = offset_resize(img, width, height) | |
| else: | |
| img = regulate_image(img) | |
| edited_img= self.image_editing( | |
| img, txt_0,txt_1, cfg_scale, step, threshold, | |
| inversion, inner_step, force_reinvert) | |
| idir = self.cache_image_folder | |
| os.makedirs(idir, exist_ok=True) | |
| remove_earliest_file(idir, max_allowance=self.cache_image_maxn) | |
| sname = "time{}_imedit_{}_{}".format( | |
| int(time.time()), self.tag_diffuser, self.tag_lora,) | |
| ipath = osp.join(idir, sname+'.png') | |
| edited_img.save(ipath) | |
| return [edited_img] | |
| def run_imintp( | |
| self, | |
| img0, img1, txt0, txt1, | |
| cfg_scale, step, | |
| framen, fps, | |
| force_resize, width, height, | |
| inversion, inner_step, force_reinvert, | |
| tag_diffuser, tag_lora, tag_scheduler,): | |
| self.precheck_model(tag_diffuser, tag_lora, tag_scheduler) | |
| if txt1 == '': | |
| txt1 = txt0 | |
| if force_resize: | |
| img0 = offset_resize(img0, width, height) | |
| img1 = offset_resize(img1, width, height) | |
| else: | |
| img0 = regulate_image(img0) | |
| img1 = regulate_image(img1) | |
| if inversion == 'DDIM': | |
| data0 = self.ddiminv(img0, {'txt':txt0, 'step':step, 'cfg_scale':cfg_scale,}) | |
| data1 = self.ddiminv(img1, {'txt':txt1, 'step':step, 'cfg_scale':cfg_scale,}) | |
| elif inversion == 'DDIM w/o text': | |
| data0 = self.ddiminv(img0, {'txt':"", 'step':step, 'cfg_scale':cfg_scale,}) | |
| data1 = self.ddiminv(img1, {'txt':"", 'step':step, 'cfg_scale':cfg_scale,}) | |
| else: | |
| data0, data1 = self.nullinvdual_or_loadcachedual( | |
| img0, img1, {'txt0':txt0, 'txt1':txt1, 'step':step, | |
| 'cfg_scale':cfg_scale, 'inner_step':inner_step, | |
| 'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,}, force_reinvert) | |
| tlist = np.linspace(0.0, 1.0, framen) | |
| iminv0 = t2i_core(self.net, data0['xt'], data0['emb'], data0['nemb'], step, cfg_scale) | |
| iminv1 = t2i_core(self.net, data1['xt'], data1['emb'], data1['nemb'], step, cfg_scale) | |
| frames = self.general_interpolation(data0, data1, cfg_scale, step, tlist) | |
| vdir = self.cache_video_folder | |
| os.makedirs(vdir, exist_ok=True) | |
| remove_earliest_file(vdir, max_allowance=self.cache_video_maxn) | |
| sname = "time{}_imintp_{}_{}_framen{}_fps{}".format( | |
| int(time.time()), self.tag_diffuser, self.tag_lora, framen, fps) | |
| vpath = osp.join(vdir, sname+'.mp4') | |
| frames2mp4(vpath, frames, fps) | |
| jpath = osp.join(vdir, sname+'.json') | |
| cfgdict = { | |
| "method" : "image_interpolation", | |
| "txt0" : txt0, "txt1" : txt1, | |
| "cfg_scale" : cfg_scale, "step" : step, | |
| "framen" : framen, "fps" : fps, | |
| "force_resize" : force_resize, "width" : width, "height" : height, | |
| "inversion" : inversion, "inner_step" : inner_step, | |
| "force_reinvert" : force_reinvert, | |
| "tag_diffuser" : tag_diffuser, "tag_lora" : tag_lora, "tag_scheduler" : tag_scheduler,} | |
| with open(jpath, 'w') as f: | |
| json.dump(cfgdict, f, indent=4) | |
| return frames, vpath, [iminv0, iminv1] | |
| ################# | |
| # get examples # | |
| ################# | |
| cache_examples = False | |
| def get_imintp_example(): | |
| case = [ | |
| [ | |
| 'assets/images/interpolation/cityview1.png', | |
| 'assets/images/interpolation/cityview2.png', | |
| 'A city view',], | |
| [ | |
| 'assets/images/interpolation/woman1.png', | |
| 'assets/images/interpolation/woman2.png', | |
| 'A woman face',], | |
| [ | |
| 'assets/images/interpolation/land1.png', | |
| 'assets/images/interpolation/land2.png', | |
| 'A beautiful landscape',], | |
| [ | |
| 'assets/images/interpolation/dog1.png', | |
| 'assets/images/interpolation/dog2.png', | |
| 'A realistic dog',], | |
| [ | |
| 'assets/images/interpolation/church1.png', | |
| 'assets/images/interpolation/church2.png', | |
| 'A church',], | |
| [ | |
| 'assets/images/interpolation/rabbit1.png', | |
| 'assets/images/interpolation/rabbit2.png', | |
| 'A cute rabbit',], | |
| [ | |
| 'assets/images/interpolation/horse1.png', | |
| 'assets/images/interpolation/horse2.png', | |
| 'A robot horse',], | |
| ] | |
| return case | |
| def get_iminvs_example(): | |
| case = [ | |
| [ | |
| 'assets/images/inversion/000000560011.jpg', | |
| 'A mouse is next to a keyboard on a desk',], | |
| [ | |
| 'assets/images/inversion/000000029596.jpg', | |
| 'A room with a couch, table set with dinnerware and a television.',], | |
| ] | |
| return case | |
| def get_imedit_example(): | |
| case = [ | |
| [ | |
| 'assets/images/editing/rabbit.png', | |
| 'A rabbit is eating a watermelon on the table', | |
| 'A cat is eating a watermelon on the table', | |
| 0.7,], | |
| [ | |
| 'assets/images/editing/cake.png', | |
| 'A chocolate cake with cream on it', | |
| 'A chocolate cake with strawberries on it', | |
| 0.9,], | |
| [ | |
| 'assets/images/editing/banana.png', | |
| 'A banana on the table', | |
| 'A banana and an apple on the table', | |
| 0.8,], | |
| ] | |
| return case | |
| ################# | |
| # sub interface # | |
| ################# | |
| def interface_imintp(wrapper_obj): | |
| with gr.Row(): | |
| with gr.Column(): | |
| img0 = gr.Image(label="Image Input 0", type='pil', elem_id='customized_imbox') | |
| with gr.Column(): | |
| img1 = gr.Image(label="Image Input 1", type='pil', elem_id='customized_imbox') | |
| with gr.Column(): | |
| video_output = gr.Video(label="Video Result", format='mp4', elem_id='customized_imbox') | |
| with gr.Row(): | |
| with gr.Column(): | |
| txt0 = gr.Textbox(label='Text Input', lines=1, placeholder="Input prompt...", ) | |
| with gr.Column(): | |
| with gr.Row(): | |
| inversion = auto_dropdown('Inversion', choices.inversion, default.inversion) | |
| inner_step = gr.Slider(label="Inner Step (NTI)", value=default.nullinv_inner_step, minimum=1, maximum=10, step=1) | |
| force_reinvert = gr.Checkbox(label="Force ReInvert (NTI)", value=False) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| framen = gr.Slider(label="Frame Number", minimum=8, maximum=default.framen, value=default.framen, step=1) | |
| fps = gr.Slider(label="Video FPS", minimum=4, maximum=default.fps, value=default.fps, step=4) | |
| with gr.Row(): | |
| button_run = gr.Button("Run") | |
| with gr.Column(): | |
| with gr.Accordion('Frame Results', open=False): | |
| frame_output = gr.Gallery(label="Frames", elem_id='customized_imbox') | |
| with gr.Accordion("Inversion Results", open=False): | |
| inv_output = gr.Gallery(label="Inversion Results", elem_id='customized_imbox') | |
| with gr.Accordion('Advanced Settings', open=False): | |
| with gr.Row(): | |
| tag_diffuser = auto_dropdown('Diffuser', choices.diffuser, default.diffuser) | |
| tag_lora = auto_dropdown('Use LoRA', choices.lora, default.lora) | |
| tag_scheduler = auto_dropdown('Scheduler', choices.scheduler, default.scheduler) | |
| with gr.Row(): | |
| cfg_scale = gr.Number(label="Scale", minimum=1, maximum=10, value=default.cfg_scale, step=0.5) | |
| step = gr.Number(default.step, label="Step", precision=0) | |
| with gr.Row(): | |
| force_resize = gr.Checkbox(label="Force Resize", value=True) | |
| inp_width = gr.Slider(label="Width", minimum=256, maximum=1024, value=512, step=64) | |
| inp_height = gr.Slider(label="Height", minimum=256, maximum=1024, value=512, step=64) | |
| with gr.Row(): | |
| txt1 = gr.Textbox(label='Optional Different Text Input for Image Input 1', lines=1, placeholder="Input prompt...", ) | |
| tag_diffuser.change( | |
| wrapper_obj.load_all, | |
| inputs = [tag_diffuser, tag_lora, tag_scheduler], | |
| outputs = [tag_diffuser, tag_lora, tag_scheduler],) | |
| tag_lora.change( | |
| wrapper_obj.load_all, | |
| inputs = [tag_diffuser, tag_lora, tag_scheduler], | |
| outputs = [tag_diffuser, tag_lora, tag_scheduler],) | |
| tag_scheduler.change( | |
| wrapper_obj.load_scheduler, | |
| inputs = [tag_scheduler], | |
| outputs = [tag_scheduler],) | |
| button_run.click( | |
| wrapper_obj.run_imintp, | |
| inputs=[img0, img1, txt0, txt1, | |
| cfg_scale, step, | |
| framen, fps, | |
| force_resize, inp_width, inp_height, | |
| inversion, inner_step, force_reinvert, | |
| tag_diffuser, tag_lora, tag_scheduler,], | |
| outputs=[frame_output, video_output, inv_output]) | |
| gr.Examples( | |
| label='Examples', | |
| examples=get_imintp_example(), | |
| fn=wrapper_obj.run_imintp, | |
| inputs=[img0, img1, txt0,], | |
| outputs=[frame_output, video_output, inv_output], | |
| cache_examples=cache_examples,) | |
| def interface_iminvs(wrapper_obj): | |
| with gr.Row(): | |
| image_input = gr.Image(label="Image input", type='pil', elem_id='customized_imbox') | |
| recon_output = gr.Gallery(label="Reconstruction output", elem_id='customized_imbox') | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(label='Text Input', lines=1, placeholder="Input prompt...", ) | |
| with gr.Row(): | |
| button_run = gr.Button("Run") | |
| with gr.Column(): | |
| with gr.Row(): | |
| inversion = auto_dropdown('Inversion', choices.inversion, default.inversion) | |
| inner_step = gr.Slider(label="Inner Step (NTI)", value=default.nullinv_inner_step, minimum=1, maximum=10, step=1) | |
| force_reinvert = gr.Checkbox(label="Force ReInvert (NTI)", value=False) | |
| with gr.Accordion('Advanced Settings', open=False): | |
| with gr.Row(): | |
| tag_diffuser = auto_dropdown('Diffuser', choices.diffuser, default.diffuser) | |
| tag_lora = auto_dropdown('Use LoRA', choices.lora, default.lora) | |
| tag_scheduler = auto_dropdown('Scheduler', choices.scheduler, default.scheduler) | |
| with gr.Row(): | |
| cfg_scale = gr.Number(label="Scale", minimum=1, maximum=10, value=default.cfg_scale, step=0.5) | |
| step = gr.Number(default.step, label="Step", precision=0) | |
| with gr.Row(): | |
| force_resize = gr.Checkbox(label="Force Resize", value=True) | |
| inp_width = gr.Slider(label="Width", minimum=256, maximum=1024, value=512, step=64) | |
| inp_height = gr.Slider(label="Height", minimum=256, maximum=1024, value=512, step=64) | |
| tag_diffuser.change( | |
| wrapper_obj.load_all, | |
| inputs = [tag_diffuser, tag_lora, tag_scheduler], | |
| outputs = [tag_diffuser, tag_lora, tag_scheduler],) | |
| tag_lora.change( | |
| wrapper_obj.load_all, | |
| inputs = [tag_diffuser, tag_lora, tag_scheduler], | |
| outputs = [tag_diffuser, tag_lora, tag_scheduler],) | |
| tag_scheduler.change( | |
| wrapper_obj.load_scheduler, | |
| inputs = [tag_scheduler], | |
| outputs = [tag_scheduler],) | |
| button_run.click( | |
| wrapper_obj.run_iminvs, | |
| inputs=[image_input, prompt, | |
| cfg_scale, step, | |
| force_resize, inp_width, inp_height, | |
| inversion, inner_step, force_reinvert, | |
| tag_diffuser, tag_lora, tag_scheduler,], | |
| outputs=[recon_output]) | |
| gr.Examples( | |
| label='Examples', | |
| examples=get_iminvs_example(), | |
| fn=wrapper_obj.run_iminvs, | |
| inputs=[image_input, prompt,], | |
| outputs=[recon_output], | |
| cache_examples=cache_examples,) | |
| def interface_imedit(wrapper_obj): | |
| with gr.Row(): | |
| image_input = gr.Image(label="Image input", type='pil', elem_id='customized_imbox') | |
| edited_output = gr.Gallery(label="Edited output", elem_id='customized_imbox') | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_0 = gr.Textbox(label='Source Text', lines=1, placeholder="Source prompt...", ) | |
| prompt_1 = gr.Textbox(label='Target Text', lines=1, placeholder="Target prompt...", ) | |
| with gr.Row(): | |
| button_run = gr.Button("Run") | |
| with gr.Column(): | |
| with gr.Row(): | |
| inversion = auto_dropdown('Inversion', choices.inversion, default.inversion) | |
| inner_step = gr.Slider(label="Inner Step (NTI)", value=default.nullinv_inner_step, minimum=1, maximum=10, step=1) | |
| force_reinvert = gr.Checkbox(label="Force ReInvert (NTI)", value=False) | |
| threshold = gr.Slider(label="Threshold", minimum=0, maximum=1, value=default.threshold, step=0.1) | |
| with gr.Accordion('Advanced Settings', open=False): | |
| with gr.Row(): | |
| tag_diffuser = auto_dropdown('Diffuser', choices.diffuser, default.diffuser) | |
| tag_lora = auto_dropdown('Use LoRA', choices.lora, default.lora) | |
| tag_scheduler = auto_dropdown('Scheduler', choices.scheduler, default.scheduler) | |
| with gr.Row(): | |
| cfg_scale = gr.Number(label="Scale", minimum=1, maximum=10, value=default.cfg_scale, step=0.5) | |
| step = gr.Number(default.step, label="Step", precision=0) | |
| with gr.Row(): | |
| force_resize = gr.Checkbox(label="Force Resize", value=True) | |
| inp_width = gr.Slider(label="Width", minimum=256, maximum=1024, value=512, step=64) | |
| inp_height = gr.Slider(label="Height", minimum=256, maximum=1024, value=512, step=64) | |
| tag_diffuser.change( | |
| wrapper_obj.load_all, | |
| inputs = [tag_diffuser, tag_lora, tag_scheduler], | |
| outputs = [tag_diffuser, tag_lora, tag_scheduler],) | |
| tag_lora.change( | |
| wrapper_obj.load_all, | |
| inputs = [tag_diffuser, tag_lora, tag_scheduler], | |
| outputs = [tag_diffuser, tag_lora, tag_scheduler],) | |
| tag_scheduler.change( | |
| wrapper_obj.load_scheduler, | |
| inputs = [tag_scheduler], | |
| outputs = [tag_scheduler],) | |
| button_run.click( | |
| wrapper_obj.run_imedit, | |
| inputs=[image_input, prompt_0, prompt_1, | |
| threshold, cfg_scale, step, | |
| force_resize, inp_width, inp_height, | |
| inversion, inner_step, force_reinvert, | |
| tag_diffuser, tag_lora, tag_scheduler,], | |
| outputs=[edited_output]) | |
| gr.Examples( | |
| label='Examples', | |
| examples=get_imedit_example(), | |
| fn=wrapper_obj.run_imedit, | |
| inputs=[image_input, prompt_0, prompt_1, threshold,], | |
| outputs=[edited_output], | |
| cache_examples=cache_examples,) | |
| ############# | |
| # Interface # | |
| ############# | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-p', '--port', type=int, default=None) | |
| args = parser.parse_args() | |
| from app_utils import css_empty, css_version_4_11_0 | |
| # css = css_empty | |
| css = css_version_4_11_0 | |
| wrapper_obj = wrapper( | |
| fp16=False, | |
| tag_diffuser=default.diffuser, | |
| tag_lora=default.lora, | |
| tag_scheduler=default.scheduler) | |
| if True: | |
| with gr.Blocks(css=css) as demo: | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center; max-width: 1200px; margin: 20px auto;"> | |
| <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem"> | |
| {} | |
| </h1> | |
| <h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem"> | |
| <b>Smooth Diffusion</b> is a new category of diffusion models that is simultaneously high-performing and smooth. <br> | |
| Our method formally introduces latent space smoothness to diffusion models like Stable Diffusion. This smoothness dramatically aids in: 1) improving the continuity of transitions in image interpolation, 2) reducing approximation errors in image inversion, and 3) better preserving unedited contents in image editing. | |
| </h2> | |
| <h3 style="font-weight: 450; font-size: 1rem; margin: 0rem"> | |
| <a href="https://www.jiayiguo.net/" target="_blank">Jiayi Guo</a>, <a href="https://www.linkedin.com/in/xingqian-xu-97b46526/" target="_blank">Xingqian Xu</a>, | |
| <a href="https://scholar.google.com/citations?user=oM9rnYQAAAAJ&hl=en" target="_blank">Yifan Pu</a>, <a href="https://scholar.google.com/citations?user=Yibz_asAAAAJ&hl=en" target="_blank">Zanlin Ni</a>, | |
| <a href="https://scholar.google.com/citations?user=-hwGMHcAAAAJ&hl=en" target="_blank">Chaofei Wang</a>, <a href="https://in.linkedin.com/in/v-manushree" target="_blank">Manushree Vasu</a>, | |
| <a href="https://www.au.tsinghua.edu.cn/info/1103/1553.htm" target="_blank">Shiji Song</a>, <a href="https://www.gaohuang.net/" target="_blank">Gao Huang</a> | |
| and <a href="https://www.humphreyshi.com/home">Humphrey Shi</a> | |
| [<a href="https://arxiv.org/abs/2312.04410" style="color:blue;">arXiv</a>] | |
| [<a href="https://github.com/SHI-Labs/Smooth-Diffusion" style="color:blue;">GitHub</a>] | |
| </h3> | |
| </div> | |
| """.format(version)) | |
| with gr.Tab('Image Interpolation'): | |
| interface_imintp(wrapper_obj) | |
| with gr.Tab('Image Inversion'): | |
| interface_iminvs(wrapper_obj) | |
| with gr.Tab('Image Editing'): | |
| interface_imedit(wrapper_obj) | |
| gr.Markdown(r""" | |
| If you find our work helpful, please **star π** the <a href='https://github.com/SHI-Labs/Smooth-Diffusion' target='_blank'>Github Repo</a>. Thanks for your support! | |
| [](https://github.com/SHI-Labs/Smooth-Diffusion) | |
| --- | |
| π **Citation** | |
| <br> | |
| If our work is useful for your research, please consider citing: | |
| ```bibtex | |
| @InProceedings{guo2024smooth, | |
| title={Smooth Diffusion: Crafting Smooth Latent Spaces in Diffusion Models}, | |
| author={Jiayi Guo and Xingqian Xu and Yifan Pu and Zanlin Ni and Chaofei Wang and Manushree Vasu and Shiji Song and Gao Huang and Humphrey Shi}, | |
| booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, | |
| year={2024} | |
| } | |
| ``` | |
| """) | |
| demo.queue() | |
| demo.launch() | |