seetrails_aigvdet_v2.0.0 / dataloader2.py
Kalpit
feat: Add model files with LFS
d39b279
# new_dataloader.py
import torch
import cv2
import os
import random
import pandas as pd
import albumentations
import numpy as np
from torch.utils.data import Dataset, DataLoader
# =========== TRANSFORMATION HELPERS ===========
def get_train_transforms():
"""Defines the probabilistic augmentations for training."""
return albumentations.Compose([
albumentations.Resize(224, 224),
albumentations.HorizontalFlip(p=0.5),
albumentations.ImageCompression(quality_lower=50, quality_upper=100, p=0.5),
albumentations.GaussNoise(p=0.3),
albumentations.GaussianBlur(blur_limit=(3, 5), p=0.3),
albumentations.ToGray(p=0.01),
albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0)
])
def get_val_transforms(cfg):
"""Defines augmentations for validation, handling different attack tasks from the config."""
aug_list = [albumentations.Resize(224, 224)]
task = cfg.get('task', 'normal') # Use .get for safety
if task == 'JPEG_Compress_Attack':
aug_list.append(albumentations.JpegCompression(quality_lower=35, quality_upper=35, p=1.0))
elif task == 'FLIP_Attack':
aug_list.append(albumentations.HorizontalFlip(p=0.5)) # Original had random choice, 50% HFlip is common
elif task == 'CROP_Attack':
aug_list.append(albumentations.RandomCrop(height=192, width=192, p=1.0))
aug_list.append(albumentations.Resize(224, 224))
elif task == 'Color_Attack':
aug_list.append(albumentations.ColorJitter(p=1.0))
elif task == 'Gaussian_Attack':
aug_list.append(albumentations.GaussianBlur(blur_limit=(7, 7), p=1.0))
aug_list.append(albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0))
return albumentations.Compose(aug_list)
# =========== TRAINING DATASET ===========
class VideoDataset(Dataset):
"""
A PyTorch Dataset for loading video frame sequences based on a DataFrame.
Handles class balancing for each epoch.
"""
def __init__(self, df, index_list, base_data_path, transform=None, select_frame_nums=8):
self.df = df
self.index_list = index_list
self.base_data_path = base_data_path
self.transform = transform
self.select_frame_nums = select_frame_nums
self.positive_indices = self.df[self.df['label'] == 1].index.tolist()
self.negative_indices = self.df[self.df['label'] == 0].index.tolist()
self.balanced_indices = []
self.resample()
def resample(self):
min_samples = min(len(self.positive_indices), len(self.negative_indices))
self.balanced_indices.clear()
self.balanced_indices.extend(random.sample(self.positive_indices, min_samples))
self.balanced_indices.extend(random.sample(self.negative_indices, min_samples))
random.shuffle(self.balanced_indices)
def __len__(self):
return len(self.balanced_indices)
def __getitem__(self, idx):
real_idx = self.balanced_indices[idx]
row = self.df.iloc[real_idx]
video_id = row['content_path']
label = int(row['label'])
frame_list = eval(row['frame_seq'])
frames = []
if len(frame_list) >= self.select_frame_nums:
start_index = random.randint(0, len(frame_list) - self.select_frame_nums)
selected_frames = frame_list[start_index : start_index + self.select_frame_nums]
else:
selected_frames = frame_list
for frame_path in selected_frames:
try:
image = cv2.imread(frame_path)
if image is None:
raise ValueError("Failed to load")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
except Exception:
image = np.zeros((224, 224, 3), dtype=np.uint8)
if self.transform:
image = self.transform(image=image)['image']
frames.append(image.transpose(2, 0, 1)[np.newaxis, :])
pad_num = self.select_frame_nums - len(frames)
if pad_num > 0:
for _ in range(pad_num):
frames.append(np.zeros((1, 3, 224, 224)))
frames_tensor = np.concatenate(frames, axis=0)
frames_tensor = torch.from_numpy(frames_tensor).float().unsqueeze(0)
label_onehot = torch.zeros(2)
label_onehot[label] = 1.0
binary_label = torch.FloatTensor([label])
original_index = self.index_list[idx]
return original_index, frames_tensor, label_onehot, binary_label
# =========== VALIDATION DATASET ===========
class VideoDatasetVal(Dataset):
"""A compatible validation dataset loader."""
def __init__(self, df, index_list, base_data_path, transform=None, select_frame_nums=8):
self.df = df
self.index_list = index_list
self.base_data_path = base_data_path
self.transform = transform
self.select_frame_nums = select_frame_nums
def __len__(self):
return len(self.index_list)
def __getitem__(self, idx):
# Validation does not use balanced sampling, it uses the provided index directly
real_idx = self.index_list[idx]
row = self.df.iloc[real_idx]
video_id = row['content_path']
label = int(row['label'])
frame_list = eval(row['frame_seq'])
# This part is identical to the training dataset's __getitem__
frames = []
if len(frame_list) >= self.select_frame_nums:
start_index = random.randint(0, len(frame_list) - self.select_frame_nums)
selected_frames = frame_list[start_index : start_index + self.select_frame_nums]
else:
selected_frames = frame_list
for frame_path in selected_frames:
try:
image = cv2.imread(frame_path)
if image is None:
raise ValueError("Failed to load")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
except Exception:
image = np.zeros((224, 224, 3), dtype=np.uint8)
if self.transform:
image = self.transform(image=image)['image']
frames.append(image.transpose(2, 0, 1)[np.newaxis, :])
pad_num = self.select_frame_nums - len(frames)
if pad_num > 0:
for _ in range(pad_num):
frames.append(np.zeros((1, 3, 224, 224)))
frames_tensor = np.concatenate(frames, axis=0)
frames_tensor = torch.from_numpy(frames_tensor).float().unsqueeze(0)
label_onehot = torch.zeros(2)
label_onehot[label] = 1.0
binary_label = torch.FloatTensor([label])
# The original validation loader returned video_id at the end
return self.index_list[idx], frames_tensor, label_onehot, binary_label, video_id
# =========== DATALOADER GENERATOR FUNCTION ===========
def generate_dataset_loader(cfg):
"""
The main function to create train and validation dataloaders using the new classes.
"""
df_train = pd.read_csv('/home/kalpit/workspace/aigc/repos/DeMamba/csv/veo_train.csv')
# This logic for selecting different validation sets is preserved
task = cfg.get('task', 'normal')
if task == 'normal':
df_val = pd.read_csv('GenVideo/datasets/val_id.csv')
elif task == 'robust_compress':
df_val = pd.read_csv('GenVideo/datasets/com_28.csv')
# ... (add other elif conditions from your original script if needed) ...
else:
df_val = pd.read_csv('/home/kalpit/workspace/aigc/repos/DeMamba/csv/veo_test.csv')
# This logic for subsetting the training data is also preserved
if cfg.get('train_sub_set'):
prefixes = [f"fake/{cfg['train_sub_set']}", "real"]
condition = df_train['content_path'].str.startswith(tuple(prefixes))
df_train = df_train[condition]
df_train.reset_index(drop=True, inplace=True)
df_val.reset_index(drop=True, inplace=True)
index_train = df_train.index.tolist()
index_val = df_val.index.tolist()
# --- Use the new VideoDataset classes ---
base_data_path = 'GenVideo'
train_dataset = VideoDataset(
df=df_train,
index_list=index_train,
base_data_path=base_data_path,
transform=get_train_transforms(),
select_frame_nums=8
)
val_dataset = VideoDatasetVal(
df=df_val,
index_list=index_val,
base_data_path=base_data_path,
transform=get_val_transforms(cfg),
select_frame_nums=8
)
train_loader = DataLoader(
train_dataset, batch_size=cfg['train_batch_size'], shuffle=True,
num_workers=cfg['num_workers'], pin_memory=True, drop_last=True
)
val_loader = DataLoader(
val_dataset, batch_size=cfg['val_batch_size'], shuffle=False,
num_workers=cfg['num_workers'], pin_memory=True, drop_last=False
)
print(f"******* Training Videos {len(index_train)}, Batch size {cfg['train_batch_size']} *******")
print(f"******* Testing Videos {len(index_val)}, Batch size {cfg['val_batch_size']} *******")
return train_loader, val_loader