| from transformers import XCLIPVisionModel | |
| import os | |
| import sys | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.nn.init as init | |
| import math | |
| from transformers import XCLIPVisionModel | |
| class XCLIP(nn.Module): | |
| def __init__( | |
| self, channel_size=512, dropout=0.2, class_num=1 | |
| ): | |
| super(XCLIP, self).__init__() | |
| self.backbone = XCLIPVisionModel.from_pretrained("GenVideo/pretrained_weights/xclip") | |
| self.fc_norm = nn.LayerNorm(768) | |
| self.head = nn.Linear(768, 1) | |
| def forward(self, x): | |
| b, t, _, h, w = x.shape | |
| images = x.view(b * t, 3, h, w) | |
| outputs = self.backbone(images, output_hidden_states=True) | |
| sequence_output = outputs['pooler_output'].reshape(b, t, -1) | |
| video_level_features = self.fc_norm(sequence_output.mean(1)) | |
| pred = self.head(video_level_features) | |
| return pred | |