Spaces:
Runtime error
Runtime error
| ''' Towards An End-to-End Framework for Video Inpainting | |
| ''' | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from model.modules.flow_comp import SPyNet | |
| from model.modules.feat_prop import BidirectionalPropagation, SecondOrderDeformableAlignment | |
| from model.modules.tfocal_transformer import TemporalFocalTransformerBlock, SoftSplit, SoftComp | |
| from model.modules.spectral_norm import spectral_norm as _spectral_norm | |
| class BaseNetwork(nn.Module): | |
| def __init__(self): | |
| super(BaseNetwork, self).__init__() | |
| def print_network(self): | |
| if isinstance(self, list): | |
| self = self[0] | |
| num_params = 0 | |
| for param in self.parameters(): | |
| num_params += param.numel() | |
| print( | |
| 'Network [%s] was created. Total number of parameters: %.1f million. ' | |
| 'To see the architecture, do print(network).' % | |
| (type(self).__name__, num_params / 1000000)) | |
| def init_weights(self, init_type='normal', gain=0.02): | |
| ''' | |
| initialize network's weights | |
| init_type: normal | xavier | kaiming | orthogonal | |
| https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 | |
| ''' | |
| def init_func(m): | |
| classname = m.__class__.__name__ | |
| if classname.find('InstanceNorm2d') != -1: | |
| if hasattr(m, 'weight') and m.weight is not None: | |
| nn.init.constant_(m.weight.data, 1.0) | |
| if hasattr(m, 'bias') and m.bias is not None: | |
| nn.init.constant_(m.bias.data, 0.0) | |
| elif hasattr(m, 'weight') and (classname.find('Conv') != -1 | |
| or classname.find('Linear') != -1): | |
| if init_type == 'normal': | |
| nn.init.normal_(m.weight.data, 0.0, gain) | |
| elif init_type == 'xavier': | |
| nn.init.xavier_normal_(m.weight.data, gain=gain) | |
| elif init_type == 'xavier_uniform': | |
| nn.init.xavier_uniform_(m.weight.data, gain=1.0) | |
| elif init_type == 'kaiming': | |
| nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') | |
| elif init_type == 'orthogonal': | |
| nn.init.orthogonal_(m.weight.data, gain=gain) | |
| elif init_type == 'none': # uses pytorch's default init method | |
| m.reset_parameters() | |
| else: | |
| raise NotImplementedError( | |
| 'initialization method [%s] is not implemented' % | |
| init_type) | |
| if hasattr(m, 'bias') and m.bias is not None: | |
| nn.init.constant_(m.bias.data, 0.0) | |
| self.apply(init_func) | |
| # propagate to children | |
| for m in self.children(): | |
| if hasattr(m, 'init_weights'): | |
| m.init_weights(init_type, gain) | |
| class Encoder(nn.Module): | |
| def __init__(self): | |
| super(Encoder, self).__init__() | |
| self.group = [1, 2, 4, 8, 1] | |
| self.layers = nn.ModuleList([ | |
| nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1), | |
| nn.LeakyReLU(0.2, inplace=True) | |
| ]) | |
| def forward(self, x): | |
| bt, c, h, w = x.size() | |
| h, w = h // 4, w // 4 | |
| out = x | |
| for i, layer in enumerate(self.layers): | |
| if i == 8: | |
| x0 = out | |
| if i > 8 and i % 2 == 0: | |
| g = self.group[(i - 8) // 2] | |
| x = x0.view(bt, g, -1, h, w) | |
| o = out.view(bt, g, -1, h, w) | |
| out = torch.cat([x, o], 2).view(bt, -1, h, w) | |
| out = layer(out) | |
| return out | |
| class deconv(nn.Module): | |
| def __init__(self, | |
| input_channel, | |
| output_channel, | |
| kernel_size=3, | |
| padding=0): | |
| super().__init__() | |
| self.conv = nn.Conv2d(input_channel, | |
| output_channel, | |
| kernel_size=kernel_size, | |
| stride=1, | |
| padding=padding) | |
| def forward(self, x): | |
| x = F.interpolate(x, | |
| scale_factor=2, | |
| mode='bilinear', | |
| align_corners=True) | |
| return self.conv(x) | |
| class InpaintGenerator(BaseNetwork): | |
| def __init__(self, init_weights=True): | |
| super(InpaintGenerator, self).__init__() | |
| channel = 256 | |
| hidden = 512 | |
| # encoder | |
| self.encoder = Encoder() | |
| # decoder | |
| self.decoder = nn.Sequential( | |
| deconv(channel // 2, 128, kernel_size=3, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| deconv(64, 64, kernel_size=3, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)) | |
| # feature propagation module | |
| self.feat_prop_module = BidirectionalPropagation(channel // 2) | |
| # soft split and soft composition | |
| kernel_size = (7, 7) | |
| padding = (3, 3) | |
| stride = (3, 3) | |
| output_size = (60, 108) | |
| t2t_params = { | |
| 'kernel_size': kernel_size, | |
| 'stride': stride, | |
| 'padding': padding, | |
| 'output_size': output_size | |
| } | |
| self.ss = SoftSplit(channel // 2, | |
| hidden, | |
| kernel_size, | |
| stride, | |
| padding, | |
| t2t_param=t2t_params) | |
| self.sc = SoftComp(channel // 2, hidden, output_size, kernel_size, | |
| stride, padding) | |
| n_vecs = 1 | |
| for i, d in enumerate(kernel_size): | |
| n_vecs *= int((output_size[i] + 2 * padding[i] - | |
| (d - 1) - 1) / stride[i] + 1) | |
| blocks = [] | |
| depths = 8 | |
| num_heads = [4] * depths | |
| window_size = [(5, 9)] * depths | |
| focal_windows = [(5, 9)] * depths | |
| focal_levels = [2] * depths | |
| pool_method = "fc" | |
| for i in range(depths): | |
| blocks.append( | |
| TemporalFocalTransformerBlock(dim=hidden, | |
| num_heads=num_heads[i], | |
| window_size=window_size[i], | |
| focal_level=focal_levels[i], | |
| focal_window=focal_windows[i], | |
| n_vecs=n_vecs, | |
| t2t_params=t2t_params, | |
| pool_method=pool_method)) | |
| self.transformer = nn.Sequential(*blocks) | |
| if init_weights: | |
| self.init_weights() | |
| # Need to initial the weights of MSDeformAttn specifically | |
| for m in self.modules(): | |
| if isinstance(m, SecondOrderDeformableAlignment): | |
| m.init_offset() | |
| # flow completion network | |
| self.update_spynet = SPyNet() | |
| def forward_bidirect_flow(self, masked_local_frames): | |
| b, l_t, c, h, w = masked_local_frames.size() | |
| # compute forward and backward flows of masked frames | |
| masked_local_frames = F.interpolate(masked_local_frames.view( | |
| -1, c, h, w), | |
| scale_factor=1 / 4, | |
| mode='bilinear', | |
| align_corners=True, | |
| recompute_scale_factor=True) | |
| masked_local_frames = masked_local_frames.view(b, l_t, c, h // 4, | |
| w // 4) | |
| mlf_1 = masked_local_frames[:, :-1, :, :, :].reshape( | |
| -1, c, h // 4, w // 4) | |
| mlf_2 = masked_local_frames[:, 1:, :, :, :].reshape( | |
| -1, c, h // 4, w // 4) | |
| pred_flows_forward = self.update_spynet(mlf_1, mlf_2) | |
| pred_flows_backward = self.update_spynet(mlf_2, mlf_1) | |
| pred_flows_forward = pred_flows_forward.view(b, l_t - 1, 2, h // 4, | |
| w // 4) | |
| pred_flows_backward = pred_flows_backward.view(b, l_t - 1, 2, h // 4, | |
| w // 4) | |
| return pred_flows_forward, pred_flows_backward | |
| def forward(self, masked_frames, num_local_frames): | |
| l_t = num_local_frames | |
| b, t, ori_c, ori_h, ori_w = masked_frames.size() | |
| # normalization before feeding into the flow completion module | |
| masked_local_frames = (masked_frames[:, :l_t, ...] + 1) / 2 | |
| pred_flows = self.forward_bidirect_flow(masked_local_frames) | |
| # extracting features and performing the feature propagation on local features | |
| enc_feat = self.encoder(masked_frames.view(b * t, ori_c, ori_h, ori_w)) | |
| _, c, h, w = enc_feat.size() | |
| local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...] | |
| ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...] | |
| local_feat = self.feat_prop_module(local_feat, pred_flows[0], | |
| pred_flows[1]) | |
| enc_feat = torch.cat((local_feat, ref_feat), dim=1) | |
| # content hallucination through stacking multiple temporal focal transformer blocks | |
| trans_feat = self.ss(enc_feat.view(-1, c, h, w), b) | |
| trans_feat = self.transformer(trans_feat) | |
| trans_feat = self.sc(trans_feat, t) | |
| trans_feat = trans_feat.view(b, t, -1, h, w) | |
| enc_feat = enc_feat + trans_feat | |
| # decode frames from features | |
| output = self.decoder(enc_feat.view(b * t, c, h, w)) | |
| output = torch.tanh(output) | |
| return output, pred_flows | |
| # ###################################################################### | |
| # Discriminator for Temporal Patch GAN | |
| # ###################################################################### | |
| class Discriminator(BaseNetwork): | |
| def __init__(self, | |
| in_channels=3, | |
| use_sigmoid=False, | |
| use_spectral_norm=True, | |
| init_weights=True): | |
| super(Discriminator, self).__init__() | |
| self.use_sigmoid = use_sigmoid | |
| nf = 32 | |
| self.conv = nn.Sequential( | |
| spectral_norm( | |
| nn.Conv3d(in_channels=in_channels, | |
| out_channels=nf * 1, | |
| kernel_size=(3, 5, 5), | |
| stride=(1, 2, 2), | |
| padding=1, | |
| bias=not use_spectral_norm), use_spectral_norm), | |
| # nn.InstanceNorm2d(64, track_running_stats=False), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| spectral_norm( | |
| nn.Conv3d(nf * 1, | |
| nf * 2, | |
| kernel_size=(3, 5, 5), | |
| stride=(1, 2, 2), | |
| padding=(1, 2, 2), | |
| bias=not use_spectral_norm), use_spectral_norm), | |
| # nn.InstanceNorm2d(128, track_running_stats=False), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| spectral_norm( | |
| nn.Conv3d(nf * 2, | |
| nf * 4, | |
| kernel_size=(3, 5, 5), | |
| stride=(1, 2, 2), | |
| padding=(1, 2, 2), | |
| bias=not use_spectral_norm), use_spectral_norm), | |
| # nn.InstanceNorm2d(256, track_running_stats=False), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| spectral_norm( | |
| nn.Conv3d(nf * 4, | |
| nf * 4, | |
| kernel_size=(3, 5, 5), | |
| stride=(1, 2, 2), | |
| padding=(1, 2, 2), | |
| bias=not use_spectral_norm), use_spectral_norm), | |
| # nn.InstanceNorm2d(256, track_running_stats=False), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| spectral_norm( | |
| nn.Conv3d(nf * 4, | |
| nf * 4, | |
| kernel_size=(3, 5, 5), | |
| stride=(1, 2, 2), | |
| padding=(1, 2, 2), | |
| bias=not use_spectral_norm), use_spectral_norm), | |
| # nn.InstanceNorm2d(256, track_running_stats=False), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv3d(nf * 4, | |
| nf * 4, | |
| kernel_size=(3, 5, 5), | |
| stride=(1, 2, 2), | |
| padding=(1, 2, 2))) | |
| if init_weights: | |
| self.init_weights() | |
| def forward(self, xs): | |
| # T, C, H, W = xs.shape (old) | |
| # B, T, C, H, W (new) | |
| xs_t = torch.transpose(xs, 1, 2) | |
| feat = self.conv(xs_t) | |
| if self.use_sigmoid: | |
| feat = torch.sigmoid(feat) | |
| out = torch.transpose(feat, 1, 2) # B, T, C, H, W | |
| return out | |
| def spectral_norm(module, mode=True): | |
| if mode: | |
| return _spectral_norm(module) | |
| return module | |