File size: 2,257 Bytes
d39b279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import pandas as pd
from datasets import load_dataset
import numpy as np
import tqdm.auto as tqdm
import io
import torch
import av
from torchvision.utils import save_image


def preprocess_and_save_frames(
    file_like: io.BytesIO,
    video_id: str,
    output_root: str = "./frames",
    crop_size: int = -1,
    every: int = 10,
    max_memory: int = 50 * 1024 * 1024,
    device: str = "cpu"
):
    """
    Loads a video and saves frames as images in output_root/<video_id>/.
    """

    # Ensure the base frames directory exists
    os.makedirs(output_root, exist_ok=True)

    # Create subfolder for this specific video
    video_dir = os.path.join(output_root, video_id)
    os.makedirs(video_dir, exist_ok=True)

    center_crop_transform = None
    if crop_size > 0:
        from torchvision import transforms
        center_crop_transform = transforms.CenterCrop(crop_size)

    file_like.seek(0)
    container = av.open(file_like)
    current_memory = 0

    for i, frame in enumerate(container.decode(video=0)):
        if i % every == 0:
            frame_array = frame.to_ndarray(format="rgb24")
            frame_tensor = torch.from_numpy(frame_array).permute(2, 0, 1).float() / 255.0

            if center_crop_transform is not None:
                frame_tensor = center_crop_transform(frame_tensor)

            frame_path = os.path.join(video_dir, f"frame_{i:05d}.png")
            save_image(frame_tensor, frame_path)

            frame_bytes = frame_tensor.numel() * 4
            current_memory += frame_bytes
            if current_memory >= max_memory:
                break


# -------- Main section --------
if __name__ == "__main__":
    DATASET_PATH = "/tmp/data"
    dataset_remote = load_dataset(DATASET_PATH, split="test", streaming=True)

    # Make sure ./frames exists before processing
    os.makedirs("./frames", exist_ok=True)

    for el in tqdm.tqdm(dataset_remote, desc="Extracting frames"):
        try:
            video_id = str(el["id"])
            file_like = io.BytesIO(el["video"]["bytes"])
            preprocess_and_save_frames(file_like, video_id)
        except Exception as e:
            print(f"⚠️ Failed {el['id']}: {e}")

    print("✅ All frames saved under ./frames/<video_id>/")