File size: 2,407 Bytes
d39b279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be0a2fb
 
d39b279
 
 
 
 
 
 
 
 
 
 
 
 
b17e291
d39b279
 
 
 
 
 
 
 
 
be0a2fb
 
 
d39b279
be0a2fb
 
 
 
b17e291
d39b279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be0a2fb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import pandas as pd
from datasets import load_dataset
import numpy as np
import tqdm.auto as tqdm
import io
import torch
import av
from torchvision.utils import save_image


def preprocess_and_save_frames(
    file_like: io.BytesIO,
    video_id: str,
    output_root: str = "./frames",
    crop_size: int = -1,
    every: int = 10,
    max_memory: int = 50 * 1024 * 1024,
    device: str = "cpu"
):
    """
    Loads a video and saves frames as images in output_root/.
    Frame names follow the pattern: <video_id>_frame_<n>.png
    """

    # Ensure the base frames directory exists
    os.makedirs(output_root, exist_ok=True)

    center_crop_transform = None
    if crop_size > 0:
        from torchvision import transforms
        center_crop_transform = transforms.CenterCrop(crop_size)

    file_like.seek(0)
    container = av.open(file_like)
    current_memory = 0
    frame_count = 1  # Start counting frames at 1

    for i, frame in enumerate(container.decode(video=0)):
        if i % every == 0:
            frame_array = frame.to_ndarray(format="rgb24")
            frame_tensor = torch.from_numpy(frame_array).permute(2, 0, 1).float() / 255.0

            if center_crop_transform is not None:
                frame_tensor = center_crop_transform(frame_tensor)

            # Flat naming pattern (no subfolders)
            frame_name = f"{video_id}_frame_{frame_count}.png"
            frame_path = os.path.join(output_root, frame_name)
            save_image(frame_tensor, frame_path)

            # Print saved frame name
            print(frame_name)

            frame_count += 1

            frame_bytes = frame_tensor.numel() * 4
            current_memory += frame_bytes
            if current_memory >= max_memory:
                break


# -------- Main section --------
if __name__ == "__main__":
    DATASET_PATH = "/tmp/data"
    dataset_remote = load_dataset(DATASET_PATH, split="test", streaming=True)

    # Make sure ./frames exists before processing
    os.makedirs("./frames", exist_ok=True)

    for el in tqdm.tqdm(dataset_remote, desc="Extracting frames"):
        try:
            video_id = str(el["id"])
            file_like = io.BytesIO(el["video"]["bytes"])
            preprocess_and_save_frames(file_like, video_id)
        except Exception as e:
            print(f"⚠️ Failed {el['id']}: {e}")

    print("✅ All frames saved under ./frames/")