Kalpit
feat: Add model files with LFS
d39b279
raw
history blame
8.82 kB
"""
MINTIME: Multi-Identity size-iNvariant TIMEsformer for Video Deepfake Detection@TIFS'2024
Copyright (c) ISTI-CNR and its affiliates.
Modified by Davide Alessandro Coccomini from https://github.com/davide-coccomini/MINTIME-Multi-Identity-size-iNvariant-TIMEsformer-for-Video-Deepfake-Detection
"""
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from statistics import mean
from torch.nn.init import trunc_normal_
import cv2
import numpy as np
from random import random
from .clip import clip
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
# classes
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, *args, **kwargs):
x = self.norm(x)
return self.fn(x, *args, **kwargs)
# time token shift
def shift(t, amt):
if amt == 0:
return t
return F.pad(t, (0, 0, 0, 0, amt, -amt))
class PreTokenShift(nn.Module):
def __init__(self, frames, fn):
super().__init__()
self.frames = frames
self.fn = fn
def forward(self, x, *args, **kwargs):
f, dim = self.frames, x.shape[-1]
cls_x, x = x[:, :1], x[:, 1:]
x = rearrange(x, 'b (f n) d -> b f n d', f = f)
# shift along time frame before and after
dim_chunk = (dim // 3)
chunks = x.split(dim_chunk, dim = -1)
chunks_to_shift, rest = chunks[:3], chunks[3:]
shifted_chunks = tuple(map(lambda args: shift(*args), zip(chunks_to_shift, (-1, 0, 1))))
x = torch.cat((*shifted_chunks, *rest), dim = -1)
x = rearrange(x, 'b f n d -> b (f n) d')
x = torch.cat((cls_x, x), dim = 1)
return self.fn(x, *args, **kwargs)
# feedforward
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return x * F.gelu(gates)
class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)
def forward(self, x):
return self.net(x)
# attention
def attn(q, k, v):
sim = torch.einsum('b i d, b j d -> b i j', q, k)
attn = sim.softmax(dim = -1)
out = torch.einsum('b i j, b j d -> b i d', attn, v)
return out, attn
class Attention(nn.Module):
def __init__(self, dim, dim_head = 64, heads = 8, dropout = 0.):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, einops_from, einops_to, **einops_dims):
h = self.heads
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
q = q * self.scale
# splice out classification token at index 1
(cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, :1], t[:, 1:]), (q, k, v))
# let classification token attend to key / values of all patches across time and space
cls_out, cls_attentions = attn(cls_q, k, v)
# rearrange across time or space
q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), (q_, k_, v_))
# expand cls token keys and values across time or space and concat
r = q_.shape[0] // cls_k.shape[0]
cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r = r), (cls_k, cls_v))
k_ = torch.cat((cls_k, k_), dim = 1)
v_ = torch.cat((cls_v, v_), dim = 1)
# attention
out, attentions = attn(q_, k_, v_)
# merge back time or space
out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims)
# concat back the cls token
out = torch.cat((cls_out, out), dim = 1)
# merge back the heads
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
# combine heads out
return self.to_out(out), cls_attentions
class SizeInvariantTimeSformer(nn.Module):
def __init__(
self,
*,
require_attention = False
):
super().__init__()
self.dim = 512
self.num_frames = 8
self.max_identities = 1
self.image_size = 224
self.num_classes = 1
self.patch_size = 1
self.num_patches = 196
self.channels = 512
self.depth = 9
self.heads = 8
self.dim_head = 64
self.attn_dropout = 0.
self.ff_dropout = 0.
self.shift_tokens = False
self.enable_size_emb = True
self.enable_pos_emb = True
self.require_attention = require_attention
num_positions = self.num_frames * self.channels
self.to_patch_embedding = nn.Linear(self.channels, self.dim)
self.cls_token = nn.Parameter(torch.randn(1, self.dim))
self.pos_emb = nn.Embedding(num_positions + 1, self.dim)
if self.enable_size_emb:
self.size_emb = nn.Embedding(num_positions + 1, self.dim)
self.layers = nn.ModuleList([])
for _ in range(self.depth):
ff = FeedForward(self.dim, dropout = self.ff_dropout)
time_attn = Attention(self.dim, dim_head = self.dim_head, heads = self.heads, dropout = self.attn_dropout)
spatial_attn = Attention(self.dim, dim_head = self.dim_head, heads = self.heads, dropout = self.attn_dropout)
if self.shift_tokens:
time_attn, spatial_attn, ff = map(lambda t: PreTokenShift(self.num_frames, t), (time_attn, spatial_attn, ff))
time_attn, spatial_attn, ff = map(lambda t: PreNorm(self.dim, t), (time_attn, spatial_attn, ff))
self.layers.append(nn.ModuleList([time_attn, spatial_attn, ff]))
self.to_out = nn.Sequential(
nn.LayerNorm(self.dim),
nn.Linear(self.dim, self.num_classes)
)
# Initialization
trunc_normal_(self.pos_emb.weight, std=.02)
trunc_normal_(self.cls_token, std=.02)
if self.enable_size_emb:
trunc_normal_(self.size_emb.weight, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_emb', 'cls_token'}
def forward(self, x):
b, f, c, h, w = x.shape
n = h * w
device = x.device
x = rearrange(x, 'b f c h w -> b (f h w) c') # B x F*P*P x C
tokens = self.to_patch_embedding(x) # B x 8*7*7 x dim
# Add cls token
cls_token = repeat(self.cls_token, 'n d -> b n d', b = b)
x = torch.cat((cls_token, tokens), dim = 1)
# Positional embedding
x += self.pos_emb(torch.arange(x.shape[1], device=device))
# Time and space attention
for (time_attn, spatial_attn, ff) in self.layers:
y, _ = time_attn(x, 'b (f n) d', '(b n) f d', n = n)
x = x + y
y, _ = spatial_attn(x, 'b (f n) d', '(b f) n d', f = f)
x = x + y
x = ff(x) + x
cls_token = x[:, 0]
if self.require_attention:
return self.to_out(cls_token)
else:
return self.to_out(cls_token)
class ViT_B_MINTIME(nn.Module):
def __init__(
self, channel_size=512, class_num=1
):
super(ViT_B_MINTIME, self).__init__()
self.clip_model, preprocess = clip.load('ViT-B-16')
self.clip_model = self.clip_model.float()
self.head = SizeInvariantTimeSformer()
def forward(self, x):
b, t, _, h, w = x.shape
images = x.view(b * t, 3, h, w)
sequence_output = self.clip_model.encode_image(images)
_, _, c = sequence_output.shape
sequence_output = sequence_output.view(b, t, 14, 14, c)
sequence_output = sequence_output.permute(0, 1, 4, 2, 3)
res = self.head(sequence_output)
return res
if __name__ == '__main__':
model = ViT_B_MINTIME()
model = model.cuda()
dummy_input = torch.randn(4,8,3,224,224)
dummy_input = dummy_input.cuda()
print(model(dummy_input))