""" F3Net: Fusion, Feedback and Focus for Salient Object Detection @ AAAI'2020 Copyright (c) University of Chinese Academy of Sciences and its affiliates. Modified by Jun Wei from https://github.com/weijun88/F3Net """ import os import sys import numpy as np import torch import torchvision import torch.nn as nn import torch.nn.functional as F import torch.utils.model_zoo as model_zoo pretrained_settings = { 'xception': { 'imagenet': { 'dir': '/ossfs/workspace/aigc_video/weights/xception-b5690688.pth', 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth', 'input_space': 'RGB', 'input_size': [3, 299, 299], 'input_range': [0, 1], 'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5], 'num_classes': 1000, 'scale': 0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 } } } class F3Net(nn.Module): """ Implementation is mainly referenced from https://github.com/yyk-wew/F3Net """ def __init__(self, num_classes: int=2, img_width: int=299, img_height: int=299, LFS_window_size: int=10, LFS_M: int=6) -> None: super(F3Net, self).__init__() assert img_width == img_height self.img_size = img_width self.num_classes = num_classes self._LFS_window_size = LFS_window_size self._LFS_M = LFS_M self.fad_head = FAD_Head(self.img_size) self.lfs_head = LFS_Head(self.img_size, self._LFS_window_size, self._LFS_M) self.fad_excep = self._init_xcep_fad() self.lfs_excep = self._init_xcep_lfs() self.mix_block7 = MixBlock(c_in=728, width=19, height=19) self.mix_block12 = MixBlock(c_in=1024, width=10, height=10) self.excep_forwards = ['conv1', 'bn1', 'relu', 'conv2', 'bn2', 'relu', 'block1', 'block2', 'block3', 'block4', 'block5', 'block6', 'block7', 'block8', 'block9', 'block10' , 'block11', 'block12', 'conv3', 'bn3', 'relu', 'conv4', 'bn4'] # classifier self.relu = nn.ReLU(inplace=True) self.fc = nn.Linear(4096, num_classes) self.dp = nn.Dropout(p=0.2) def _init_xcep_fad(self): fad_excep = return_pytorch04_xception(True) conv1_data = fad_excep.conv1.weight.data # let new conv1 use old param to balance the network fad_excep.conv1 = nn.Conv2d(12, 32, 3, 2, 0, bias=False) for i in range(4): fad_excep.conv1.weight.data[:, i*3:(i+1)*3, :, :] = conv1_data / 4.0 return fad_excep def _init_xcep_lfs(self): lfs_excep = return_pytorch04_xception(True) conv1_data = lfs_excep.conv1.weight.data # let new conv1 use old param to balance the network lfs_excep.conv1 = nn.Conv2d(self._LFS_M, 32, 3, 1, 0, bias=False) for i in range(int(self._LFS_M / 3)): lfs_excep.conv1.weight.data[:, i*3:(i+1)*3, :, :] = conv1_data / float(self._LFS_M / 3.0) return lfs_excep def _features(self, x_fad, x_fls): for forward_func in self.excep_forwards: x_fad = getattr(self.fad_excep, forward_func)(x_fad) x_fls = getattr(self.lfs_excep, forward_func)(x_fls) if forward_func == 'block7': x_fad, x_fls = self.mix_block7(x_fad, x_fls) if forward_func == 'block12': x_fad, x_fls = self.mix_block12(x_fad, x_fls) return x_fad, x_fls def _norm_feature(self, x): x = self.relu(x) x = F.adaptive_avg_pool2d(x, (1,1)) x = x.view(x.size(0), -1) return x def forward(self, x): fad_input = self.fad_head(x) lfs_input = self.lfs_head(x) x_fad, x_fls = self._features(fad_input, lfs_input) x_fad = self._norm_feature(x_fad) x_fls = self._norm_feature(x_fls) x_cat = torch.cat((x_fad, x_fls), dim=1) x_drop = self.dp(x_cat) logit = self.fc(x_drop) return logit class SeparableConv2d(nn.Module): def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False): super(SeparableConv2d,self).__init__() self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias) self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias) def forward(self,x): x = self.conv1(x) x = self.pointwise(x) return x class Block(nn.Module): def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True): super(Block, self).__init__() if out_filters != in_filters or strides!=1: self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False) self.skipbn = nn.BatchNorm2d(out_filters) else: self.skip=None self.relu = nn.ReLU(inplace=True) rep=[] filters=in_filters if grow_first: rep.append(self.relu) rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False)) rep.append(nn.BatchNorm2d(out_filters)) filters = out_filters for i in range(reps-1): rep.append(self.relu) rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False)) rep.append(nn.BatchNorm2d(filters)) if not grow_first: rep.append(self.relu) rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False)) rep.append(nn.BatchNorm2d(out_filters)) if not start_with_relu: rep = rep[1:] else: rep[0] = nn.ReLU(inplace=False) if strides != 1: rep.append(nn.MaxPool2d(3,strides,1)) self.rep = nn.Sequential(*rep) def forward(self,inp): x = self.rep(inp) if self.skip is not None: skip = self.skip(inp) skip = self.skipbn(skip) else: skip = inp x+=skip return x class Xception(nn.Module): """ Xception optimized for the ImageNet dataset, as specified in https://arxiv.org/pdf/1610.02357.pdf """ def __init__(self, num_classes=1000): """ Constructor Args: num_classes: number of classes """ super(Xception, self).__init__() self.num_classes = num_classes self.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False) self.bn1 = nn.BatchNorm2d(32) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(32,64,3,bias=False) self.bn2 = nn.BatchNorm2d(64) #do relu here self.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True) self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True) self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True) self.block4=Block(728,728,3,1,start_with_relu=True,grow_first=True) self.block5=Block(728,728,3,1,start_with_relu=True,grow_first=True) self.block6=Block(728,728,3,1,start_with_relu=True,grow_first=True) self.block7=Block(728,728,3,1,start_with_relu=True,grow_first=True) self.block8=Block(728,728,3,1,start_with_relu=True,grow_first=True) self.block9=Block(728,728,3,1,start_with_relu=True,grow_first=True) self.block10=Block(728,728,3,1,start_with_relu=True,grow_first=True) self.block11=Block(728,728,3,1,start_with_relu=True,grow_first=True) self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False) self.conv3 = SeparableConv2d(1024,1536,3,1,1) self.bn3 = nn.BatchNorm2d(1536) #do relu here self.conv4 = SeparableConv2d(1536,2048,3,1,1) self.bn4 = nn.BatchNorm2d(2048) def xception(num_classes=1000, pretrained='imagenet'): model = Xception(num_classes=num_classes) if pretrained: settings = pretrained_settings['xception'][pretrained] assert num_classes == settings['num_classes'], \ "num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) model = Xception(num_classes=num_classes) model.load_state_dict(settings['dir']) model.input_space = settings['input_space'] model.input_size = settings['input_size'] model.input_range = settings['input_range'] model.mean = settings['mean'] model.std = settings['std'] return model def return_pytorch04_xception(pretrained=True): model = xception(pretrained=False) if pretrained: state_dict = torch.load( '/ossfs/workspace/GenVideo/weights/xception-b5690688.pth') for name, weights in state_dict.items(): if 'pointwise' in name: state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) model.load_state_dict(state_dict, strict=False) return model class Filter(nn.Module): def __init__(self, size, band_start, band_end, use_learnable=True, norm=False): super(Filter, self).__init__() self.use_learnable = use_learnable self.base = nn.Parameter(torch.tensor(generate_filter(band_start, band_end, size)), requires_grad=False) if self.use_learnable: self.learnable = nn.Parameter(torch.randn(size, size), requires_grad=True) self.learnable.data.normal_(0., 0.1) # Todo # self.learnable = nn.Parameter(torch.rand((size, size)) * 0.2 - 0.1, requires_grad=True) self.norm = norm if norm: self.ft_num = nn.Parameter(torch.sum(torch.tensor(generate_filter(band_start, band_end, size))), requires_grad=False) def forward(self, x): if self.use_learnable: filt = self.base + norm_sigma(self.learnable) else: filt = self.base if self.norm: y = x * filt / self.ft_num else: y = x * filt return y class FAD_Head(nn.Module): def __init__(self, size): super(FAD_Head, self).__init__() # init DCT matrix self._DCT_all = nn.Parameter(torch.tensor(DCT_mat(size)).float(), requires_grad=False) self._DCT_all_T = nn.Parameter(torch.transpose(torch.tensor(DCT_mat(size)).float(), 0, 1), requires_grad=False) # define base filters and learnable # 0 - 1/16 || 1/16 - 1/8 || 1/8 - 1 low_filter = Filter(size, 0, size // 16) middle_filter = Filter(size, size // 16, size // 8) high_filter = Filter(size, size // 8, size) all_filter = Filter(size, 0, size * 2) self.filters = nn.ModuleList([low_filter, middle_filter, high_filter, all_filter]) def forward(self, x): # DCT x_freq = self._DCT_all @ x @ self._DCT_all_T # [N, 3, 299, 299] # 4 kernel y_list = [] for i in range(4): x_pass = self.filters[i](x_freq) # [N, 3, 299, 299] y = self._DCT_all_T @ x_pass @ self._DCT_all # [N, 3, 299, 299] y_list.append(y) out = torch.cat(y_list, dim=1) # [N, 12, 299, 299] return out class LFS_Head(nn.Module): def __init__(self, size, window_size, M): super(LFS_Head, self).__init__() self.window_size = window_size self._M = M # init DCT matrix self._DCT_patch = nn.Parameter(torch.tensor(DCT_mat(window_size)).float(), requires_grad=False) self._DCT_patch_T = nn.Parameter(torch.transpose(torch.tensor(DCT_mat(window_size)).float(), 0, 1), requires_grad=False) self.unfold = nn.Unfold(kernel_size=(window_size, window_size), stride=2, padding=4) # init filters self.filters = nn.ModuleList([Filter(window_size, window_size * 2. / M * i, window_size * 2. / M * (i+1), norm=True) for i in range(M)]) def forward(self, x): # turn RGB into Gray x_gray = 0.299*x[:,0,:,:] + 0.587*x[:,1,:,:] + 0.114*x[:,2,:,:] x = x_gray.unsqueeze(1) # rescale to 0 - 255 x = (x + 1.) * 122.5 # calculate size N, C, W, H = x.size() S = self.window_size size_after = int((W - S + 8)/2) + 1 assert size_after == 149 # sliding window unfold and DCT x_unfold = self.unfold(x) # [N, C * S * S, L] L:block num L = x_unfold.size()[2] x_unfold = x_unfold.transpose(1, 2).reshape(N, L, C, S, S) # [N, L, C, S, S] x_dct = self._DCT_patch @ x_unfold @ self._DCT_patch_T # M kernels filtering y_list = [] for i in range(self._M): # y = self.filters[i](x_dct) # [N, L, C, S, S] # y = torch.abs(y) # y = torch.sum(y, dim=[2,3,4]) # [N, L] # y = torch.log10(y + 1e-15) y = torch.abs(x_dct) y = torch.log10(y + 1e-15) y = self.filters[i](y) y = torch.sum(y, dim=[2,3,4]) y = y.reshape(N, size_after, size_after).unsqueeze(dim=1) # [N, 1, 149, 149] y_list.append(y) out = torch.cat(y_list, dim=1) # [N, M, 149, 149] return out class MixBlock(nn.Module): def __init__(self, c_in, width, height): super(MixBlock, self).__init__() self.FAD_query = nn.Conv2d(c_in, c_in, (1,1)) self.LFS_query = nn.Conv2d(c_in, c_in, (1,1)) self.FAD_key = nn.Conv2d(c_in, c_in, (1,1)) self.LFS_key = nn.Conv2d(c_in, c_in, (1,1)) self.softmax = nn.Softmax(dim=-1) self.relu = nn.ReLU() self.FAD_gamma = nn.Parameter(torch.zeros(1)) self.LFS_gamma = nn.Parameter(torch.zeros(1)) self.FAD_conv = nn.Conv2d(c_in, c_in, (1,1), groups=c_in) self.FAD_bn = nn.BatchNorm2d(c_in) self.LFS_conv = nn.Conv2d(c_in, c_in, (1,1), groups=c_in) self.LFS_bn = nn.BatchNorm2d(c_in) def forward(self, x_FAD, x_LFS): B, C, W, H = x_FAD.size() assert W == H q_FAD = self.FAD_query(x_FAD).view(-1, W, H) # [BC, W, H] q_LFS = self.LFS_query(x_LFS).view(-1, W, H) M_query = torch.cat([q_FAD, q_LFS], dim=2) # [BC, W, 2H] k_FAD = self.FAD_key(x_FAD).view(-1, W, H).transpose(1, 2) # [BC, H, W] k_LFS = self.LFS_key(x_LFS).view(-1, W, H).transpose(1, 2) M_key = torch.cat([k_FAD, k_LFS], dim=1) # [BC, 2H, W] energy = torch.bmm(M_query, M_key) #[BC, W, W] attention = self.softmax(energy).view(B, C, W, W) att_LFS = x_LFS * attention * (torch.sigmoid(self.LFS_gamma) * 2.0 - 1.0) y_FAD = x_FAD + self.FAD_bn(self.FAD_conv(att_LFS)) att_FAD = x_FAD * attention * (torch.sigmoid(self.FAD_gamma) * 2.0 - 1.0) y_LFS = x_LFS + self.LFS_bn(self.LFS_conv(att_FAD)) return y_FAD, y_LFS def DCT_mat(size): m = [[ (np.sqrt(1./size) if i == 0 else np.sqrt(2./size)) * np.cos((j + 0.5) * np.pi * i / size) for j in range(size)] for i in range(size)] return m def generate_filter(start, end, size): return [[0. if i + j > end or i + j <= start else 1. for j in range(size)] for i in range(size)] def norm_sigma(x): return 2. * torch.sigmoid(x) - 1. class Det_F3_Net(nn.Module): def __init__(self): super(Det_F3_Net, self).__init__() self.f3net = F3Net(num_classes=1) def forward(self, x): b, t, _, h, w = x.shape images = x.view(b * t, 3, h, w) sequence_output = self.f3net(images) sequence_output = sequence_output.view(b, t, -1) sequence_output = sequence_output.mean(1) return sequence_output if __name__ == '__main__': model = F3Net() print(model)