|
|
""" |
|
|
MINTIME: Multi-Identity size-iNvariant TIMEsformer for Video Deepfake Detection@TIFS'2024 |
|
|
Copyright (c) ISTI-CNR and its affiliates. |
|
|
Modified by Davide Alessandro Coccomini from https://github.com/davide-coccomini/MINTIME-Multi-Identity-size-iNvariant-TIMEsformer-for-Video-Deepfake-Detection |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from torch import nn, einsum |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange, repeat |
|
|
from statistics import mean |
|
|
from torch.nn.init import trunc_normal_ |
|
|
import cv2 |
|
|
import numpy as np |
|
|
from random import random |
|
|
from .clip import clip |
|
|
from einops.layers.torch import Rearrange |
|
|
|
|
|
|
|
|
def exists(val): |
|
|
return val is not None |
|
|
|
|
|
|
|
|
class PreNorm(nn.Module): |
|
|
def __init__(self, dim, fn): |
|
|
super().__init__() |
|
|
self.fn = fn |
|
|
self.norm = nn.LayerNorm(dim) |
|
|
|
|
|
def forward(self, x, *args, **kwargs): |
|
|
x = self.norm(x) |
|
|
return self.fn(x, *args, **kwargs) |
|
|
|
|
|
|
|
|
def shift(t, amt): |
|
|
if amt == 0: |
|
|
return t |
|
|
return F.pad(t, (0, 0, 0, 0, amt, -amt)) |
|
|
|
|
|
class PreTokenShift(nn.Module): |
|
|
def __init__(self, frames, fn): |
|
|
super().__init__() |
|
|
self.frames = frames |
|
|
self.fn = fn |
|
|
|
|
|
def forward(self, x, *args, **kwargs): |
|
|
f, dim = self.frames, x.shape[-1] |
|
|
cls_x, x = x[:, :1], x[:, 1:] |
|
|
x = rearrange(x, 'b (f n) d -> b f n d', f = f) |
|
|
|
|
|
|
|
|
dim_chunk = (dim // 3) |
|
|
chunks = x.split(dim_chunk, dim = -1) |
|
|
chunks_to_shift, rest = chunks[:3], chunks[3:] |
|
|
shifted_chunks = tuple(map(lambda args: shift(*args), zip(chunks_to_shift, (-1, 0, 1)))) |
|
|
x = torch.cat((*shifted_chunks, *rest), dim = -1) |
|
|
|
|
|
x = rearrange(x, 'b f n d -> b (f n) d') |
|
|
x = torch.cat((cls_x, x), dim = 1) |
|
|
return self.fn(x, *args, **kwargs) |
|
|
|
|
|
|
|
|
class GEGLU(nn.Module): |
|
|
def forward(self, x): |
|
|
x, gates = x.chunk(2, dim = -1) |
|
|
return x * F.gelu(gates) |
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
def __init__(self, dim, mult = 4, dropout = 0.): |
|
|
super().__init__() |
|
|
self.net = nn.Sequential( |
|
|
nn.Linear(dim, dim * mult * 2), |
|
|
GEGLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(dim * mult, dim) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.net(x) |
|
|
|
|
|
|
|
|
def attn(q, k, v): |
|
|
sim = torch.einsum('b i d, b j d -> b i j', q, k) |
|
|
attn = sim.softmax(dim = -1) |
|
|
out = torch.einsum('b i j, b j d -> b i d', attn, v) |
|
|
return out, attn |
|
|
|
|
|
class Attention(nn.Module): |
|
|
def __init__(self, dim, dim_head = 64, heads = 8, dropout = 0.): |
|
|
super().__init__() |
|
|
self.heads = heads |
|
|
self.scale = dim_head ** -0.5 |
|
|
inner_dim = dim_head * heads |
|
|
|
|
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) |
|
|
self.to_out = nn.Sequential( |
|
|
nn.Linear(inner_dim, dim), |
|
|
nn.Dropout(dropout) |
|
|
) |
|
|
|
|
|
def forward(self, x, einops_from, einops_to, **einops_dims): |
|
|
h = self.heads |
|
|
q, k, v = self.to_qkv(x).chunk(3, dim = -1) |
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v)) |
|
|
|
|
|
q = q * self.scale |
|
|
|
|
|
|
|
|
(cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k, v)) |
|
|
|
|
|
|
|
|
cls_out, cls_attentions = attn(cls_q, k, v) |
|
|
|
|
|
|
|
|
q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), (q_, k_, v_)) |
|
|
|
|
|
|
|
|
r = q_.shape[0] // cls_k.shape[0] |
|
|
cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r = r), (cls_k, cls_v)) |
|
|
|
|
|
k_ = torch.cat((cls_k, k_), dim = 1) |
|
|
v_ = torch.cat((cls_v, v_), dim = 1) |
|
|
|
|
|
|
|
|
out, attentions = attn(q_, k_, v_) |
|
|
|
|
|
|
|
|
out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims) |
|
|
|
|
|
|
|
|
out = torch.cat((cls_out, out), dim = 1) |
|
|
|
|
|
|
|
|
out = rearrange(out, '(b h) n d -> b n (h d)', h = h) |
|
|
|
|
|
|
|
|
return self.to_out(out), cls_attentions |
|
|
|
|
|
class SizeInvariantTimeSformer(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
*, |
|
|
require_attention = False |
|
|
): |
|
|
super().__init__() |
|
|
self.dim = 512 |
|
|
self.num_frames = 8 |
|
|
self.max_identities = 1 |
|
|
self.image_size = 224 |
|
|
self.num_classes = 1 |
|
|
self.patch_size = 1 |
|
|
self.num_patches = 196 |
|
|
self.channels = 512 |
|
|
self.depth = 9 |
|
|
self.heads = 8 |
|
|
self.dim_head = 64 |
|
|
self.attn_dropout = 0. |
|
|
self.ff_dropout = 0. |
|
|
self.shift_tokens = False |
|
|
self.enable_size_emb = True |
|
|
self.enable_pos_emb = True |
|
|
self.require_attention = require_attention |
|
|
|
|
|
num_positions = self.num_frames * self.channels |
|
|
self.to_patch_embedding = nn.Linear(self.channels, self.dim) |
|
|
self.cls_token = nn.Parameter(torch.randn(1, self.dim)) |
|
|
self.pos_emb = nn.Embedding(num_positions + 1, self.dim) |
|
|
|
|
|
if self.enable_size_emb: |
|
|
self.size_emb = nn.Embedding(num_positions + 1, self.dim) |
|
|
|
|
|
self.layers = nn.ModuleList([]) |
|
|
for _ in range(self.depth): |
|
|
ff = FeedForward(self.dim, dropout = self.ff_dropout) |
|
|
time_attn = Attention(self.dim, dim_head = self.dim_head, heads = self.heads, dropout = self.attn_dropout) |
|
|
spatial_attn = Attention(self.dim, dim_head = self.dim_head, heads = self.heads, dropout = self.attn_dropout) |
|
|
if self.shift_tokens: |
|
|
time_attn, spatial_attn, ff = map(lambda t: PreTokenShift(self.num_frames, t), (time_attn, spatial_attn, ff)) |
|
|
|
|
|
time_attn, spatial_attn, ff = map(lambda t: PreNorm(self.dim, t), (time_attn, spatial_attn, ff)) |
|
|
self.layers.append(nn.ModuleList([time_attn, spatial_attn, ff])) |
|
|
|
|
|
self.to_out = nn.Sequential( |
|
|
nn.LayerNorm(self.dim), |
|
|
nn.Linear(self.dim, self.num_classes) |
|
|
) |
|
|
|
|
|
|
|
|
trunc_normal_(self.pos_emb.weight, std=.02) |
|
|
trunc_normal_(self.cls_token, std=.02) |
|
|
if self.enable_size_emb: |
|
|
trunc_normal_(self.size_emb.weight, std=.02) |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, nn.Linear): |
|
|
trunc_normal_(m.weight, std=.02) |
|
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.LayerNorm): |
|
|
nn.init.constant_(m.bias, 0) |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
|
|
@torch.jit.ignore |
|
|
def no_weight_decay(self): |
|
|
return {'pos_emb', 'cls_token'} |
|
|
|
|
|
def forward(self, x): |
|
|
b, f, c, h, w = x.shape |
|
|
n = h * w |
|
|
device = x.device |
|
|
|
|
|
x = rearrange(x, 'b f c h w -> b (f h w) c') |
|
|
tokens = self.to_patch_embedding(x) |
|
|
|
|
|
|
|
|
cls_token = repeat(self.cls_token, 'n d -> b n d', b = b) |
|
|
x = torch.cat((cls_token, tokens), dim = 1) |
|
|
|
|
|
|
|
|
x += self.pos_emb(torch.arange(x.shape[1], device=device)) |
|
|
|
|
|
|
|
|
for (time_attn, spatial_attn, ff) in self.layers: |
|
|
y, _ = time_attn(x, 'b (f n) d', '(b n) f d', n = n) |
|
|
x = x + y |
|
|
y, _ = spatial_attn(x, 'b (f n) d', '(b f) n d', f = f) |
|
|
x = x + y |
|
|
x = ff(x) + x |
|
|
|
|
|
cls_token = x[:, 0] |
|
|
|
|
|
if self.require_attention: |
|
|
return self.to_out(cls_token) |
|
|
else: |
|
|
return self.to_out(cls_token) |
|
|
|
|
|
|
|
|
class ViT_B_MINTIME(nn.Module): |
|
|
def __init__( |
|
|
self, channel_size=512, class_num=1 |
|
|
): |
|
|
super(ViT_B_MINTIME, self).__init__() |
|
|
self.clip_model, preprocess = clip.load('ViT-B-16') |
|
|
self.clip_model = self.clip_model.float() |
|
|
self.head = SizeInvariantTimeSformer() |
|
|
|
|
|
def forward(self, x): |
|
|
b, t, _, h, w = x.shape |
|
|
images = x.view(b * t, 3, h, w) |
|
|
sequence_output = self.clip_model.encode_image(images) |
|
|
_, _, c = sequence_output.shape |
|
|
sequence_output = sequence_output.view(b, t, 14, 14, c) |
|
|
sequence_output = sequence_output.permute(0, 1, 4, 2, 3) |
|
|
|
|
|
|
|
|
res = self.head(sequence_output) |
|
|
|
|
|
return res |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
model = ViT_B_MINTIME() |
|
|
model = model.cuda() |
|
|
dummy_input = torch.randn(4,8,3,224,224) |
|
|
dummy_input = dummy_input.cuda() |
|
|
print(model(dummy_input)) |
|
|
|
|
|
|