Spaces:
Runtime error
Runtime error
| from utils.dataset_utils import * | |
| class VideoFolderDataset(Dataset): | |
| def __init__( | |
| self, | |
| tokenizer=None, | |
| width: int = 256, | |
| height: int = 256, | |
| n_sample_frames: int = 16, | |
| fps: int = 8, | |
| path: str = "./data", | |
| fallback_prompt: str = "", | |
| use_bucketing: bool = False, | |
| **kwargs | |
| ): | |
| self.tokenizer = tokenizer | |
| self.use_bucketing = use_bucketing | |
| self.fallback_prompt = fallback_prompt | |
| self.video_files = glob(f"{path}/*.mp4") | |
| self.width = width | |
| self.height = height | |
| self.n_sample_frames = n_sample_frames | |
| self.fps = fps | |
| def get_frame_buckets(self, vr): | |
| h, w, c = vr[0].shape | |
| width, height = sensible_buckets(self.width, self.height, w, h) | |
| resize = T.transforms.Resize((height, width), antialias=True) | |
| return resize | |
| def get_frame_batch(self, vr, resize=None): | |
| n_sample_frames = self.n_sample_frames | |
| native_fps = vr.get_avg_fps() | |
| every_nth_frame = max(1, round(native_fps / self.fps)) | |
| every_nth_frame = min(len(vr), every_nth_frame) | |
| effective_length = len(vr) // every_nth_frame | |
| if effective_length < n_sample_frames: | |
| n_sample_frames = effective_length | |
| effective_idx = random.randint(0, (effective_length - n_sample_frames)) | |
| idxs = every_nth_frame * np.arange(effective_idx, effective_idx + n_sample_frames) | |
| video = vr.get_batch(idxs) | |
| video = rearrange(video, "f h w c -> f c h w") | |
| if resize is not None: video = resize(video) | |
| return video, vr | |
| def process_video_wrapper(self, vid_path): | |
| video, vr = process_video( | |
| vid_path, | |
| self.use_bucketing, | |
| self.width, | |
| self.height, | |
| self.get_frame_buckets, | |
| self.get_frame_batch | |
| ) | |
| return video, vr | |
| def get_prompt_ids(self, prompt): | |
| return self.tokenizer( | |
| prompt, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=self.tokenizer.model_max_length, | |
| return_tensors="pt", | |
| ).input_ids | |
| def __getname__(): return 'folder' | |
| def __len__(self): | |
| return len(self.video_files) | |
| def __getitem__(self, index): | |
| video, _ = self.process_video_wrapper(self.video_files[index]) | |
| prompt = self.fallback_prompt | |
| prompt_ids = self.get_prompt_ids(prompt) | |
| return {"pixel_values": (video[0] / 127.5 - 1.0), "prompt_ids": prompt_ids[0], "text_prompt": prompt, 'dataset': self.__getname__()} |