from transformers import XCLIPVisionModel import os import sys import numpy as np import torch import torchvision import torch.nn as nn import torch.nn.functional as F from .mamba_base import MambaConfig, ResidualBlock import torch.nn.init as init from .clip import clip import math from transformers import XCLIPVisionConfig, XCLIPVisionModel def create_reorder_index(N, device): new_order = [] for col in range(N): if col % 2 == 0: new_order.extend(range(col, N*N, N)) else: new_order.extend(range(col + N*(N-1), col-1, -N)) return torch.tensor(new_order, device=device) def reorder_data(data, N): assert isinstance(data, torch.Tensor), "data should be a torch.Tensor" device = data.device new_order = create_reorder_index(N, device) B, t, _, _ = data.shape index = new_order.repeat(B, t, 1).unsqueeze(-1) reordered_data = torch.gather(data, 2, index.expand_as(data)) return reordered_data class XCLIP_DeMamba(nn.Module): def __init__( self, channel_size=768, class_num=1 ): super(XCLIP_DeMamba, self).__init__() # self.encoder = XCLIPVisionModel.from_pretrained("GenVideo/pretrained_weights/xclip") # my code for training from scratch config = XCLIPVisionConfig() self.encoder = XCLIPVisionModel(config) blocks = [] channel = 768 self.fusing_ratios = 1 self.patch_nums = (14//self.fusing_ratios)**2 self.mamba_configs = MambaConfig(d_model=channel) self.mamba = ResidualBlock(config = self.mamba_configs) # self.fc1 = nn.Linear((self.patch_nums+1)*channel, class_num) self.fc1 = nn.Linear(38400, class_num) # my code # self.fc_norm = nn.LayerNorm(self.patch_nums*channel) self.fc_norm = None # my code self.fc_norm2 = nn.LayerNorm(768) self.initialize_weights(self.fc1) self.dropout = nn.Dropout(p=0.0) def initialize_weights(self, module): for m in module.modules(): if isinstance(m, nn.Linear): init.xavier_uniform_(m.weight) if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.Conv2d): init.kaiming_uniform_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias, 0) def forward(self, x): b, t, _, h, w = x.shape images = x.view(b * t, 3, h, w) outputs = self.encoder(images, output_hidden_states=True) sequence_output = outputs['last_hidden_state'][:,1:,:] _, _, c = sequence_output.shape global_feat = outputs['pooler_output'].reshape(b, t, -1) global_feat = global_feat.mean(1) global_feat = self.fc_norm2(global_feat) sequence_output = sequence_output.view(b, t, -1, c) _, _, f_w, _ = sequence_output.shape f_h, f_w = int(math.sqrt(f_w)), int(math.sqrt(f_w)) s = f_h//self.fusing_ratios sequence_output = sequence_output.view(b, t, self.fusing_ratios, s, self.fusing_ratios, s, c) x = sequence_output.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(b*s*s, t, -1, c) b_l = b*s*s x = reorder_data(x, self.fusing_ratios) x = x.permute(0, 2, 1, 3).contiguous().view(b_l, -1, c) res = self.mamba(x) video_level_features = res.mean(1) video_level_features = video_level_features.view(b, -1) # my code if self.fc_norm is None: self.fc_norm = nn.LayerNorm(video_level_features.size(-1)).to(video_level_features.device) video_level_features = self.fc_norm(video_level_features) video_level_features = torch.cat((global_feat, video_level_features), dim=1) pred = self.fc1(video_level_features) pred = self.dropout(pred) return pred class CLIP_DeMamba(nn.Module): def __init__( self, channel_size=512, class_num=1 ): super(CLIP_DeMamba, self).__init__() self.clip_model, preprocess = clip.load('ViT-B-14') self.clip_model = self.clip_model.float() blocks = [] channel = 512 self.fusing_ratios = 2 self.patch_nums = (14//self.fusing_ratios)**2 self.mamba_configs = MambaConfig(d_model=channel) self.mamba = ResidualBlock(config = self.mamba_configs) self.fc1 = nn.Linear(channel*(self.patch_nums+1), class_num) self.bn1 = nn.BatchNorm1d(channel) self.initialize_weights(self.fc1) def initialize_weights(self, module): for m in module.modules(): if isinstance(m, nn.Linear): init.xavier_uniform_(m.weight) if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.Conv2d): init.kaiming_uniform_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias, 0) def forward(self, x): b, t, _, h, w = x.shape images = x.view(b * t, 3, h, w) sequence_output = self.clip_model.encode_image(images) _, _, c = sequence_output.shape sequence_output = sequence_output.view(b, t, -1, c) global_feat = sequence_output.reshape(b, -1, c) global_feat = global_feat.mean(1) _, _, f_w, _ = sequence_output.shape f_h, f_w = int(math.sqrt(f_w)), int(math.sqrt(f_w)) s = f_h//self.fusing_ratios sequence_output = sequence_output.view(b, t, self.fusing_ratios, s, self.fusing_ratios, s, c) x = sequence_output.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(b*s*s, t, -1, c) b_l = b*s*s x = reorder_data(x, self.fusing_ratios) x = x.permute(0, 2, 1, 3).contiguous().view(b_l, -1, c) res = self.mamba(x) video_level_features = res.mean(1) video_level_features = video_level_features.view(b, -1) video_level_features = torch.cat((global_feat, video_level_features), dim=1) x = self.fc1(video_level_features) return x if __name__ == '__main__': model = CLIP_DeMamba() print(model)