# new_dataloader.py import torch import cv2 import os import random import pandas as pd import albumentations import numpy as np from torch.utils.data import Dataset, DataLoader # =========== TRANSFORMATION HELPERS =========== def get_train_transforms(): """Defines the probabilistic augmentations for training.""" return albumentations.Compose([ albumentations.Resize(224, 224), albumentations.HorizontalFlip(p=0.5), albumentations.ImageCompression(quality_lower=50, quality_upper=100, p=0.5), albumentations.GaussNoise(p=0.3), albumentations.GaussianBlur(blur_limit=(3, 5), p=0.3), albumentations.ToGray(p=0.01), albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0) ]) def get_val_transforms(cfg): """Defines augmentations for validation, handling different attack tasks from the config.""" aug_list = [albumentations.Resize(224, 224)] task = cfg.get('task', 'normal') # Use .get for safety if task == 'JPEG_Compress_Attack': aug_list.append(albumentations.JpegCompression(quality_lower=35, quality_upper=35, p=1.0)) elif task == 'FLIP_Attack': aug_list.append(albumentations.HorizontalFlip(p=0.5)) # Original had random choice, 50% HFlip is common elif task == 'CROP_Attack': aug_list.append(albumentations.RandomCrop(height=192, width=192, p=1.0)) aug_list.append(albumentations.Resize(224, 224)) elif task == 'Color_Attack': aug_list.append(albumentations.ColorJitter(p=1.0)) elif task == 'Gaussian_Attack': aug_list.append(albumentations.GaussianBlur(blur_limit=(7, 7), p=1.0)) aug_list.append(albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, p=1.0)) return albumentations.Compose(aug_list) # =========== TRAINING DATASET =========== class VideoDataset(Dataset): """ A PyTorch Dataset for loading video frame sequences based on a DataFrame. Handles class balancing for each epoch. """ def __init__(self, df, index_list, base_data_path, transform=None, select_frame_nums=8): self.df = df self.index_list = index_list self.base_data_path = base_data_path self.transform = transform self.select_frame_nums = select_frame_nums self.positive_indices = self.df[self.df['label'] == 1].index.tolist() self.negative_indices = self.df[self.df['label'] == 0].index.tolist() self.balanced_indices = [] self.resample() def resample(self): min_samples = min(len(self.positive_indices), len(self.negative_indices)) self.balanced_indices.clear() self.balanced_indices.extend(random.sample(self.positive_indices, min_samples)) self.balanced_indices.extend(random.sample(self.negative_indices, min_samples)) random.shuffle(self.balanced_indices) def __len__(self): return len(self.balanced_indices) def __getitem__(self, idx): real_idx = self.balanced_indices[idx] row = self.df.iloc[real_idx] video_id = row['content_path'] label = int(row['label']) frame_list = eval(row['frame_seq']) frames = [] if len(frame_list) >= self.select_frame_nums: start_index = random.randint(0, len(frame_list) - self.select_frame_nums) selected_frames = frame_list[start_index : start_index + self.select_frame_nums] else: selected_frames = frame_list for frame_path in selected_frames: try: image = cv2.imread(frame_path) if image is None: raise ValueError("Failed to load") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) except Exception: image = np.zeros((224, 224, 3), dtype=np.uint8) if self.transform: image = self.transform(image=image)['image'] frames.append(image.transpose(2, 0, 1)[np.newaxis, :]) pad_num = self.select_frame_nums - len(frames) if pad_num > 0: for _ in range(pad_num): frames.append(np.zeros((1, 3, 224, 224))) frames_tensor = np.concatenate(frames, axis=0) frames_tensor = torch.from_numpy(frames_tensor).float().unsqueeze(0) label_onehot = torch.zeros(2) label_onehot[label] = 1.0 binary_label = torch.FloatTensor([label]) original_index = self.index_list[idx] return original_index, frames_tensor, label_onehot, binary_label # =========== VALIDATION DATASET =========== class VideoDatasetVal(Dataset): """A compatible validation dataset loader.""" def __init__(self, df, index_list, base_data_path, transform=None, select_frame_nums=8): self.df = df self.index_list = index_list self.base_data_path = base_data_path self.transform = transform self.select_frame_nums = select_frame_nums def __len__(self): return len(self.index_list) def __getitem__(self, idx): # Validation does not use balanced sampling, it uses the provided index directly real_idx = self.index_list[idx] row = self.df.iloc[real_idx] video_id = row['content_path'] label = int(row['label']) frame_list = eval(row['frame_seq']) # This part is identical to the training dataset's __getitem__ frames = [] if len(frame_list) >= self.select_frame_nums: start_index = random.randint(0, len(frame_list) - self.select_frame_nums) selected_frames = frame_list[start_index : start_index + self.select_frame_nums] else: selected_frames = frame_list for frame_path in selected_frames: try: image = cv2.imread(frame_path) if image is None: raise ValueError("Failed to load") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) except Exception: image = np.zeros((224, 224, 3), dtype=np.uint8) if self.transform: image = self.transform(image=image)['image'] frames.append(image.transpose(2, 0, 1)[np.newaxis, :]) pad_num = self.select_frame_nums - len(frames) if pad_num > 0: for _ in range(pad_num): frames.append(np.zeros((1, 3, 224, 224))) frames_tensor = np.concatenate(frames, axis=0) frames_tensor = torch.from_numpy(frames_tensor).float().unsqueeze(0) label_onehot = torch.zeros(2) label_onehot[label] = 1.0 binary_label = torch.FloatTensor([label]) # The original validation loader returned video_id at the end return self.index_list[idx], frames_tensor, label_onehot, binary_label, video_id # =========== DATALOADER GENERATOR FUNCTION =========== def generate_dataset_loader(cfg): """ The main function to create train and validation dataloaders using the new classes. """ df_train = pd.read_csv('/home/kalpit/workspace/aigc/repos/DeMamba/csv/veo_train.csv') # This logic for selecting different validation sets is preserved task = cfg.get('task', 'normal') if task == 'normal': df_val = pd.read_csv('GenVideo/datasets/val_id.csv') elif task == 'robust_compress': df_val = pd.read_csv('GenVideo/datasets/com_28.csv') # ... (add other elif conditions from your original script if needed) ... else: df_val = pd.read_csv('/home/kalpit/workspace/aigc/repos/DeMamba/csv/veo_test.csv') # This logic for subsetting the training data is also preserved if cfg.get('train_sub_set'): prefixes = [f"fake/{cfg['train_sub_set']}", "real"] condition = df_train['content_path'].str.startswith(tuple(prefixes)) df_train = df_train[condition] df_train.reset_index(drop=True, inplace=True) df_val.reset_index(drop=True, inplace=True) index_train = df_train.index.tolist() index_val = df_val.index.tolist() # --- Use the new VideoDataset classes --- base_data_path = 'GenVideo' train_dataset = VideoDataset( df=df_train, index_list=index_train, base_data_path=base_data_path, transform=get_train_transforms(), select_frame_nums=8 ) val_dataset = VideoDatasetVal( df=df_val, index_list=index_val, base_data_path=base_data_path, transform=get_val_transforms(cfg), select_frame_nums=8 ) train_loader = DataLoader( train_dataset, batch_size=cfg['train_batch_size'], shuffle=True, num_workers=cfg['num_workers'], pin_memory=True, drop_last=True ) val_loader = DataLoader( val_dataset, batch_size=cfg['val_batch_size'], shuffle=False, num_workers=cfg['num_workers'], pin_memory=True, drop_last=False ) print(f"******* Training Videos {len(index_train)}, Batch size {cfg['train_batch_size']} *******") print(f"******* Testing Videos {len(index_val)}, Batch size {cfg['val_batch_size']} *******") return train_loader, val_loader