Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Update clip_slider_pipeline.py
Browse files- clip_slider_pipeline.py +120 -39
    	
        clip_slider_pipeline.py
    CHANGED
    
    | @@ -4,6 +4,66 @@ import random | |
| 4 | 
             
            from tqdm import tqdm
         | 
| 5 | 
             
            from constants import SUBJECTS, MEDIUMS
         | 
| 6 | 
             
            from PIL import Image
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 7 |  | 
| 8 | 
             
            class CLIPSlider:
         | 
| 9 | 
             
                def __init__(
         | 
| @@ -49,9 +109,9 @@ class CLIPSlider: | |
| 49 | 
             
                            pos_prompt = f"a {medium} of a {target_word} {subject}"
         | 
| 50 | 
             
                            neg_prompt = f"a {medium} of a {opposite} {subject}"
         | 
| 51 | 
             
                            pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
         | 
| 52 | 
            -
                                                      max_length=self.pipe.tokenizer.model_max_length).input_ids. | 
| 53 | 
             
                            neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
         | 
| 54 | 
            -
                                                      max_length=self.pipe.tokenizer.model_max_length).input_ids. | 
| 55 | 
             
                            pos = self.pipe.text_encoder(pos_toks).pooler_output
         | 
| 56 | 
             
                            neg = self.pipe.text_encoder(neg_toks).pooler_output
         | 
| 57 | 
             
                            positives.append(pos)
         | 
| @@ -81,7 +141,7 @@ class CLIPSlider: | |
| 81 |  | 
| 82 | 
             
                    with torch.no_grad():
         | 
| 83 | 
             
                        toks = self.pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True,
         | 
| 84 | 
            -
                                              max_length=self.pipe.tokenizer.model_max_length).input_ids. | 
| 85 | 
             
                    prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
         | 
| 86 |  | 
| 87 | 
             
                    if self.avg_diff_2nd and normalize_scales:
         | 
| @@ -163,18 +223,18 @@ class CLIPSliderXL(CLIPSlider): | |
| 163 | 
             
                            neg_prompt = f"a {medium} of a {opposite} {subject}"
         | 
| 164 |  | 
| 165 | 
             
                            pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
         | 
| 166 | 
            -
                                                      max_length=self.pipe.tokenizer.model_max_length).input_ids. | 
| 167 | 
             
                            neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
         | 
| 168 | 
            -
                                                      max_length=self.pipe.tokenizer.model_max_length).input_ids. | 
| 169 | 
             
                            pos = self.pipe.text_encoder(pos_toks).pooler_output
         | 
| 170 | 
             
                            neg = self.pipe.text_encoder(neg_toks).pooler_output
         | 
| 171 | 
             
                            positives.append(pos)
         | 
| 172 | 
             
                            negatives.append(neg)
         | 
| 173 |  | 
| 174 | 
             
                            pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
         | 
| 175 | 
            -
                                                         max_length=self.pipe.tokenizer_2.model_max_length).input_ids. | 
| 176 | 
             
                            neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
         | 
| 177 | 
            -
                                                         max_length=self.pipe.tokenizer_2.model_max_length).input_ids. | 
| 178 | 
             
                            pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds
         | 
| 179 | 
             
                            neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds
         | 
| 180 | 
             
                            positives2.append(pos2)
         | 
| @@ -207,7 +267,7 @@ class CLIPSliderXL(CLIPSlider): | |
| 207 | 
             
                    text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2]
         | 
| 208 | 
             
                    tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2]
         | 
| 209 | 
             
                    with torch.no_grad():
         | 
| 210 | 
            -
                        # toks = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77).input_ids. | 
| 211 | 
             
                        # prompt_embeds = pipe.text_encoder(toks).last_hidden_state
         | 
| 212 |  | 
| 213 | 
             
                        prompt_embeds_list = []
         | 
| @@ -300,18 +360,18 @@ class CLIPSliderXL_inv(CLIPSlider): | |
| 300 | 
             
                            neg_prompt = f"a {medium} of a {opposite} {subject}"
         | 
