Linoy Tsaban
commited on
Commit
·
e1cb97f
1
Parent(s):
d91ace1
Update inversion_utils.py
Browse files- inversion_utils.py +9 -6
inversion_utils.py
CHANGED
|
@@ -7,6 +7,7 @@ import torchvision.transforms as T
|
|
| 7 |
import os
|
| 8 |
import yaml
|
| 9 |
import numpy as np
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
def load_512(image_path, left=0, right=0, top=0, bottom=0, device=None):
|
|
@@ -129,10 +130,11 @@ def get_variance(model, timestep): #, prev_timestep):
|
|
| 129 |
|
| 130 |
def inversion_forward_process(model, x0,
|
| 131 |
etas = None,
|
| 132 |
-
prog_bar =
|
| 133 |
prompt = "",
|
| 134 |
cfg_scale = 3.5,
|
| 135 |
-
num_inference_steps=50, eps = None
|
|
|
|
| 136 |
|
| 137 |
if not prompt=="":
|
| 138 |
text_embeddings = encode_text(model, prompt)
|
|
@@ -155,7 +157,7 @@ def inversion_forward_process(model, x0,
|
|
| 155 |
|
| 156 |
t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
|
| 157 |
xt = x0
|
| 158 |
-
op = tqdm(reversed(timesteps)) if prog_bar else reversed(timesteps)
|
| 159 |
|
| 160 |
for t in op:
|
| 161 |
idx = t_to_idx[int(t)]
|
|
@@ -241,10 +243,11 @@ def inversion_reverse_process(model,
|
|
| 241 |
etas = 0,
|
| 242 |
prompts = "",
|
| 243 |
cfg_scales = None,
|
| 244 |
-
prog_bar =
|
| 245 |
zs = None,
|
| 246 |
controller=None,
|
| 247 |
-
asyrp = False
|
|
|
|
| 248 |
|
| 249 |
batch_size = len(prompts)
|
| 250 |
|
|
@@ -259,7 +262,7 @@ def inversion_reverse_process(model,
|
|
| 259 |
timesteps = model.scheduler.timesteps.to(model.device)
|
| 260 |
|
| 261 |
xt = xT.expand(batch_size, -1, -1, -1)
|
| 262 |
-
op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:]
|
| 263 |
|
| 264 |
t_to_idx = {int(v):k for k,v in enumerate(timesteps[-zs.shape[0]:])}
|
| 265 |
|
|
|
|
| 7 |
import os
|
| 8 |
import yaml
|
| 9 |
import numpy as np
|
| 10 |
+
import gradio as gr
|
| 11 |
|
| 12 |
|
| 13 |
def load_512(image_path, left=0, right=0, top=0, bottom=0, device=None):
|
|
|
|
| 130 |
|
| 131 |
def inversion_forward_process(model, x0,
|
| 132 |
etas = None,
|
| 133 |
+
prog_bar = True,
|
| 134 |
prompt = "",
|
| 135 |
cfg_scale = 3.5,
|
| 136 |
+
num_inference_steps=50, eps = None
|
| 137 |
+
progress=gr.Progress()):
|
| 138 |
|
| 139 |
if not prompt=="":
|
| 140 |
text_embeddings = encode_text(model, prompt)
|
|
|
|
| 157 |
|
| 158 |
t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
|
| 159 |
xt = x0
|
| 160 |
+
op = progress.tqdm(reversed(timesteps)) if prog_bar else reversed(timesteps)
|
| 161 |
|
| 162 |
for t in op:
|
| 163 |
idx = t_to_idx[int(t)]
|
|
|
|
| 243 |
etas = 0,
|
| 244 |
prompts = "",
|
| 245 |
cfg_scales = None,
|
| 246 |
+
prog_bar = True,
|
| 247 |
zs = None,
|
| 248 |
controller=None,
|
| 249 |
+
asyrp = False,
|
| 250 |
+
progress=gr.Progress()):
|
| 251 |
|
| 252 |
batch_size = len(prompts)
|
| 253 |
|
|
|
|
| 262 |
timesteps = model.scheduler.timesteps.to(model.device)
|
| 263 |
|
| 264 |
xt = xT.expand(batch_size, -1, -1, -1)
|
| 265 |
+
op = progress.tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:]
|
| 266 |
|
| 267 |
t_to_idx = {int(v):k for k,v in enumerate(timesteps[-zs.shape[0]:])}
|
| 268 |
|