Linoy Tsaban
commited on
Commit
·
8f7289c
1
Parent(s):
998e5bc
Update app.py
Browse files
app.py
CHANGED
|
@@ -132,7 +132,23 @@ def edit(input_image,
|
|
| 132 |
neg_guidance_1, neg_guidance_2, neg_guidance_3,
|
| 133 |
threshold_1, threshold_2, threshold_3,
|
| 134 |
do_reconstruction,
|
| 135 |
-
reconstruction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
if edit_concept_1 != "" or edit_concept_2 != "" or edit_concept_3 != "":
|
| 138 |
editing_args = dict(
|
|
@@ -151,7 +167,7 @@ def edit(input_image,
|
|
| 151 |
num_inference_steps=steps,
|
| 152 |
use_ddpm=True, wts=wts.value, zs=zs.value[skip:], **editing_args)
|
| 153 |
|
| 154 |
-
return sega_out.images[0], reconstruct_button.update(visible=True), do_reconstruction, reconstruction
|
| 155 |
|
| 156 |
else: # if sega concepts were not added, performs regular ddpm sampling
|
| 157 |
|
|
@@ -159,9 +175,9 @@ def edit(input_image,
|
|
| 159 |
pure_ddpm_img = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
|
| 160 |
reconstruction = gr.State(value=pure_ddpm_img)
|
| 161 |
do_reconstruction = False
|
| 162 |
-
return pure_ddpm_img, reconstruct_button.update(visible=False), do_reconstruction, reconstruction
|
| 163 |
|
| 164 |
-
return reconstruction.value, reconstruct_button.update(visible=False), do_reconstruction, reconstruction
|
| 165 |
|
| 166 |
|
| 167 |
def randomize_seed_fn(seed, randomize_seed):
|
|
@@ -635,21 +651,7 @@ with gr.Blocks(css="style.css") as demo:
|
|
| 635 |
#add_concept_button.click(fn = update_display_concept, inputs=sega_concepts_counter,
|
| 636 |
# outputs= [row2, row2_advanced, row3, row3_advanced, add_concept_button, sega_concepts_counter], queue = False)
|
| 637 |
|
| 638 |
-
run_button.click(
|
| 639 |
-
fn=load_and_invert,
|
| 640 |
-
inputs=[input_image,
|
| 641 |
-
do_inversion,
|
| 642 |
-
seed, randomize_seed,
|
| 643 |
-
wts, zs,
|
| 644 |
-
src_prompt,
|
| 645 |
-
tar_prompt,
|
| 646 |
-
steps,
|
| 647 |
-
src_cfg_scale,
|
| 648 |
-
skip,
|
| 649 |
-
tar_cfg_scale
|
| 650 |
-
],
|
| 651 |
-
outputs=[wts, zs, do_inversion, inversion_progress],
|
| 652 |
-
).success(
|
| 653 |
fn=edit,
|
| 654 |
inputs=[input_image,
|
| 655 |
wts, zs,
|
|
@@ -661,10 +663,16 @@ with gr.Blocks(css="style.css") as demo:
|
|
| 661 |
guidnace_scale_1,guidnace_scale_2,guidnace_scale_3,
|
| 662 |
warmup_1, warmup_2, warmup_3,
|
| 663 |
neg_guidance_1, neg_guidance_2, neg_guidance_3,
|
| 664 |
-
threshold_1, threshold_2, threshold_3, do_reconstruction, reconstruction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 665 |
|
| 666 |
],
|
| 667 |
-
outputs=[sega_edited_image, reconstruct_button, do_reconstruction, reconstruction])
|
| 668 |
# .success(fn=update_gallery_display, inputs= [prev_output_image, sega_edited_image], outputs = [gallery, gallery, prev_output_image])
|
| 669 |
|
| 670 |
|
|
|
|
| 132 |
neg_guidance_1, neg_guidance_2, neg_guidance_3,
|
| 133 |
threshold_1, threshold_2, threshold_3,
|
| 134 |
do_reconstruction,
|
| 135 |
+
reconstruction,
|
| 136 |
+
|
| 137 |
+
# for inversion in case it needs to be re computed (and avoid delay):
|
| 138 |
+
do_inversion,
|
| 139 |
+
seed,
|
| 140 |
+
randomize_seed,
|
| 141 |
+
src_prompt,
|
| 142 |
+
src_cfg_scale):
|
| 143 |
+
|
| 144 |
+
if do_inversion or randomize_seed:
|
| 145 |
+
x0 = load_512(input_image, device=device)
|
| 146 |
+
# invert and retrieve noise maps and latent
|
| 147 |
+
zs_tensor, wts_tensor = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
|
| 148 |
+
wts = gr.State(value=wts_tensor)
|
| 149 |
+
zs = gr.State(value=zs_tensor)
|
| 150 |
+
do_inversion = False
|
| 151 |
+
|
| 152 |
|
| 153 |
if edit_concept_1 != "" or edit_concept_2 != "" or edit_concept_3 != "":
|
| 154 |
editing_args = dict(
|
|
|
|
| 167 |
num_inference_steps=steps,
|
| 168 |
use_ddpm=True, wts=wts.value, zs=zs.value[skip:], **editing_args)
|
| 169 |
|
| 170 |
+
return sega_out.images[0], reconstruct_button.update(visible=True), do_reconstruction, reconstruction, wts, zs, do_inversion
|
| 171 |
|
| 172 |
else: # if sega concepts were not added, performs regular ddpm sampling
|
| 173 |
|
|
|
|
| 175 |
pure_ddpm_img = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
|
| 176 |
reconstruction = gr.State(value=pure_ddpm_img)
|
| 177 |
do_reconstruction = False
|
| 178 |
+
return pure_ddpm_img, reconstruct_button.update(visible=False), do_reconstruction, reconstruction wts, zs, do_inversion
|
| 179 |
|
| 180 |
+
return reconstruction.value, reconstruct_button.update(visible=False), do_reconstruction, reconstruction, wts, zs, do_inversion
|
| 181 |
|
| 182 |
|
| 183 |
def randomize_seed_fn(seed, randomize_seed):
|
|
|
|
| 651 |
#add_concept_button.click(fn = update_display_concept, inputs=sega_concepts_counter,
|
| 652 |
# outputs= [row2, row2_advanced, row3, row3_advanced, add_concept_button, sega_concepts_counter], queue = False)
|
| 653 |
|
| 654 |
+
run_button.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 655 |
fn=edit,
|
| 656 |
inputs=[input_image,
|
| 657 |
wts, zs,
|
|
|
|
| 663 |
guidnace_scale_1,guidnace_scale_2,guidnace_scale_3,
|
| 664 |
warmup_1, warmup_2, warmup_3,
|
| 665 |
neg_guidance_1, neg_guidance_2, neg_guidance_3,
|
| 666 |
+
threshold_1, threshold_2, threshold_3, do_reconstruction, reconstruction,
|
| 667 |
+
do_inversion,
|
| 668 |
+
seed,
|
| 669 |
+
randomize_seed,
|
| 670 |
+
src_prompt,
|
| 671 |
+
src_cfg_scale
|
| 672 |
+
|
| 673 |
|
| 674 |
],
|
| 675 |
+
outputs=[sega_edited_image, reconstruct_button, do_reconstruction, reconstruction, wts, zs, do_inversion])
|
| 676 |
# .success(fn=update_gallery_display, inputs= [prev_output_image, sega_edited_image], outputs = [gallery, gallery, prev_output_image])
|
| 677 |
|
| 678 |
|