| 301 |  | 
| 302 | 
             
                            pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
         | 
| 303 | 
            -
                                                      max_length=self.pipe.tokenizer.model_max_length).input_ids. | 
| 304 | 
             
                            neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
         | 
| 305 | 
            -
                                                      max_length=self.pipe.tokenizer.model_max_length).input_ids. | 
| 306 | 
             
                            pos = self.pipe.text_encoder(pos_toks).pooler_output
         | 
| 307 | 
             
                            neg = self.pipe.text_encoder(neg_toks).pooler_output
         | 
| 308 | 
             
                            positives.append(pos)
         | 
| 309 | 
             
                            negatives.append(neg)
         | 
| 310 |  | 
| 311 | 
             
                            pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
         | 
| 312 | 
            -
                                                         max_length=self.pipe.tokenizer_2.model_max_length).input_ids. | 
| 313 | 
             
                            neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
         | 
| 314 | 
            -
                                                         max_length=self.pipe.tokenizer_2.model_max_length).input_ids. | 
| 315 | 
             
                            pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds
         | 
| 316 | 
             
                            neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds
         | 
| 317 | 
             
                            positives2.append(pos2)
         | 
| @@ -377,14 +437,14 @@ class CLIPSliderFlux(CLIPSlider): | |
| 377 | 
             
                                                           truncation=True,
         | 
| 378 | 
             
                                                           return_overflowing_tokens=False,
         | 
| 379 | 
             
                                                           return_length=False,
         | 
| 380 | 
            -
                                                           return_tensors="pt",).input_ids. | 
| 381 | 
             
                            neg_toks = self.pipe.tokenizer(neg_prompt,
         | 
| 382 | 
             
                                                           padding="max_length",
         | 
| 383 | 
             
                                                           max_length=self.pipe.tokenizer_max_length,
         | 
| 384 | 
             
                                                           truncation=True,
         | 
| 385 | 
             
                                                           return_overflowing_tokens=False,
         | 
| 386 | 
             
                                                           return_length=False,
         | 
| 387 | 
            -
                                                           return_tensors="pt",).input_ids. | 
| 388 | 
             
                            pos = self.pipe.text_encoder(pos_toks).pooler_output
         | 
| 389 | 
             
                            neg = self.pipe.text_encoder(neg_toks).pooler_output
         | 
| 390 | 
             
                            positives.append(pos)
         | 
| @@ -400,17 +460,22 @@ class CLIPSliderFlux(CLIPSlider): | |
| 400 |  | 
| 401 | 
             
                def generate(self,
         | 
| 402 | 
             
                    prompt = "a photo of a house",
         | 
| 403 | 
            -
                    scale = 2,
         | 
| 404 | 
            -
                    scale_2nd = 2,
         | 
| 405 | 
             
                    seed = 15,
         | 
| 406 | 
             
                    normalize_scales = False,
         | 
| 407 | 
             
                    avg_diff = None,
         | 
| 408 | 
            -
                    avg_diff_2nd = None, | 
|  | |
|  | |
| 409 | 
             
                    **pipeline_kwargs
         | 
| 410 | 
             
                    ):
         | 
| 411 | 
             
                    # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
         | 
| 412 | 
             
                    # if pooler token only [-4,4] work well
         | 
| 413 |  | 
|  | |
|  | |
|  | |
|  | |
| 414 | 
             
                    with torch.no_grad():
         | 
| 415 | 
             
                        text_inputs = self.pipe.tokenizer(
         | 
| 416 | 
             
                            prompt,
         | 
| @@ -423,15 +488,11 @@ class CLIPSliderFlux(CLIPSlider): | |
| 423 | 
             
                        )
         | 
| 424 |  | 
| 425 | 
             
                        text_input_ids = text_inputs.input_ids
         | 
| 426 | 
            -
                         | 
| 427 | 
            -
             | 
| 428 | 
            -
                         | 
| 429 | 
            -
                         | 
| 430 | 
            -
                         | 
