|
|
""" |
|
|
F3Net: Fusion, Feedback and Focus for Salient Object Detection @ AAAI'2020 |
|
|
Copyright (c) University of Chinese Academy of Sciences and its affiliates. |
|
|
Modified by Jun Wei from https://github.com/weijun88/F3Net |
|
|
""" |
|
|
import os |
|
|
import sys |
|
|
import numpy as np |
|
|
import torch |
|
|
import torchvision |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.utils.model_zoo as model_zoo |
|
|
|
|
|
pretrained_settings = { |
|
|
'xception': { |
|
|
'imagenet': { |
|
|
'dir': '/ossfs/workspace/aigc_video/weights/xception-b5690688.pth', |
|
|
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth', |
|
|
'input_space': 'RGB', |
|
|
'input_size': [3, 299, 299], |
|
|
'input_range': [0, 1], |
|
|
'mean': [0.5, 0.5, 0.5], |
|
|
'std': [0.5, 0.5, 0.5], |
|
|
'num_classes': 1000, |
|
|
'scale': 0.8975 |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
class F3Net(nn.Module): |
|
|
""" |
|
|
Implementation is mainly referenced from https://github.com/yyk-wew/F3Net |
|
|
""" |
|
|
def __init__(self, |
|
|
num_classes: int=2, |
|
|
img_width: int=299, |
|
|
img_height: int=299, |
|
|
LFS_window_size: int=10, |
|
|
LFS_M: int=6) -> None: |
|
|
super(F3Net, self).__init__() |
|
|
assert img_width == img_height |
|
|
self.img_size = img_width |
|
|
self.num_classes = num_classes |
|
|
self._LFS_window_size = LFS_window_size |
|
|
self._LFS_M = LFS_M |
|
|
|
|
|
|
|
|
self.fad_head = FAD_Head(self.img_size) |
|
|
self.lfs_head = LFS_Head(self.img_size, self._LFS_window_size, self._LFS_M) |
|
|
|
|
|
self.fad_excep = self._init_xcep_fad() |
|
|
self.lfs_excep = self._init_xcep_lfs() |
|
|
|
|
|
self.mix_block7 = MixBlock(c_in=728, width=19, height=19) |
|
|
self.mix_block12 = MixBlock(c_in=1024, width=10, height=10) |
|
|
self.excep_forwards = ['conv1', 'bn1', 'relu', 'conv2', 'bn2', 'relu', |
|
|
'block1', 'block2', 'block3', 'block4', 'block5', 'block6', |
|
|
'block7', 'block8', 'block9', 'block10' , 'block11', 'block12', |
|
|
'conv3', 'bn3', 'relu', 'conv4', 'bn4'] |
|
|
|
|
|
|
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
self.fc = nn.Linear(4096, num_classes) |
|
|
self.dp = nn.Dropout(p=0.2) |
|
|
|
|
|
def _init_xcep_fad(self): |
|
|
fad_excep = return_pytorch04_xception(True) |
|
|
conv1_data = fad_excep.conv1.weight.data |
|
|
|
|
|
fad_excep.conv1 = nn.Conv2d(12, 32, 3, 2, 0, bias=False) |
|
|
for i in range(4): |
|
|
fad_excep.conv1.weight.data[:, i*3:(i+1)*3, :, :] = conv1_data / 4.0 |
|
|
return fad_excep |
|
|
|
|
|
def _init_xcep_lfs(self): |
|
|
lfs_excep = return_pytorch04_xception(True) |
|
|
conv1_data = lfs_excep.conv1.weight.data |
|
|
|
|
|
lfs_excep.conv1 = nn.Conv2d(self._LFS_M, 32, 3, 1, 0, bias=False) |
|
|
for i in range(int(self._LFS_M / 3)): |
|
|
lfs_excep.conv1.weight.data[:, i*3:(i+1)*3, :, :] = conv1_data / float(self._LFS_M / 3.0) |
|
|
return lfs_excep |
|
|
|
|
|
def _features(self, x_fad, x_fls): |
|
|
for forward_func in self.excep_forwards: |
|
|
x_fad = getattr(self.fad_excep, forward_func)(x_fad) |
|
|
x_fls = getattr(self.lfs_excep, forward_func)(x_fls) |
|
|
if forward_func == 'block7': |
|
|
x_fad, x_fls = self.mix_block7(x_fad, x_fls) |
|
|
if forward_func == 'block12': |
|
|
x_fad, x_fls = self.mix_block12(x_fad, x_fls) |
|
|
return x_fad, x_fls |
|
|
|
|
|
def _norm_feature(self, x): |
|
|
x = self.relu(x) |
|
|
x = F.adaptive_avg_pool2d(x, (1,1)) |
|
|
x = x.view(x.size(0), -1) |
|
|
return x |
|
|
|
|
|
def forward(self, x): |
|
|
fad_input = self.fad_head(x) |
|
|
lfs_input = self.lfs_head(x) |
|
|
x_fad, x_fls = self._features(fad_input, lfs_input) |
|
|
x_fad = self._norm_feature(x_fad) |
|
|
x_fls = self._norm_feature(x_fls) |
|
|
x_cat = torch.cat((x_fad, x_fls), dim=1) |
|
|
x_drop = self.dp(x_cat) |
|
|
logit = self.fc(x_drop) |
|
|
return logit |
|
|
|
|
|
class SeparableConv2d(nn.Module): |
|
|
def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False): |
|
|
super(SeparableConv2d,self).__init__() |
|
|
|
|
|
self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias) |
|
|
self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias) |
|
|
|
|
|
def forward(self,x): |
|
|
x = self.conv1(x) |
|
|
x = self.pointwise(x) |
|
|
return x |
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True): |
|
|
super(Block, self).__init__() |
|
|
|
|
|
if out_filters != in_filters or strides!=1: |
|
|
self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False) |
|
|
self.skipbn = nn.BatchNorm2d(out_filters) |
|
|
else: |
|
|
self.skip=None |
|
|
|
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
rep=[] |
|
|
|
|
|
filters=in_filters |
|
|
if grow_first: |
|
|
rep.append(self.relu) |
|
|
rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False)) |
|
|
rep.append(nn.BatchNorm2d(out_filters)) |
|
|
filters = out_filters |
|
|
|
|
|
for i in range(reps-1): |
|
|
rep.append(self.relu) |
|
|
rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False)) |
|
|
rep.append(nn.BatchNorm2d(filters)) |
|
|
|
|
|
if not grow_first: |
|
|
rep.append(self.relu) |
|
|
rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False)) |
|
|
rep.append(nn.BatchNorm2d(out_filters)) |
|
|
|
|
|
if not start_with_relu: |
|
|
rep = rep[1:] |
|
|
else: |
|
|
rep[0] = nn.ReLU(inplace=False) |
|
|
|
|
|
if strides != 1: |
|
|
rep.append(nn.MaxPool2d(3,strides,1)) |
|
|
self.rep = nn.Sequential(*rep) |
|
|
|
|
|
def forward(self,inp): |
|
|
x = self.rep(inp) |
|
|
|
|
|
if self.skip is not None: |
|
|
skip = self.skip(inp) |
|
|
skip = self.skipbn(skip) |
|
|
else: |
|
|
skip = inp |
|
|
|
|
|
x+=skip |
|
|
return x |
|
|
|
|
|
class Xception(nn.Module): |
|
|
""" |
|
|
Xception optimized for the ImageNet dataset, as specified in |
|
|
https://arxiv.org/pdf/1610.02357.pdf |
|
|
""" |
|
|
def __init__(self, num_classes=1000): |
|
|
""" Constructor |
|
|
Args: |
|
|
num_classes: number of classes |
|
|
""" |
|
|
super(Xception, self).__init__() |
|
|
self.num_classes = num_classes |
|
|
|
|
|
self.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False) |
|
|
self.bn1 = nn.BatchNorm2d(32) |
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
|
|
|
self.conv2 = nn.Conv2d(32,64,3,bias=False) |
|
|
self.bn2 = nn.BatchNorm2d(64) |
|
|
|
|
|
|
|
|
self.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True) |
|
|
self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True) |
|
|
self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True) |
|
|
|
|
|
self.block4=Block(728,728,3,1,start_with_relu=True,grow_first=True) |
|
|
self.block5=Block(728,728,3,1,start_with_relu=True,grow_first=True) |
|
|
self.block6=Block(728,728,3,1,start_with_relu=True,grow_first=True) |
|
|
self.block7=Block(728,728,3,1,start_with_relu=True,grow_first=True) |
|
|
|
|
|
self.block8=Block(728,728,3,1,start_with_relu=True,grow_first=True) |
|
|
self.block9=Block(728,728,3,1,start_with_relu=True,grow_first=True) |
|
|
self.block10=Block(728,728,3,1,start_with_relu=True,grow_first=True) |
|
|
self.block11=Block(728,728,3,1,start_with_relu=True,grow_first=True) |
|
|
|
|
|
self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False) |
|
|
|
|
|
self.conv3 = SeparableConv2d(1024,1536,3,1,1) |
|
|
self.bn3 = nn.BatchNorm2d(1536) |
|
|
|
|
|
|
|
|
self.conv4 = SeparableConv2d(1536,2048,3,1,1) |
|
|
self.bn4 = nn.BatchNorm2d(2048) |
|
|
|
|
|
def xception(num_classes=1000, pretrained='imagenet'): |
|
|
model = Xception(num_classes=num_classes) |
|
|
if pretrained: |
|
|
settings = pretrained_settings['xception'][pretrained] |
|
|
assert num_classes == settings['num_classes'], \ |
|
|
"num_classes should be {}, but is {}".format(settings['num_classes'], num_classes) |
|
|
|
|
|
model = Xception(num_classes=num_classes) |
|
|
model.load_state_dict(settings['dir']) |
|
|
|
|
|
model.input_space = settings['input_space'] |
|
|
model.input_size = settings['input_size'] |
|
|
model.input_range = settings['input_range'] |
|
|
model.mean = settings['mean'] |
|
|
model.std = settings['std'] |
|
|
|
|
|
return model |
|
|
|
|
|
def return_pytorch04_xception(pretrained=True): |
|
|
model = xception(pretrained=False) |
|
|
if pretrained: |
|
|
state_dict = torch.load( |
|
|
'/ossfs/workspace/GenVideo/weights/xception-b5690688.pth') |
|
|
for name, weights in state_dict.items(): |
|
|
if 'pointwise' in name: |
|
|
state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
return model |
|
|
|
|
|
class Filter(nn.Module): |
|
|
def __init__(self, size, |
|
|
band_start, |
|
|
band_end, |
|
|
use_learnable=True, |
|
|
norm=False): |
|
|
super(Filter, self).__init__() |
|
|
self.use_learnable = use_learnable |
|
|
|
|
|
self.base = nn.Parameter(torch.tensor(generate_filter(band_start, band_end, size)), requires_grad=False) |
|
|
if self.use_learnable: |
|
|
self.learnable = nn.Parameter(torch.randn(size, size), requires_grad=True) |
|
|
self.learnable.data.normal_(0., 0.1) |
|
|
|
|
|
|
|
|
|
|
|
self.norm = norm |
|
|
if norm: |
|
|
self.ft_num = nn.Parameter(torch.sum(torch.tensor(generate_filter(band_start, band_end, size))), requires_grad=False) |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
if self.use_learnable: |
|
|
filt = self.base + norm_sigma(self.learnable) |
|
|
else: |
|
|
filt = self.base |
|
|
|
|
|
if self.norm: |
|
|
y = x * filt / self.ft_num |
|
|
else: |
|
|
y = x * filt |
|
|
return y |
|
|
|
|
|
class FAD_Head(nn.Module): |
|
|
def __init__(self, size): |
|
|
super(FAD_Head, self).__init__() |
|
|
|
|
|
self._DCT_all = nn.Parameter(torch.tensor(DCT_mat(size)).float(), requires_grad=False) |
|
|
self._DCT_all_T = nn.Parameter(torch.transpose(torch.tensor(DCT_mat(size)).float(), 0, 1), requires_grad=False) |
|
|
|
|
|
|
|
|
|
|
|
low_filter = Filter(size, 0, size // 16) |
|
|
middle_filter = Filter(size, size // 16, size // 8) |
|
|
high_filter = Filter(size, size // 8, size) |
|
|
all_filter = Filter(size, 0, size * 2) |
|
|
|
|
|
self.filters = nn.ModuleList([low_filter, middle_filter, high_filter, all_filter]) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x_freq = self._DCT_all @ x @ self._DCT_all_T |
|
|
|
|
|
|
|
|
y_list = [] |
|
|
for i in range(4): |
|
|
x_pass = self.filters[i](x_freq) |
|
|
y = self._DCT_all_T @ x_pass @ self._DCT_all |
|
|
y_list.append(y) |
|
|
out = torch.cat(y_list, dim=1) |
|
|
return out |
|
|
|
|
|
|
|
|
class LFS_Head(nn.Module): |
|
|
def __init__(self, size, window_size, M): |
|
|
super(LFS_Head, self).__init__() |
|
|
|
|
|
self.window_size = window_size |
|
|
self._M = M |
|
|
|
|
|
|
|
|
self._DCT_patch = nn.Parameter(torch.tensor(DCT_mat(window_size)).float(), requires_grad=False) |
|
|
self._DCT_patch_T = nn.Parameter(torch.transpose(torch.tensor(DCT_mat(window_size)).float(), 0, 1), requires_grad=False) |
|
|
|
|
|
self.unfold = nn.Unfold(kernel_size=(window_size, window_size), stride=2, padding=4) |
|
|
|
|
|
|
|
|
self.filters = nn.ModuleList([Filter(window_size, window_size * 2. / M * i, window_size * 2. / M * (i+1), norm=True) for i in range(M)]) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x_gray = 0.299*x[:,0,:,:] + 0.587*x[:,1,:,:] + 0.114*x[:,2,:,:] |
|
|
x = x_gray.unsqueeze(1) |
|
|
|
|
|
|
|
|
x = (x + 1.) * 122.5 |
|
|
|
|
|
|
|
|
N, C, W, H = x.size() |
|
|
S = self.window_size |
|
|
size_after = int((W - S + 8)/2) + 1 |
|
|
assert size_after == 149 |
|
|
|
|
|
|
|
|
x_unfold = self.unfold(x) |
|
|
L = x_unfold.size()[2] |
|
|
x_unfold = x_unfold.transpose(1, 2).reshape(N, L, C, S, S) |
|
|
x_dct = self._DCT_patch @ x_unfold @ self._DCT_patch_T |
|
|
|
|
|
|
|
|
y_list = [] |
|
|
for i in range(self._M): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
y = torch.abs(x_dct) |
|
|
y = torch.log10(y + 1e-15) |
|
|
y = self.filters[i](y) |
|
|
y = torch.sum(y, dim=[2,3,4]) |
|
|
y = y.reshape(N, size_after, size_after).unsqueeze(dim=1) |
|
|
y_list.append(y) |
|
|
out = torch.cat(y_list, dim=1) |
|
|
return out |
|
|
|
|
|
class MixBlock(nn.Module): |
|
|
|
|
|
def __init__(self, c_in, width, height): |
|
|
super(MixBlock, self).__init__() |
|
|
self.FAD_query = nn.Conv2d(c_in, c_in, (1,1)) |
|
|
self.LFS_query = nn.Conv2d(c_in, c_in, (1,1)) |
|
|
|
|
|
self.FAD_key = nn.Conv2d(c_in, c_in, (1,1)) |
|
|
self.LFS_key = nn.Conv2d(c_in, c_in, (1,1)) |
|
|
|
|
|
self.softmax = nn.Softmax(dim=-1) |
|
|
self.relu = nn.ReLU() |
|
|
|
|
|
self.FAD_gamma = nn.Parameter(torch.zeros(1)) |
|
|
self.LFS_gamma = nn.Parameter(torch.zeros(1)) |
|
|
|
|
|
self.FAD_conv = nn.Conv2d(c_in, c_in, (1,1), groups=c_in) |
|
|
self.FAD_bn = nn.BatchNorm2d(c_in) |
|
|
self.LFS_conv = nn.Conv2d(c_in, c_in, (1,1), groups=c_in) |
|
|
self.LFS_bn = nn.BatchNorm2d(c_in) |
|
|
|
|
|
def forward(self, x_FAD, x_LFS): |
|
|
B, C, W, H = x_FAD.size() |
|
|
assert W == H |
|
|
|
|
|
q_FAD = self.FAD_query(x_FAD).view(-1, W, H) |
|
|
q_LFS = self.LFS_query(x_LFS).view(-1, W, H) |
|
|
M_query = torch.cat([q_FAD, q_LFS], dim=2) |
|
|
|
|
|
k_FAD = self.FAD_key(x_FAD).view(-1, W, H).transpose(1, 2) |
|
|
k_LFS = self.LFS_key(x_LFS).view(-1, W, H).transpose(1, 2) |
|
|
M_key = torch.cat([k_FAD, k_LFS], dim=1) |
|
|
|
|
|
energy = torch.bmm(M_query, M_key) |
|
|
attention = self.softmax(energy).view(B, C, W, W) |
|
|
|
|
|
att_LFS = x_LFS * attention * (torch.sigmoid(self.LFS_gamma) * 2.0 - 1.0) |
|
|
y_FAD = x_FAD + self.FAD_bn(self.FAD_conv(att_LFS)) |
|
|
|
|
|
att_FAD = x_FAD * attention * (torch.sigmoid(self.FAD_gamma) * 2.0 - 1.0) |
|
|
y_LFS = x_LFS + self.LFS_bn(self.LFS_conv(att_FAD)) |
|
|
return y_FAD, y_LFS |
|
|
|
|
|
def DCT_mat(size): |
|
|
m = [[ (np.sqrt(1./size) if i == 0 else np.sqrt(2./size)) * np.cos((j + 0.5) * np.pi * i / size) for j in range(size)] for i in range(size)] |
|
|
return m |
|
|
|
|
|
def generate_filter(start, end, size): |
|
|
return [[0. if i + j > end or i + j <= start else 1. for j in range(size)] for i in range(size)] |
|
|
|
|
|
def norm_sigma(x): |
|
|
return 2. * torch.sigmoid(x) - 1. |
|
|
|
|
|
|
|
|
class Det_F3_Net(nn.Module): |
|
|
def __init__(self): |
|
|
super(Det_F3_Net, self).__init__() |
|
|
self.f3net = F3Net(num_classes=1) |
|
|
|
|
|
def forward(self, x): |
|
|
b, t, _, h, w = x.shape |
|
|
images = x.view(b * t, 3, h, w) |
|
|
sequence_output = self.f3net(images) |
|
|
sequence_output = sequence_output.view(b, t, -1) |
|
|
sequence_output = sequence_output.mean(1) |
|
|
|
|
|
return sequence_output |
|
|
|
|
|
if __name__ == '__main__': |
|
|
model = F3Net() |
|
|
print(model) |
|
|
|