Spaces:
Running
on
Zero
Running
on
Zero
revert taking avg_diff out from attributes
Browse files- clip_slider_pipeline.py +24 -27
clip_slider_pipeline.py
CHANGED
|
@@ -17,8 +17,8 @@ class CLIPSlider:
|
|
| 17 |
iterations: int = 300,
|
| 18 |
):
|
| 19 |
|
| 20 |
-
|
| 21 |
-
self.pipe = sd_pipe
|
| 22 |
self.iterations = iterations
|
| 23 |
if target_word != "" or opposite != "":
|
| 24 |
self.avg_diff = self.find_latent_direction(target_word, opposite)
|
|
@@ -73,8 +73,6 @@ class CLIPSlider:
|
|
| 73 |
only_pooler = False,
|
| 74 |
normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
|
| 75 |
correlation_weight_factor = 1.0,
|
| 76 |
-
avg_diff = None,
|
| 77 |
-
avg_diff_2nd = None,
|
| 78 |
**pipeline_kwargs
|
| 79 |
):
|
| 80 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
|
@@ -85,14 +83,14 @@ class CLIPSlider:
|
|
| 85 |
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
|
| 86 |
prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
|
| 87 |
|
| 88 |
-
if avg_diff_2nd and normalize_scales:
|
| 89 |
denominator = abs(scale) + abs(scale_2nd)
|
| 90 |
scale = scale / denominator
|
| 91 |
scale_2nd = scale_2nd / denominator
|
| 92 |
if only_pooler:
|
| 93 |
-
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff * scale
|
| 94 |
-
if avg_diff_2nd:
|
| 95 |
-
prompt_embeds[:, toks.argmax()] += avg_diff_2nd * scale_2nd
|
| 96 |
else:
|
| 97 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
| 98 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
|
@@ -104,9 +102,9 @@ class CLIPSlider:
|
|
| 104 |
|
| 105 |
# weights = torch.sigmoid((weights-0.5)*7)
|
| 106 |
prompt_embeds = prompt_embeds + (
|
| 107 |
-
weights * avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
|
| 108 |
-
if avg_diff_2nd:
|
| 109 |
-
prompt_embeds += weights * avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd
|
| 110 |
|
| 111 |
|
| 112 |
torch.manual_seed(seed)
|
|
@@ -200,8 +198,6 @@ class CLIPSliderXL(CLIPSlider):
|
|
| 200 |
only_pooler = False,
|
| 201 |
normalize_scales = False,
|
| 202 |
correlation_weight_factor = 1.0,
|
| 203 |
-
avg_diff = None,
|
| 204 |
-
avg_diff_2nd= None,
|
| 205 |
**pipeline_kwargs
|
| 206 |
):
|
| 207 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
|
@@ -236,15 +232,16 @@ class CLIPSliderXL(CLIPSlider):
|
|
| 236 |
pooled_prompt_embeds = prompt_embeds[0]
|
| 237 |
prompt_embeds = prompt_embeds.hidden_states[-2]
|
| 238 |
|
| 239 |
-
if avg_diff_2nd and normalize_scales:
|
| 240 |
denominator = abs(scale) + abs(scale_2nd)
|
| 241 |
scale = scale / denominator
|
| 242 |
scale_2nd = scale_2nd / denominator
|
| 243 |
if only_pooler:
|
| 244 |
-
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff[0] * scale
|
| 245 |
-
if avg_diff_2nd:
|
| 246 |
-
prompt_embeds[:, toks.argmax()] += avg_diff_2nd[0] * scale_2nd
|
| 247 |
else:
|
|
|
|
| 248 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
| 249 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
| 250 |
|
|
@@ -254,18 +251,18 @@ class CLIPSliderXL(CLIPSlider):
|
|
| 254 |
standard_weights = torch.ones_like(weights)
|
| 255 |
|
| 256 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
| 257 |
-
prompt_embeds = prompt_embeds + (weights * avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
|
| 258 |
-
if avg_diff_2nd:
|
| 259 |
-
prompt_embeds += (weights * avg_diff_2nd[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd)
|
| 260 |
else:
|
| 261 |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
|
| 262 |
|
| 263 |
standard_weights = torch.ones_like(weights)
|
| 264 |
|
| 265 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
| 266 |
-
prompt_embeds = prompt_embeds + (weights * avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
|
| 267 |
-
if avg_diff_2nd:
|
| 268 |
-
prompt_embeds += (weights * avg_diff_2nd[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale_2nd)
|
| 269 |
|
| 270 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 271 |
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
|
@@ -328,7 +325,7 @@ class CLIPSlider3(CLIPSlider):
|
|
| 328 |
positives2 = torch.cat(positives2, dim=0)
|
| 329 |
negatives2 = torch.cat(negatives2, dim=0)
|
| 330 |
diffs2 = positives2 - negatives2
|
| 331 |
-
avg_diff2 = diffs2.mean(0, keepdim=True)
|
| 332 |
return (avg_diff, avg_diff2)
|
| 333 |
|
| 334 |
def generate(self,
|
|
@@ -386,7 +383,7 @@ class CLIPSlider3(CLIPSlider):
|
|
| 386 |
t5_prompt_embed_shape = prompt_embeds.shape[-1]
|
| 387 |
|
| 388 |
if only_pooler:
|
| 389 |
-
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + avg_diff[0] * scale
|
| 390 |
else:
|
| 391 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
| 392 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
|
@@ -396,14 +393,14 @@ class CLIPSlider3(CLIPSlider):
|
|
| 396 |
standard_weights = torch.ones_like(weights)
|
| 397 |
|
| 398 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
| 399 |
-
prompt_embeds = prompt_embeds + (weights * avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
|
| 400 |
else:
|
| 401 |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
|
| 402 |
|
| 403 |
standard_weights = torch.ones_like(weights)
|
| 404 |
|
| 405 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
| 406 |
-
prompt_embeds = prompt_embeds + (weights * avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
|
| 407 |
|
| 408 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 409 |
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
|
|
|
| 17 |
iterations: int = 300,
|
| 18 |
):
|
| 19 |
|
| 20 |
+
self.device = device
|
| 21 |
+
self.pipe = sd_pipe.to(self.device)
|
| 22 |
self.iterations = iterations
|
| 23 |
if target_word != "" or opposite != "":
|
| 24 |
self.avg_diff = self.find_latent_direction(target_word, opposite)
|
|
|
|
| 73 |
only_pooler = False,
|
| 74 |
normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None
|
| 75 |
correlation_weight_factor = 1.0,
|
|
|
|
|
|
|
| 76 |
**pipeline_kwargs
|
| 77 |
):
|
| 78 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
|
|
|
| 83 |
max_length=self.pipe.tokenizer.model_max_length).input_ids.cuda()
|
| 84 |
prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
|
| 85 |
|
| 86 |
+
if self.avg_diff_2nd and normalize_scales:
|
| 87 |
denominator = abs(scale) + abs(scale_2nd)
|
| 88 |
scale = scale / denominator
|
| 89 |
scale_2nd = scale_2nd / denominator
|
| 90 |
if only_pooler:
|
| 91 |
+
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff * scale
|
| 92 |
+
if self.avg_diff_2nd:
|
| 93 |
+
prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd * scale_2nd
|
| 94 |
else:
|
| 95 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
| 96 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
|
|
|
| 102 |
|
| 103 |
# weights = torch.sigmoid((weights-0.5)*7)
|
| 104 |
prompt_embeds = prompt_embeds + (
|
| 105 |
+
weights * self.avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
|
| 106 |
+
if self.avg_diff_2nd:
|
| 107 |
+
prompt_embeds += weights * self.avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd
|
| 108 |
|
| 109 |
|
| 110 |
torch.manual_seed(seed)
|
|
|
|
| 198 |
only_pooler = False,
|
| 199 |
normalize_scales = False,
|
| 200 |
correlation_weight_factor = 1.0,
|
|
|
|
|
|
|
| 201 |
**pipeline_kwargs
|
| 202 |
):
|
| 203 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
|
|
|
| 232 |
pooled_prompt_embeds = prompt_embeds[0]
|
| 233 |
prompt_embeds = prompt_embeds.hidden_states[-2]
|
| 234 |
|
| 235 |
+
if self.avg_diff_2nd and normalize_scales:
|
| 236 |
denominator = abs(scale) + abs(scale_2nd)
|
| 237 |
scale = scale / denominator
|
| 238 |
scale_2nd = scale_2nd / denominator
|
| 239 |
if only_pooler:
|
| 240 |
+
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff[0] * scale
|
| 241 |
+
if self.avg_diff_2nd:
|
| 242 |
+
prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd[0] * scale_2nd
|
| 243 |
else:
|
| 244 |
+
print(self.avg_diff)
|
| 245 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
| 246 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
| 247 |
|
|
|
|
| 251 |
standard_weights = torch.ones_like(weights)
|
| 252 |
|
| 253 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
| 254 |
+
prompt_embeds = prompt_embeds + (weights * self.avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
|
| 255 |
+
if self.avg_diff_2nd:
|
| 256 |
+
prompt_embeds += (weights * self.avg_diff_2nd[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd)
|
| 257 |
else:
|
| 258 |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
|
| 259 |
|
| 260 |
standard_weights = torch.ones_like(weights)
|
| 261 |
|
| 262 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
| 263 |
+
prompt_embeds = prompt_embeds + (weights * self.avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
|
| 264 |
+
if self.avg_diff_2nd:
|
| 265 |
+
prompt_embeds += (weights * self.avg_diff_2nd[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale_2nd)
|
| 266 |
|
| 267 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 268 |
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
|
|
|
| 325 |
positives2 = torch.cat(positives2, dim=0)
|
| 326 |
negatives2 = torch.cat(negatives2, dim=0)
|
| 327 |
diffs2 = positives2 - negatives2
|
| 328 |
+
avg_diff2 = diffs2.mean(0, keepdim=True)
|
| 329 |
return (avg_diff, avg_diff2)
|
| 330 |
|
| 331 |
def generate(self,
|
|
|
|
| 383 |
t5_prompt_embed_shape = prompt_embeds.shape[-1]
|
| 384 |
|
| 385 |
if only_pooler:
|
| 386 |
+
prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff[0] * scale
|
| 387 |
else:
|
| 388 |
normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True)
|
| 389 |
sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
|
|
|
|
| 393 |
standard_weights = torch.ones_like(weights)
|
| 394 |
|
| 395 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
| 396 |
+
prompt_embeds = prompt_embeds + (weights * self.avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale)
|
| 397 |
else:
|
| 398 |
weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
|
| 399 |
|
| 400 |
standard_weights = torch.ones_like(weights)
|
| 401 |
|
| 402 |
weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
|
| 403 |
+
prompt_embeds = prompt_embeds + (weights * self.avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale)
|
| 404 |
|
| 405 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 406 |
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|