| 431 | 
            -
             | 
| 432 | 
            -
                        # Use pooled output of CLIPTextModel
         | 
| 433 | 
            -
             | 
| 434 | 
            -
                        text_inputs = self.pipe.tokenizer_2(
         | 
| 435 | 
             
                            prompt,
         | 
| 436 | 
             
                            padding="max_length",
         | 
| 437 | 
             
                            max_length=512,
         | 
| @@ -440,21 +501,40 @@ class CLIPSliderFlux(CLIPSlider): | |
| 440 | 
             
                            return_overflowing_tokens=False,
         | 
| 441 | 
             
                            return_tensors="pt",
         | 
| 442 | 
             
                        )
         | 
| 443 | 
            -
                         | 
| 444 | 
            -
                         | 
| 445 | 
            -
                         | 
| 446 | 
            -
                         | 
| 447 | 
            -
             | 
| 448 | 
            -
             | 
| 449 | 
            -
             | 
| 450 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 451 |  | 
| 452 | 
            -
                        pooled_prompt_embeds = pooled_prompt_embeds + avg_diff * scale
         | 
| 453 | 
             
                        if avg_diff_2nd is not None:
         | 
| 454 | 
            -
                             | 
|  | |
| 455 |  | 
| 456 | 
             
                        torch.manual_seed(seed)
         | 
| 457 | 
            -
                        images = self.pipe(prompt_embeds= | 
|  | |
| 458 | 
             
                                           **pipeline_kwargs).images
         | 
| 459 |  | 
| 460 | 
             
                    return images[0]
         | 
| @@ -483,6 +563,7 @@ class CLIPSliderFlux(CLIPSlider): | |
| 483 | 
             
                        canvas.paste(im, (640 * i, 0))
         | 
| 484 |  | 
| 485 | 
             
                    return canvas
         | 
|  | |
| 486 | 
             
            class T5SliderFlux(CLIPSlider):
         | 
| 487 |  | 
| 488 | 
             
                def find_latent_direction(self,
         | 
| @@ -509,14 +590,14 @@ class T5SliderFlux(CLIPSlider): | |
| 509 | 
             
                                                             truncation=True,
         | 
| 510 | 
             
                                                             return_length=False,
         | 
| 511 | 
             
                                                             return_overflowing_tokens=False,
         | 
| 512 | 
            -
                                                             max_length=self.pipe.tokenizer_2.model_max_length).input_ids. | 
| 513 | 
             
                            neg_toks = self.pipe.tokenizer_2(neg_prompt,
         | 
| 514 | 
             
                                                             return_tensors="pt",
         | 
| 515 | 
             
                                                             padding="max_length",
         | 
| 516 | 
             
                                                             truncation=True,
         | 
| 517 | 
             
                                                             return_length=False,
         | 
| 518 | 
             
                                                             return_overflowing_tokens=False,
         | 
| 519 | 
            -
                                                             max_length=self.pipe.tokenizer_2.model_max_length).input_ids. | 
| 520 | 
             
                            pos = self.pipe.text_encoder_2(pos_toks, output_hidden_states=False)[0]
         | 
| 521 | 
             
                            neg = self.pipe.text_encoder_2(neg_toks, output_hidden_states=False)[0]
         | 
| 522 | 
             
                            positives.append(pos)
         | 
|  | |
| 4 | 
             
            from tqdm import tqdm
         | 
| 5 | 
             
            from constants import SUBJECTS, MEDIUMS
         | 
| 6 | 
             
            from PIL import Image
         | 
| 7 | 
            +
            import math # For acos, sin
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # Slerp (Spherical Linear Interpolation) function
         | 
| 10 | 
            +
            def slerp(v0, v1, t, DOT_THRESHOLD=0.9995):
         | 
| 11 | 
            +
                """
         | 
| 12 | 
            +
                Spherical linear interpolation.
         | 
| 13 | 
            +
                v0, v1: Tensors to interpolate between.
         | 
| 14 | 
            +
                t: Interpolation factor (scalar or tensor).
         | 
| 15 | 
            +
                DOT_THRESHOLD: Threshold for considering vectors collinear.
         | 
| 16 | 
            +
                """
         | 
