Spaces:
Paused
Paused
| ### demo.py | |
| # Define model classes for inference. | |
| ### | |
| from collections import OrderedDict | |
| import json | |
| import numpy as np | |
| import os | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| import torch.backends.cudnn as cudnn | |
| import torchvision.transforms as transforms | |
| import torchvision.transforms._transforms_video as transforms_video | |
| from sklearn.metrics import confusion_matrix | |
| from lavila.data import datasets | |
| from lavila.data.video_transforms import Permute, SpatialCrop, TemporalCrop | |
| from lavila.models import models | |
| from lavila.models.tokenizer import (MyBertTokenizer, MyDistilBertTokenizer, MyGPT2Tokenizer, SimpleTokenizer) | |
| from lavila.models.utils import inflate_positional_embeds | |
| from lavila.utils.config import load_cfg | |
| from lavila.utils.evaluation_charades import charades_map | |
| from lavila.utils.evaluation import get_mean_accuracy | |
| class VideoModel(nn.Module): | |
| """ Base model for video understanding based on LaViLa architecture. """ | |
| def __init__(self, config): | |
| """ Initializes the model. | |
| Parameters: | |
| config: config file | |
| """ | |
| super(VideoModel, self).__init__() | |
| self.cfg = load_cfg(config) | |
| self.model = self.build_model() | |
| self.tokenizer = self.get_tokenizer() | |
| self.templates = ['{}'] | |
| self.dataset = self.cfg['data']['dataset'] | |
| self.eval() | |
| def build_model(self): | |
| cfg = self.cfg | |
| if cfg['model'].get('pretrain', False): | |
| ckpt_path = cfg['model']['pretrain'] | |
| else: | |
| raise Exception('no checkpoint found') | |
| ckpt = torch.load(ckpt_path, map_location='cpu') | |
| state_dict = OrderedDict() | |
| for k, v in ckpt['state_dict'].items(): | |
| state_dict[k.replace('module.', '')] = v | |
| old_args = vars(ckpt['args']) | |
| arch = old_args.get('model', 'CLIP_OPENAI_TIMESFORMER_BASE') | |
| self.arch = arch | |
| cfg['model']['arch'] = arch | |
| cfg['model']['norm_embed'] = old_args.get('norm_embed', True) | |
| print("=> creating model: {}".format(arch)) | |
| model = getattr(models, arch)( | |
| pretrained=old_args.get('load_visual_pretrained', None), | |
| pretrained2d=old_args.get('load_visual_pretrained', None) is not None, | |
| text_use_cls_token=old_args.get('use_cls_token', False), | |
| project_embed_dim=old_args.get('project_embed_dim', 256), | |
| timesformer_gated_xattn=False, | |
| num_frames=cfg['model'].get('num_frames', cfg['data']['clip_length']), | |
| model_cfg=cfg['model'] | |
| ) | |
| model.logit_scale.requires_grad = False | |
| if torch.cuda.is_available(): | |
| model.cuda() | |
| if ('TIMESFORMER' in arch or 'EGOVLP' in arch) and cfg['model'].get('inflat_posemb', True): | |
| # inflate weight | |
| print('=> inflating PE in models due to different frame numbers') | |
| state_dict = inflate_positional_embeds( | |
| model.state_dict(), state_dict, | |
| num_frames=cfg['model'].get('num_frames', cfg['data']['clip_length']), | |
| load_temporal_fix='bilinear', | |
| ) | |
| model.load_state_dict(state_dict, strict=True) | |
| print("=> loaded resume checkpoint '{}' (epoch {})".format(ckpt_path, ckpt['epoch'])) | |
| return model | |
| def eval(self): | |
| cudnn.benchmark = True | |
| for p in self.model.parameters(): | |
| p.requires_grad = False | |
| self.model.eval() | |
| def get_tokenizer(self): | |
| arch = self.arch | |
| if arch.endswith('DISTILBERT_BASE'): | |
| tokenizer = MyDistilBertTokenizer('distilbert-base-uncased') | |
| elif arch.endswith('BERT_BASE'): | |
| tokenizer = MyBertTokenizer('bert-base-uncased') | |
| elif arch.endswith('BERT_LARGE'): | |
| tokenizer = MyBertTokenizer('bert-large-uncased') | |
| elif arch.endswith('GPT2'): | |
| tokenizer = MyGPT2Tokenizer('gpt2') | |
| elif arch.endswith('GPT2_MEDIUM'): | |
| tokenizer = MyGPT2Tokenizer('gpt2-medium') | |
| elif arch.endswith('GPT2_LARGE'): | |
| tokenizer = MyGPT2Tokenizer('gpt2-large') | |
| elif arch.endswith('GPT2_XL'): | |
| tokenizer = MyGPT2Tokenizer('gpt2-xl') | |
| else: | |
| print("Using SimpleTokenizer because of model '{}'. " | |
| "Please check if this is what you want".format(arch)) | |
| tokenizer = SimpleTokenizer() | |
| return tokenizer | |
| class VideoCLSModel(VideoModel): | |
| """ Video model for video classification tasks (Charades-Ego, EGTEA). """ | |
| def __init__(self, config): | |
| super(VideoCLSModel, self).__init__(config) | |
| self.labels, self.mapping_vn2act = self.gen_label_map() | |
| self.text_features = self.get_text_features() | |
| def gen_label_map(self): | |
| labelmap = self.cfg.get('label_map', 'meta/charades_ego/label_map.json') | |
| if os.path.isfile(labelmap): | |
| print(f"=> Loading label maps from {labelmap}") | |
| meta = json.load(open(labelmap, 'r')) | |
| labels, mapping_vn2act = meta['labels'], meta['mapping_vn2act'] | |
| else: | |
| from lavila.utils.preprocess import generate_label_map | |
| labels, mapping_vn2act = generate_label_map(self.dataset) | |
| meta = {'labels': labels, 'mapping_vn2act': mapping_vn2act} | |
| meta_dir = f'meta/{self.dataset}' | |
| if not os.path.exists(meta_dir): | |
| os.makedirs(meta_dir) | |
| json.dump(meta, open(f'{meta_dir}/label_map.json', 'w')) | |
| print(f"=> Label map is generated and saved to {meta_dir}/label_map.json") | |
| return labels, mapping_vn2act | |
| def load_data(self, idx=None): | |
| print(f"=> Creating dataset") | |
| cfg, dataset = self.cfg, self.dataset | |
| data_cfg = cfg['data'] | |
| crop_size = 224 if '336PX' not in self.arch else 336 | |
| val_transform = transforms.Compose([ | |
| Permute([3, 0, 1, 2]), # T H W C -> C T H W | |
| transforms.Resize(crop_size), | |
| transforms.CenterCrop(crop_size), | |
| transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305]), | |
| ]) | |
| if idx is None: | |
| metadata_val = data_cfg['metadata_val'] | |
| else: | |
| metadata_val = data_cfg['metadata_val'].format(idx) | |
| if dataset in ['charades_ego', 'egtea']: | |
| val_dataset = datasets.VideoClassyDataset( | |
| dataset, data_cfg['root'], metadata_val, | |
| transform=val_transform, is_training=False, | |
| label_mapping=self.mapping_vn2act, is_trimmed=False, | |
| num_clips=1, clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'], | |
| sparse_sample=data_cfg['sparse_sample'] | |
| ) | |
| else: | |
| raise NotImplementedError | |
| val_loader = torch.utils.data.DataLoader( | |
| val_dataset, batch_size=8, shuffle=False, | |
| num_workers=4, pin_memory=True, sampler=None, drop_last=False | |
| ) | |
| return val_loader | |
| def get_text_features(self): | |
| print('=> Extracting text features') | |
| text_features = [] | |
| for label in self.labels: | |
| if isinstance(label, list): | |
| texts = [tmpl.format(lbl) for tmpl in self.templates for lbl in label] | |
| else: | |
| texts = [tmpl.format(label) for tmpl in self.templates] | |
| texts = self.tokenizer(texts) | |
| if isinstance(texts, tuple): | |
| # Bert-style tokenizer will output both ids and mask | |
| texts, masks = texts | |
| texts = texts.cuda(non_blocking=True) | |
| masks = masks.cuda(non_blocking=True) | |
| else: | |
| texts = texts.cuda(non_blocking=True) | |
| masks = None | |
| texts = texts.view(-1, 77).contiguous() | |
| masks = masks.view(-1, 77).contiguous() if masks is not None else None | |
| if masks is not None: | |
| class_embeddings, _ = self.model.encode_text(texts, attention_mask=masks) | |
| else: | |
| class_embeddings, _ = self.model.encode_text(texts) | |
| class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) | |
| class_embeddings = class_embeddings.mean(dim=0) | |
| class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) | |
| text_features.append(class_embeddings) | |
| text_features = torch.stack(text_features, dim=0) | |
| return text_features | |
| def forward(self, idx=None): | |
| print('=> Start forwarding') | |
| val_loader = self.load_data(idx) | |
| all_outputs = [] | |
| all_targets = [] | |
| for i, values in enumerate(val_loader): | |
| images = values[0] | |
| target = values[1] | |
| images = images.cuda(non_blocking=True) | |
| target = target.cuda(non_blocking=True) | |
| # encode images | |
| image_features, _ = self.model.encode_image(images) | |
| image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
| # cosine similarity as logits | |
| logits_per_image = image_features @ self.text_features.t() | |
| logits_per_image = torch.softmax(logits_per_image, dim=1) | |
| all_outputs.append(logits_per_image.cpu()) | |
| all_targets.append(target.cpu()) | |
| all_outputs = torch.cat(all_outputs) | |
| all_targets = torch.cat(all_targets) | |
| return all_outputs, all_targets | |
| def predict(self, idx=0): | |
| all_outputs, all_targets = self.forward(idx) | |
| preds, targets = all_outputs.numpy(), all_targets.numpy() | |
| sel = np.where(np.cumsum(sorted(preds[0].tolist(), reverse=True)) > 0.055)[0][0] | |
| #sel = 5 | |
| df = pd.DataFrame(self.labels) | |
| pred_action = df.iloc[preds[0].argsort()[-sel:]].values.tolist() | |
| gt_action = df.iloc[np.where(targets[0])[0]].values.tolist() | |
| pred_action = sorted([x[0] for x in pred_action]) | |
| gt_action = sorted([x[0] for x in gt_action]) | |
| return pred_action, gt_action | |
| def evaluate(self): | |
| all_outputs, all_targets = self.forward() | |
| preds, targets = all_outputs.numpy(), all_targets.numpy() | |
| if self.dataset == 'charades_ego': | |
| m_ap, _, m_aps = charades_map(preds, targets) | |
| print('mAP = {:.3f}'.format(m_ap)) | |
| elif self.dataset == 'egtea': | |
| cm = confusion_matrix(targets, preds.argmax(axis=1)) | |
| mean_class_acc, acc = get_mean_accuracy(cm) | |
| print('Mean Acc. = {:.3f}, Top-1 Acc. = {:.3f}'.format(mean_class_acc, acc)) | |
| else: | |
| raise NotImplementedError | |
| def main(): | |
| lavila = VideoCLSModel("configs/charades_ego/zeroshot.yml") | |
| egovpa = VideoCLSModel("configs/charades_ego/egovpa.yml") | |
| lavila.evaluate() | |
| egovpa.evaluate() | |
| if __name__ == '__main__': | |
| main() | |