Kalpit
feat: Add model files with LFS
d39b279
import torch
from transformers import VideoMAEImageProcessor, VideoMAEForVideoClassification
import torch.nn as nn
import torchvision
import time
from .mamba_base import MambaConfig, ResidualBlock
def create_reorder_index(N, device):
new_order = []
for col in range(N):
if col % 2 == 0:
new_order.extend(range(col, N*N, N))
else:
new_order.extend(range(col + N*(N-1), col-1, -N))
return torch.tensor(new_order, device=device)
def reorder_data(data, N):
assert isinstance(data, torch.Tensor), "data should be a torch.Tensor"
device = data.device
new_order = create_reorder_index(N, device)
B, t, _, _ = data.shape
index = new_order.repeat(B, t, 1).unsqueeze(-1)
reordered_data = torch.gather(data, 2, index.expand_as(data))
return reordered_data
class Videomae_Net(nn.Module):
def __init__(
self, channel_size=512, dropout=0.2, class_num=1
):
super(Videomae_Net, self).__init__()
self.model = VideoMAEForVideoClassification.from_pretrained("/ossfs/workspace/GenVideo/pretrained_weights/videomae")
self.fc1 = nn.Linear(768, class_num)
self.bn1 = nn.BatchNorm1d(768)
self._init_params()
def _init_params(self):
nn.init.xavier_normal_(self.fc1.weight)
nn.init.constant_(self.fc1.bias, 0)
def forward(self, x):
x = self.model.videomae(x)
sequence_output = x[0]
print(sequence_output.shape)
if self.model.fc_norm is not None:
sequence_output = self.model.fc_norm(sequence_output.mean(1))
else:
sequence_output = sequence_output[:, 0]
x = self.bn1(sequence_output)
x = self.fc1(x)
return x
if __name__ == '__main__':
model = Videomae_Net()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
input_data = torch.randn(1, 16, 3, 224, 224).to(device)
model(input_data)