| 17 | 
            +
                if not isinstance(t, torch.Tensor):
         | 
| 18 | 
            +
                    t = torch.tensor(t, device=v0.device, dtype=v0.dtype)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                # Dot product
         | 
| 21 | 
            +
                dot = torch.sum(v0 * v1 / (torch.norm(v0, dim=-1, keepdim=True) * torch.norm(v1, dim=-1, keepdim=True) + 1e-8), dim=-1, keepdim=True)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                # If vectors are too close, use linear interpolation (LERP)
         | 
| 24 | 
            +
                # This also handles t=0 and t=1 correctly if dot is 1.
         | 
| 25 | 
            +
                # Also, if dot is -1 (opposite), omega is pi.
         | 
| 26 | 
            +
                if torch.any(torch.abs(dot) > DOT_THRESHOLD):
         | 
| 27 | 
            +
                    # For Slerp, if they are too close, omega is small, sin(omega) is small.
         | 
| 28 | 
            +
                    # Fallback to LERP for stability and when vectors are nearly collinear.
         | 
| 29 | 
            +
                    # However, the general Slerp formula handles this if dot is clamped.
         | 
| 30 | 
            +
                    # Let's use the standard formula but ensure stability.
         | 
| 31 | 
            +
                    pass # Continue to Slerp formula with clamping
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                # Clamp dot to prevent NaN from acos due to floating point errors.
         | 
| 34 | 
            +
                dot = torch.clamp(dot, -1.0, 1.0)
         | 
| 35 | 
            +
                omega = torch.acos(dot) # Angle between vectors
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                # Get magnitudes for later linear interpolation of magnitude
         | 
| 38 | 
            +
                mag_v0 = torch.norm(v0, dim=-1, keepdim=True)
         | 
| 39 | 
            +
                mag_v1 = torch.norm(v1, dim=-1, keepdim=True)
         | 
| 40 | 
            +
                
         | 
| 41 | 
            +
                interpolated_mag = (1 - t) * mag_v0 + t * mag_v1
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                # Normalize v0 and v1 for pure Slerp on direction
         | 
| 44 | 
            +
                v0_norm = v0 / (mag_v0 + 1e-8)
         | 
| 45 | 
            +
                v1_norm = v1 / (mag_v1 + 1e-8)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                # If sin_omega is very small, vectors are nearly collinear.
         | 
| 48 | 
            +
                # LERP on normalized vectors is a good approximation.
         | 
| 49 | 
            +
                # Then re-apply interpolated magnitude.
         | 
| 50 | 
            +
                sin_omega = torch.sin(omega)
         | 
| 51 | 
            +
                
         | 
| 52 | 
            +
                # Condition for LERP fallback (nearly collinear)
         | 
| 53 | 
            +
                # Using a small epsilon for sin_omega
         | 
| 54 | 
            +
                use_lerp_fallback = sin_omega.abs() < 1e-5 
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                s0 = torch.sin((1 - t) * omega) / (sin_omega + 1e-8) # Add epsilon to sin_omega for stability
         | 
| 57 | 
            +
                s1 = torch.sin(t * omega) / (sin_omega + 1e-8)       # Add epsilon to sin_omega for stability
         | 
| 58 | 
            +
                
         | 
| 59 | 
            +
                # For elements where LERP fallback is needed
         | 
| 60 | 
            +
                s0[use_lerp_fallback] = 1.0 - t
         | 
| 61 | 
            +
                s1[use_lerp_fallback] = t
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                result_norm = s0 * v0_norm + s1 * v1_norm
         | 
| 64 | 
            +
                result = result_norm * interpolated_mag # Re-apply interpolated magnitude
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                return result.to(v0.dtype)
         | 
| 67 |  | 
| 68 | 
             
            class CLIPSlider:
         | 
