Spaces:
Runtime error
Runtime error
| import pathlib | |
| import tempfile | |
| from collections import OrderedDict | |
| from typing import Tuple, Union | |
| import logging | |
| import os | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from timm.models.layers import DropPath, trunc_normal_ | |
| from .image_encoder import build_image_encoder | |
| from .text_encoder import build_text_encoder | |
| from .text_encoder import build_tokenizer | |
| from .templates import DEFAULT_TEMPLATES | |
| logger = logging.getLogger(__name__) | |
| class UniCLModel(nn.Module): | |
| def __init__(self, config: dict,): | |
| super().__init__() | |
| self.conf_lang_encoder = config['MODEL']['TEXT_ENCODER'] | |
| self.tokenizer = build_tokenizer(self.conf_lang_encoder) | |
| self.text_encoder = build_text_encoder(self.conf_lang_encoder, self.tokenizer, config['VERBOSE']) | |
| dim_projection = config['MODEL']['DIM_PROJECTION'] | |
| if hasattr(self.text_encoder, 'dim_out'): | |
| dim_out = self.text_encoder.dim_out | |
| else: | |
| with torch.no_grad(): | |
| dim_out = self.text_encoder( | |
| torch.zeros(1,1).type(torch.LongTensor) | |
| )['last_hidden_state'].size(2) | |
| self.text_projection = nn.Parameter(torch.empty(dim_out, dim_projection)) | |
| self.conf_image_encoder = config['MODEL']['IMAGE_ENCODER'] | |
| self.image_encoder = build_image_encoder(self.conf_image_encoder) | |
| self.image_projection = nn.Parameter( | |
| torch.empty(self.image_encoder.dim_out, dim_projection) | |
| ) | |
| self.logit_scale = nn.Parameter(torch.ones([])) | |
| trunc_normal_(self.text_projection, std=.02) | |
| trunc_normal_(self.image_projection, std=.02) | |
| def _convert_old_weights(self, model_dict): | |
| model_dict_updated = {} | |
| for k, v in model_dict.items(): | |
| if k.startswith('visual.'): | |
| model_dict_updated['image_encoder.'+k[7:]] = v | |
| elif k.startswith('text.'): | |
| model_dict_updated['lang_encoder.'+k[5:]] = v | |
| elif k == 'vision_projection': | |
| model_dict_updated['image_projection'] = v | |
| elif k == 'text_projection': | |
| model_dict_updated['text_projection'] = v | |
| else: | |
| model_dict_updated[k] = v | |
| return model_dict_updated | |
| def from_pretrained(self, pretrained='', pretrained_layers=[], verbose=True): | |
| if not os.path.isfile(pretrained): | |
| logger.warning(f'=> Pretrained model ({pretrained}) is not a file, skip init weight') | |
| return | |
| pretrained_dict = torch.load(pretrained, map_location='cpu') | |
| logger.info(f'=> Loading pretrained model {pretrained}') | |
| pretrained_dict = self._convert_old_weights(pretrained_dict) | |
| model_dict = self.state_dict() | |
| pretrained_dict = { | |
| k: v for k, v in pretrained_dict.items() | |
| if k in model_dict.keys() | |
| } | |
| need_init_state_dict = {} | |
| image_encoder_state_dict = {} | |
| for k, v in pretrained_dict.items(): | |
| need_init = ( | |
| k.split('.')[0] in pretrained_layers | |
| or pretrained_layers[0] == '*' | |
| ) | |
| if need_init: | |
| if k.startswith('image_encoder.'): | |
| image_encoder_state_dict[k] = v | |
| else: | |
| if verbose: | |
| logger.info(f'=> init {k} from {pretrained}') | |
| need_init_state_dict[k] = v | |
| self.image_encoder.from_state_dict(image_encoder_state_dict, ['*'], verbose) | |
| self.load_state_dict(need_init_state_dict, strict=False) | |
| def no_weight_decay(self): | |
| no_weight_decay = {'logit_scale'} | |
| if hasattr(self.text_encoder, 'no_weight_decay'): | |
| for k in self.text_encoder.no_weight_decay(): | |
| no_weight_decay.add('lang_encoder.'+k) | |
| if hasattr(self.image_encoder, 'no_weight_decay'): | |
| for k in self.image_encoder.no_weight_decay(): | |
| no_weight_decay.add('image_encoder.'+k) | |
| return no_weight_decay | |
| def dtype(self): | |
| return self.logit_scale.dtype | |
| def get_imnet_embeddings(self): | |
| templates = IMAGENET_DEFAULT_TEMPLATES[:1] | |
| clss_embeddings = [] | |
| for clss in IMAGENET_CLASSES: | |
| txts = [template.format(clss) for template in templates] | |
| tokens = self.tokenizer( | |
| txts, padding='max_length', truncation=True, max_length=77, return_tensors='pt' | |
| ) | |
| tokens = {key:(val.cuda() if next(self.parameters()).is_cuda else val) for key,val in tokens.items()} | |
| clss_embedding = self.encode_text(tokens) | |
| clss_embedding = clss_embedding.mean(dim=0) | |
| clss_embedding /= clss_embedding.norm() | |
| clss_embeddings.append(clss_embedding) | |
| imnet_text_embeddings = torch.stack(clss_embeddings, dim=0) | |
| return imnet_text_embeddings | |
| def get_text_embeddings(self, texts): | |
| templates = DEFAULT_TEMPLATES[:1] | |
| clss_embeddings = [] | |
| for clss in texts: | |
| txts = [template.format(clss) for template in templates] | |
| tokens = self.tokenizer( | |
| txts, padding='max_length', truncation=True, max_length=77, return_tensors='pt' | |
| ) | |
| tokens = {key:(val.cuda() if next(self.parameters()).is_cuda else val) for key,val in tokens.items()} | |
| clss_embedding = self.encode_text(tokens) | |
| clss_embedding = clss_embedding.mean(dim=0) | |
| clss_embedding /= clss_embedding.norm() | |
| clss_embeddings.append(clss_embedding) | |
| imnet_text_embeddings = torch.stack(clss_embeddings, dim=0) | |
| return imnet_text_embeddings | |
| def encode_image(self, image, norm=True): | |
| x = self.image_encoder.forward_features(image) | |
| x = x @ self.image_projection | |
| if norm: | |
| x = x / x.norm(dim=-1, keepdim=True) | |
| return x | |
| def encode_text(self, text, norm=True): | |
| x = self.text_encoder(**text) | |
| x = x['last_hidden_state'] | |
| if self.conf_lang_encoder['TOKENIZER'] == 'clip': | |
| x = x[torch.arange(x.size(0)), text['input_ids'].argmax(dim=-1)] | |
| else: | |
| x = x[:, 0] | |
| x = x @ self.text_projection | |
| if norm: | |
| x = x / x.norm(dim=-1, keepdim=True) | |
| return x | |
| def forward(self, image, text): | |
| features_image = self.encode_image(image) | |
| features_text = self.encode_text(text) | |
| # cosine similarity as logits | |
| T = self.logit_scale.exp() | |
| return features_image, features_text, T | |
| def build_unicl_model(config, **kwargs): | |
| model = UniCLModel(config) | |
| if config['MODEL']['PRETRAINED'] != '': | |
| pretrained_path = config['MODEL']['PRETRAINED'] | |
| from ..Utils.Utils import is_valid_url, download_file | |
| if is_valid_url(pretrained_path): | |
| with tempfile.TemporaryDirectory() as tmp_path: | |
| file_local_path = pathlib.Path(tmp_path) / 'base_model.pt' | |
| download_file(pretrained_path, file_local_path) | |
| model.from_pretrained(str(file_local_path), config['MODEL']['PRETRAINED_LAYERS'], config['VERBOSE']) | |
| else: | |
| model.from_pretrained(pretrained_path, config['MODEL']['PRETRAINED_LAYERS'], config['VERBOSE']) | |
| return model | |