Spaces:
Runtime error
Runtime error
| ### demo.py | |
| # Define model classes for inference. | |
| ### | |
| import json | |
| import torch | |
| import torch.nn as nn | |
| import torch.backends.cudnn as cudnn | |
| from einops import rearrange | |
| from transformers import BertTokenizer | |
| from torchvision import transforms | |
| from torchvision.transforms._transforms_video import ( | |
| NormalizeVideo, | |
| ) | |
| from svitt.model import SViTT | |
| from svitt.config import load_cfg, setup_config | |
| from svitt.base_dataset import read_frames_cv2_egoclip | |
| class VideoModel(nn.Module): | |
| """ Base model for video understanding based on SViTT architecture. """ | |
| def __init__(self, config): | |
| """ Initializes the model. | |
| Parameters: | |
| config: config file | |
| """ | |
| super().__init__() | |
| self.cfg = load_cfg(config) | |
| self.model = self.build_model() | |
| use_gpu = torch.cuda.is_available() | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if use_gpu: | |
| self.model = self.model.to(self.device) | |
| self.templates = ['{}'] | |
| self.dataset = self.cfg['data']['dataset'] | |
| self.eval() | |
| def build_model(self): | |
| cfg = self.cfg | |
| if cfg['model'].get('pretrain', False): | |
| ckpt_path = cfg['model']['pretrain'] | |
| else: | |
| raise Exception('no checkpoint found') | |
| if cfg['model'].get('config', False): | |
| config_path = cfg['model']['config'] | |
| else: | |
| raise Exception('no model config found') | |
| self.model_cfg = setup_config(config_path) | |
| self.tokenizer = BertTokenizer.from_pretrained(self.model_cfg.text_encoder) | |
| model = SViTT(config=self.model_cfg, tokenizer=self.tokenizer) | |
| print(f"Loading checkpoint from {ckpt_path}") | |
| checkpoint = torch.load(ckpt_path, map_location="cpu") | |
| state_dict = checkpoint["model"] | |
| # fix for zero-shot evaluation | |
| for key in list(state_dict.keys()): | |
| if "bert" in key: | |
| encoder_key = key.replace("bert.", "") | |
| state_dict[encoder_key] = state_dict[key] | |
| if torch.cuda.is_available(): | |
| model.cuda() | |
| model.load_state_dict(state_dict, strict=False) | |
| return model | |
| def eval(self): | |
| cudnn.benchmark = True | |
| for p in self.model.parameters(): | |
| p.requires_grad = False | |
| self.model.eval() | |
| class VideoCLSModel(VideoModel): | |
| """ Video model for video classification tasks (Charades-Ego, EGTEA). """ | |
| def __init__(self, config, sample_videos): | |
| super().__init__(config) | |
| self.sample_videos = sample_videos | |
| self.video_transform = self.init_video_transform() | |
| #def load_data(self, idx=None): | |
| # filename = f"{self.cfg['data']['root']}/{idx}/tensors.pt" | |
| # return torch.load(filename) | |
| def init_video_transform(self, | |
| input_res=224, | |
| center_crop=256, | |
| norm_mean=(0.485, 0.456, 0.406), | |
| norm_std=(0.229, 0.224, 0.225), | |
| ): | |
| print('Video Transform is used!') | |
| normalize = NormalizeVideo(mean=norm_mean, std=norm_std) | |
| return transforms.Compose( | |
| [ | |
| transforms.Resize(center_crop), | |
| transforms.CenterCrop(center_crop), | |
| transforms.Resize(input_res), | |
| normalize, | |
| ] | |
| ) | |
| def load_data(self, idx): | |
| num_frames = self.model_cfg.video_input.num_frames | |
| video_paths = self.sample_videos[idx] | |
| clips = [None] * len(video_paths) | |
| for i, path in enumerate(video_paths): | |
| imgs = read_frames_cv2_egoclip(path, num_frames, 'uniform') | |
| imgs = imgs.transpose(0, 1) | |
| imgs = self.video_transform(imgs) | |
| imgs = imgs.transpose(0, 1) | |
| clips[i] = imgs | |
| return torch.stack(clips) | |
| def load_meta(self, idx=None): | |
| filename = f"{self.cfg['data']['root']}/{idx}/meta.json" | |
| with open(filename, "r") as f: | |
| meta = json.load(f) | |
| return meta | |
| def get_text_features(self, text): | |
| print('=> Extracting text features') | |
| embeddings = self.tokenizer( | |
| text, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=self.model_cfg.max_txt_l.video, | |
| return_tensors="pt", | |
| ).to(self.device) | |
| _, class_embeddings = self.model.encode_text(embeddings) | |
| return class_embeddings | |
| def forward(self, idx, text=None): | |
| print('=> Start forwarding') | |
| meta = self.load_meta(idx) | |
| clips = self.load_data(idx) | |
| if text is None: | |
| text = meta["text"][4:] | |
| text_features = self.get_text_features(text) | |
| target = meta["correct"] | |
| # encode images | |
| pooled_image_feat_all = [] | |
| for i in range(clips.shape[0]): | |
| images = clips[i,:].unsqueeze(0).to(self.device) | |
| bsz = images.shape[0] | |
| _, pooled_image_feat, *outputs = self.model.encode_image(images) | |
| if pooled_image_feat.ndim == 3: | |
| pooled_image_feat = rearrange(pooled_image_feat, '(b k) n d -> b (k n) d', b=bsz) | |
| else: | |
| pooled_image_feat = rearrange(pooled_image_feat, '(b k) d -> b k d', b=bsz) | |
| pooled_image_feat_all.append(pooled_image_feat) | |
| pooled_image_feat_all = torch.cat(pooled_image_feat_all, dim=0) | |
| similarity = self.model.get_sim(pooled_image_feat_all, text_features)[0] | |
| return similarity.argmax(), target | |
| def predict(self, idx, text=None): | |
| output, target = self.forward(idx, text) | |
| return output.cpu().numpy(), target | |