Spaces:
Sleeping
Sleeping
Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| import open_clip | |
| from functools import partial | |
| from utils.registry_class import EMBEDMANAGER | |
| DEFAULT_PLACEHOLDER_TOKEN = ["*"] | |
| PROGRESSIVE_SCALE = 2000 | |
| per_img_token_list = [ | |
| '讗', '讘', '讙', '讚', '讛', '讜', '讝', '讞', '讟', '讬', '讻', '诇', '诪', '谞', '住', '注', '驻', '爪', '拽', '专', '砖', '转', | |
| ] | |
| def get_clip_token_for_string(string): | |
| tokens = open_clip.tokenize(string) | |
| return tokens[0, 1] | |
| def get_embedding_for_clip_token(embedder, token): | |
| return embedder(token.unsqueeze(0))[0] | |
| class EmbeddingManager(nn.Module): | |
| def __init__( | |
| self, | |
| embedder, | |
| placeholder_strings=None, | |
| initializer_words=None, | |
| per_image_tokens=False, | |
| num_vectors_per_token=1, | |
| progressive_words=False, | |
| temporal_prompt_length=1, | |
| token_dim=1024, | |
| **kwargs | |
| ): | |
| super().__init__() | |
| self.string_to_token_dict = {} | |
| self.string_to_param_dict = nn.ParameterDict() | |
| self.initial_embeddings = nn.ParameterDict() # These should not be optimized | |
| self.progressive_words = progressive_words | |
| self.progressive_counter = 0 | |
| self.max_vectors_per_token = num_vectors_per_token | |
| get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.model.token_embedding.cpu()) | |
| if per_image_tokens: | |
| placeholder_strings.extend(per_img_token_list) | |
| for idx, placeholder_string in enumerate(placeholder_strings): | |
| token = get_clip_token_for_string(placeholder_string) | |
| if initializer_words and idx < len(initializer_words): | |
| init_word_token = get_clip_token_for_string(initializer_words[idx]) | |
| with torch.no_grad(): | |
| init_word_embedding = get_embedding_for_tkn(init_word_token) | |
| token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True) | |
| self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False) | |
| else: | |
| token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True)) | |
| self.string_to_token_dict[placeholder_string] = token | |
| self.string_to_param_dict[placeholder_string] = token_params | |
| def forward( | |
| self, | |
| tokenized_text, | |
| embedded_text, | |
| ): | |
| b, n, device = *tokenized_text.shape, tokenized_text.device | |
| for placeholder_string, placeholder_token in self.string_to_token_dict.items(): | |
| placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device) | |
| if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement | |
| placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device)) | |
| embedded_text[placeholder_idx] = placeholder_embedding | |
| else: # otherwise, need to insert and keep track of changing indices | |
| if self.progressive_words: | |
| self.progressive_counter += 1 | |
| max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE | |
| else: | |
| max_step_tokens = self.max_vectors_per_token | |
| num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens) | |
| placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device)) | |
| if placeholder_rows.nelement() == 0: | |
| continue | |
| sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True) | |
| sorted_rows = placeholder_rows[sort_idx] | |
| for idx in range(len(sorted_rows)): | |
| row = sorted_rows[idx] | |
| col = sorted_cols[idx] | |
| new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n] | |
| new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n] | |
| embedded_text[row] = new_embed_row | |
| tokenized_text[row] = new_token_row | |
| return embedded_text | |
| def forward_with_text_img( | |
| self, | |
| tokenized_text, | |
| embedded_text, | |
| embedded_img, | |
| ): | |
| device = tokenized_text.device | |
| for placeholder_string, placeholder_token in self.string_to_token_dict.items(): | |
| placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device) | |
| placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device)) | |
| embedded_text[placeholder_idx] = embedded_text[placeholder_idx] + embedded_img + placeholder_embedding | |
| return embedded_text | |
| def forward_with_text( | |
| self, | |
| tokenized_text, | |
| embedded_text | |
| ): | |
| device = tokenized_text.device | |
| for placeholder_string, placeholder_token in self.string_to_token_dict.items(): | |
| placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device) | |
| placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device)) | |
| embedded_text[placeholder_idx] = embedded_text[placeholder_idx] + placeholder_embedding | |
| return embedded_text | |
| def save(self, ckpt_path): | |
| torch.save({"string_to_token": self.string_to_token_dict, | |
| "string_to_param": self.string_to_param_dict}, ckpt_path) | |
| def load(self, ckpt_path): | |
| ckpt = torch.load(ckpt_path, map_location='cpu') | |
| string_to_token = ckpt["string_to_token"] | |
| string_to_param = ckpt["string_to_param"] | |
| for string, token in string_to_token.items(): | |
| self.string_to_token_dict[string] = token | |
| for string, param in string_to_param.items(): | |
| self.string_to_param_dict[string] = param | |
| def get_embedding_norms_squared(self): | |
| all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim | |
| param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders | |
| return param_norm_squared | |
| def embedding_parameters(self): | |
| return self.string_to_param_dict.parameters() | |
| def embedding_to_coarse_loss(self): | |
| loss = 0. | |
| num_embeddings = len(self.initial_embeddings) | |
| for key in self.initial_embeddings: | |
| optimized = self.string_to_param_dict[key] | |
| coarse = self.initial_embeddings[key].clone().to(optimized.device) | |
| loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings | |
| return loss |