seetrails_aigvdet_v2.0.0 / extract_frames.py
kalpitbcontrails's picture
Update extract_frames.py
be0a2fb verified
import os
import pandas as pd
from datasets import load_dataset
import numpy as np
import tqdm.auto as tqdm
import io
import torch
import av
from torchvision.utils import save_image
def preprocess_and_save_frames(
file_like: io.BytesIO,
video_id: str,
output_root: str = "./frames",
crop_size: int = -1,
every: int = 10,
max_memory: int = 50 * 1024 * 1024,
device: str = "cpu"
):
"""
Loads a video and saves frames as images in output_root/.
Frame names follow the pattern: <video_id>_frame_<n>.png
"""
# Ensure the base frames directory exists
os.makedirs(output_root, exist_ok=True)
center_crop_transform = None
if crop_size > 0:
from torchvision import transforms
center_crop_transform = transforms.CenterCrop(crop_size)
file_like.seek(0)
container = av.open(file_like)
current_memory = 0
frame_count = 1 # Start counting frames at 1
for i, frame in enumerate(container.decode(video=0)):
if i % every == 0:
frame_array = frame.to_ndarray(format="rgb24")
frame_tensor = torch.from_numpy(frame_array).permute(2, 0, 1).float() / 255.0
if center_crop_transform is not None:
frame_tensor = center_crop_transform(frame_tensor)
# Flat naming pattern (no subfolders)
frame_name = f"{video_id}_frame_{frame_count}.png"
frame_path = os.path.join(output_root, frame_name)
save_image(frame_tensor, frame_path)
# Print saved frame name
print(frame_name)
frame_count += 1
frame_bytes = frame_tensor.numel() * 4
current_memory += frame_bytes
if current_memory >= max_memory:
break
# -------- Main section --------
if __name__ == "__main__":
DATASET_PATH = "/tmp/data"
dataset_remote = load_dataset(DATASET_PATH, split="test", streaming=True)
# Make sure ./frames exists before processing
os.makedirs("./frames", exist_ok=True)
for el in tqdm.tqdm(dataset_remote, desc="Extracting frames"):
try:
video_id = str(el["id"])
file_like = io.BytesIO(el["video"]["bytes"])
preprocess_and_save_frames(file_like, video_id)
except Exception as e:
print(f"⚠️ Failed {el['id']}: {e}")
print("✅ All frames saved under ./frames/")