Spaces:
Runtime error
Runtime error
Linoy Tsaban
commited on
Commit
·
4065064
1
Parent(s):
45e73ca
Update pipeline_semantic_stable_diffusion_img2img_solver.py
Browse files
pipeline_semantic_stable_diffusion_img2img_solver.py
CHANGED
|
@@ -500,6 +500,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
|
|
| 500 |
use_cross_attn_mask: bool = False,
|
| 501 |
# Attention store (just for visualization purposes)
|
| 502 |
attention_store = None,
|
|
|
|
| 503 |
attn_store_steps: Optional[List[int]] = [],
|
| 504 |
store_averaged_over_steps: bool = True,
|
| 505 |
use_intersect_mask: bool = False,
|
|
@@ -755,10 +756,10 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
|
|
| 755 |
# For classifier free guidance, we need to do two forward passes.
|
| 756 |
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 757 |
# to avoid doing two forward passes
|
| 758 |
-
|
| 759 |
if enable_edit_guidance:
|
| 760 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts])
|
| 761 |
-
|
| 762 |
([editing_prompt] if isinstance(editing_prompt, str) else editing_prompt)
|
| 763 |
else:
|
| 764 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
|
@@ -920,11 +921,11 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
|
|
| 920 |
if use_cross_attn_mask:
|
| 921 |
out = attention_store.aggregate_attention(
|
| 922 |
attention_maps=attention_store.step_store,
|
| 923 |
-
prompts=
|
| 924 |
res=16,
|
| 925 |
from_where=["up", "down"],
|
| 926 |
is_cross=True,
|
| 927 |
-
select=
|
| 928 |
)
|
| 929 |
attn_map = out[:, :, :, 1:1 + num_edit_tokens[c]] # 0 -> startoftext
|
| 930 |
|
|
@@ -1105,7 +1106,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
|
|
| 1105 |
if not return_dict:
|
| 1106 |
return (image, has_nsfw_concept), attention_store
|
| 1107 |
|
| 1108 |
-
return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept), attention_store
|
| 1109 |
|
| 1110 |
def encode_text(self, prompts):
|
| 1111 |
text_inputs = self.tokenizer(
|
|
|
|
| 500 |
use_cross_attn_mask: bool = False,
|
| 501 |
# Attention store (just for visualization purposes)
|
| 502 |
attention_store = None,
|
| 503 |
+
text_cross_attention_maps = None,
|
| 504 |
attn_store_steps: Optional[List[int]] = [],
|
| 505 |
store_averaged_over_steps: bool = True,
|
| 506 |
use_intersect_mask: bool = False,
|
|
|
|
| 756 |
# For classifier free guidance, we need to do two forward passes.
|
| 757 |
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 758 |
# to avoid doing two forward passes
|
| 759 |
+
text_cross_attention_maps = [org_prompt] if isinstance(org_prompt, str) else org_prompt
|
| 760 |
if enable_edit_guidance:
|
| 761 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts])
|
| 762 |
+
text_cross_attention_maps += \
|
| 763 |
([editing_prompt] if isinstance(editing_prompt, str) else editing_prompt)
|
| 764 |
else:
|
| 765 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
|
|
|
| 921 |
if use_cross_attn_mask:
|
| 922 |
out = attention_store.aggregate_attention(
|
| 923 |
attention_maps=attention_store.step_store,
|
| 924 |
+
prompts=text_cross_attention_maps,
|
| 925 |
res=16,
|
| 926 |
from_where=["up", "down"],
|
| 927 |
is_cross=True,
|
| 928 |
+
select=text_cross_attention_maps.index(editing_prompt[c]),
|
| 929 |
)
|
| 930 |
attn_map = out[:, :, :, 1:1 + num_edit_tokens[c]] # 0 -> startoftext
|
| 931 |
|
|
|
|
| 1106 |
if not return_dict:
|
| 1107 |
return (image, has_nsfw_concept), attention_store
|
| 1108 |
|
| 1109 |
+
return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept), attention_store, text_cross_attention_maps
|
| 1110 |
|
| 1111 |
def encode_text(self, prompts):
|
| 1112 |
text_inputs = self.tokenizer(
|