Spaces:
Runtime error
Runtime error
| # Based on https://raw.githubusercontent.com/okotaku/diffusers/feature/reference_only_control/examples/community/stable_diffusion_reference.py | |
| # Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 | |
| import torch.fft as fft | |
| from typing import Any, Callable, Dict, List, Optional, Union, Tuple | |
| import numpy as np | |
| import PIL.Image | |
| import torch | |
| from diffusers.models.attention import BasicTransformerBlock | |
| from diffusers.models.unet_2d_blocks import ( | |
| CrossAttnDownBlock2D, | |
| CrossAttnUpBlock2D, | |
| DownBlock2D, | |
| UpBlock2D, | |
| ) | |
| from diffusers.utils import PIL_INTERPOLATION, logging | |
| import torch.nn.functional as F | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| EXAMPLE_DOC_STRING = """ | |
| Examples: | |
| ```py | |
| >>> import torch | |
| >>> from diffusers import UniPCMultistepScheduler | |
| >>> from diffusers.utils import load_image | |
| >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png") | |
| >>> pipe = StableDiffusionReferencePipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| safety_checker=None, | |
| torch_dtype=torch.float16 | |
| ).to('cuda:0') | |
| >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config) | |
| >>> result_img = pipe(ref_image=input_image, | |
| prompt="1girl", | |
| num_inference_steps=20, | |
| reference_attn=True, | |
| reference_adain=True).images[0] | |
| >>> result_img.show() | |
| ``` | |
| """ | |
| def torch_dfs(model: torch.nn.Module): | |
| result = [model] | |
| for child in model.children(): | |
| result += torch_dfs(child) | |
| return result | |
| def add_freq_feature(feature1, feature2, ref_ratio): | |
| """ | |
| feature1: reference feature | |
| feature2: target feature | |
| ref_ratio: larger ratio means larger reference frequency | |
| """ | |
| # Convert features to float32 (if not already) for compatibility with fft operations | |
| data_type = feature2.dtype | |
| feature1 = feature1.to(torch.float32) | |
| feature2 = feature2.to(torch.float32) | |
| # Compute the Fourier transforms of both features | |
| spectrum1 = fft.fftn(feature1, dim=(-2, -1)) | |
| spectrum2 = fft.fftn(feature2, dim=(-2, -1)) | |
| # Extract high-frequency magnitude and phase from feature1 | |
| magnitude1 = torch.abs(spectrum1) | |
| # phase1 = torch.angle(spectrum1) | |
| # Extract magnitude and phase from feature2 | |
| magnitude2 = torch.abs(spectrum2) | |
| phase2 = torch.angle(spectrum2) | |
| magnitude2.mul_((1-ref_ratio)).add_(magnitude1 * ref_ratio) | |
| # phase2.mul_(1.0).add_(phase1 * 0.0) | |
| # Combine magnitude and phase information | |
| mixed_spectrum = torch.polar(magnitude2, phase2) | |
| # Compute the inverse Fourier transform to get the mixed feature | |
| mixed_feature = fft.ifftn(mixed_spectrum, dim=(-2, -1)) | |
| del feature1, feature2, spectrum1, spectrum2, magnitude1, magnitude2, phase2, mixed_spectrum | |
| # Convert back to the original data type and return the result | |
| return mixed_feature.to(data_type) | |
| def save_ref_feature(feature, mask): | |
| """ | |
| feature: n,c,h,w | |
| mask: n,1,h,w | |
| return n,c,h,w | |
| """ | |
| return feature * mask | |
| def mix_ref_feature(feature, ref_fea_bank, cfg=True, ref_scale=0.0, dim3=False): | |
| """ | |
| feature: n,l,c or n,c,h,w | |
| ref_fea_bank: [(n,c,h,w)] | |
| cfg: True/False | |
| return n,l,c or n,c,h,w | |
| """ | |
| if cfg: | |
| ref_fea = torch.cat( | |
| (ref_fea_bank+ref_fea_bank), dim=0) | |
| else: | |
| ref_fea = ref_fea_bank | |
| if dim3: | |
| feature = feature.permute(0, 2, 1).view(ref_fea.shape) | |
| mixed_feature = add_freq_feature(ref_fea, feature, ref_scale) | |
| if dim3: | |
| mixed_feature = mixed_feature.view( | |
| ref_fea.shape[0], ref_fea.shape[1], -1).permute(0, 2, 1) | |
| del ref_fea | |
| del feature | |
| return mixed_feature | |
| def mix_norm_feature(x, inpaint_mask, mean_bank, var_bank, do_classifier_free_guidance, style_fidelity, uc_mask, eps=1e-6): | |
| """ | |
| x: input feature n,c,h,w | |
| inpaint_mask: mask region to inpain | |
| """ | |
| # get the inpainting region and only mix this region. | |
| scale_ratio = inpaint_mask.shape[2] / x.shape[2] | |
| this_inpaint_mask = F.interpolate( | |
| inpaint_mask.to(x.device), scale_factor=1 / scale_ratio | |
| ) | |
| this_inpaint_mask = this_inpaint_mask.repeat( | |
| x.shape[0], x.shape[1], 1, 1 | |
| ).bool() | |
| masked_x = ( | |
| x[this_inpaint_mask] | |
| .detach() | |
| .clone() | |
| .view(x.shape[0], x.shape[1], -1, 1) | |
| ) | |
| var, mean = torch.var_mean( | |
| masked_x, dim=(2, 3), keepdim=True, correction=0 | |
| ) | |
| std = torch.maximum( | |
| var, torch.zeros_like(var) + eps) ** 0.5 | |
| mean_acc = sum(mean_bank) / float(len(mean_bank)) | |
| var_acc = sum(var_bank) / float(len(var_bank)) | |
| std_acc = ( | |
| torch.maximum(var_acc, torch.zeros_like( | |
| var_acc) + eps) ** 0.5 | |
| ) | |
| x_uc = (((masked_x - mean) / std) * std_acc) + mean_acc | |
| x_c = x_uc.clone() | |
| if do_classifier_free_guidance and style_fidelity > 0: | |
| x_c[uc_mask] = masked_x[uc_mask] | |
| masked_x = style_fidelity * x_c + \ | |
| (1.0 - style_fidelity) * x_uc | |
| x[this_inpaint_mask] = masked_x.view(-1) | |
| return x | |
| class StableDiffusionReferencePipeline: | |
| def prepare_ref_image( | |
| self, | |
| image, | |
| width, | |
| height, | |
| batch_size, | |
| num_images_per_prompt, | |
| device, | |
| dtype, | |
| do_classifier_free_guidance=False, | |
| guess_mode=False, | |
| ): | |
| if not isinstance(image, torch.Tensor): | |
| if isinstance(image, PIL.Image.Image): | |
| image = [image] | |
| if isinstance(image[0], PIL.Image.Image): | |
| images = [] | |
| for image_ in image: | |
| image_ = image_.convert("RGB") | |
| image_ = image_.resize( | |
| (width, height), resample=PIL_INTERPOLATION["lanczos"] | |
| ) | |
| image_ = np.array(image_) | |
| image_ = image_[None, :] | |
| images.append(image_) | |
| image = images | |
| image = np.concatenate(image, axis=0) | |
| image = np.array(image).astype(np.float32) / 255.0 | |
| image = (image - 0.5) / 0.5 | |
| image = image.transpose(0, 3, 1, 2) | |
| image = torch.from_numpy(image) | |
| elif isinstance(image[0], torch.Tensor): | |
| image = torch.cat(image, dim=0) | |
| image_batch_size = image.shape[0] | |
| if image_batch_size == 1: | |
| repeat_by = batch_size | |
| else: | |
| # image batch size is the same as prompt batch size | |
| repeat_by = num_images_per_prompt | |
| image = image.repeat_interleave(repeat_by, dim=0) | |
| image = image.to(device=device, dtype=dtype) | |
| if do_classifier_free_guidance and not guess_mode: | |
| image = torch.cat([image] * 2) | |
| return image | |
| def prepare_ref_latents( | |
| self, | |
| refimage, | |
| batch_size, | |
| dtype, | |
| device, | |
| generator, | |
| do_classifier_free_guidance, | |
| ): | |
| refimage = refimage.to(device=device, dtype=dtype) | |
| # encode the mask image into latents space so we can concatenate it to the latents | |
| if isinstance(generator, list): | |
| ref_image_latents = [ | |
| self.vae.encode(refimage[i: i + 1]).latent_dist.sample( | |
| generator=generator[i] | |
| ) | |
| for i in range(batch_size) | |
| ] | |
| ref_image_latents = torch.cat(ref_image_latents, dim=0) | |
| else: | |
| ref_image_latents = self.vae.encode(refimage).latent_dist.sample( | |
| generator=generator | |
| ) | |
| ref_image_latents = self.vae.config.scaling_factor * ref_image_latents | |
| # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method | |
| if ref_image_latents.shape[0] < batch_size: | |
| if not batch_size % ref_image_latents.shape[0] == 0: | |
| raise ValueError( | |
| "The passed images and the required batch size don't match. Images are supposed to be duplicated" | |
| f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed." | |
| " Make sure the number of images that you pass is divisible by the total requested batch size." | |
| ) | |
| ref_image_latents = ref_image_latents.repeat( | |
| batch_size // ref_image_latents.shape[0], 1, 1, 1 | |
| ) | |
| ref_image_latents = ( | |
| torch.cat([ref_image_latents] * 2) | |
| if do_classifier_free_guidance | |
| else ref_image_latents | |
| ) | |
| # aligning device to prevent device errors when concating it with the latent model input | |
| ref_image_latents = ref_image_latents.to(device=device, dtype=dtype) | |
| return ref_image_latents | |
| def check_ref_input(self, reference_attn, reference_adain): | |
| assert ( | |
| reference_attn or reference_adain | |
| ), "`reference_attn` or `reference_adain` must be True." | |
| def redefine_ref_model( | |
| self, model, reference_attn, reference_adain, model_type="unet" | |
| ): | |
| def hacked_basic_transformer_inner_forward( | |
| self, | |
| hidden_states: torch.FloatTensor, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
| timestep: Optional[torch.LongTensor] = None, | |
| cross_attention_kwargs: Dict[str, Any] = None, | |
| class_labels: Optional[torch.LongTensor] = None, | |
| ): | |
| if self.use_ada_layer_norm: | |
| norm_hidden_states = self.norm1(hidden_states, timestep) | |
| elif self.use_ada_layer_norm_zero: | |
| ( | |
| norm_hidden_states, | |
| gate_msa, | |
| shift_mlp, | |
| scale_mlp, | |
| gate_mlp, | |
| ) = self.norm1( | |
| hidden_states, | |
| timestep, | |
| class_labels, | |
| hidden_dtype=hidden_states.dtype, | |
| ) | |
| else: | |
| norm_hidden_states = self.norm1(hidden_states) | |
| # 1. Self-Attention | |
| cross_attention_kwargs = ( | |
| cross_attention_kwargs if cross_attention_kwargs is not None else {} | |
| ) | |
| if self.only_cross_attention: | |
| attn_output = self.attn1( | |
| norm_hidden_states, | |
| encoder_hidden_states=encoder_hidden_states | |
| if self.only_cross_attention | |
| else None, | |
| attention_mask=attention_mask, | |
| **cross_attention_kwargs, | |
| ) | |
| else: | |
| if self.MODE == "write": | |
| if self.attention_auto_machine_weight > self.attn_weight: | |
| # print("hacked_basic_transformer_inner_forward") | |
| scale_ratio = ( | |
| (self.ref_mask.shape[2] * self.ref_mask.shape[3]) | |
| / norm_hidden_states.shape[1] | |
| ) ** 0.5 | |
| this_ref_mask = F.interpolate( | |
| self.ref_mask.to(norm_hidden_states.device), | |
| scale_factor=1 / scale_ratio, | |
| ) | |
| resize_norm_hidden_states = norm_hidden_states.view( | |
| norm_hidden_states.shape[0], | |
| this_ref_mask.shape[2], | |
| this_ref_mask.shape[3], | |
| -1, | |
| ).permute(0, 3, 1, 2) | |
| ref_scale = 1.0 | |
| resize_norm_hidden_states = F.interpolate( | |
| resize_norm_hidden_states, | |
| scale_factor=ref_scale, | |
| mode="bilinear", | |
| ) | |
| this_ref_mask = F.interpolate( | |
| this_ref_mask, scale_factor=ref_scale | |
| ) | |
| self.fea_bank.append(save_ref_feature( | |
| resize_norm_hidden_states, this_ref_mask)) | |
| this_ref_mask = this_ref_mask.repeat( | |
| resize_norm_hidden_states.shape[0], | |
| resize_norm_hidden_states.shape[1], | |
| 1, | |
| 1, | |
| ).bool() | |
| masked_norm_hidden_states = ( | |
| resize_norm_hidden_states[this_ref_mask] | |
| .detach() | |
| .clone() | |
| .view( | |
| resize_norm_hidden_states.shape[0], | |
| resize_norm_hidden_states.shape[1], | |
| -1, | |
| ) | |
| ) | |
| masked_norm_hidden_states = masked_norm_hidden_states.permute( | |
| 0, 2, 1 | |
| ) | |
| self.bank.append(masked_norm_hidden_states) | |
| del masked_norm_hidden_states | |
| del this_ref_mask | |
| del resize_norm_hidden_states | |
| attn_output = self.attn1( | |
| norm_hidden_states, | |
| encoder_hidden_states=encoder_hidden_states | |
| if self.only_cross_attention | |
| else None, | |
| attention_mask=attention_mask, | |
| **cross_attention_kwargs, | |
| ) | |
| if self.MODE == "read": | |
| if self.attention_auto_machine_weight > self.attn_weight: | |
| freq_norm_hidden_states = mix_ref_feature( | |
| norm_hidden_states, | |
| self.fea_bank, | |
| cfg=self.do_classifier_free_guidance, | |
| ref_scale=self.ref_scale, | |
| dim3=True) | |
| self.fea_bank.clear() | |
| this_bank = torch.cat(self.bank+self.bank, dim=0) | |
| ref_hidden_states = torch.cat( | |
| (freq_norm_hidden_states, this_bank), dim=1 | |
| ) | |
| del this_bank | |
| self.bank.clear() | |
| attn_output_uc = self.attn1( | |
| freq_norm_hidden_states, | |
| encoder_hidden_states=ref_hidden_states, | |
| **cross_attention_kwargs, | |
| ) | |
| del ref_hidden_states | |
| attn_output_c = attn_output_uc.clone() | |
| if self.do_classifier_free_guidance and self.style_fidelity > 0: | |
| attn_output_c[self.uc_mask] = self.attn1( | |
| norm_hidden_states[self.uc_mask], | |
| encoder_hidden_states=norm_hidden_states[self.uc_mask], | |
| **cross_attention_kwargs, | |
| ) | |
| attn_output = ( | |
| self.style_fidelity * attn_output_c | |
| + (1.0 - self.style_fidelity) * attn_output_uc | |
| ) | |
| self.bank.clear() | |
| self.fea_bank.clear() | |
| del attn_output_c | |
| del attn_output_uc | |
| else: | |
| attn_output = self.attn1( | |
| norm_hidden_states, | |
| encoder_hidden_states=encoder_hidden_states | |
| if self.only_cross_attention | |
| else None, | |
| attention_mask=attention_mask, | |
| **cross_attention_kwargs, | |
| ) | |
| self.bank.clear() | |
| self.fea_bank.clear() | |
| if self.use_ada_layer_norm_zero: | |
| attn_output = gate_msa.unsqueeze(1) * attn_output | |
| hidden_states = attn_output + hidden_states | |
| if self.attn2 is not None: | |
| norm_hidden_states = ( | |
| self.norm2(hidden_states, timestep) | |
| if self.use_ada_layer_norm | |
| else self.norm2(hidden_states) | |
| ) | |
| # 2. Cross-Attention | |
| attn_output = self.attn2( | |
| norm_hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| attention_mask=encoder_attention_mask, | |
| **cross_attention_kwargs, | |
| ) | |
| hidden_states = attn_output + hidden_states | |
| # 3. Feed-forward | |
| norm_hidden_states = self.norm3(hidden_states) | |
| if self.use_ada_layer_norm_zero: | |
| norm_hidden_states = ( | |
| norm_hidden_states * | |
| (1 + scale_mlp[:, None]) + shift_mlp[:, None] | |
| ) | |
| ff_output = self.ff(norm_hidden_states) | |
| if self.use_ada_layer_norm_zero: | |
| ff_output = gate_mlp.unsqueeze(1) * ff_output | |
| hidden_states = ff_output + hidden_states | |
| return hidden_states | |
| def hacked_mid_forward(self, *args, **kwargs): | |
| eps = 1e-6 | |
| x = self.original_forward(*args, **kwargs) | |
| if self.MODE == "write": | |
| if self.gn_auto_machine_weight >= self.gn_weight: | |
| # mask var mean | |
| scale_ratio = self.ref_mask.shape[2] / x.shape[2] | |
| this_ref_mask = F.interpolate( | |
| self.ref_mask.to(x.device), scale_factor=1 / scale_ratio | |
| ) | |
| self.fea_bank.append(save_ref_feature( | |
| x, this_ref_mask)) | |
| this_ref_mask = this_ref_mask.repeat( | |
| x.shape[0], x.shape[1], 1, 1 | |
| ).bool() | |
| masked_x = ( | |
| x[this_ref_mask] | |
| .detach() | |
| .clone() | |
| .view(x.shape[0], x.shape[1], -1, 1) | |
| ) | |
| var, mean = torch.var_mean( | |
| masked_x, dim=(2, 3), keepdim=True, correction=0 | |
| ) | |
| self.mean_bank.append(torch.cat([mean]*2, dim=0)) | |
| self.var_bank.append(torch.cat([var]*2, dim=0)) | |
| if self.MODE == "read": | |
| if ( | |
| self.gn_auto_machine_weight >= self.gn_weight | |
| and len(self.mean_bank) > 0 | |
| and len(self.var_bank) > 0 | |
| ): | |
| # print("hacked_mid_forward") | |
| x = mix_ref_feature( | |
| x, self.fea_bank, cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale) | |
| self.fea_bank = [] | |
| x = mix_norm_feature(x, self.inpaint_mask, self.mean_bank, self.var_bank, | |
| self.do_classifier_free_guidance, | |
| self.style_fidelity, self.uc_mask) | |
| self.mean_bank = [] | |
| self.var_bank = [] | |
| return x | |
| def hack_CrossAttnDownBlock2D_forward( | |
| self, | |
| hidden_states: torch.FloatTensor, | |
| temb: Optional[torch.FloatTensor] = None, | |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
| ): | |
| eps = 1e-6 | |
| # TODO(Patrick, William) - attention mask is not used | |
| output_states = () | |
| for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): | |
| hidden_states = resnet(hidden_states, temb) | |
| if self.MODE == "write": | |
| if self.gn_auto_machine_weight >= self.gn_weight: | |
| # var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) | |
| # mask var mean | |
| scale_ratio = self.ref_mask.shape[2] / \ | |
| hidden_states.shape[2] | |
| this_ref_mask = F.interpolate( | |
| self.ref_mask.to(hidden_states.device), | |
| scale_factor=1 / scale_ratio, | |
| ) | |
| self.fea_bank0.append(save_ref_feature( | |
| hidden_states, this_ref_mask)) | |
| this_ref_mask = this_ref_mask.repeat( | |
| hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
| ).bool() | |
| masked_hidden_states = ( | |
| hidden_states[this_ref_mask] | |
| .detach() | |
| .clone() | |
| .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
| ) | |
| var, mean = torch.var_mean( | |
| masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
| ) | |
| self.mean_bank0.append(torch.cat([mean]*2, dim=0)) | |
| self.var_bank0.append(torch.cat([var]*2, dim=0)) | |
| if self.MODE == "read": | |
| if ( | |
| self.gn_auto_machine_weight >= self.gn_weight | |
| and len(self.mean_bank0) > 0 | |
| and len(self.var_bank0) > 0 | |
| ): | |
| # print("hacked_CrossAttnDownBlock2D_forward0") | |
| hidden_states = mix_ref_feature( | |
| hidden_states, [self.fea_bank0[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale) | |
| hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank0[i], self.var_bank0[i], | |
| self.do_classifier_free_guidance, | |
| self.style_fidelity, self.uc_mask) | |
| hidden_states = attn( | |
| hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| return_dict=False, | |
| )[0] | |
| if self.MODE == "write": | |
| if self.gn_auto_machine_weight >= self.gn_weight: | |
| # mask var mean | |
| scale_ratio = self.ref_mask.shape[2] / \ | |
| hidden_states.shape[2] | |
| this_ref_mask = F.interpolate( | |
| self.ref_mask.to(hidden_states.device), | |
| scale_factor=1 / scale_ratio, | |
| ) | |
| self.fea_bank.append(save_ref_feature( | |
| hidden_states, this_ref_mask)) | |
| this_ref_mask = this_ref_mask.repeat( | |
| hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
| ).bool() | |
| masked_hidden_states = ( | |
| hidden_states[this_ref_mask] | |
| .detach() | |
| .clone() | |
| .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
| ) | |
| var, mean = torch.var_mean( | |
| masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
| ) | |
| self.mean_bank.append(torch.cat([mean]*2, dim=0)) | |
| self.var_bank.append(torch.cat([var]*2, dim=0)) | |
| if self.MODE == "read": | |
| if ( | |
| self.gn_auto_machine_weight >= self.gn_weight | |
| and len(self.mean_bank) > 0 | |
| and len(self.var_bank) > 0 | |
| ): | |
| # print("hack_CrossAttnDownBlock2D_forward") | |
| hidden_states = mix_ref_feature( | |
| hidden_states, [self.fea_bank[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale) | |
| hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank[i], self.var_bank[i], | |
| self.do_classifier_free_guidance, | |
| self.style_fidelity, self.uc_mask) | |
| output_states = output_states + (hidden_states,) | |
| if self.MODE == "read": | |
| self.mean_bank0 = [] | |
| self.var_bank0 = [] | |
| self.mean_bank = [] | |
| self.var_bank = [] | |
| self.fea_bank0 = [] | |
| self.fea_bank = [] | |
| if self.downsamplers is not None: | |
| for downsampler in self.downsamplers: | |
| hidden_states = downsampler(hidden_states) | |
| output_states = output_states + (hidden_states,) | |
| return hidden_states, output_states | |
| def hacked_DownBlock2D_forward(self, hidden_states, temb=None): | |
| eps = 1e-6 | |
| output_states = () | |
| for i, resnet in enumerate(self.resnets): | |
| hidden_states = resnet(hidden_states, temb) | |
| if self.MODE == "write": | |
| if self.gn_auto_machine_weight >= self.gn_weight: | |
| # var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) | |
| # mask var mean | |
| scale_ratio = self.ref_mask.shape[2] / \ | |
| hidden_states.shape[2] | |
| this_ref_mask = F.interpolate( | |
| self.ref_mask.to(hidden_states.device), | |
| scale_factor=1 / scale_ratio, | |
| ) | |
| self.fea_bank.append(save_ref_feature( | |
| hidden_states, this_ref_mask)) | |
| this_ref_mask = this_ref_mask.repeat( | |
| hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
| ).bool() | |
| masked_hidden_states = ( | |
| hidden_states[this_ref_mask] | |
| .detach() | |
| .clone() | |
| .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
| ) | |
| var, mean = torch.var_mean( | |
| masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
| ) | |
| self.mean_bank.append(torch.cat([mean]*2, dim=0)) | |
| self.var_bank.append(torch.cat([var]*2, dim=0)) | |
| if self.MODE == "read": | |
| if ( | |
| self.gn_auto_machine_weight >= self.gn_weight | |
| and len(self.mean_bank) > 0 | |
| and len(self.var_bank) > 0 | |
| ): | |
| # print("hacked_DownBlock2D_forward") | |
| hidden_states = mix_ref_feature( | |
| hidden_states, [self.fea_bank[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale) | |
| hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank[i], self.var_bank[i], | |
| self.do_classifier_free_guidance, | |
| self.style_fidelity, self.uc_mask) | |
| output_states = output_states + (hidden_states,) | |
| if self.MODE == "read": | |
| self.mean_bank = [] | |
| self.var_bank = [] | |
| self.fea_bank = [] | |
| if self.downsamplers is not None: | |
| for downsampler in self.downsamplers: | |
| hidden_states = downsampler(hidden_states) | |
| output_states = output_states + (hidden_states,) | |
| return hidden_states, output_states | |
| def hacked_CrossAttnUpBlock2D_forward( | |
| self, | |
| hidden_states: torch.FloatTensor, | |
| res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], | |
| temb: Optional[torch.FloatTensor] = None, | |
| encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
| upsample_size: Optional[int] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
| ): | |
| eps = 1e-6 | |
| # TODO(Patrick, William) - attention mask is not used | |
| for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): | |
| # pop res hidden states | |
| res_hidden_states = res_hidden_states_tuple[-1] | |
| res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
| hidden_states = torch.cat( | |
| [hidden_states, res_hidden_states], dim=1) | |
| hidden_states = resnet(hidden_states, temb) | |
| if self.MODE == "write": | |
| if self.gn_auto_machine_weight >= self.gn_weight: | |
| # var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) | |
| # mask var mean | |
| scale_ratio = self.ref_mask.shape[2] / \ | |
| hidden_states.shape[2] | |
| this_ref_mask = F.interpolate( | |
| self.ref_mask.to(hidden_states.device), | |
| scale_factor=1 / scale_ratio, | |
| ) | |
| self.fea_bank0.append(save_ref_feature( | |
| hidden_states, this_ref_mask)) | |
| this_ref_mask = this_ref_mask.repeat( | |
| hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
| ).bool() | |
| masked_hidden_states = ( | |
| hidden_states[this_ref_mask] | |
| .detach() | |
| .clone() | |
| .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
| ) | |
| var, mean = torch.var_mean( | |
| masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
| ) | |
| self.mean_bank0.append(torch.cat([mean]*2, dim=0)) | |
| self.var_bank0.append(torch.cat([var]*2, dim=0)) | |
| if self.MODE == "read": | |
| if ( | |
| self.gn_auto_machine_weight >= self.gn_weight | |
| and len(self.mean_bank0) > 0 | |
| and len(self.var_bank0) > 0 | |
| ): | |
| # print("hacked_CrossAttnUpBlock2D_forward1") | |
| hidden_states = mix_ref_feature( | |
| hidden_states, [self.fea_bank0[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale) | |
| hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank0[i], self.var_bank0[i], | |
| self.do_classifier_free_guidance, | |
| self.style_fidelity, self.uc_mask) | |
| hidden_states = attn( | |
| hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| cross_attention_kwargs=cross_attention_kwargs, | |
| # attention_mask=attention_mask, | |
| # encoder_attention_mask=encoder_attention_mask, | |
| return_dict=False, | |
| )[0] | |
| if self.MODE == "write": | |
| if self.gn_auto_machine_weight >= self.gn_weight: | |
| # var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) | |
| # mask var mean | |
| scale_ratio = self.ref_mask.shape[2] / \ | |
| hidden_states.shape[2] | |
| this_ref_mask = F.interpolate( | |
| self.ref_mask.to(hidden_states.device), | |
| scale_factor=1 / scale_ratio, | |
| ) | |
| self.fea_bank.append(save_ref_feature( | |
| hidden_states, this_ref_mask)) | |
| this_ref_mask = this_ref_mask.repeat( | |
| hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
| ).bool() | |
| masked_hidden_states = ( | |
| hidden_states[this_ref_mask] | |
| .detach() | |
| .clone() | |
| .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
| ) | |
| var, mean = torch.var_mean( | |
| masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
| ) | |
| self.mean_bank.append(torch.cat([mean]*2, dim=0)) | |
| self.var_bank.append(torch.cat([var]*2, dim=0)) | |
| if self.MODE == "read": | |
| if ( | |
| self.gn_auto_machine_weight >= self.gn_weight | |
| and len(self.mean_bank) > 0 | |
| and len(self.var_bank) > 0 | |
| ): | |
| # print("hacked_CrossAttnUpBlock2D_forward") | |
| hidden_states = mix_ref_feature( | |
| hidden_states, [self.fea_bank[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale) | |
| hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank[i], self.var_bank[i], | |
| self.do_classifier_free_guidance, | |
| self.style_fidelity, self.uc_mask) | |
| if self.MODE == "read": | |
| self.mean_bank0 = [] | |
| self.var_bank0 = [] | |
| self.mean_bank = [] | |
| self.var_bank = [] | |
| self.fea_bank = [] | |
| self.fea_bank0 = [] | |
| if self.upsamplers is not None: | |
| for upsampler in self.upsamplers: | |
| hidden_states = upsampler(hidden_states, upsample_size) | |
| return hidden_states | |
| def hacked_UpBlock2D_forward( | |
| self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None | |
| ): | |
| eps = 1e-6 | |
| for i, resnet in enumerate(self.resnets): | |
| # pop res hidden states | |
| res_hidden_states = res_hidden_states_tuple[-1] | |
| res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
| hidden_states = torch.cat( | |
| [hidden_states, res_hidden_states], dim=1) | |
| hidden_states = resnet(hidden_states, temb) | |
| if self.MODE == "write": | |
| if self.gn_auto_machine_weight >= self.gn_weight: | |
| # var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) | |
| # mask var mean | |
| scale_ratio = self.ref_mask.shape[2] / \ | |
| hidden_states.shape[2] | |
| this_ref_mask = F.interpolate( | |
| self.ref_mask.to(hidden_states.device), | |
| scale_factor=1 / scale_ratio, | |
| ) | |
| self.fea_bank.append(save_ref_feature( | |
| hidden_states, this_ref_mask)) | |
| this_ref_mask = this_ref_mask.repeat( | |
| hidden_states.shape[0], hidden_states.shape[1], 1, 1 | |
| ).bool() | |
| masked_hidden_states = ( | |
| hidden_states[this_ref_mask] | |
| .detach() | |
| .clone() | |
| .view(hidden_states.shape[0], hidden_states.shape[1], -1, 1) | |
| ) | |
| var, mean = torch.var_mean( | |
| masked_hidden_states, dim=(2, 3), keepdim=True, correction=0 | |
| ) | |
| self.mean_bank.append(torch.cat([mean]*2, dim=0)) | |
| self.var_bank.append(torch.cat([var]*2, dim=0)) | |
| if self.MODE == "read": | |
| if ( | |
| self.gn_auto_machine_weight >= self.gn_weight | |
| and len(self.mean_bank) > 0 | |
| and len(self.var_bank) > 0 | |
| ): | |
| # print("hacked_UpBlock2D_forward") | |
| hidden_states = mix_ref_feature( | |
| hidden_states, [self.fea_bank[i]], cfg=self.do_classifier_free_guidance, ref_scale=self.ref_scale) | |
| hidden_states = mix_norm_feature(hidden_states, self.inpaint_mask, self.mean_bank[i], self.var_bank[i], | |
| self.do_classifier_free_guidance, | |
| self.style_fidelity, self.uc_mask) | |
| if self.MODE == "read": | |
| self.mean_bank = [] | |
| self.var_bank = [] | |
| self.fea_bank = [] | |
| if self.upsamplers is not None: | |
| for upsampler in self.upsamplers: | |
| hidden_states = upsampler(hidden_states, upsample_size) | |
| return hidden_states | |
| if model_type == "unet": | |
| if reference_attn: | |
| attn_modules = [ | |
| module | |
| for module in torch_dfs(model) | |
| if isinstance(module, BasicTransformerBlock) | |
| ] | |
| attn_modules = sorted( | |
| attn_modules, key=lambda x: -x.norm1.normalized_shape[0] | |
| ) | |
| for i, module in enumerate(attn_modules): | |
| module._original_inner_forward = module.forward | |
| module.forward = hacked_basic_transformer_inner_forward.__get__( | |
| module, BasicTransformerBlock | |
| ) | |
| module.bank = [] | |
| module.fea_bank = [] | |
| module.attn_weight = float(i) / float(len(attn_modules)) | |
| module.attention_auto_machine_weight = ( | |
| self.attention_auto_machine_weight | |
| ) | |
| module.gn_auto_machine_weight = self.gn_auto_machine_weight | |
| module.do_classifier_free_guidance = ( | |
| self.do_classifier_free_guidance | |
| ) | |
| module.do_classifier_free_guidance = ( | |
| self.do_classifier_free_guidance | |
| ) | |
| module.uc_mask = self.uc_mask | |
| module.style_fidelity = self.style_fidelity | |
| module.ref_mask = self.ref_mask | |
| module.ref_scale = self.ref_scale | |
| else: | |
| attn_modules = None | |
| if reference_adain: | |
| gn_modules = [model.mid_block] | |
| model.mid_block.gn_weight = 0 | |
| down_blocks = model.down_blocks | |
| for w, module in enumerate(down_blocks): | |
| module.gn_weight = 1.0 - float(w) / float(len(down_blocks)) | |
| gn_modules.append(module) | |
| # print(module.__class__.__name__,module.gn_weight) | |
| up_blocks = model.up_blocks | |
| for w, module in enumerate(up_blocks): | |
| module.gn_weight = float(w) / float(len(up_blocks)) | |
| gn_modules.append(module) | |
| # print(module.__class__.__name__,module.gn_weight) | |
| for i, module in enumerate(gn_modules): | |
| if getattr(module, "original_forward", None) is None: | |
| module.original_forward = module.forward | |
| if i == 0: | |
| # mid_block | |
| module.forward = hacked_mid_forward.__get__( | |
| module, torch.nn.Module | |
| ) | |
| # elif isinstance(module, CrossAttnDownBlock2D): | |
| # module.forward = hack_CrossAttnDownBlock2D_forward.__get__( | |
| # module, CrossAttnDownBlock2D | |
| # ) | |
| # module.mean_bank0 = [] | |
| # module.var_bank0 = [] | |
| # module.fea_bank0 = [] | |
| elif isinstance(module, DownBlock2D): | |
| module.forward = hacked_DownBlock2D_forward.__get__( | |
| module, DownBlock2D | |
| ) | |
| # elif isinstance(module, CrossAttnUpBlock2D): | |
| # module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D) | |
| # module.mean_bank0 = [] | |
| # module.var_bank0 = [] | |
| # module.fea_bank0 = [] | |
| elif isinstance(module, UpBlock2D): | |
| module.forward = hacked_UpBlock2D_forward.__get__( | |
| module, UpBlock2D | |
| ) | |
| module.mean_bank0 = [] | |
| module.var_bank0 = [] | |
| module.fea_bank0 = [] | |
| module.mean_bank = [] | |
| module.var_bank = [] | |
| module.fea_bank = [] | |
| module.attention_auto_machine_weight = ( | |
| self.attention_auto_machine_weight | |
| ) | |
| module.gn_auto_machine_weight = self.gn_auto_machine_weight | |
| module.do_classifier_free_guidance = ( | |
| self.do_classifier_free_guidance | |
| ) | |
| module.do_classifier_free_guidance = ( | |
| self.do_classifier_free_guidance | |
| ) | |
| module.uc_mask = self.uc_mask | |
| module.style_fidelity = self.style_fidelity | |
| module.ref_mask = self.ref_mask | |
| module.inpaint_mask = self.inpaint_mask | |
| module.ref_scale = self.ref_scale | |
| else: | |
| gn_modules = None | |
| elif model_type == "controlnet": | |
| model = model.nets[-1] # only hack the inpainting controlnet | |
| if reference_attn: | |
| attn_modules = [ | |
| module | |
| for module in torch_dfs(model) | |
| if isinstance(module, BasicTransformerBlock) | |
| ] | |
| attn_modules = sorted( | |
| attn_modules, key=lambda x: -x.norm1.normalized_shape[0] | |
| ) | |
| for i, module in enumerate(attn_modules): | |
| module._original_inner_forward = module.forward | |
| module.forward = hacked_basic_transformer_inner_forward.__get__( | |
| module, BasicTransformerBlock | |
| ) | |
| module.bank = [] | |
| module.fea_bank = [] | |
| # float(i) / float(len(attn_modules)) | |
| module.attn_weight = 0.0 | |
| module.attention_auto_machine_weight = ( | |
| self.attention_auto_machine_weight | |
| ) | |
| module.gn_auto_machine_weight = self.gn_auto_machine_weight | |
| module.do_classifier_free_guidance = ( | |
| self.do_classifier_free_guidance | |
| ) | |
| module.do_classifier_free_guidance = ( | |
| self.do_classifier_free_guidance | |
| ) | |
| module.uc_mask = self.uc_mask | |
| module.style_fidelity = self.style_fidelity | |
| module.ref_mask = self.ref_mask | |
| module.ref_scale = self.ref_scale | |
| else: | |
| attn_modules = None | |
| # gn_modules = None | |
| if reference_adain: | |
| gn_modules = [model.mid_block] | |
| model.mid_block.gn_weight = 0 | |
| down_blocks = model.down_blocks | |
| for w, module in enumerate(down_blocks): | |
| module.gn_weight = 1.0 - float(w) / float(len(down_blocks)) | |
| gn_modules.append(module) | |
| # print(module.__class__.__name__,module.gn_weight) | |
| for i, module in enumerate(gn_modules): | |
| if getattr(module, "original_forward", None) is None: | |
| module.original_forward = module.forward | |
| if i == 0: | |
| # mid_block | |
| module.forward = hacked_mid_forward.__get__( | |
| module, torch.nn.Module | |
| ) | |
| # elif isinstance(module, CrossAttnDownBlock2D): | |
| # module.forward = hack_CrossAttnDownBlock2D_forward.__get__( | |
| # module, CrossAttnDownBlock2D | |
| # ) | |
| # module.mean_bank0 = [] | |
| # module.var_bank0 = [] | |
| # module.fea_bank0 = [] | |
| elif isinstance(module, DownBlock2D): | |
| module.forward = hacked_DownBlock2D_forward.__get__( | |
| module, DownBlock2D | |
| ) | |
| module.mean_bank = [] | |
| module.var_bank = [] | |
| module.fea_bank = [] | |
| module.attention_auto_machine_weight = ( | |
| self.attention_auto_machine_weight | |
| ) | |
| module.gn_auto_machine_weight = self.gn_auto_machine_weight | |
| module.do_classifier_free_guidance = ( | |
| self.do_classifier_free_guidance | |
| ) | |
| module.do_classifier_free_guidance = ( | |
| self.do_classifier_free_guidance | |
| ) | |
| module.uc_mask = self.uc_mask | |
| module.style_fidelity = self.style_fidelity | |
| module.ref_mask = self.ref_mask | |
| module.inpaint_mask = self.inpaint_mask | |
| module.ref_scale = self.ref_scale | |
| else: | |
| gn_modules = None | |
| return attn_modules, gn_modules | |
| def change_module_mode(self, mode, attn_modules, gn_modules): | |
| if attn_modules is not None: | |
| for i, module in enumerate(attn_modules): | |
| module.MODE = mode | |
| if gn_modules is not None: | |
| for i, module in enumerate(gn_modules): | |
| module.MODE = mode | |