Spaces:
Runtime error
Runtime error
Update pipeline_semantic_stable_diffusion_img2img_solver.py
#9
by
linoyts
HF Staff
- opened
- app.py +17 -14
- pipeline_semantic_stable_diffusion_img2img_solver.py +6 -5
app.py
CHANGED
|
@@ -35,9 +35,9 @@ def caption_image(input_image):
|
|
| 35 |
generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 36 |
return generated_caption, generated_caption
|
| 37 |
|
| 38 |
-
def sample(zs, wts, attention_store, prompt_tar="", cfg_scale_tar=15, skip=36, eta=1):
|
| 39 |
latents = wts[-1].expand(1, -1, -1, -1)
|
| 40 |
-
img, attention_store = pipe(
|
| 41 |
prompt=prompt_tar,
|
| 42 |
init_latents=latents,
|
| 43 |
guidance_scale=cfg_scale_tar,
|
|
@@ -45,10 +45,10 @@ def sample(zs, wts, attention_store, prompt_tar="", cfg_scale_tar=15, skip=36, e
|
|
| 45 |
# num_inference_steps=steps,
|
| 46 |
# use_ddpm=True,
|
| 47 |
# wts=wts.value,
|
| 48 |
-
attention_store = attention_store,
|
| 49 |
zs=zs,
|
| 50 |
)
|
| 51 |
-
return img.images[0], attention_store
|
| 52 |
|
| 53 |
|
| 54 |
def reconstruct(
|
|
@@ -59,6 +59,7 @@ def reconstruct(
|
|
| 59 |
wts,
|
| 60 |
zs,
|
| 61 |
attention_store,
|
|
|
|
| 62 |
do_reconstruction,
|
| 63 |
reconstruction,
|
| 64 |
reconstruct_button,
|
|
@@ -79,8 +80,8 @@ def reconstruct(
|
|
| 79 |
): # if image caption was not changed, run actual reconstruction
|
| 80 |
tar_prompt = ""
|
| 81 |
latents = wts[-1].expand(1, -1, -1, -1)
|
| 82 |
-
reconstruction, attention_store = sample(
|
| 83 |
-
zs, wts, attention_store=attention_store, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale
|
| 84 |
)
|
| 85 |
do_reconstruction = False
|
| 86 |
return (
|
|
@@ -130,7 +131,7 @@ def load_and_invert(
|
|
| 130 |
## SEGA ##
|
| 131 |
|
| 132 |
def edit(input_image,
|
| 133 |
-
wts, zs, attention_store,
|
| 134 |
tar_prompt,
|
| 135 |
image_caption,
|
| 136 |
steps,
|
|
@@ -197,27 +198,27 @@ def edit(input_image,
|
|
| 197 |
)
|
| 198 |
|
| 199 |
latnets = wts[-1].expand(1, -1, -1, -1)
|
| 200 |
-
sega_out, attention_store = pipe(prompt=tar_prompt,
|
| 201 |
init_latents=latnets,
|
| 202 |
guidance_scale = tar_cfg_scale,
|
| 203 |
# num_images_per_prompt=1,
|
| 204 |
# num_inference_steps=steps,
|
| 205 |
# use_ddpm=True,
|
| 206 |
# wts=wts.value,
|
| 207 |
-
zs=zs, attention_store=attention_store, **editing_args)
|
| 208 |
|
| 209 |
-
return sega_out.images[0], gr.update(visible=True), do_reconstruction, reconstruction, wts, zs, attention_store, do_inversion, show_share_button
|
| 210 |
|
| 211 |
|
| 212 |
else: # if sega concepts were not added, performs regular ddpm sampling
|
| 213 |
|
| 214 |
if do_reconstruction: # if ddpm sampling wasn't computed
|
| 215 |
-
pure_ddpm_img, attention_store = sample(zs, wts, attention_store=attention_store, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
|
| 216 |
reconstruction = pure_ddpm_img
|
| 217 |
do_reconstruction = False
|
| 218 |
-
return pure_ddpm_img, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, do_inversion, show_share_button
|
| 219 |
|
| 220 |
-
return reconstruction, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, do_inversion, show_share_button
|
| 221 |
|
| 222 |
|
| 223 |
def randomize_seed_fn(seed, is_random):
|
|
@@ -461,6 +462,7 @@ with gr.Blocks(css="style.css") as demo:
|
|
| 461 |
wts = gr.State()
|
| 462 |
zs = gr.State()
|
| 463 |
attention_store=gr.State()
|
|
|
|
| 464 |
reconstruction = gr.State()
|
| 465 |
do_inversion = gr.State(value=True)
|
| 466 |
do_reconstruction = gr.State(value=True)
|
|
@@ -697,6 +699,7 @@ with gr.Blocks(css="style.css") as demo:
|
|
| 697 |
fn=edit,
|
| 698 |
inputs=[input_image,
|
| 699 |
wts, zs, attention_store,
|
|
|
|
| 700 |
tar_prompt,
|
| 701 |
image_caption,
|
| 702 |
steps,
|
|
@@ -716,7 +719,7 @@ with gr.Blocks(css="style.css") as demo:
|
|
| 716 |
|
| 717 |
|
| 718 |
],
|
| 719 |
-
outputs=[sega_edited_image, reconstruct_button, do_reconstruction, reconstruction, wts, zs,attention_store, do_inversion, share_btn_container])
|
| 720 |
# .success(fn=update_gallery_display, inputs= [prev_output_image, sega_edited_image], outputs = [gallery, gallery, prev_output_image])
|
| 721 |
|
| 722 |
|
|
|
|
| 35 |
generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 36 |
return generated_caption, generated_caption
|
| 37 |
|
| 38 |
+
def sample(zs, wts, attention_store, text_cross_attention_maps, prompt_tar="", cfg_scale_tar=15, skip=36, eta=1):
|
| 39 |
latents = wts[-1].expand(1, -1, -1, -1)
|
| 40 |
+
img, attention_store, text_cross_attention_maps = pipe(
|
| 41 |
prompt=prompt_tar,
|
| 42 |
init_latents=latents,
|
| 43 |
guidance_scale=cfg_scale_tar,
|
|
|
|
| 45 |
# num_inference_steps=steps,
|
| 46 |
# use_ddpm=True,
|
| 47 |
# wts=wts.value,
|
| 48 |
+
attention_store = attention_store, text_cross_attention_maps=text_cross_attention_maps,
|
| 49 |
zs=zs,
|
| 50 |
)
|
| 51 |
+
return img.images[0], attention_store, text_cross_attention_maps
|
| 52 |
|
| 53 |
|
| 54 |
def reconstruct(
|
|
|
|
| 59 |
wts,
|
| 60 |
zs,
|
| 61 |
attention_store,
|
| 62 |
+
text_cross_attention_maps,
|
| 63 |
do_reconstruction,
|
| 64 |
reconstruction,
|
| 65 |
reconstruct_button,
|
|
|
|
| 80 |
): # if image caption was not changed, run actual reconstruction
|
| 81 |
tar_prompt = ""
|
| 82 |
latents = wts[-1].expand(1, -1, -1, -1)
|
| 83 |
+
reconstruction, attention_store, text_cross_attention_maps = sample(
|
| 84 |
+
zs, wts, attention_store=attention_store, text_cross_attention_maps=text_cross_attention_maps,prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale
|
| 85 |
)
|
| 86 |
do_reconstruction = False
|
| 87 |
return (
|
|
|
|
| 131 |
## SEGA ##
|
| 132 |
|
| 133 |
def edit(input_image,
|
| 134 |
+
wts, zs, attention_store, text_cross_attention_maps,
|
| 135 |
tar_prompt,
|
| 136 |
image_caption,
|
| 137 |
steps,
|
|
|
|
| 198 |
)
|
| 199 |
|
| 200 |
latnets = wts[-1].expand(1, -1, -1, -1)
|
| 201 |
+
sega_out, attention_store, text_cross_attention_maps = pipe(prompt=tar_prompt,
|
| 202 |
init_latents=latnets,
|
| 203 |
guidance_scale = tar_cfg_scale,
|
| 204 |
# num_images_per_prompt=1,
|
| 205 |
# num_inference_steps=steps,
|
| 206 |
# use_ddpm=True,
|
| 207 |
# wts=wts.value,
|
| 208 |
+
zs=zs, attention_store=attention_store, text_cross_attention_maps=text_cross_attention_maps, **editing_args)
|
| 209 |
|
| 210 |
+
return sega_out.images[0], gr.update(visible=True), do_reconstruction, reconstruction, wts, zs, attention_store, text_cross_attention_maps, do_inversion, show_share_button
|
| 211 |
|
| 212 |
|
| 213 |
else: # if sega concepts were not added, performs regular ddpm sampling
|
| 214 |
|
| 215 |
if do_reconstruction: # if ddpm sampling wasn't computed
|
| 216 |
+
pure_ddpm_img, attention_store, text_cross_attention_maps = sample(zs, wts, attention_store=attention_store, text_cross_attention_maps=text_cross_attention_maps, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
|
| 217 |
reconstruction = pure_ddpm_img
|
| 218 |
do_reconstruction = False
|
| 219 |
+
return pure_ddpm_img, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, text_cross_attention_maps, do_inversion, show_share_button
|
| 220 |
|
| 221 |
+
return reconstruction, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, text_cross_attention_maps, do_inversion, show_share_button
|
| 222 |
|
| 223 |
|
| 224 |
def randomize_seed_fn(seed, is_random):
|
|
|
|
| 462 |
wts = gr.State()
|
| 463 |
zs = gr.State()
|
| 464 |
attention_store=gr.State()
|
| 465 |
+
text_cross_attention_maps = gr.State()
|
| 466 |
reconstruction = gr.State()
|
| 467 |
do_inversion = gr.State(value=True)
|
| 468 |
do_reconstruction = gr.State(value=True)
|
|
|
|
| 699 |
fn=edit,
|
| 700 |
inputs=[input_image,
|
| 701 |
wts, zs, attention_store,
|
| 702 |
+
text_cross_attention_maps,
|
| 703 |
tar_prompt,
|
| 704 |
image_caption,
|
| 705 |
steps,
|
|
|
|
| 719 |
|
| 720 |
|
| 721 |
],
|
| 722 |
+
outputs=[sega_edited_image, reconstruct_button, do_reconstruction, reconstruction, wts, zs,attention_store, text_cross_attention_maps, do_inversion, share_btn_container])
|
| 723 |
# .success(fn=update_gallery_display, inputs= [prev_output_image, sega_edited_image], outputs = [gallery, gallery, prev_output_image])
|
| 724 |
|
| 725 |
|
pipeline_semantic_stable_diffusion_img2img_solver.py
CHANGED
|
@@ -500,6 +500,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
|
|
| 500 |
use_cross_attn_mask: bool = False,
|
| 501 |
# Attention store (just for visualization purposes)
|
| 502 |
attention_store = None,
|
|
|
|
| 503 |
attn_store_steps: Optional[List[int]] = [],
|
| 504 |
store_averaged_over_steps: bool = True,
|
| 505 |
use_intersect_mask: bool = False,
|
|
@@ -755,10 +756,10 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
|
|
| 755 |
# For classifier free guidance, we need to do two forward passes.
|
| 756 |
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 757 |
# to avoid doing two forward passes
|
| 758 |
-
|
| 759 |
if enable_edit_guidance:
|
| 760 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts])
|
| 761 |
-
|
| 762 |
([editing_prompt] if isinstance(editing_prompt, str) else editing_prompt)
|
| 763 |
else:
|
| 764 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
|
@@ -920,11 +921,11 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
|
|
| 920 |
if use_cross_attn_mask:
|
| 921 |
out = attention_store.aggregate_attention(
|
| 922 |
attention_maps=attention_store.step_store,
|
| 923 |
-
prompts=
|
| 924 |
res=16,
|
| 925 |
from_where=["up", "down"],
|
| 926 |
is_cross=True,
|
| 927 |
-
select=
|
| 928 |
)
|
| 929 |
attn_map = out[:, :, :, 1:1 + num_edit_tokens[c]] # 0 -> startoftext
|
| 930 |
|
|
@@ -1105,7 +1106,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
|
|
| 1105 |
if not return_dict:
|
| 1106 |
return (image, has_nsfw_concept), attention_store
|
| 1107 |
|
| 1108 |
-
return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept), attention_store
|
| 1109 |
|
| 1110 |
def encode_text(self, prompts):
|
| 1111 |
text_inputs = self.tokenizer(
|
|
|
|
| 500 |
use_cross_attn_mask: bool = False,
|
| 501 |
# Attention store (just for visualization purposes)
|
| 502 |
attention_store = None,
|
| 503 |
+
text_cross_attention_maps = None,
|
| 504 |
attn_store_steps: Optional[List[int]] = [],
|
| 505 |
store_averaged_over_steps: bool = True,
|
| 506 |
use_intersect_mask: bool = False,
|
|
|
|
| 756 |
# For classifier free guidance, we need to do two forward passes.
|
| 757 |
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 758 |
# to avoid doing two forward passes
|
| 759 |
+
text_cross_attention_maps = [org_prompt] if isinstance(org_prompt, str) else org_prompt
|
| 760 |
if enable_edit_guidance:
|
| 761 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts])
|
| 762 |
+
text_cross_attention_maps += \
|
| 763 |
([editing_prompt] if isinstance(editing_prompt, str) else editing_prompt)
|
| 764 |
else:
|
| 765 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
|
|
|
| 921 |
if use_cross_attn_mask:
|
| 922 |
out = attention_store.aggregate_attention(
|
| 923 |
attention_maps=attention_store.step_store,
|
| 924 |
+
prompts=text_cross_attention_maps,
|
| 925 |
res=16,
|
| 926 |
from_where=["up", "down"],
|
| 927 |
is_cross=True,
|
| 928 |
+
select=text_cross_attention_maps.index(editing_prompt[c]),
|
| 929 |
)
|
| 930 |
attn_map = out[:, :, :, 1:1 + num_edit_tokens[c]] # 0 -> startoftext
|
| 931 |
|
|
|
|
| 1106 |
if not return_dict:
|
| 1107 |
return (image, has_nsfw_concept), attention_store
|
| 1108 |
|
| 1109 |
+
return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept), attention_store, text_cross_attention_maps
|
| 1110 |
|
| 1111 |
def encode_text(self, prompts):
|
| 1112 |
text_inputs = self.tokenizer(
|