Spaces:
Running
on
Zero
Running
on
Zero
| import diffusers | |
| import torch | |
| import random | |
| from tqdm import tqdm | |
| from constants import SUBJECTS, MEDIUMS | |
| from PIL import Image | |
| import math # For acos, sin | |
| # Slerp (Spherical Linear Interpolation) function | |
| def slerp(v0, v1, t, DOT_THRESHOLD=0.9995): | |
| """ | |
| Spherical linear interpolation. | |
| v0, v1: Tensors to interpolate between. | |
| t: Interpolation factor (scalar or tensor). | |
| DOT_THRESHOLD: Threshold for considering vectors collinear. | |
| """ | |
| if not isinstance(t, torch.Tensor): | |
| t = torch.tensor(t, device=v0.device, dtype=v0.dtype) | |
| # Dot product | |
| dot = torch.sum(v0 * v1 / (torch.norm(v0, dim=-1, keepdim=True) * torch.norm(v1, dim=-1, keepdim=True) + 1e-8), dim=-1, keepdim=True) | |
| # If vectors are too close, use linear interpolation (LERP) | |
| # This also handles t=0 and t=1 correctly if dot is 1. | |
| # Also, if dot is -1 (opposite), omega is pi. | |
| if torch.any(torch.abs(dot) > DOT_THRESHOLD): | |
| # For Slerp, if they are too close, omega is small, sin(omega) is small. | |
| # Fallback to LERP for stability and when vectors are nearly collinear. | |
| # However, the general Slerp formula handles this if dot is clamped. | |
| # Let's use the standard formula but ensure stability. | |
| pass # Continue to Slerp formula with clamping | |
| # Clamp dot to prevent NaN from acos due to floating point errors. | |
| dot = torch.clamp(dot, -1.0, 1.0) | |
| omega = torch.acos(dot) # Angle between vectors | |
| # Get magnitudes for later linear interpolation of magnitude | |
| mag_v0 = torch.norm(v0, dim=-1, keepdim=True) | |
| mag_v1 = torch.norm(v1, dim=-1, keepdim=True) | |
| interpolated_mag = (1 - t) * mag_v0 + t * mag_v1 | |
| # Normalize v0 and v1 for pure Slerp on direction | |
| v0_norm = v0 / (mag_v0 + 1e-8) | |
| v1_norm = v1 / (mag_v1 + 1e-8) | |
| # If sin_omega is very small, vectors are nearly collinear. | |
| # LERP on normalized vectors is a good approximation. | |
| # Then re-apply interpolated magnitude. | |
| sin_omega = torch.sin(omega) | |
| # Condition for LERP fallback (nearly collinear) | |
| # Using a small epsilon for sin_omega | |
| use_lerp_fallback = sin_omega.abs() < 1e-5 | |
| s0 = torch.sin((1 - t) * omega) / (sin_omega + 1e-8) # Add epsilon to sin_omega for stability | |
| s1 = torch.sin(t * omega) / (sin_omega + 1e-8) # Add epsilon to sin_omega for stability | |
| # For elements where LERP fallback is needed | |
| s0[use_lerp_fallback] = 1.0 - t | |
| s1[use_lerp_fallback] = t | |
| result_norm = s0 * v0_norm + s1 * v1_norm | |
| result = result_norm * interpolated_mag # Re-apply interpolated magnitude | |
| return result.to(v0.dtype) | |
| class CLIPSlider: | |
| def __init__( | |
| self, | |
| sd_pipe, | |
| device: torch.device, | |
| target_word: str = "", | |
| opposite: str = "", | |
| target_word_2nd: str = "", | |
| opposite_2nd: str = "", | |
| iterations: int = 300, | |
| ): | |
| self.device = device | |
| self.pipe = sd_pipe.to(self.device, torch.float16) | |
| self.iterations = iterations | |
| if target_word != "" or opposite != "": | |
| self.avg_diff = self.find_latent_direction(target_word, opposite) | |
| else: | |
| self.avg_diff = None | |
| if target_word_2nd != "" or opposite_2nd != "": | |
| self.avg_diff_2nd = self.find_latent_direction(target_word_2nd, opposite_2nd) | |
| else: | |
| self.avg_diff_2nd = None | |
| def find_latent_direction(self, | |
| target_word:str, | |
| opposite:str): | |
| # lets identify a latent direction by taking differences between opposites | |
| # target_word = "happy" | |
| # opposite = "sad" | |
| with torch.no_grad(): | |
| positives = [] | |
| negatives = [] | |
| for i in tqdm(range(self.iterations)): | |
| medium = random.choice(MEDIUMS) | |
| subject = random.choice(SUBJECTS) | |
| pos_prompt = f"a {medium} of a {target_word} {subject}" | |
| neg_prompt = f"a {medium} of a {opposite} {subject}" | |
| pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True, | |
| max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device) | |
| neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True, | |
| max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device) | |
| pos = self.pipe.text_encoder(pos_toks).pooler_output | |
| neg = self.pipe.text_encoder(neg_toks).pooler_output | |
| positives.append(pos) | |
| negatives.append(neg) | |
| positives = torch.cat(positives, dim=0) | |
| negatives = torch.cat(negatives, dim=0) | |
| diffs = positives - negatives | |
| avg_diff = diffs.mean(0, keepdim=True) | |
| return avg_diff | |
| def generate(self, | |
| prompt = "a photo of a house", | |
| scale = 2., | |
| scale_2nd = 0., # scale for the 2nd dim directions when avg_diff_2nd is not None | |
| seed = 15, | |
| only_pooler = False, | |
| normalize_scales = False, # whether to normalize the scales when avg_diff_2nd is not None | |
| correlation_weight_factor = 1.0, | |
| **pipeline_kwargs | |
| ): | |
| # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true | |
| # if pooler token only [-4,4] work well | |
| with torch.no_grad(): | |
| toks = self.pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, | |
| max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device) | |
| prompt_embeds = self.pipe.text_encoder(toks).last_hidden_state | |
| if self.avg_diff_2nd and normalize_scales: | |
| denominator = abs(scale) + abs(scale_2nd) | |
| scale = scale / denominator | |
| scale_2nd = scale_2nd / denominator | |
| if only_pooler: | |
| prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff * scale | |
| if self.avg_diff_2nd: | |
| prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd * scale_2nd | |
| else: | |
| normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True) | |
| sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T | |
| weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768) | |
| standard_weights = torch.ones_like(weights) | |
| weights = standard_weights + (weights - standard_weights) * correlation_weight_factor | |
| # weights = torch.sigmoid((weights-0.5)*7) | |
| prompt_embeds = prompt_embeds + ( | |
| weights * self.avg_diff[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale) | |
| if self.avg_diff_2nd: | |
| prompt_embeds += weights * self.avg_diff_2nd[None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd | |
| torch.manual_seed(seed) | |
| images = self.pipe(prompt_embeds=prompt_embeds, **pipeline_kwargs).images | |
| return images | |
| def spectrum(self, | |
| prompt="a photo of a house", | |
| low_scale=-2, | |
| low_scale_2nd=-2, | |
| high_scale=2, | |
| high_scale_2nd=2, | |
| steps=5, | |
| seed=15, | |
| only_pooler=False, | |
| normalize_scales=False, | |
| correlation_weight_factor=1.0, | |
| **pipeline_kwargs | |
| ): | |
| images = [] | |
| for i in range(steps): | |
| scale = low_scale + (high_scale - low_scale) * i / (steps - 1) | |
| scale_2nd = low_scale_2nd + (high_scale_2nd - low_scale_2nd) * i / (steps - 1) | |
| image = self.generate(prompt, scale, scale_2nd, seed, only_pooler, normalize_scales, correlation_weight_factor, **pipeline_kwargs) | |
| images.append(image[0]) | |
| canvas = Image.new('RGB', (640 * steps, 640)) | |
| for i, im in enumerate(images): | |
| canvas.paste(im, (640 * i, 0)) | |
| return canvas | |
| class CLIPSliderXL(CLIPSlider): | |
| def find_latent_direction(self, | |
| target_word:str, | |
| opposite:str): | |
| # lets identify a latent direction by taking differences between opposites | |
| # target_word = "happy" | |
| # opposite = "sad" | |
| with torch.no_grad(): | |
| positives = [] | |
| negatives = [] | |
| positives2 = [] | |
| negatives2 = [] | |
| for i in tqdm(range(self.iterations)): | |
| medium = random.choice(MEDIUMS) | |
| subject = random.choice(SUBJECTS) | |
| pos_prompt = f"a {medium} of a {target_word} {subject}" | |
| neg_prompt = f"a {medium} of a {opposite} {subject}" | |
| pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True, | |
| max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device) | |
| neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True, | |
| max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device) | |
| pos = self.pipe.text_encoder(pos_toks).pooler_output | |
| neg = self.pipe.text_encoder(neg_toks).pooler_output | |
| positives.append(pos) | |
| negatives.append(neg) | |
| pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True, | |
| max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device) | |
| neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True, | |
| max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device) | |
| pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds | |
| neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds | |
| positives2.append(pos2) | |
| negatives2.append(neg2) | |
| positives = torch.cat(positives, dim=0) | |
| negatives = torch.cat(negatives, dim=0) | |
| diffs = positives - negatives | |
| avg_diff = diffs.mean(0, keepdim=True) | |
| positives2 = torch.cat(positives2, dim=0) | |
| negatives2 = torch.cat(negatives2, dim=0) | |
| diffs2 = positives2 - negatives2 | |
| avg_diff2 = diffs2.mean(0, keepdim=True) | |
| return (avg_diff, avg_diff2) | |
| def generate(self, | |
| prompt = "a photo of a house", | |
| scale = 2, | |
| scale_2nd = 2, | |
| seed = 15, | |
| only_pooler = False, | |
| normalize_scales = False, | |
| correlation_weight_factor = 1.0, | |
| **pipeline_kwargs | |
| ): | |
| # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true | |
| # if pooler token only [-4,4] work well | |
| text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2] | |
| tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2] | |
| with torch.no_grad(): | |
| # toks = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77).input_ids.to(self.device) | |
| # prompt_embeds = pipe.text_encoder(toks).last_hidden_state | |
| prompt_embeds_list = [] | |
| for i, text_encoder in enumerate(text_encoders): | |
| tokenizer = tokenizers[i] | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| toks = text_inputs.input_ids | |
| prompt_embeds = text_encoder( | |
| toks.to(text_encoder.device), | |
| output_hidden_states=True, | |
| ) | |
| # We are only ALWAYS interested in the pooled output of the final text encoder | |
| pooled_prompt_embeds = prompt_embeds[0] | |
| prompt_embeds = prompt_embeds.hidden_states[-2] | |
| if self.avg_diff_2nd and normalize_scales: | |
| denominator = abs(scale) + abs(scale_2nd) | |
| scale = scale / denominator | |
| scale_2nd = scale_2nd / denominator | |
| if only_pooler: | |
| prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff[0] * scale | |
| if self.avg_diff_2nd: | |
| prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd[0] * scale_2nd | |
| else: | |
| normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True) | |
| sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T | |
| if i == 0: | |
| weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 768) | |
| standard_weights = torch.ones_like(weights) | |
| weights = standard_weights + (weights - standard_weights) * correlation_weight_factor | |
| prompt_embeds = prompt_embeds + (weights * self.avg_diff[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale) | |
| if self.avg_diff_2nd: | |
| prompt_embeds += (weights * self.avg_diff_2nd[0][None, :].repeat(1, self.pipe.tokenizer.model_max_length, 1) * scale_2nd) | |
| else: | |
| weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280) | |
| standard_weights = torch.ones_like(weights) | |
| weights = standard_weights + (weights - standard_weights) * correlation_weight_factor | |
| prompt_embeds = prompt_embeds + (weights * self.avg_diff[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale) | |
| if self.avg_diff_2nd: | |
| prompt_embeds += (weights * self.avg_diff_2nd[1][None, :].repeat(1, self.pipe.tokenizer_2.model_max_length, 1) * scale_2nd) | |
| bs_embed, seq_len, _ = prompt_embeds.shape | |
| prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) | |
| prompt_embeds_list.append(prompt_embeds) | |
| prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) | |
| pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) | |
| torch.manual_seed(seed) | |
| images = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, | |
| **pipeline_kwargs).images | |
| return images | |
| class CLIPSliderXL_inv(CLIPSlider): | |
| def find_latent_direction(self, | |
| target_word:str, | |
| opposite:str): | |
| # lets identify a latent direction by taking differences between opposites | |
| # target_word = "happy" | |
| # opposite = "sad" | |
| with torch.no_grad(): | |
| positives = [] | |
| negatives = [] | |
| positives2 = [] | |
| negatives2 = [] | |
| for i in tqdm(range(self.iterations)): | |
| medium = random.choice(MEDIUMS) | |
| subject = random.choice(SUBJECTS) | |
| pos_prompt = f"a {medium} of a {target_word} {subject}" | |
| neg_prompt = f"a {medium} of a {opposite} {subject}" | |
| pos_toks = self.pipe.tokenizer(pos_prompt, return_tensors="pt", padding="max_length", truncation=True, | |
| max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device) | |
| neg_toks = self.pipe.tokenizer(neg_prompt, return_tensors="pt", padding="max_length", truncation=True, | |
| max_length=self.pipe.tokenizer.model_max_length).input_ids.to(self.device) | |
| pos = self.pipe.text_encoder(pos_toks).pooler_output | |
| neg = self.pipe.text_encoder(neg_toks).pooler_output | |
| positives.append(pos) | |
| negatives.append(neg) | |
| pos_toks2 = self.pipe.tokenizer_2(pos_prompt, return_tensors="pt", padding="max_length", truncation=True, | |
| max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device) | |
| neg_toks2 = self.pipe.tokenizer_2(neg_prompt, return_tensors="pt", padding="max_length", truncation=True, | |
| max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device) | |
| pos2 = self.pipe.text_encoder_2(pos_toks2).text_embeds | |
| neg2 = self.pipe.text_encoder_2(neg_toks2).text_embeds | |
| positives2.append(pos2) | |
| negatives2.append(neg2) | |
| positives = torch.cat(positives, dim=0) | |
| negatives = torch.cat(negatives, dim=0) | |
| diffs = positives - negatives | |
| avg_diff = diffs.mean(0, keepdim=True) | |
| positives2 = torch.cat(positives2, dim=0) | |
| negatives2 = torch.cat(negatives2, dim=0) | |
| diffs2 = positives2 - negatives2 | |
| avg_diff2 = diffs2.mean(0, keepdim=True) | |
| return (avg_diff, avg_diff2) | |
| def generate(self, | |
| prompt = "a photo of a house", | |
| scale = 2, | |
| scale_2nd = 2, | |
| seed = 15, | |
| only_pooler = False, | |
| normalize_scales = False, | |
| correlation_weight_factor = 1.0, | |
| **pipeline_kwargs | |
| ): | |
| with torch.no_grad(): | |
| torch.manual_seed(seed) | |
| images = self.pipe(editing_prompt=prompt, | |
| avg_diff=self.avg_diff, avg_diff_2nd=self.avg_diff_2nd, | |
| scale=scale, scale_2nd=scale_2nd, | |
| **pipeline_kwargs).images | |
| return images | |
| class CLIPSliderFlux(CLIPSlider): | |
| def find_latent_direction(self, | |
| target_word:str, | |
| opposite:str, | |
| num_iterations: int = None): | |
| # lets identify a latent direction by taking differences between opposites | |
| # target_word = "happy" | |
| # opposite = "sad" | |
| if num_iterations is not None: | |
| iterations = num_iterations | |
| else: | |
| iterations = self.iterations | |
| with torch.no_grad(): | |
| positives = [] | |
| negatives = [] | |
| for i in tqdm(range(iterations)): | |
| medium = random.choice(MEDIUMS) | |
| subject = random.choice(SUBJECTS) | |
| pos_prompt = f"a {medium} of a {target_word} {subject}" | |
| neg_prompt = f"a {medium} of a {opposite} {subject}" | |
| pos_toks = self.pipe.tokenizer(pos_prompt, | |
| padding="max_length", | |
| max_length=self.pipe.tokenizer_max_length, | |
| truncation=True, | |
| return_overflowing_tokens=False, | |
| return_length=False, | |
| return_tensors="pt",).input_ids.to(self.device) | |
| neg_toks = self.pipe.tokenizer(neg_prompt, | |
| padding="max_length", | |
| max_length=self.pipe.tokenizer_max_length, | |
| truncation=True, | |
| return_overflowing_tokens=False, | |
| return_length=False, | |
| return_tensors="pt",).input_ids.to(self.device) | |
| pos = self.pipe.text_encoder(pos_toks).pooler_output | |
| neg = self.pipe.text_encoder(neg_toks).pooler_output | |
| positives.append(pos) | |
| negatives.append(neg) | |
| positives = torch.cat(positives, dim=0) | |
| negatives = torch.cat(negatives, dim=0) | |
| diffs = positives - negatives | |
| avg_diff = diffs.mean(0, keepdim=True) | |
| return avg_diff | |
| def generate(self, | |
| prompt = "a photo of a house", | |
| scale = 2.0, | |
| seed = 15, | |
| normalize_scales = False, | |
| avg_diff = None, | |
| avg_diff_2nd = None, | |
| use_slerp: bool = False, | |
| max_strength_for_slerp_endpoint: float = 0.0, | |
| **pipeline_kwargs | |
| ): | |
| # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true | |
| # if pooler token only [-4,4] work well | |
| # Remove slider-specific kwargs before passing to the pipeline | |
| pipeline_kwargs.pop('use_slerp', None) | |
| pipeline_kwargs.pop('max_strength_for_slerp_endpoint', None) | |
| with torch.no_grad(): | |
| text_inputs = self.pipe.tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=77, | |
| truncation=True, | |
| return_overflowing_tokens=False, | |
| return_length=False, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| prompt_embeds_out = self.pipe.text_encoder(text_input_ids.to(self.device), output_hidden_states=False) | |
| original_pooled_prompt_embeds = prompt_embeds_out.pooler_output.to(dtype=self.pipe.text_encoder.dtype, device=self.device) | |
| # For the second text encoder (T5-like for FLUX) | |
| text_inputs_2 = self.pipe.tokenizer_2( | |
| prompt, | |
| padding="max_length", | |
| max_length=512, | |
| truncation=True, | |
| return_length=False, | |
| return_overflowing_tokens=False, | |
| return_tensors="pt", | |
| ) | |
| toks_2 = text_inputs_2.input_ids | |
| # This is the non-pooled, sequence output for the second encoder | |
| prompt_embeds_seq_2 = self.pipe.text_encoder_2(toks_2.to(self.device), output_hidden_states=False)[0] | |
| prompt_embeds_seq_2 = prompt_embeds_seq_2.to(dtype=self.pipe.text_encoder_2.dtype, device=self.device) | |
| modified_pooled_embeds = original_pooled_prompt_embeds.clone() | |
| if avg_diff is not None: | |
| if use_slerp and max_strength_for_slerp_endpoint != 0.0: | |
| # Slerp logic | |
| slerp_t_val = 0.0 | |
| if max_strength_for_slerp_endpoint != 0: | |
| slerp_t_val = abs(scale) / max_strength_for_slerp_endpoint | |
| slerp_t_val = min(slerp_t_val, 1.0) | |
| if scale == 0: | |
| pass | |
| else: | |
| v0 = original_pooled_prompt_embeds.float() | |
| if scale > 0: | |
| v_end_target = original_pooled_prompt_embeds + max_strength_for_slerp_endpoint * avg_diff | |
| else: | |
| v_end_target = original_pooled_prompt_embeds - max_strength_for_slerp_endpoint * avg_diff | |
| modified_pooled_embeds = slerp(v0, v_end_target.float(), slerp_t_val).to(original_pooled_prompt_embeds.dtype) | |
| else: | |
| modified_pooled_embeds = modified_pooled_embeds + avg_diff * scale | |
| if avg_diff_2nd is not None: | |
| scale_2nd_val = pipeline_kwargs.get("scale_2nd", 0.0) | |
| modified_pooled_embeds += avg_diff_2nd * scale_2nd_val | |
| torch.manual_seed(seed) | |
| images = self.pipe(prompt_embeds=prompt_embeds_seq_2, | |
| pooled_prompt_embeds=modified_pooled_embeds, | |
| **pipeline_kwargs).images | |
| return images[0] | |
| def spectrum(self, | |
| prompt="a photo of a house", | |
| low_scale=-2, | |
| low_scale_2nd=-2, | |
| high_scale=2, | |
| high_scale_2nd=2, | |
| steps=5, | |
| seed=15, | |
| normalize_scales=False, | |
| **pipeline_kwargs | |
| ): | |
| images = [] | |
| for i in range(steps): | |
| scale = low_scale + (high_scale - low_scale) * i / (steps - 1) | |
| scale_2nd = low_scale_2nd + (high_scale_2nd - low_scale_2nd) * i / (steps - 1) | |
| image = self.generate(prompt, scale, scale_2nd, seed, normalize_scales, **pipeline_kwargs) | |
| images.append(image[0].resize((512,512))) | |
| canvas = Image.new('RGB', (640 * steps, 640)) | |
| for i, im in enumerate(images): | |
| canvas.paste(im, (640 * i, 0)) | |
| return canvas | |
| class T5SliderFlux(CLIPSlider): | |
| def find_latent_direction(self, | |
| target_word:str, | |
| opposite:str): | |
| # lets identify a latent direction by taking differences between opposites | |
| # target_word = "happy" | |
| # opposite = "sad" | |
| with torch.no_grad(): | |
| positives = [] | |
| negatives = [] | |
| for i in tqdm(range(self.iterations)): | |
| medium = random.choice(MEDIUMS) | |
| subject = random.choice(SUBJECTS) | |
| pos_prompt = f"a {medium} of a {target_word} {subject}" | |
| neg_prompt = f"a {medium} of a {opposite} {subject}" | |
| pos_toks = self.pipe.tokenizer_2(pos_prompt, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| return_length=False, | |
| return_overflowing_tokens=False, | |
| max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device) | |
| neg_toks = self.pipe.tokenizer_2(neg_prompt, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| return_length=False, | |
| return_overflowing_tokens=False, | |
| max_length=self.pipe.tokenizer_2.model_max_length).input_ids.to(self.device) | |
| pos = self.pipe.text_encoder_2(pos_toks, output_hidden_states=False)[0] | |
| neg = self.pipe.text_encoder_2(neg_toks, output_hidden_states=False)[0] | |
| positives.append(pos) | |
| negatives.append(neg) | |
| positives = torch.cat(positives, dim=0) | |
| negatives = torch.cat(negatives, dim=0) | |
| diffs = positives - negatives | |
| avg_diff = diffs.mean(0, keepdim=True) | |
| return avg_diff | |
| def generate(self, | |
| prompt = "a photo of a house", | |
| scale = 2, | |
| scale_2nd = 2, | |
| seed = 15, | |
| only_pooler = False, | |
| normalize_scales = False, | |
| correlation_weight_factor = 1.0, | |
| **pipeline_kwargs | |
| ): | |
| # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true | |
| # if pooler token only [-4,4] work well | |
| with torch.no_grad(): | |
| text_inputs = self.pipe.tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=77, | |
| truncation=True, | |
| return_overflowing_tokens=False, | |
| return_length=False, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| prompt_embeds = self.pipe.text_encoder(text_input_ids.to(self.device), output_hidden_states=False) | |
| # Use pooled output of CLIPTextModel | |
| prompt_embeds = prompt_embeds.pooler_output | |
| pooled_prompt_embeds = prompt_embeds.to(dtype=self.pipe.text_encoder.dtype, device=self.device) | |
| # Use pooled output of CLIPTextModel | |
| text_inputs = self.pipe.tokenizer_2( | |
| prompt, | |
| padding="max_length", | |
| max_length=512, | |
| truncation=True, | |
| return_length=False, | |
| return_overflowing_tokens=False, | |
| return_tensors="pt", | |
| ) | |
| toks = text_inputs.input_ids | |
| prompt_embeds = self.pipe.text_encoder_2(toks.to(self.device), output_hidden_states=False)[0] | |
| dtype = self.pipe.text_encoder_2.dtype | |
| prompt_embeds = prompt_embeds.to(dtype=dtype, device=self.device) | |
| if self.avg_diff_2nd and normalize_scales: | |
| denominator = abs(scale) + abs(scale_2nd) | |
| scale = scale / denominator | |
| scale_2nd = scale_2nd / denominator | |
| if only_pooler: | |
| prompt_embeds[:, toks.argmax()] = prompt_embeds[:, toks.argmax()] + self.avg_diff * scale | |
| if self.avg_diff_2nd: | |
| prompt_embeds[:, toks.argmax()] += self.avg_diff_2nd * scale_2nd | |
| else: | |
| normed_prompt_embeds = prompt_embeds / prompt_embeds.norm(dim=-1, keepdim=True) | |
| sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T | |
| weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, prompt_embeds.shape[2]) | |
| standard_weights = torch.ones_like(weights) | |
| weights = standard_weights + (weights - standard_weights) * correlation_weight_factor | |
| prompt_embeds = prompt_embeds + ( | |
| weights * self.avg_diff * scale) | |
| if self.avg_diff_2nd: | |
| prompt_embeds += ( | |
| weights * self.avg_diff_2nd * scale_2nd) | |
| torch.manual_seed(seed) | |
| images = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, | |
| **pipeline_kwargs).images | |
| return images |