Spaces:
Running
on
Zero
Running
on
Zero
Update clip_slider_pipeline.py
Browse files- clip_slider_pipeline.py +30 -20
clip_slider_pipeline.py
CHANGED
|
@@ -10,17 +10,20 @@ class CLIPSlider:
|
|
| 10 |
self,
|
| 11 |
sd_pipe,
|
| 12 |
device: torch.device,
|
| 13 |
-
target_word: str,
|
| 14 |
-
opposite: str,
|
| 15 |
target_word_2nd: str = "",
|
| 16 |
opposite_2nd: str = "",
|
| 17 |
iterations: int = 300,
|
| 18 |
):
|
| 19 |
|
| 20 |
self.device = device
|
| 21 |
-
self.pipe = sd_pipe.to(self.device)
|
| 22 |
self.iterations = iterations
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
| 24 |
if target_word_2nd != "" or opposite_2nd != "":
|
| 25 |
self.avg_diff_2nd = self.find_latent_direction(target_word_2nd, opposite_2nd)
|
| 26 |
else:
|
|
@@ -29,12 +32,15 @@ class CLIPSlider:
|
|
| 29 |
|
| 30 |
def find_latent_direction(self,
|
| 31 |
target_word:str,
|
| 32 |
-
opposite:str):
|
| 33 |
|
| 34 |
# lets identify a latent direction by taking differences between opposites
|
| 35 |
# target_word = "happy"
|
| 36 |
# opposite = "sad"
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
with torch.no_grad():
|
| 40 |
positives = []
|
|
@@ -70,6 +76,8 @@ class CLIPSlider:
|
|
| 70 |
only_pooler = False,
|
| 71 |
normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
|
| 72 |
correlation_weight_factor = 1.0,
|
|
|
|
|
|
|
| 73 |
**pipeline_kwargs
|
| 74 |
):
|
| 75 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
|
@@ -80,14 +88,14 @@ class CLIPSlider:
|
|
| 80 |
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
|
| 81 |
prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
|
| 82 |
|
| 83 |
-
if
|
| 84 |
denominator = abs(scale) + abs(scale_2nd)
|
| 85 |
scale = scale / denominator
|
| 86 |
scale_2nd = scale_2nd / denominator
|
| 87 |
if only_pooler:
|
| 88 |
-
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] +
|
| 89 |
-
if
|
| 90 |
-
prompt_embeds[:, toks.argmax()] +=
|
| 91 |
else:
|
| 92 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
| 93 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
|
@@ -99,9 +107,9 @@ class CLIPSlider:
|
|
| 99 |
|
| 100 |
# weights = torch.sigmoid((weights-0.5)*7)
|
| 101 |
prompt_embeds = prompt_embeds + (
|
| 102 |
-
weights *
|
| 103 |
-
if
|
| 104 |
-
prompt_embeds += weights *
|
| 105 |
|
| 106 |
|
| 107 |
torch.manual_seed(seed)
|
|
@@ -399,6 +407,8 @@ class T5SliderFlux(CLIPSlider):
|
|
| 399 |
only_pooler = False,
|
| 400 |
normalize_scales = False,
|
| 401 |
correlation_weight_factor = 1.0,
|
|
|
|
|
|
|
| 402 |
**pipeline_kwargs
|
| 403 |
):
|
| 404 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
|
@@ -438,14 +448,14 @@ class T5SliderFlux(CLIPSlider):
|
|
| 438 |
dtype = self.pipe.text_encoder_2.dtype
|
| 439 |
prompt_embeds = prompt_embeds.to(dtype=dtype, device=self.device)
|
| 440 |
print("1", prompt_embeds.shape)
|
| 441 |
-
if
|
| 442 |
denominator = abs(scale) + abs(scale_2nd)
|
| 443 |
scale = scale / denominator
|
| 444 |
scale_2nd = scale_2nd / denominator
|
| 445 |
if only_pooler:
|
| 446 |
-
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] +
|
| 447 |
-
if
|
| 448 |
-
prompt_embeds[:, toks.argmax()] +=
|
| 449 |
else:
|
| 450 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
| 451 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
|
@@ -457,11 +467,11 @@ class T5SliderFlux(CLIPSlider):
|
|
| 457 |
|
| 458 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
| 459 |
prompt_embeds = prompt_embeds + (
|
| 460 |
-
weights *
|
| 461 |
print("2", prompt_embeds.shape)
|
| 462 |
-
if
|
| 463 |
prompt_embeds += (
|
| 464 |
-
weights *
|
| 465 |
|
| 466 |
torch.manual_seed(seed)
|
| 467 |
images = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
|
|
|
|
| 10 |
self,
|
| 11 |
sd_pipe,
|
| 12 |
device: torch.device,
|
| 13 |
+
target_word: str = "",
|
| 14 |
+
opposite: str = "",
|
| 15 |
target_word_2nd: str = "",
|
| 16 |
opposite_2nd: str = "",
|
| 17 |
iterations: int = 300,
|
| 18 |
):
|
| 19 |
|
| 20 |
self.device = device
|
| 21 |
+
self.pipe = sd_pipe.to(self.device, torch.float16)
|
| 22 |
self.iterations = iterations
|
| 23 |
+
if target_word != "" or opposite != "":
|
| 24 |
+
self.avg_diff = self.find_latent_direction(target_word, opposite)
|
| 25 |
+
else:
|
| 26 |
+
self.avg_diff = None
|
| 27 |
if target_word_2nd != "" or opposite_2nd != "":
|
| 28 |
self.avg_diff_2nd = self.find_latent_direction(target_word_2nd, opposite_2nd)
|
| 29 |
else:
|
|
|
|
| 32 |
|
| 33 |
def find_latent_direction(self,
|
| 34 |
target_word:str,
|
| 35 |
+
opposite:str, num_iterations: int = None):
|
| 36 |
|
| 37 |
# lets identify a latent direction by taking differences between opposites
|
| 38 |
# target_word = "happy"
|
| 39 |
# opposite = "sad"
|
| 40 |
+
if num_iterations is not None:
|
| 41 |
+
iterations = num_iterations
|
| 42 |
+
else:
|
| 43 |
+
iterations = self.iterations
|
| 44 |
|
| 45 |
with torch.no_grad():
|
| 46 |
positives = []
|
|
|
|
| 76 |
only_pooler = False,
|
| 77 |
normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
|
| 78 |
correlation_weight_factor = 1.0,
|
| 79 |
+
avg_diff = None,
|
| 80 |
+
avg_diff_2nd = None,
|
| 81 |
**pipeline_kwargs
|
| 82 |
):
|
| 83 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
|
|
|
| 88 |
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
|
| 89 |
prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
|
| 90 |
|
| 91 |
+
if avg_diff_2nd and normalize_scales:
|
| 92 |
denominator = abs(scale) + abs(scale_2nd)
|
| 93 |
scale = scale / denominator
|
| 94 |
scale_2nd = scale_2nd / denominator
|
| 95 |
if only_pooler:
|
| 96 |
+
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff * scale
|
| 97 |
+
if avg_diff_2nd:
|
| 98 |
+
prompt_embeds[:, toks.argmax()] += avg_diff_2nd * scale_2nd
|
| 99 |
else:
|
| 100 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
| 101 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
|
|
|
| 107 |
|
| 108 |
# weights = torch.sigmoid((weights-0.5)*7)
|
| 109 |
prompt_embeds = prompt_embeds + (
|
| 110 |
+
weights * avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
|
| 111 |
+
if avg_diff_2nd:
|
| 112 |
+
prompt_embeds += weights * avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd
|
| 113 |
|
| 114 |
|
| 115 |
torch.manual_seed(seed)
|
|
|
|
| 407 |
only_pooler = False,
|
| 408 |
normalize_scales = False,
|
| 409 |
correlation_weight_factor = 1.0,
|
| 410 |
+
avg_diff = None,
|
| 411 |
+
avg_diff_2nd = None,
|
| 412 |
**pipeline_kwargs
|
| 413 |
):
|
| 414 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
|
|
|
| 448 |
dtype = self.pipe.text_encoder_2.dtype
|
| 449 |
prompt_embeds = prompt_embeds.to(dtype=dtype, device=self.device)
|
| 450 |
print("1", prompt_embeds.shape)
|
| 451 |
+
if avg_diff_2nd and normalize_scales:
|
| 452 |
denominator = abs(scale) + abs(scale_2nd)
|
| 453 |
scale = scale / denominator
|
| 454 |
scale_2nd = scale_2nd / denominator
|
| 455 |
if only_pooler:
|
| 456 |
+
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff * scale
|
| 457 |
+
if avg_diff_2nd:
|
| 458 |
+
prompt_embeds[:, toks.argmax()] += avg_diff_2nd * scale_2nd
|
| 459 |
else:
|
| 460 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
| 461 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
|
|
|
| 467 |
|
| 468 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
| 469 |
prompt_embeds = prompt_embeds + (
|
| 470 |
+
weights * avg_diff * scale)
|
| 471 |
print("2", prompt_embeds.shape)
|
| 472 |
+
if avg_diff_2nd:
|
| 473 |
prompt_embeds += (
|
| 474 |
+
weights * avg_diff_2nd * scale_2nd)
|
| 475 |
|
| 476 |
torch.manual_seed(seed)
|
| 477 |
images = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
|