Spaces:
Running
on
Zero
Running
on
Zero
| import json | |
| import logging | |
| import numbers | |
| import torch | |
| from modules.Device import Device | |
| from modules.cond import cast | |
| from modules.clip.CLIPTextModel import CLIPTextModel | |
| def gen_empty_tokens(special_tokens: dict, length: int) -> list: | |
| """#### Generate a list of empty tokens. | |
| #### Args: | |
| - `special_tokens` (dict): The special tokens. | |
| - `length` (int): The length of the token list. | |
| #### Returns: | |
| - `list`: The list of empty tokens. | |
| """ | |
| start_token = special_tokens.get("start", None) | |
| end_token = special_tokens.get("end", None) | |
| pad_token = special_tokens.get("pad") | |
| output = [] | |
| if start_token is not None: | |
| output.append(start_token) | |
| if end_token is not None: | |
| output.append(end_token) | |
| output += [pad_token] * (length - len(output)) | |
| return output | |
| class ClipTokenWeightEncoder: | |
| """#### Class representing a CLIP token weight encoder.""" | |
| def encode_token_weights(self, token_weight_pairs: list) -> tuple: | |
| """#### Encode token weights. | |
| #### Args: | |
| - `token_weight_pairs` (list): The token weight pairs. | |
| #### Returns: | |
| - `tuple`: The encoded tokens and the pooled output. | |
| """ | |
| to_encode = list() | |
| max_token_len = 0 | |
| has_weights = False | |
| for x in token_weight_pairs: | |
| tokens = list(map(lambda a: a[0], x)) | |
| max_token_len = max(len(tokens), max_token_len) | |
| has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x)) | |
| to_encode.append(tokens) | |
| sections = len(to_encode) | |
| if has_weights or sections == 0: | |
| to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) | |
| o = self.encode(to_encode) | |
| out, pooled = o[:2] | |
| if pooled is not None: | |
| first_pooled = pooled[0:1].to(Device.intermediate_device()) | |
| else: | |
| first_pooled = pooled | |
| output = [] | |
| for k in range(0, sections): | |
| z = out[k : k + 1] | |
| if has_weights: | |
| z_empty = out[-1] | |
| for i in range(len(z)): | |
| for j in range(len(z[i])): | |
| weight = token_weight_pairs[k][j][1] | |
| if weight != 1.0: | |
| z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j] | |
| output.append(z) | |
| if len(output) == 0: | |
| r = (out[-1:].to(Device.intermediate_device()), first_pooled) | |
| else: | |
| r = (torch.cat(output, dim=-2).to(Device.intermediate_device()), first_pooled) | |
| if len(o) > 2: | |
| extra = {} | |
| for k in o[2]: | |
| v = o[2][k] | |
| if k == "attention_mask": | |
| v = ( | |
| v[:sections] | |
| .flatten() | |
| .unsqueeze(dim=0) | |
| .to(Device.intermediate_device()) | |
| ) | |
| extra[k] = v | |
| r = r + (extra,) | |
| return r | |
| class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): | |
| """#### Uses the CLIP transformer encoder for text (from huggingface).""" | |
| LAYERS = ["last", "pooled", "hidden"] | |
| def __init__( | |
| self, | |
| version: str = "openai/clip-vit-large-patch14", | |
| device: str = "cpu", | |
| max_length: int = 77, | |
| freeze: bool = True, | |
| layer: str = "last", | |
| layer_idx: int = None, | |
| textmodel_json_config: str = None, | |
| dtype: torch.dtype = None, | |
| model_class: type = CLIPTextModel, | |
| special_tokens: dict = {"start": 49406, "end": 49407, "pad": 49407}, | |
| layer_norm_hidden_state: bool = True, | |
| enable_attention_masks: bool = False, | |
| zero_out_masked:bool = False, | |
| return_projected_pooled: bool = True, | |
| return_attention_masks: bool = False, | |
| model_options={}, | |
| ): | |
| """#### Initialize the SDClipModel. | |
| #### Args: | |
| - `version` (str, optional): The version of the model. Defaults to "openai/clip-vit-large-patch14". | |
| - `device` (str, optional): The device to use. Defaults to "cpu". | |
| - `max_length` (int, optional): The maximum length of the input. Defaults to 77. | |
| - `freeze` (bool, optional): Whether to freeze the model parameters. Defaults to True. | |
| - `layer` (str, optional): The layer to use. Defaults to "last". | |
| - `layer_idx` (int, optional): The index of the layer. Defaults to None. | |
| - `textmodel_json_config` (str, optional): The path to the JSON config file. Defaults to None. | |
| - `dtype` (torch.dtype, optional): The data type. Defaults to None. | |
| - `model_class` (type, optional): The model class. Defaults to CLIPTextModel. | |
| - `special_tokens` (dict, optional): The special tokens. Defaults to {"start": 49406, "end": 49407, "pad": 49407}. | |
| - `layer_norm_hidden_state` (bool, optional): Whether to normalize the hidden state. Defaults to True. | |
| - `enable_attention_masks` (bool, optional): Whether to enable attention masks. Defaults to False. | |
| - `zero_out_masked` (bool, optional): Whether to zero out masked tokens. Defaults to False. | |
| - `return_projected_pooled` (bool, optional): Whether to return the projected pooled output. Defaults to True. | |
| - `return_attention_masks` (bool, optional): Whether to return the attention masks. Defaults to False. | |
| - `model_options` (dict, optional): Additional model options. Defaults to {}. | |
| """ | |
| super().__init__() | |
| assert layer in self.LAYERS | |
| if textmodel_json_config is None: | |
| textmodel_json_config = "./_internal/clip/sd1_clip_config.json" | |
| with open(textmodel_json_config) as f: | |
| config = json.load(f) | |
| operations = model_options.get("custom_operations", None) | |
| if operations is None: | |
| operations = cast.manual_cast | |
| self.operations = operations | |
| self.transformer = model_class(config, dtype, device, self.operations) | |
| self.num_layers = self.transformer.num_layers | |
| self.max_length = max_length | |
| if freeze: | |
| self.freeze() | |
| self.layer = layer | |
| self.layer_idx = None | |
| self.special_tokens = special_tokens | |
| self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) | |
| self.enable_attention_masks = enable_attention_masks | |
| self.zero_out_masked = zero_out_masked | |
| self.layer_norm_hidden_state = layer_norm_hidden_state | |
| self.return_projected_pooled = return_projected_pooled | |
| self.return_attention_masks = return_attention_masks | |
| if layer == "hidden": | |
| assert layer_idx is not None | |
| assert abs(layer_idx) < self.num_layers | |
| self.set_clip_options({"layer": layer_idx}) | |
| self.options_default = ( | |
| self.layer, | |
| self.layer_idx, | |
| self.return_projected_pooled, | |
| ) | |
| def freeze(self) -> None: | |
| """#### Freeze the model parameters.""" | |
| self.transformer = self.transformer.eval() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def set_clip_options(self, options: dict) -> None: | |
| """#### Set the CLIP options. | |
| #### Args: | |
| - `options` (dict): The options to set. | |
| """ | |
| layer_idx = options.get("layer", self.layer_idx) | |
| self.return_projected_pooled = options.get( | |
| "projected_pooled", self.return_projected_pooled | |
| ) | |
| if layer_idx is None or abs(layer_idx) > self.num_layers: | |
| self.layer = "last" | |
| else: | |
| self.layer = "hidden" | |
| self.layer_idx = layer_idx | |
| def reset_clip_options(self) -> None: | |
| """#### Reset the CLIP options to default.""" | |
| self.layer = self.options_default[0] | |
| self.layer_idx = self.options_default[1] | |
| self.return_projected_pooled = self.options_default[2] | |
| def set_up_textual_embeddings(self, tokens: list, current_embeds: torch.nn.Embedding) -> list: | |
| """#### Set up the textual embeddings. | |
| #### Args: | |
| - `tokens` (list): The input tokens. | |
| - `current_embeds` (torch.nn.Embedding): The current embeddings. | |
| #### Returns: | |
| - `list`: The processed tokens. | |
| """ | |
| out_tokens = [] | |
| next_new_token = token_dict_size = current_embeds.weight.shape[0] | |
| embedding_weights = [] | |
| for x in tokens: | |
| tokens_temp = [] | |
| for y in x: | |
| if isinstance(y, numbers.Integral): | |
| tokens_temp += [int(y)] | |
| else: | |
| if y.shape[0] == current_embeds.weight.shape[1]: | |
| embedding_weights += [y] | |
| tokens_temp += [next_new_token] | |
| next_new_token += 1 | |
| else: | |
| logging.warning( | |
| "WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format( | |
| y.shape[0], current_embeds.weight.shape[1] | |
| ) | |
| ) | |
| while len(tokens_temp) < len(x): | |
| tokens_temp += [self.special_tokens["pad"]] | |
| out_tokens += [tokens_temp] | |
| n = token_dict_size | |
| if len(embedding_weights) > 0: | |
| new_embedding = self.operations.Embedding( | |
| next_new_token + 1, | |
| current_embeds.weight.shape[1], | |
| device=current_embeds.weight.device, | |
| dtype=current_embeds.weight.dtype, | |
| ) | |
| new_embedding.weight[:token_dict_size] = current_embeds.weight | |
| for x in embedding_weights: | |
| new_embedding.weight[n] = x | |
| n += 1 | |
| self.transformer.set_input_embeddings(new_embedding) | |
| processed_tokens = [] | |
| for x in out_tokens: | |
| processed_tokens += [ | |
| list(map(lambda a: n if a == -1 else a, x)) | |
| ] # The EOS token should always be the largest one | |
| return processed_tokens | |
| def forward(self, tokens: list) -> tuple: | |
| """#### Forward pass of the model. | |
| #### Args: | |
| - `tokens` (list): The input tokens. | |
| #### Returns: | |
| - `tuple`: The output and the pooled output. | |
| """ | |
| backup_embeds = self.transformer.get_input_embeddings() | |
| device = backup_embeds.weight.device | |
| tokens = self.set_up_textual_embeddings(tokens, backup_embeds) | |
| tokens = torch.LongTensor(tokens).to(device) | |
| attention_mask = None | |
| if ( | |
| self.enable_attention_masks | |
| or self.zero_out_masked | |
| or self.return_attention_masks | |
| ): | |
| attention_mask = torch.zeros_like(tokens) | |
| end_token = self.special_tokens.get("end", -1) | |
| for x in range(attention_mask.shape[0]): | |
| for y in range(attention_mask.shape[1]): | |
| attention_mask[x, y] = 1 | |
| if tokens[x, y] == end_token: | |
| break | |
| attention_mask_model = None | |
| if self.enable_attention_masks: | |
| attention_mask_model = attention_mask | |
| outputs = self.transformer( | |
| tokens, | |
| attention_mask_model, | |
| intermediate_output=self.layer_idx, | |
| final_layer_norm_intermediate=self.layer_norm_hidden_state, | |
| dtype=torch.float32, | |
| ) | |
| self.transformer.set_input_embeddings(backup_embeds) | |
| if self.layer == "last": | |
| z = outputs[0].float() | |
| else: | |
| z = outputs[1].float() | |
| if self.zero_out_masked: | |
| z *= attention_mask.unsqueeze(-1).float() | |
| pooled_output = None | |
| if len(outputs) >= 3: | |
| if ( | |
| not self.return_projected_pooled | |
| and len(outputs) >= 4 | |
| and outputs[3] is not None | |
| ): | |
| pooled_output = outputs[3].float() | |
| elif outputs[2] is not None: | |
| pooled_output = outputs[2].float() | |
| extra = {} | |
| if self.return_attention_masks: | |
| extra["attention_mask"] = attention_mask | |
| if len(extra) > 0: | |
| return z, pooled_output, extra | |
| return z, pooled_output | |
| def encode(self, tokens: list) -> tuple: | |
| """#### Encode the input tokens. | |
| #### Args: | |
| - `tokens` (list): The input tokens. | |
| #### Returns: | |
| - `tuple`: The encoded tokens and the pooled output. | |
| """ | |
| return self(tokens) | |
| def load_sd(self, sd: dict) -> None: | |
| """#### Load the state dictionary. | |
| #### Args: | |
| - `sd` (dict): The state dictionary. | |
| """ | |
| return self.transformer.load_state_dict(sd, strict=False) | |
| class SD1ClipModel(torch.nn.Module): | |
| """#### Class representing the SD1ClipModel.""" | |
| def __init__( | |
| self, device: str = "cpu", dtype: torch.dtype = None, clip_name: str = "l", clip_model: type = SDClipModel, **kwargs | |
| ): | |
| """#### Initialize the SD1ClipModel. | |
| #### Args: | |
| - `device` (str, optional): The device to use. Defaults to "cpu". | |
| - `dtype` (torch.dtype, optional): The data type. Defaults to None. | |
| - `clip_name` (str, optional): The name of the CLIP model. Defaults to "l". | |
| - `clip_model` (type, optional): The CLIP model class. Defaults to SDClipModel. | |
| - `**kwargs`: Additional keyword arguments. | |
| """ | |
| super().__init__() | |
| self.clip_name = clip_name | |
| self.clip = "clip_{}".format(self.clip_name) | |
| self.lowvram_patch_counter = 0 | |
| self.model_loaded_weight_memory = 0 | |
| setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs)) | |
| def set_clip_options(self, options: dict) -> None: | |
| """#### Set the CLIP options. | |
| #### Args: | |
| - `options` (dict): The options to set. | |
| """ | |
| getattr(self, self.clip).set_clip_options(options) | |
| def reset_clip_options(self) -> None: | |
| """#### Reset the CLIP options to default.""" | |
| getattr(self, self.clip).reset_clip_options() | |
| def encode_token_weights(self, token_weight_pairs: dict) -> tuple: | |
| """#### Encode token weights. | |
| #### Args: | |
| - `token_weight_pairs` (dict): The token weight pairs. | |
| #### Returns: | |
| - `tuple`: The encoded tokens and the pooled output. | |
| """ | |
| token_weight_pairs = token_weight_pairs[self.clip_name] | |
| out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs) | |
| return out, pooled |