| 69 | 
             
                def __init__(
         | 
|  | |
| 109 | 
             
                            pos_prompt = f"a {medium} of a {target_word} {subject}"
         | 
| 110 | 
             
                            neg_prompt = f"a {medium} of a {opposite} {subject}"
         | 
| 111 | 
             
                            pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
         | 
| 112 | 
            +
                                                      max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device)
         | 
| 113 | 
             
                            neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
         | 
| 114 | 
            +
                                                      max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device)
         | 
| 115 | 
             
                            pos = self.pipe.text_encoder(pos_toks).pooler_output
         | 
| 116 | 
             
                            neg = self.pipe.text_encoder(neg_toks).pooler_output
         | 
| 117 | 
             
                            positives.append(pos)
         | 
|  | |
| 141 |  | 
| 142 | 
             
                    with torch.no_grad():
         | 
| 143 | 
             
                        toks = self.pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True,
         | 
| 144 | 
            +
                                              max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device)
         | 
| 145 | 
             
                    prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state
         | 
| 146 |  | 
| 147 | 
             
                    if self.avg_diff_2nd and normalize_scales:
         | 
|  | |
| 223 | 
             
                            neg_prompt = f"a {medium} of a {opposite} {subject}"
         | 
| 224 |  | 
| 225 | 
             
                            pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
         | 
| 226 | 
            +
                                                      max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device)
         | 
| 227 | 
             
                            neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
         | 
| 228 | 
            +
                                                      max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device)
         | 
| 229 | 
             
                            pos = self.pipe.text_encoder(pos_toks).pooler_output
         | 
| 230 | 
             
                            neg = self.pipe.text_encoder(neg_toks).pooler_output
         | 
| 231 | 
             
                            positives.append(pos)
         | 
| 232 | 
             
                            negatives.append(neg)
         | 
| 233 |  | 
| 234 | 
             
                            pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
         | 
| 235 | 
            +
                                                         max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device)
         | 
| 236 | 
             
                            neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
         | 
| 237 | 
            +
                                                         max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device)
         | 
| 238 | 
             
                            pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds
         | 
| 239 | 
             
                            neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds
         | 
| 240 | 
             
                            positives2.append(pos2)
         | 
|  | |
| 267 | 
             
                    text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2]
         | 
| 268 | 
             
                    tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2]
         | 
| 269 | 
             
                    with torch.no_grad():
         | 
| 270 | 
            +
                        # toks = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77).input_ids.to(self.device)
         | 
| 271 | 
             
                        # prompt_embeds = pipe.text_encoder(toks).last_hidden_state
         | 
| 272 |  | 
| 273 | 
             
                        prompt_embeds_list = []
         | 
|  | |
| 360 | 
             
                            neg_prompt = f"a {medium} of a {opposite} {subject}"
         | 
| 361 |  | 
| 362 | 
             
                            pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
         | 
| 363 | 
            +
                                                      max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device)
         | 
| 364 | 
             
                            neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
         | 
| 365 | 
            +
                                                      max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device)
         | 
| 366 | 
             
                            pos = self.pipe.text_encoder(pos_toks).pooler_output
         | 
| 367 | 
             
                            neg = self.pipe.text_encoder(neg_toks).pooler_output
         | 
| 368 | 
             
                            positives.append(pos)
         | 
| 369 | 
             
                            negatives.append(neg)
         | 
| 370 |  | 
| 371 | 
             
                            pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True,
         | 
| 372 | 
            +
                                                         max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device)
         | 
| 373 | 
             
                            neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True,
         | 
| 374 | 
            +
                                                         max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device)
         | 
| 375 | 
             
                            pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds
         | 
| 376 | 
             
                            neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds
         | 
| 377 | 
             
                            positives2.append(pos2)
         | 
|  | |
| 437 | 
             
                                                           truncation=True,
         | 
| 438 | 
             
                                                           return_overflowing_tokens=False,
         | 
| 439 | 
             
                                                           return_length=False,
         | 
| 440 | 
            +
                                                           return_tensors="pt",).input_ids.to(self.device)
         | 
