|
|
from transformers import XCLIPVisionModel |
|
|
import os |
|
|
import sys |
|
|
import numpy as np |
|
|
import torch |
|
|
import torchvision |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from .mamba_base import MambaConfig, ResidualBlock |
|
|
import torch.nn.init as init |
|
|
from .clip import clip |
|
|
import math |
|
|
from transformers import XCLIPVisionConfig, XCLIPVisionModel |
|
|
|
|
|
def create_reorder_index(N, device): |
|
|
new_order = [] |
|
|
for col in range(N): |
|
|
if col % 2 == 0: |
|
|
new_order.extend(range(col, N*N, N)) |
|
|
else: |
|
|
new_order.extend(range(col + N*(N-1), col-1, -N)) |
|
|
return torch.tensor(new_order, device=device) |
|
|
|
|
|
def reorder_data(data, N): |
|
|
assert isinstance(data, torch.Tensor), "data should be a torch.Tensor" |
|
|
device = data.device |
|
|
new_order = create_reorder_index(N, device) |
|
|
B, t, _, _ = data.shape |
|
|
index = new_order.repeat(B, t, 1).unsqueeze(-1) |
|
|
reordered_data = torch.gather(data, 2, index.expand_as(data)) |
|
|
return reordered_data |
|
|
|
|
|
class XCLIP_DeMamba(nn.Module): |
|
|
def __init__( |
|
|
self, channel_size=768, class_num=1 |
|
|
): |
|
|
super(XCLIP_DeMamba, self).__init__() |
|
|
|
|
|
|
|
|
config = XCLIPVisionConfig() |
|
|
self.encoder = XCLIPVisionModel(config) |
|
|
|
|
|
blocks = [] |
|
|
channel = 768 |
|
|
self.fusing_ratios = 1 |
|
|
self.patch_nums = (14//self.fusing_ratios)**2 |
|
|
self.mamba_configs = MambaConfig(d_model=channel) |
|
|
self.mamba = ResidualBlock(config = self.mamba_configs) |
|
|
|
|
|
self.fc1 = nn.Linear(38400, class_num) |
|
|
|
|
|
self.fc_norm = None |
|
|
self.fc_norm2 = nn.LayerNorm(768) |
|
|
self.initialize_weights(self.fc1) |
|
|
self.dropout = nn.Dropout(p=0.0) |
|
|
|
|
|
def initialize_weights(self, module): |
|
|
for m in module.modules(): |
|
|
if isinstance(m, nn.Linear): |
|
|
init.xavier_uniform_(m.weight) |
|
|
if m.bias is not None: |
|
|
init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.Conv2d): |
|
|
init.kaiming_uniform_(m.weight, mode='fan_out', nonlinearity='relu') |
|
|
if m.bias is not None: |
|
|
init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.BatchNorm2d): |
|
|
init.constant_(m.weight, 1) |
|
|
init.constant_(m.bias, 0) |
|
|
|
|
|
def forward(self, x): |
|
|
b, t, _, h, w = x.shape |
|
|
images = x.view(b * t, 3, h, w) |
|
|
outputs = self.encoder(images, output_hidden_states=True) |
|
|
sequence_output = outputs['last_hidden_state'][:,1:,:] |
|
|
_, _, c = sequence_output.shape |
|
|
|
|
|
global_feat = outputs['pooler_output'].reshape(b, t, -1) |
|
|
global_feat = global_feat.mean(1) |
|
|
global_feat = self.fc_norm2(global_feat) |
|
|
|
|
|
sequence_output = sequence_output.view(b, t, -1, c) |
|
|
_, _, f_w, _ = sequence_output.shape |
|
|
f_h, f_w = int(math.sqrt(f_w)), int(math.sqrt(f_w)) |
|
|
|
|
|
s = f_h//self.fusing_ratios |
|
|
sequence_output = sequence_output.view(b, t, self.fusing_ratios, s, self.fusing_ratios, s, c) |
|
|
x = sequence_output.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(b*s*s, t, -1, c) |
|
|
b_l = b*s*s |
|
|
|
|
|
x = reorder_data(x, self.fusing_ratios) |
|
|
x = x.permute(0, 2, 1, 3).contiguous().view(b_l, -1, c) |
|
|
res = self.mamba(x) |
|
|
|
|
|
video_level_features = res.mean(1) |
|
|
video_level_features = video_level_features.view(b, -1) |
|
|
|
|
|
|
|
|
if self.fc_norm is None: |
|
|
self.fc_norm = nn.LayerNorm(video_level_features.size(-1)).to(video_level_features.device) |
|
|
|
|
|
video_level_features = self.fc_norm(video_level_features) |
|
|
video_level_features = torch.cat((global_feat, video_level_features), dim=1) |
|
|
|
|
|
pred = self.fc1(video_level_features) |
|
|
pred = self.dropout(pred) |
|
|
|
|
|
return pred |
|
|
|
|
|
|
|
|
|
|
|
class CLIP_DeMamba(nn.Module): |
|
|
def __init__( |
|
|
self, channel_size=512, class_num=1 |
|
|
): |
|
|
super(CLIP_DeMamba, self).__init__() |
|
|
self.clip_model, preprocess = clip.load('ViT-B-14') |
|
|
self.clip_model = self.clip_model.float() |
|
|
blocks = [] |
|
|
channel = 512 |
|
|
self.fusing_ratios = 2 |
|
|
self.patch_nums = (14//self.fusing_ratios)**2 |
|
|
self.mamba_configs = MambaConfig(d_model=channel) |
|
|
self.mamba = ResidualBlock(config = self.mamba_configs) |
|
|
|
|
|
self.fc1 = nn.Linear(channel*(self.patch_nums+1), class_num) |
|
|
|
|
|
self.bn1 = nn.BatchNorm1d(channel) |
|
|
self.initialize_weights(self.fc1) |
|
|
|
|
|
def initialize_weights(self, module): |
|
|
for m in module.modules(): |
|
|
if isinstance(m, nn.Linear): |
|
|
init.xavier_uniform_(m.weight) |
|
|
if m.bias is not None: |
|
|
init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.Conv2d): |
|
|
init.kaiming_uniform_(m.weight, mode='fan_out', nonlinearity='relu') |
|
|
if m.bias is not None: |
|
|
init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.BatchNorm2d): |
|
|
init.constant_(m.weight, 1) |
|
|
init.constant_(m.bias, 0) |
|
|
|
|
|
def forward(self, x): |
|
|
b, t, _, h, w = x.shape |
|
|
images = x.view(b * t, 3, h, w) |
|
|
sequence_output = self.clip_model.encode_image(images) |
|
|
_, _, c = sequence_output.shape |
|
|
sequence_output = sequence_output.view(b, t, -1, c) |
|
|
|
|
|
global_feat = sequence_output.reshape(b, -1, c) |
|
|
global_feat = global_feat.mean(1) |
|
|
|
|
|
_, _, f_w, _ = sequence_output.shape |
|
|
f_h, f_w = int(math.sqrt(f_w)), int(math.sqrt(f_w)) |
|
|
|
|
|
s = f_h//self.fusing_ratios |
|
|
sequence_output = sequence_output.view(b, t, self.fusing_ratios, s, self.fusing_ratios, s, c) |
|
|
x = sequence_output.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(b*s*s, t, -1, c) |
|
|
b_l = b*s*s |
|
|
|
|
|
x = reorder_data(x, self.fusing_ratios) |
|
|
x = x.permute(0, 2, 1, 3).contiguous().view(b_l, -1, c) |
|
|
res = self.mamba(x) |
|
|
video_level_features = res.mean(1) |
|
|
video_level_features = video_level_features.view(b, -1) |
|
|
|
|
|
video_level_features = torch.cat((global_feat, video_level_features), dim=1) |
|
|
x = self.fc1(video_level_features) |
|
|
|
|
|
return x |
|
|
|
|
|
if __name__ == '__main__': |
|
|
model = CLIP_DeMamba() |
|
|
print(model) |
|
|
|