|
|
|
|
|
|
|
|
import torch |
|
|
import cv2 |
|
|
import os |
|
|
import random |
|
|
import pandas as pd |
|
|
import albumentations |
|
|
import numpy as np |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
|
|
|
|
|
|
|
|
def get_train_transforms(): |
|
|
"""Defines the probabilistic augmentations for training.""" |
|
|
return albumentations.Compose([ |
|
|
albumentations.Resize(224, 224), |
|
|
albumentations.HorizontalFlip(p=0.5), |
|
|
albumentations.ImageCompression(quality_lower=50, quality_upper=100, p=0.5), |
|
|
albumentations.GaussNoise(p=0.3), |
|
|
albumentations.GaussianBlur(blur_limit=(3, 5), p=0.3), |
|
|
albumentations.ToGray(p=0.01), |
|
|
albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0) |
|
|
]) |
|
|
|
|
|
def get_val_transforms(cfg): |
|
|
"""Defines augmentations for validation, handling different attack tasks from the config.""" |
|
|
aug_list = [albumentations.Resize(224, 224)] |
|
|
|
|
|
task = cfg.get('task', 'normal') |
|
|
|
|
|
if task == 'JPEG_Compress_Attack': |
|
|
aug_list.append(albumentations.JpegCompression(quality_lower=35, quality_upper=35, p=1.0)) |
|
|
elif task == 'FLIP_Attack': |
|
|
aug_list.append(albumentations.HorizontalFlip(p=0.5)) |
|
|
elif task == 'CROP_Attack': |
|
|
aug_list.append(albumentations.RandomCrop(height=192, width=192, p=1.0)) |
|
|
aug_list.append(albumentations.Resize(224, 224)) |
|
|
elif task == 'Color_Attack': |
|
|
aug_list.append(albumentations.ColorJitter(p=1.0)) |
|
|
elif task == 'Gaussian_Attack': |
|
|
aug_list.append(albumentations.GaussianBlur(blur_limit=(7, 7), p=1.0)) |
|
|
|
|
|
aug_list.append(albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0)) |
|
|
return albumentations.Compose(aug_list) |
|
|
|
|
|
|
|
|
|
|
|
class VideoDataset(Dataset): |
|
|
""" |
|
|
A PyTorch Dataset for loading video frame sequences based on a DataFrame. |
|
|
Handles class balancing for each epoch. |
|
|
""" |
|
|
def __init__(self, df, index_list, base_data_path, transform=None, select_frame_nums=8): |
|
|
self.df = df |
|
|
self.index_list = index_list |
|
|
self.base_data_path = base_data_path |
|
|
self.transform = transform |
|
|
self.select_frame_nums = select_frame_nums |
|
|
|
|
|
self.positive_indices = self.df[self.df['label'] == 1].index.tolist() |
|
|
self.negative_indices = self.df[self.df['label'] == 0].index.tolist() |
|
|
|
|
|
self.balanced_indices = [] |
|
|
self.resample() |
|
|
|
|
|
def resample(self): |
|
|
min_samples = min(len(self.positive_indices), len(self.negative_indices)) |
|
|
self.balanced_indices.clear() |
|
|
self.balanced_indices.extend(random.sample(self.positive_indices, min_samples)) |
|
|
self.balanced_indices.extend(random.sample(self.negative_indices, min_samples)) |
|
|
random.shuffle(self.balanced_indices) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.balanced_indices) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
real_idx = self.balanced_indices[idx] |
|
|
row = self.df.iloc[real_idx] |
|
|
|
|
|
video_id = row['content_path'] |
|
|
label = int(row['label']) |
|
|
frame_list = eval(row['frame_seq']) |
|
|
|
|
|
frames = [] |
|
|
|
|
|
if len(frame_list) >= self.select_frame_nums: |
|
|
start_index = random.randint(0, len(frame_list) - self.select_frame_nums) |
|
|
selected_frames = frame_list[start_index : start_index + self.select_frame_nums] |
|
|
else: |
|
|
selected_frames = frame_list |
|
|
|
|
|
for frame_path in selected_frames: |
|
|
try: |
|
|
image = cv2.imread(frame_path) |
|
|
if image is None: |
|
|
raise ValueError("Failed to load") |
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
except Exception: |
|
|
image = np.zeros((224, 224, 3), dtype=np.uint8) |
|
|
|
|
|
if self.transform: |
|
|
image = self.transform(image=image)['image'] |
|
|
|
|
|
frames.append(image.transpose(2, 0, 1)[np.newaxis, :]) |
|
|
|
|
|
|
|
|
pad_num = self.select_frame_nums - len(frames) |
|
|
if pad_num > 0: |
|
|
for _ in range(pad_num): |
|
|
frames.append(np.zeros((1, 3, 224, 224))) |
|
|
|
|
|
frames_tensor = np.concatenate(frames, axis=0) |
|
|
frames_tensor = torch.from_numpy(frames_tensor).float().unsqueeze(0) |
|
|
|
|
|
label_onehot = torch.zeros(2) |
|
|
label_onehot[label] = 1.0 |
|
|
binary_label = torch.FloatTensor([label]) |
|
|
|
|
|
original_index = self.index_list[idx] |
|
|
return original_index, frames_tensor, label_onehot, binary_label |
|
|
|
|
|
|
|
|
|
|
|
class VideoDatasetVal(Dataset): |
|
|
"""A compatible validation dataset loader.""" |
|
|
def __init__(self, df, index_list, base_data_path, transform=None, select_frame_nums=8): |
|
|
self.df = df |
|
|
self.index_list = index_list |
|
|
self.base_data_path = base_data_path |
|
|
self.transform = transform |
|
|
self.select_frame_nums = select_frame_nums |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.index_list) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
real_idx = self.index_list[idx] |
|
|
row = self.df.iloc[real_idx] |
|
|
|
|
|
video_id = row['content_path'] |
|
|
label = int(row['label']) |
|
|
frame_list = eval(row['frame_seq']) |
|
|
|
|
|
|
|
|
frames = [] |
|
|
if len(frame_list) >= self.select_frame_nums: |
|
|
start_index = random.randint(0, len(frame_list) - self.select_frame_nums) |
|
|
selected_frames = frame_list[start_index : start_index + self.select_frame_nums] |
|
|
else: |
|
|
selected_frames = frame_list |
|
|
|
|
|
for frame_path in selected_frames: |
|
|
try: |
|
|
image = cv2.imread(frame_path) |
|
|
if image is None: |
|
|
raise ValueError("Failed to load") |
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
except Exception: |
|
|
image = np.zeros((224, 224, 3), dtype=np.uint8) |
|
|
|
|
|
if self.transform: |
|
|
image = self.transform(image=image)['image'] |
|
|
|
|
|
frames.append(image.transpose(2, 0, 1)[np.newaxis, :]) |
|
|
|
|
|
|
|
|
pad_num = self.select_frame_nums - len(frames) |
|
|
if pad_num > 0: |
|
|
for _ in range(pad_num): |
|
|
frames.append(np.zeros((1, 3, 224, 224))) |
|
|
|
|
|
frames_tensor = np.concatenate(frames, axis=0) |
|
|
frames_tensor = torch.from_numpy(frames_tensor).float().unsqueeze(0) |
|
|
|
|
|
label_onehot = torch.zeros(2) |
|
|
label_onehot[label] = 1.0 |
|
|
binary_label = torch.FloatTensor([label]) |
|
|
|
|
|
|
|
|
return self.index_list[idx], frames_tensor, label_onehot, binary_label, video_id |
|
|
|
|
|
|
|
|
|
|
|
def generate_dataset_loader(cfg): |
|
|
""" |
|
|
The main function to create train and validation dataloaders using the new classes. |
|
|
""" |
|
|
df_train = pd.read_csv('/home/kalpit/workspace/aigc/repos/DeMamba/csv/veo_train.csv') |
|
|
|
|
|
|
|
|
task = cfg.get('task', 'normal') |
|
|
if task == 'normal': |
|
|
df_val = pd.read_csv('GenVideo/datasets/val_id.csv') |
|
|
elif task == 'robust_compress': |
|
|
df_val = pd.read_csv('GenVideo/datasets/com_28.csv') |
|
|
|
|
|
else: |
|
|
df_val = pd.read_csv('/home/kalpit/workspace/aigc/repos/DeMamba/csv/veo_test.csv') |
|
|
|
|
|
|
|
|
if cfg.get('train_sub_set'): |
|
|
prefixes = [f"fake/{cfg['train_sub_set']}", "real"] |
|
|
condition = df_train['content_path'].str.startswith(tuple(prefixes)) |
|
|
df_train = df_train[condition] |
|
|
|
|
|
df_train.reset_index(drop=True, inplace=True) |
|
|
df_val.reset_index(drop=True, inplace=True) |
|
|
|
|
|
index_train = df_train.index.tolist() |
|
|
index_val = df_val.index.tolist() |
|
|
|
|
|
|
|
|
base_data_path = 'GenVideo' |
|
|
|
|
|
train_dataset = VideoDataset( |
|
|
df=df_train, |
|
|
index_list=index_train, |
|
|
base_data_path=base_data_path, |
|
|
transform=get_train_transforms(), |
|
|
select_frame_nums=8 |
|
|
) |
|
|
|
|
|
val_dataset = VideoDatasetVal( |
|
|
df=df_val, |
|
|
index_list=index_val, |
|
|
base_data_path=base_data_path, |
|
|
transform=get_val_transforms(cfg), |
|
|
select_frame_nums=8 |
|
|
) |
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, batch_size=cfg['train_batch_size'], shuffle=True, |
|
|
num_workers=cfg['num_workers'], pin_memory=True, drop_last=True |
|
|
) |
|
|
|
|
|
val_loader = DataLoader( |
|
|
val_dataset, batch_size=cfg['val_batch_size'], shuffle=False, |
|
|
num_workers=cfg['num_workers'], pin_memory=True, drop_last=False |
|
|
) |
|
|
|
|
|
print(f"******* Training Videos {len(index_train)}, Batch size {cfg['train_batch_size']} *******") |
|
|
print(f"******* Testing Videos {len(index_val)}, Batch size {cfg['val_batch_size']} *******") |
|
|
|
|
|
return train_loader, val_loader |