|
|
import torch.utils.data as data |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import pandas as pd |
|
|
import torch |
|
|
import albumentations |
|
|
import random |
|
|
import os |
|
|
import numpy as np |
|
|
import cv2 |
|
|
import math |
|
|
import warnings |
|
|
|
|
|
|
|
|
def crop_center_by_percentage(image, percentage): |
|
|
height, width = image.shape[:2] |
|
|
|
|
|
if width > height: |
|
|
left_pixels = int(width * percentage) |
|
|
right_pixels = int(width * percentage) |
|
|
start_x = left_pixels |
|
|
end_x = width - right_pixels |
|
|
cropped_image = image[:, start_x:end_x] |
|
|
else: |
|
|
up_pixels = int(height * percentage) |
|
|
down_pixels = int(height * percentage) |
|
|
start_y = up_pixels |
|
|
end_y = height - down_pixels |
|
|
cropped_image = image[start_y:end_y, :] |
|
|
|
|
|
return cropped_image |
|
|
|
|
|
class Ours_Dataset_train(Dataset): |
|
|
def __init__(self, index_list=None, df=None): |
|
|
self.index_list = index_list |
|
|
self.df = df |
|
|
self.positive_indices = df[df['label'] == 1].index.tolist() |
|
|
self.negative_indices = df[df['label'] == 0].index.tolist() |
|
|
self.balanced_indices = [] |
|
|
self.resample() |
|
|
|
|
|
def resample(self): |
|
|
|
|
|
min_samples = min(len(self.positive_indices), len(self.negative_indices)) |
|
|
self.balanced_indices.clear() |
|
|
self.balanced_indices.extend(random.sample(self.positive_indices, min_samples)) |
|
|
self.balanced_indices.extend(random.sample(self.negative_indices, min_samples)) |
|
|
random.shuffle(self.balanced_indices) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
real_idx = self.balanced_indices[idx] |
|
|
row = self.df.iloc[real_idx] |
|
|
video_id = row['content_path'] |
|
|
label = row['label'] |
|
|
frame_list = eval(row['frame_seq']) |
|
|
label_onehot = [0]*2 |
|
|
select_frame_nums = 8 |
|
|
|
|
|
aug_list = [ |
|
|
albumentations.Resize(224, 224) |
|
|
] |
|
|
|
|
|
if random.random() < 0.5: |
|
|
aug_list.append(albumentations.HorizontalFlip(p=1.0)) |
|
|
if random.random() < 0.5: |
|
|
quality_score = random.randint(50, 100) |
|
|
aug_list.append(albumentations.ImageCompression(quality_lower=quality_score, quality_upper=quality_score)) |
|
|
if random.random() < 0.3: |
|
|
aug_list.append(albumentations.GaussNoise(p=1.0)) |
|
|
if random.random() < 0.3: |
|
|
aug_list.append(albumentations.GaussianBlur(blur_limit=(3, 5), p=1.0)) |
|
|
if random.random() < 0.001: |
|
|
aug_list.append(albumentations.ToGray(p=1.0)) |
|
|
|
|
|
aug_list.append(albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0)) |
|
|
trans = albumentations.Compose(aug_list) |
|
|
|
|
|
if len(frame_list) >= select_frame_nums: |
|
|
start_frame = random.randint(0, len(frame_list)-select_frame_nums) |
|
|
select_frames = frame_list[start_frame:start_frame+select_frame_nums] |
|
|
frames = [] |
|
|
for x in frame_list[start_frame:start_frame+select_frame_nums]: |
|
|
while True: |
|
|
try: |
|
|
temp_image_path = video_id+'/'+str(x)+'.jpg' |
|
|
image = download_oss_file('GenVideo/'+ temp_image_path) |
|
|
if video_id.startswith("real/youku"): |
|
|
image = crop_center_by_percentage(image, 0.15) |
|
|
break |
|
|
except Exception as e: |
|
|
if x+1 < len(frame_list): |
|
|
x = x + 1 |
|
|
elif x - 1 >=0 : |
|
|
x = x - 1 |
|
|
augmented = trans(image=image) |
|
|
image = augmented["image"] |
|
|
frames.append(image.transpose(2,0,1)[np.newaxis,:]) |
|
|
else: |
|
|
pad_num = select_frame_nums-len(frame_list) |
|
|
frames = [] |
|
|
for x in frame_list: |
|
|
temp_image_path = video_id+'/'+str(x)+'.jpg' |
|
|
image = download_oss_file('GenVideo/'+temp_image_path) |
|
|
if video_id.startswith("real/youku"): |
|
|
image = crop_center_by_percentage(image, 0.15) |
|
|
augmented = trans(image=image) |
|
|
image = augmented["image"] |
|
|
frames.append(image.transpose(2,0,1)[np.newaxis,:]) |
|
|
for i in range(pad_num): |
|
|
frames.append(np.zeros((224,224,3)).transpose(2,0,1)[np.newaxis,:]) |
|
|
|
|
|
label_onehot[int(label)] = 1 |
|
|
frames = np.concatenate(frames, 0) |
|
|
frames = torch.tensor(frames[np.newaxis,:]) |
|
|
label_onehot = torch.FloatTensor(label_onehot) |
|
|
binary_label = torch.FloatTensor([int(label)]) |
|
|
|
|
|
return self.index_list[idx], frames, label_onehot, binary_label |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.balanced_indices) |
|
|
|
|
|
|
|
|
class Ours_Dataset_val(data.Dataset): |
|
|
def __init__(self, cfg, index_list=None, df=None): |
|
|
self.index_list = index_list |
|
|
self.cfg = cfg |
|
|
self.df = df |
|
|
self.frame_dir = df['image_path'].tolist() |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
aug_list = [ |
|
|
albumentations.Resize(224, 224), |
|
|
] |
|
|
|
|
|
if self.cfg['task'] == 'JPEG_Compress_Attack': |
|
|
aug_list.append(albumentations.JpegCompression(quality_lower=35, quality_upper=35,p=1.0)) |
|
|
if self.cfg['task'] == 'FLIP_Attack': |
|
|
if random.random() < 0.5: |
|
|
aug_list.append(albumentations.HorizontalFlip(p=1.0)) |
|
|
else: |
|
|
aug_list.append(albumentations.VerticalFlip(p=1.0)) |
|
|
if self.cfg['task'] == 'CROP_Attack': |
|
|
random_crop_x = random.randint(0, 16) |
|
|
random_crop_y = random.randint(0, 16) |
|
|
crop_width = random.randint(160, 208) |
|
|
crop_height = random.randint(160, 208) |
|
|
aug_list.append(albumentations.Crop(x_min=random_crop_x, y_min=random_crop_y, x_max=random_crop_x+crop_width, y_max=random_crop_y+crop_height)) |
|
|
aug_list.append(albumentations.Resize(224, 224)) |
|
|
|
|
|
if self.cfg['task'] == 'Color_Attack': |
|
|
index = random.choice([i for i in range(4)]) |
|
|
dicts = {0:[0.5,0,0,0],1:[0,0.5,0,0],2:[0,0,0.5,0],3:[0,0,0,0.5]} |
|
|
brightness,contrast,saturation,hue = dicts[index] |
|
|
aug_list.append(albumentations.ColorJitter( |
|
|
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)) |
|
|
|
|
|
if self.cfg['task'] == 'Gaussian_Attack': |
|
|
aug_list.append(albumentations.GaussianBlur(blur_limit=(7, 7), p=1.0)) |
|
|
|
|
|
aug_list.append(albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0)) |
|
|
trans = albumentations.Compose(aug_list) |
|
|
|
|
|
|
|
|
df_v = self.df.loc[self.index_list[idx]] |
|
|
video_id = df_v['content_path'] |
|
|
activity_id = df_v['activity_id'] |
|
|
label = df_v['label'] |
|
|
label_onehot = [0]*2 |
|
|
frame_list = eval(df_v['frame_seq']) |
|
|
|
|
|
select_frame_nums = 8 |
|
|
|
|
|
if len(frame_list) >= select_frame_nums: |
|
|
start_frame = random.randint(0, len(frame_list)-select_frame_nums) |
|
|
select_frames = frame_list[start_frame:start_frame+select_frame_nums] |
|
|
frames = [] |
|
|
for x in frame_list[start_frame:start_frame+select_frame_nums]: |
|
|
while True: |
|
|
try: |
|
|
temp_image_path = video_id+'/'+str(x)+'.jpg' |
|
|
image = download_oss_file('GenVideo/'+ temp_image_path) |
|
|
image = crop_center_by_percentage(image, 0.1) |
|
|
break |
|
|
except Exception as e: |
|
|
if x+1 < len(frame_list): |
|
|
x = x + 1 |
|
|
elif x - 1 >=0 : |
|
|
x = x - 1 |
|
|
augmented = trans(image=image) |
|
|
image = augmented["image"] |
|
|
frames.append(image.transpose(2,0,1)[np.newaxis,:]) |
|
|
else: |
|
|
pad_num = select_frame_nums-len(frame_list) |
|
|
frames = [] |
|
|
for x in frame_list: |
|
|
temp_image_path = video_id+'/'+str(x)+'.jpg' |
|
|
image = download_oss_file('GenVideo/'+temp_image_path) |
|
|
image = crop_center_by_percentage(image, 0.1) |
|
|
augmented = trans(image=image) |
|
|
image = augmented["image"] |
|
|
frames.append(image.transpose(2,0,1)[np.newaxis,:]) |
|
|
for i in range(pad_num): |
|
|
frames.append(np.zeros((224,224,3)).transpose(2,0,1)[np.newaxis,:]) |
|
|
|
|
|
label_onehot[int(label)] = 1 |
|
|
frames = np.concatenate(frames, 0) |
|
|
frames = torch.tensor(frames[np.newaxis,:]) |
|
|
label_onehot = torch.FloatTensor(label_onehot) |
|
|
binary_label = torch.FloatTensor([int(label)]) |
|
|
return self.index_list[idx], frames, label_onehot, binary_label, video_id |
|
|
|
|
|
|
|
|
def __len__(self): |
|
|
return len(self.index_list) |
|
|
|
|
|
|
|
|
|
|
|
def generate_dataset_loader(cfg): |
|
|
df_train = pd.read_csv('GenVideo/datasets/train.csv') |
|
|
|
|
|
if cfg['task'] == 'normal': |
|
|
df_val = pd.read_csv('GenVideo/datasets/val_id.csv') |
|
|
elif cfg['task'] == 'robust_compress': |
|
|
df_val = pd.read_csv('GenVideo/datasets/com_28.csv') |
|
|
elif cfg['task'] == 'Image_Water_Attack': |
|
|
df_val = pd.read_csv('GenVideo/datasets/imgwater.csv') |
|
|
elif cfg['task'] == 'Text_Water_Attack': |
|
|
df_val = pd.read_csv('GenVideo/datasets/textwater.csv') |
|
|
elif cfg['task'] == 'one2many': |
|
|
df_val = pd.read_csv('GenVideo/datasets/val_ood.csv') |
|
|
if cfg['train_sub_set'] == 'pika': |
|
|
prefixes = ["fake/pika", "real"] |
|
|
video_condition = df_train['content_path'].str.startswith(prefixes[0]) |
|
|
for prefix in prefixes[1:]: |
|
|
video_condition |= df_train['content_path'].str.startswith(prefix) |
|
|
df_train = df_train[video_condition] |
|
|
elif cfg['train_sub_set'] == 'SEINE': |
|
|
prefixes = ["fake/SEINE", "real"] |
|
|
video_condition = df_train['content_path'].str.startswith(prefixes[0]) |
|
|
for prefix in prefixes[1:]: |
|
|
video_condition |= df_train['content_path'].str.startswith(prefix) |
|
|
df_train = df_train[video_condition] |
|
|
elif cfg['train_sub_set'] == 'OpenSora': |
|
|
prefixes = ["fake/OpenSora", "real"] |
|
|
video_condition = df_train['content_path'].str.startswith(prefixes[0]) |
|
|
for prefix in prefixes[1:]: |
|
|
video_condition |= df_train['content_path'].str.startswith(prefix) |
|
|
df_train = df_train[video_condition] |
|
|
elif cfg['train_sub_set'] == 'Latte': |
|
|
prefixes = ["fake/Latte", "real"] |
|
|
video_condition = df_train['content_path'].str.startswith(prefixes[0]) |
|
|
for prefix in prefixes[1:]: |
|
|
video_condition |= df_train['content_path'].str.startswith(prefix) |
|
|
df_train = df_train[video_condition] |
|
|
else: |
|
|
df_val = pd.read_csv('GenVideo/datasets/val_ood.csv') |
|
|
|
|
|
df_train.reset_index(drop=True, inplace=True) |
|
|
df_val.reset_index(drop=True, inplace=True) |
|
|
|
|
|
index_val = df_val.index.tolist() |
|
|
index_val = index_val[:] |
|
|
|
|
|
val_dataset = Ours_Dataset_val(cfg, index_val, df_val) |
|
|
val_loader = torch.utils.data.DataLoader( |
|
|
val_dataset, batch_size=cfg['val_batch_size'], shuffle=False, num_workers=cfg['num_workers'], pin_memory=True, drop_last=False |
|
|
) |
|
|
|
|
|
index_train = df_train.index.tolist() |
|
|
index_train = index_train[:] |
|
|
train_dataset = Ours_Dataset_train(index_train, df_train) |
|
|
train_loader = torch.utils.data.DataLoader( |
|
|
train_dataset, batch_size=cfg['train_batch_size'], shuffle=True, num_workers=cfg['num_workers'], pin_memory=True, drop_last=True |
|
|
) |
|
|
|
|
|
print("******* Training Video IDs", str(len(index_train))," Training Batch size ", str(cfg['train_batch_size'])," *******") |
|
|
print("******* Testing Video IDs", str(len(index_val)), " Testing Batch size ", str(cfg['val_batch_size'])," *******") |
|
|
|
|
|
return train_loader, val_loader |
|
|
|
|
|
|
|
|
|