| 441 | 
             
                            neg_toks = self.pipe.tokenizer(neg_prompt,
         | 
| 442 | 
             
                                                           padding="max_length",
         | 
| 443 | 
             
                                                           max_length=self.pipe.tokenizer_max_length,
         | 
| 444 | 
             
                                                           truncation=True,
         | 
| 445 | 
             
                                                           return_overflowing_tokens=False,
         | 
| 446 | 
             
                                                           return_length=False,
         | 
| 447 | 
            +
                                                           return_tensors="pt",).input_ids.to(self.device)
         | 
| 448 | 
             
                            pos = self.pipe.text_encoder(pos_toks).pooler_output
         | 
| 449 | 
             
                            neg = self.pipe.text_encoder(neg_toks).pooler_output
         | 
| 450 | 
             
                            positives.append(pos)
         | 
|  | |
| 460 |  | 
| 461 | 
             
                def generate(self,
         | 
| 462 | 
             
                    prompt = "a photo of a house",
         | 
| 463 | 
            +
                    scale = 2.0,
         | 
|  | |
| 464 | 
             
                    seed = 15,
         | 
| 465 | 
             
                    normalize_scales = False,
         | 
| 466 | 
             
                    avg_diff = None,
         | 
| 467 | 
            +
                    avg_diff_2nd = None,
         | 
| 468 | 
            +
                    use_slerp: bool = False, 
         | 
| 469 | 
            +
                    max_strength_for_slerp_endpoint: float = 0.0,
         | 
| 470 | 
             
                    **pipeline_kwargs
         | 
| 471 | 
             
                    ):
         | 
| 472 | 
             
                    # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
         | 
| 473 | 
             
                    # if pooler token only [-4,4] work well
         | 
| 474 |  | 
| 475 | 
            +
                    # Remove slider-specific kwargs before passing to the pipeline
         | 
| 476 | 
            +
                    pipeline_kwargs.pop('use_slerp', None)
         | 
| 477 | 
            +
                    pipeline_kwargs.pop('max_strength_for_slerp_endpoint', None)
         | 
| 478 | 
            +
             | 
| 479 | 
             
                    with torch.no_grad():
         | 
| 480 | 
             
                        text_inputs = self.pipe.tokenizer(
         | 
| 481 | 
             
                            prompt,
         | 
|  | |
| 488 | 
             
                        )
         | 
| 489 |  | 
| 490 | 
             
                        text_input_ids = text_inputs.input_ids
         | 
| 491 | 
            +
                        prompt_embeds_out = self.pipe.text_encoder(text_input_ids.to(self.device), output_hidden_states=False)
         | 
| 492 | 
            +
                        original_pooled_prompt_embeds = prompt_embeds_out.pooler_output.to(dtype=self.pipe.text_encoder.dtype, device=self.device)
         | 
| 493 | 
            +
                        
         | 
| 494 | 
            +
                        # For the second text encoder (T5-like for FLUX)
         | 
| 495 | 
            +
                        text_inputs_2 = self.pipe.tokenizer_2(
         | 
|  | |
|  | |
|  | |
|  | |
| 496 | 
             
                            prompt,
         | 
| 497 | 
             
                            padding="max_length",
         | 
| 498 | 
             
                            max_length=512,
         | 
|  | |
| 501 | 
             
                            return_overflowing_tokens=False,
         | 
| 502 | 
             
                            return_tensors="pt",
         | 
| 503 | 
             
                        )
         | 
| 504 | 
            +
                        toks_2 = text_inputs_2.input_ids
         | 
| 505 | 
            +
                        # This is the non-pooled, sequence output for the second encoder
         | 
| 506 | 
            +
                        prompt_embeds_seq_2 = self.pipe.text_encoder_2(toks_2.to(self.device), output_hidden_states=False)[0]
         | 
| 507 | 
            +
                        prompt_embeds_seq_2 = prompt_embeds_seq_2.to(dtype=self.pipe.text_encoder_2.dtype, device=self.device)
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                        modified_pooled_embeds = original_pooled_prompt_embeds.clone()
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                        if avg_diff is not None:
         | 
