seetrails_aigvdet_v2.0.0 / dataloader.py
Kalpit
feat: Add model files with LFS
d39b279
import torch.utils.data as data
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import torch
import albumentations
import random
import os
import numpy as np
import cv2
import math
import warnings
def crop_center_by_percentage(image, percentage):
height, width = image.shape[:2]
if width > height:
left_pixels = int(width * percentage)
right_pixels = int(width * percentage)
start_x = left_pixels
end_x = width - right_pixels
cropped_image = image[:, start_x:end_x]
else:
up_pixels = int(height * percentage)
down_pixels = int(height * percentage)
start_y = up_pixels
end_y = height - down_pixels
cropped_image = image[start_y:end_y, :]
return cropped_image
class Ours_Dataset_train(Dataset):
def __init__(self, index_list=None, df=None):
self.index_list = index_list
self.df = df
self.positive_indices = df[df['label'] == 1].index.tolist()
self.negative_indices = df[df['label'] == 0].index.tolist()
self.balanced_indices = []
self.resample()
def resample(self):
# Ensure each epoch uses a balanced dataset
min_samples = min(len(self.positive_indices), len(self.negative_indices))
self.balanced_indices.clear()
self.balanced_indices.extend(random.sample(self.positive_indices, min_samples))
self.balanced_indices.extend(random.sample(self.negative_indices, min_samples))
random.shuffle(self.balanced_indices) # Shuffle to mix positive and negative samples
def __getitem__(self, idx):
real_idx = self.balanced_indices[idx]
row = self.df.iloc[real_idx]
video_id = row['content_path']
label = row['label']
frame_list = eval(row['frame_seq'])
label_onehot = [0]*2
select_frame_nums = 8
aug_list = [
albumentations.Resize(224, 224)
]
if random.random() < 0.5:
aug_list.append(albumentations.HorizontalFlip(p=1.0))
if random.random() < 0.5:
quality_score = random.randint(50, 100)
aug_list.append(albumentations.ImageCompression(quality_lower=quality_score, quality_upper=quality_score))
if random.random() < 0.3:
aug_list.append(albumentations.GaussNoise(p=1.0))
if random.random() < 0.3:
aug_list.append(albumentations.GaussianBlur(blur_limit=(3, 5), p=1.0))
if random.random() < 0.001:
aug_list.append(albumentations.ToGray(p=1.0))
aug_list.append(albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0))
trans = albumentations.Compose(aug_list)
if len(frame_list) >= select_frame_nums:
start_frame = random.randint(0, len(frame_list)-select_frame_nums)
select_frames = frame_list[start_frame:start_frame+select_frame_nums]
frames = []
for x in frame_list[start_frame:start_frame+select_frame_nums]:
while True:
try:
temp_image_path = video_id+'/'+str(x)+'.jpg'
image = download_oss_file('GenVideo/'+ temp_image_path)
if video_id.startswith("real/youku"):
image = crop_center_by_percentage(image, 0.15)
break
except Exception as e:
if x+1 < len(frame_list):
x = x + 1
elif x - 1 >=0 :
x = x - 1
augmented = trans(image=image)
image = augmented["image"]
frames.append(image.transpose(2,0,1)[np.newaxis,:])
else:
pad_num = select_frame_nums-len(frame_list)
frames = []
for x in frame_list:
temp_image_path = video_id+'/'+str(x)+'.jpg'
image = download_oss_file('GenVideo/'+temp_image_path)
if video_id.startswith("real/youku"):
image = crop_center_by_percentage(image, 0.15)
augmented = trans(image=image)
image = augmented["image"]
frames.append(image.transpose(2,0,1)[np.newaxis,:])
for i in range(pad_num):
frames.append(np.zeros((224,224,3)).transpose(2,0,1)[np.newaxis,:])
label_onehot[int(label)] = 1
frames = np.concatenate(frames, 0)
frames = torch.tensor(frames[np.newaxis,:])
label_onehot = torch.FloatTensor(label_onehot)
binary_label = torch.FloatTensor([int(label)])
return self.index_list[idx], frames, label_onehot, binary_label
def __len__(self):
return len(self.balanced_indices)
class Ours_Dataset_val(data.Dataset):
def __init__(self, cfg, index_list=None, df=None):
self.index_list = index_list
self.cfg = cfg
self.df = df
self.frame_dir = df['image_path'].tolist()
def __getitem__(self, idx):
aug_list = [
albumentations.Resize(224, 224),
]
if self.cfg['task'] == 'JPEG_Compress_Attack':
aug_list.append(albumentations.JpegCompression(quality_lower=35, quality_upper=35,p=1.0))
if self.cfg['task'] == 'FLIP_Attack':
if random.random() < 0.5:
aug_list.append(albumentations.HorizontalFlip(p=1.0))
else:
aug_list.append(albumentations.VerticalFlip(p=1.0))
if self.cfg['task'] == 'CROP_Attack':
random_crop_x = random.randint(0, 16)
random_crop_y = random.randint(0, 16)
crop_width = random.randint(160, 208)
crop_height = random.randint(160, 208)
aug_list.append(albumentations.Crop(x_min=random_crop_x, y_min=random_crop_y, x_max=random_crop_x+crop_width, y_max=random_crop_y+crop_height))
aug_list.append(albumentations.Resize(224, 224))
if self.cfg['task'] == 'Color_Attack':
index = random.choice([i for i in range(4)])
dicts = {0:[0.5,0,0,0],1:[0,0.5,0,0],2:[0,0,0.5,0],3:[0,0,0,0.5]}
brightness,contrast,saturation,hue = dicts[index]
aug_list.append(albumentations.ColorJitter(
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue))
if self.cfg['task'] == 'Gaussian_Attack':
aug_list.append(albumentations.GaussianBlur(blur_limit=(7, 7), p=1.0))
aug_list.append(albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0))
trans = albumentations.Compose(aug_list)
df_v = self.df.loc[self.index_list[idx]]
video_id = df_v['content_path']
activity_id = df_v['activity_id']
label = df_v['label']
label_onehot = [0]*2
frame_list = eval(df_v['frame_seq'])
select_frame_nums = 8
if len(frame_list) >= select_frame_nums:
start_frame = random.randint(0, len(frame_list)-select_frame_nums)
select_frames = frame_list[start_frame:start_frame+select_frame_nums]
frames = []
for x in frame_list[start_frame:start_frame+select_frame_nums]:
while True:
try:
temp_image_path = video_id+'/'+str(x)+'.jpg'
image = download_oss_file('GenVideo/'+ temp_image_path)
image = crop_center_by_percentage(image, 0.1)
break
except Exception as e:
if x+1 < len(frame_list):
x = x + 1
elif x - 1 >=0 :
x = x - 1
augmented = trans(image=image)
image = augmented["image"]
frames.append(image.transpose(2,0,1)[np.newaxis,:])
else:
pad_num = select_frame_nums-len(frame_list)
frames = []
for x in frame_list:
temp_image_path = video_id+'/'+str(x)+'.jpg'
image = download_oss_file('GenVideo/'+temp_image_path)
image = crop_center_by_percentage(image, 0.1)
augmented = trans(image=image)
image = augmented["image"]
frames.append(image.transpose(2,0,1)[np.newaxis,:])
for i in range(pad_num):
frames.append(np.zeros((224,224,3)).transpose(2,0,1)[np.newaxis,:])
label_onehot[int(label)] = 1
frames = np.concatenate(frames, 0)
frames = torch.tensor(frames[np.newaxis,:])
label_onehot = torch.FloatTensor(label_onehot)
binary_label = torch.FloatTensor([int(label)])
return self.index_list[idx], frames, label_onehot, binary_label, video_id
def __len__(self):
return len(self.index_list)
def generate_dataset_loader(cfg):
df_train = pd.read_csv('GenVideo/datasets/train.csv')
if cfg['task'] == 'normal':
df_val = pd.read_csv('GenVideo/datasets/val_id.csv')
elif cfg['task'] == 'robust_compress':
df_val = pd.read_csv('GenVideo/datasets/com_28.csv')
elif cfg['task'] == 'Image_Water_Attack':
df_val = pd.read_csv('GenVideo/datasets/imgwater.csv')
elif cfg['task'] == 'Text_Water_Attack':
df_val = pd.read_csv('GenVideo/datasets/textwater.csv')
elif cfg['task'] == 'one2many':
df_val = pd.read_csv('GenVideo/datasets/val_ood.csv')
if cfg['train_sub_set'] == 'pika':
prefixes = ["fake/pika", "real"]
video_condition = df_train['content_path'].str.startswith(prefixes[0])
for prefix in prefixes[1:]:
video_condition |= df_train['content_path'].str.startswith(prefix)
df_train = df_train[video_condition]
elif cfg['train_sub_set'] == 'SEINE':
prefixes = ["fake/SEINE", "real"]
video_condition = df_train['content_path'].str.startswith(prefixes[0])
for prefix in prefixes[1:]:
video_condition |= df_train['content_path'].str.startswith(prefix)
df_train = df_train[video_condition]
elif cfg['train_sub_set'] == 'OpenSora':
prefixes = ["fake/OpenSora", "real"]
video_condition = df_train['content_path'].str.startswith(prefixes[0])
for prefix in prefixes[1:]:
video_condition |= df_train['content_path'].str.startswith(prefix)
df_train = df_train[video_condition]
elif cfg['train_sub_set'] == 'Latte':
prefixes = ["fake/Latte", "real"]
video_condition = df_train['content_path'].str.startswith(prefixes[0])
for prefix in prefixes[1:]:
video_condition |= df_train['content_path'].str.startswith(prefix)
df_train = df_train[video_condition]
else:
df_val = pd.read_csv('GenVideo/datasets/val_ood.csv')
df_train.reset_index(drop=True, inplace=True)
df_val.reset_index(drop=True, inplace=True)
index_val = df_val.index.tolist()
index_val = index_val[:]
val_dataset = Ours_Dataset_val(cfg, index_val, df_val)
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=cfg['val_batch_size'], shuffle=False, num_workers=cfg['num_workers'], pin_memory=True, drop_last=False
)
index_train = df_train.index.tolist()
index_train = index_train[:]
train_dataset = Ours_Dataset_train(index_train, df_train)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=cfg['train_batch_size'], shuffle=True, num_workers=cfg['num_workers'], pin_memory=True, drop_last=True
)
print("******* Training Video IDs", str(len(index_train))," Training Batch size ", str(cfg['train_batch_size'])," *******")
print("******* Testing Video IDs", str(len(index_val)), " Testing Batch size ", str(cfg['val_batch_size'])," *******")
return train_loader, val_loader