Spaces:
Running
on
Zero
Running
on
Zero
| from typing import * | |
| import os | |
| os.environ['TOKENIZERS_PARALLELISM'] = 'true' | |
| import torch | |
| from transformers import AutoTokenizer, CLIPTextModel | |
| from ....utils import dist_utils | |
| class TextConditionedMixin: | |
| """ | |
| Mixin for text-conditioned models. | |
| Args: | |
| text_cond_model: The text conditioning model. | |
| """ | |
| def __init__(self, *args, text_cond_model: str = 'openai/clip-vit-large-patch14', **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.text_cond_model_name = text_cond_model | |
| self.text_cond_model = None # the model is init lazily | |
| def _init_text_cond_model(self): | |
| """ | |
| Initialize the text conditioning model. | |
| """ | |
| # load model | |
| with dist_utils.local_master_first(): | |
| model = CLIPTextModel.from_pretrained(self.text_cond_model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(self.text_cond_model_name) | |
| model.eval() | |
| model = model.cuda() | |
| self.text_cond_model = { | |
| 'model': model, | |
| 'tokenizer': tokenizer, | |
| } | |
| self.text_cond_model['null_cond'] = self.encode_text(['']) | |
| def encode_text(self, text: List[str]) -> torch.Tensor: | |
| """ | |
| Encode the text. | |
| """ | |
| assert isinstance(text, list) and isinstance(text[0], str), "TextConditionedMixin only supports list of strings as cond" | |
| if self.text_cond_model is None: | |
| self._init_text_cond_model() | |
| encoding = self.text_cond_model['tokenizer'](text, max_length=77, padding='max_length', truncation=True, return_tensors='pt') | |
| tokens = encoding['input_ids'].cuda() | |
| embeddings = self.text_cond_model['model'](input_ids=tokens).last_hidden_state | |
| return embeddings | |
| def get_cond(self, cond, **kwargs): | |
| """ | |
| Get the conditioning data. | |
| """ | |
| cond = self.encode_text(cond) | |
| kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1) | |
| cond = super().get_cond(cond, **kwargs) | |
| return cond | |
| def get_inference_cond(self, cond, **kwargs): | |
| """ | |
| Get the conditioning data for inference. | |
| """ | |
| cond = self.encode_text(cond) | |
| kwargs['neg_cond'] = self.text_cond_model['null_cond'].repeat(cond.shape[0], 1, 1) | |
| cond = super().get_inference_cond(cond, **kwargs) | |
| return cond | |