Spaces:
Runtime error
Runtime error
Commit
·
a722e19
1
Parent(s):
f34dfad
Update app.py
Browse files
app.py
CHANGED
|
@@ -119,8 +119,8 @@ def load_and_invert(
|
|
| 119 |
skip=skip,
|
| 120 |
eta=1.0,
|
| 121 |
)
|
| 122 |
-
wts =
|
| 123 |
-
zs =
|
| 124 |
do_inversion = False
|
| 125 |
|
| 126 |
return wts, zs, do_inversion, gr.update(visible=False)
|
|
@@ -173,8 +173,8 @@ def edit(input_image,
|
|
| 173 |
skip = skip,
|
| 174 |
eta = 1.0,
|
| 175 |
)
|
| 176 |
-
wts =
|
| 177 |
-
zs =
|
| 178 |
do_inversion = False
|
| 179 |
|
| 180 |
if image_caption.lower() == tar_prompt.lower(): # if image caption was not changed, run pure sega
|
|
@@ -194,7 +194,7 @@ def edit(input_image,
|
|
| 194 |
use_intersect_mask=use_intersect_mask
|
| 195 |
)
|
| 196 |
|
| 197 |
-
latnets = wts
|
| 198 |
sega_out = pipe(prompt=tar_prompt,
|
| 199 |
init_latents=latnets,
|
| 200 |
guidance_scale = tar_cfg_scale,
|
|
@@ -202,7 +202,7 @@ def edit(input_image,
|
|
| 202 |
# num_inference_steps=steps,
|
| 203 |
# use_ddpm=True,
|
| 204 |
# wts=wts.value,
|
| 205 |
-
zs=zs
|
| 206 |
|
| 207 |
return sega_out.images[0], gr.update(visible=True), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
|
| 208 |
|
|
@@ -210,12 +210,12 @@ def edit(input_image,
|
|
| 210 |
else: # if sega concepts were not added, performs regular ddpm sampling
|
| 211 |
|
| 212 |
if do_reconstruction: # if ddpm sampling wasn't computed
|
| 213 |
-
pure_ddpm_img = sample(zs
|
| 214 |
-
reconstruction =
|
| 215 |
do_reconstruction = False
|
| 216 |
return pure_ddpm_img, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
|
| 217 |
|
| 218 |
-
return reconstruction
|
| 219 |
|
| 220 |
|
| 221 |
def randomize_seed_fn(seed, is_random):
|
|
@@ -872,6 +872,5 @@ with gr.Blocks(css="style.css") as demo:
|
|
| 872 |
cache_examples=True
|
| 873 |
)
|
| 874 |
|
| 875 |
-
|
| 876 |
demo.queue()
|
| 877 |
demo.launch()
|
|
|
|
| 119 |
skip=skip,
|
| 120 |
eta=1.0,
|
| 121 |
)
|
| 122 |
+
wts = wts_tensor
|
| 123 |
+
zs = zs_tensor
|
| 124 |
do_inversion = False
|
| 125 |
|
| 126 |
return wts, zs, do_inversion, gr.update(visible=False)
|
|
|
|
| 173 |
skip = skip,
|
| 174 |
eta = 1.0,
|
| 175 |
)
|
| 176 |
+
wts = wts_tensor
|
| 177 |
+
zs = zs_tensor
|
| 178 |
do_inversion = False
|
| 179 |
|
| 180 |
if image_caption.lower() == tar_prompt.lower(): # if image caption was not changed, run pure sega
|
|
|
|
| 194 |
use_intersect_mask=use_intersect_mask
|
| 195 |
)
|
| 196 |
|
| 197 |
+
latnets = wts[-1].expand(1, -1, -1, -1)
|
| 198 |
sega_out = pipe(prompt=tar_prompt,
|
| 199 |
init_latents=latnets,
|
| 200 |
guidance_scale = tar_cfg_scale,
|
|
|
|
| 202 |
# num_inference_steps=steps,
|
| 203 |
# use_ddpm=True,
|
| 204 |
# wts=wts.value,
|
| 205 |
+
zs=zs, **editing_args)
|
| 206 |
|
| 207 |
return sega_out.images[0], gr.update(visible=True), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
|
| 208 |
|
|
|
|
| 210 |
else: # if sega concepts were not added, performs regular ddpm sampling
|
| 211 |
|
| 212 |
if do_reconstruction: # if ddpm sampling wasn't computed
|
| 213 |
+
pure_ddpm_img = sample(zs, wts, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
|
| 214 |
+
reconstruction = pure_ddpm_img
|
| 215 |
do_reconstruction = False
|
| 216 |
return pure_ddpm_img, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
|
| 217 |
|
| 218 |
+
return reconstruction, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
|
| 219 |
|
| 220 |
|
| 221 |
def randomize_seed_fn(seed, is_random):
|
|
|
|
| 872 |
cache_examples=True
|
| 873 |
)
|
| 874 |
|
|
|
|
| 875 |
demo.queue()
|
| 876 |
demo.launch()
|