| import torch | |
| from transformers import VideoMAEImageProcessor, VideoMAEForVideoClassification | |
| import torch.nn as nn | |
| import torchvision | |
| import time | |
| from .mamba_base import MambaConfig, ResidualBlock | |
| def create_reorder_index(N, device): | |
| new_order = [] | |
| for col in range(N): | |
| if col % 2 == 0: | |
| new_order.extend(range(col, N*N, N)) | |
| else: | |
| new_order.extend(range(col + N*(N-1), col-1, -N)) | |
| return torch.tensor(new_order, device=device) | |
| def reorder_data(data, N): | |
| assert isinstance(data, torch.Tensor), "data should be a torch.Tensor" | |
| device = data.device | |
| new_order = create_reorder_index(N, device) | |
| B, t, _, _ = data.shape | |
| index = new_order.repeat(B, t, 1).unsqueeze(-1) | |
| reordered_data = torch.gather(data, 2, index.expand_as(data)) | |
| return reordered_data | |
| class Videomae_Net(nn.Module): | |
| def __init__( | |
| self, channel_size=512, dropout=0.2, class_num=1 | |
| ): | |
| super(Videomae_Net, self).__init__() | |
| self.model = VideoMAEForVideoClassification.from_pretrained("/ossfs/workspace/GenVideo/pretrained_weights/videomae") | |
| self.fc1 = nn.Linear(768, class_num) | |
| self.bn1 = nn.BatchNorm1d(768) | |
| self._init_params() | |
| def _init_params(self): | |
| nn.init.xavier_normal_(self.fc1.weight) | |
| nn.init.constant_(self.fc1.bias, 0) | |
| def forward(self, x): | |
| x = self.model.videomae(x) | |
| sequence_output = x[0] | |
| print(sequence_output.shape) | |
| if self.model.fc_norm is not None: | |
| sequence_output = self.model.fc_norm(sequence_output.mean(1)) | |
| else: | |
| sequence_output = sequence_output[:, 0] | |
| x = self.bn1(sequence_output) | |
| x = self.fc1(x) | |
| return x | |
| if __name__ == '__main__': | |
| model = Videomae_Net() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = model.to(device) | |
| model.eval() | |
| input_data = torch.randn(1, 16, 3, 224, 224).to(device) | |
| model(input_data) | |