Spaces:
Running
on
Zero
Running
on
Zero
| import abc | |
| import torch | |
| from typing import Tuple, List | |
| from einops import rearrange | |
| class AttentionControl(abc.ABC): | |
| def step_callback(self, x_t): | |
| return x_t | |
| def between_steps(self): | |
| return | |
| def num_uncond_att_layers(self): | |
| return 0 | |
| def forward(self, attn, is_cross: bool, place_in_unet: str): | |
| raise NotImplementedError | |
| def __call__(self, attn, is_cross: bool, place_in_unet: str): | |
| if self.cur_att_layer >= self.num_uncond_att_layers: | |
| self.forward(attn, is_cross, place_in_unet) | |
| self.cur_att_layer += 1 | |
| if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: | |
| self.cur_att_layer = 0 | |
| self.cur_step += 1 | |
| self.between_steps() | |
| def reset(self): | |
| self.cur_step = 0 | |
| self.cur_att_layer = 0 | |
| def __init__(self): | |
| self.cur_step = 0 | |
| self.num_att_layers = -1 | |
| self.cur_att_layer = 0 | |
| class AttentionStore(AttentionControl): | |
| def get_empty_store(): | |
| return {"down_cross": [], "mid_cross": [], "up_cross": [], | |
| "down_self": [], "mid_self": [], "up_self": []} | |
| def forward(self, attn, is_cross: bool, place_in_unet: str): | |
| key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" | |
| #if attn.shape[1] <= 32 ** 2: # avoid memory overhead | |
| self.step_store[key].append(attn) | |
| return attn | |
| def between_steps(self): | |
| self.attention_store = self.step_store | |
| if self.save_global_store: | |
| with torch.no_grad(): | |
| if len(self.global_store) == 0: | |
| self.global_store = self.step_store | |
| else: | |
| for key in self.global_store: | |
| for i in range(len(self.global_store[key])): | |
| self.global_store[key][i] += self.step_store[key][i].detach() | |
| self.step_store = self.get_empty_store() | |
| self.step_store = self.get_empty_store() | |
| def get_average_attention(self): | |
| average_attention = self.attention_store | |
| return average_attention | |
| def get_average_global_attention(self): | |
| average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in | |
| self.attention_store} | |
| return average_attention | |
| def reset(self): | |
| super(AttentionStore, self).reset() | |
| self.step_store = self.get_empty_store() | |
| self.attention_store = {} | |
| self.global_store = {} | |
| def __init__(self, save_global_store=False): | |
| ''' | |
| Initialize an empty AttentionStore | |
| :param step_index: used to visualize only a specific step in the diffusion process | |
| ''' | |
| super(AttentionStore, self).__init__() | |
| self.save_global_store = save_global_store | |
| self.step_store = self.get_empty_store() | |
| self.attention_store = {} | |
| self.global_store = {} | |
| self.curr_step_index = 0 | |
| class AttentionStoreProcessor: | |
| def __init__(self, attnstore, place_in_unet): | |
| super().__init__() | |
| self.attnstore = attnstore | |
| self.place_in_unet = place_in_unet | |
| def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
| ) | |
| if attention_mask is not None: | |
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| query = attn.head_to_batch_dim(query) | |
| key = attn.head_to_batch_dim(key) | |
| value = attn.head_to_batch_dim(value) | |
| attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
| self.attnstore(rearrange(attention_probs, '(b h) i j -> b h i j', b=batch_size), False, self.place_in_unet) | |
| hidden_states = torch.bmm(attention_probs, value) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| class AttentionFlipCtrlProcessor: | |
| def __init__(self, attnstore, attnstore_ref, place_in_unet): | |
| super().__init__() | |
| self.attnstore = attnstore | |
| self.attnrstore_ref = attnstore_ref | |
| self.place_in_unet = place_in_unet | |
| def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
| ) | |
| if attention_mask is not None: | |
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| query = attn.head_to_batch_dim(query) | |
| key = attn.head_to_batch_dim(key) | |
| value = attn.head_to_batch_dim(value) | |
| attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
| if self.place_in_unet == 'mid': | |
| cur_att_layer = self.attnstore.cur_att_layer-len(self.attnrstore_ref.attention_store["down_self"]) | |
| elif self.place_in_unet == 'up': | |
| cur_att_layer = self.attnstore.cur_att_layer-(len(self.attnrstore_ref.attention_store["down_self"])+len(self.attnrstore_ref.attention_store["mid_self"])) | |
| else: | |
| cur_att_layer = self.attnstore.cur_att_layer | |
| attention_probs_ref = self.attnrstore_ref.attention_store[f"{self.place_in_unet}_{'self'}"][cur_att_layer] | |
| attention_probs_ref = rearrange(attention_probs_ref, 'b h i j -> (b h) i j') | |
| attention_probs = 0.0 * attention_probs + 1.0 * torch.flip(attention_probs_ref, dims=(-2, -1)) | |
| self.attnstore(rearrange(attention_probs, '(b h) i j -> b h i j', b=batch_size), False, self.place_in_unet) | |
| hidden_states = torch.bmm(attention_probs, value) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| def register_temporal_self_attention_control(unet, controller): | |
| attn_procs = {} | |
| temporal_self_att_count = 0 | |
| for name in unet.attn_processors.keys(): | |
| if name.endswith("temporal_transformer_blocks.0.attn1.processor"): | |
| if name.startswith("mid_block"): | |
| place_in_unet = "mid" | |
| elif name.startswith("up_blocks"): | |
| block_id = int(name[len("up_blocks.")]) | |
| place_in_unet = "up" | |
| elif name.startswith("down_blocks"): | |
| block_id = int(name[len("down_blocks.")]) | |
| place_in_unet = "down" | |
| else: | |
| continue | |
| temporal_self_att_count += 1 | |
| attn_procs[name] = AttentionStoreProcessor( | |
| attnstore=controller, place_in_unet=place_in_unet | |
| ) | |
| else: | |
| attn_procs[name] = unet.attn_processors[name] | |
| unet.set_attn_processor(attn_procs) | |
| controller.num_att_layers = temporal_self_att_count | |
| def register_temporal_self_attention_flip_control(unet, controller, controller_ref): | |
| attn_procs = {} | |
| temporal_self_att_count = 0 | |
| for name in unet.attn_processors.keys(): | |
| if name.endswith("temporal_transformer_blocks.0.attn1.processor"): | |
| if name.startswith("mid_block"): | |
| place_in_unet = "mid" | |
| elif name.startswith("up_blocks"): | |
| block_id = int(name[len("up_blocks.")]) | |
| place_in_unet = "up" | |
| elif name.startswith("down_blocks"): | |
| block_id = int(name[len("down_blocks.")]) | |
| place_in_unet = "down" | |
| else: | |
| continue | |
| temporal_self_att_count += 1 | |
| attn_procs[name] = AttentionFlipCtrlProcessor( | |
| attnstore=controller, attnstore_ref=controller_ref, place_in_unet=place_in_unet | |
| ) | |
| else: | |
| attn_procs[name] = unet.attn_processors[name] | |
| unet.set_attn_processor(attn_procs) | |
| controller.num_att_layers = temporal_self_att_count | |