| 512 | 
            +
                            if use_slerp and max_strength_for_slerp_endpoint != 0.0:
         | 
| 513 | 
            +
                                # Slerp logic
         | 
| 514 | 
            +
                                slerp_t_val = 0.0
         | 
| 515 | 
            +
                                if max_strength_for_slerp_endpoint != 0: 
         | 
| 516 | 
            +
                                    slerp_t_val = abs(scale) / max_strength_for_slerp_endpoint
         | 
| 517 | 
            +
                                slerp_t_val = min(slerp_t_val, 1.0) 
         | 
| 518 | 
            +
             | 
| 519 | 
            +
                                if scale == 0:
         | 
| 520 | 
            +
                                    pass
         | 
| 521 | 
            +
                                else:
         | 
| 522 | 
            +
                                    v0 = original_pooled_prompt_embeds.float() 
         | 
| 523 | 
            +
                                    if scale > 0:
         | 
| 524 | 
            +
                                        v_end_target = original_pooled_prompt_embeds + max_strength_for_slerp_endpoint * avg_diff
         | 
| 525 | 
            +
                                    else: 
         | 
| 526 | 
            +
                                        v_end_target = original_pooled_prompt_embeds - max_strength_for_slerp_endpoint * avg_diff
         | 
| 527 | 
            +
                                    modified_pooled_embeds = slerp(v0, v_end_target.float(), slerp_t_val).to(original_pooled_prompt_embeds.dtype)
         | 
| 528 | 
            +
                            else:
         | 
| 529 | 
            +
                                modified_pooled_embeds = modified_pooled_embeds + avg_diff * scale
         | 
| 530 |  | 
|  | |
| 531 | 
             
                        if avg_diff_2nd is not None:
         | 
| 532 | 
            +
                            scale_2nd_val = pipeline_kwargs.get("scale_2nd", 0.0) 
         | 
| 533 | 
            +
                            modified_pooled_embeds += avg_diff_2nd * scale_2nd_val
         | 
| 534 |  | 
| 535 | 
             
                        torch.manual_seed(seed)
         | 
| 536 | 
            +
                        images = self.pipe(prompt_embeds=prompt_embeds_seq_2, 
         | 
| 537 | 
            +
                                           pooled_prompt_embeds=modified_pooled_embeds, 
         | 
| 538 | 
             
                                           **pipeline_kwargs).images
         | 
| 539 |  | 
| 540 | 
             
                    return images[0]
         | 
|  | |
| 563 | 
             
                        canvas.paste(im, (640 * i, 0))
         | 
| 564 |  | 
| 565 | 
             
                    return canvas
         | 
| 566 | 
            +
             | 
| 567 | 
             
            class T5SliderFlux(CLIPSlider):
         | 
| 568 |  | 
| 569 | 
             
                def find_latent_direction(self,
         | 
|  | |
| 590 | 
             
                                                             truncation=True,
         | 
| 591 | 
             
                                                             return_length=False,
         | 
| 592 | 
             
                                                             return_overflowing_tokens=False,
         | 
| 593 | 
            +
                                                             max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device)
         | 
| 594 | 
             
                            neg_toks = self.pipe.tokenizer_2(neg_prompt,
         | 
| 595 | 
             
                                                             return_tensors="pt",
         | 
| 596 | 
             
                                                             padding="max_length",
         | 
| 597 | 
             
                                                             truncation=True,
         | 
| 598 | 
             
                                                             return_length=False,
         | 
| 599 | 
             
                                                             return_overflowing_tokens=False,
         | 
| 600 | 
            +
                                                             max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device)
         | 
| 601 | 
             
                            pos = self.pipe.text_encoder_2(pos_toks, output_hidden_states=False)[0]
         | 
| 602 | 
             
                            neg = self.pipe.text_encoder_2(neg_toks, output_hidden_states=False)[0]
         | 
| 603 | 
             
                            positives.append(pos)
         | 
