Spaces:
Runtime error
Runtime error
| from utils.dataset_utils import * | |
| class SingleVideoDataset(Dataset): | |
| def __init__( | |
| self, | |
| tokenizer = None, | |
| width: int = 256, | |
| height: int = 256, | |
| n_sample_frames: int = 4, | |
| frame_step: int = 1, | |
| single_video_path: str = "", | |
| single_video_prompt: str = "", | |
| use_caption: bool = False, | |
| use_bucketing: bool = False, | |
| **kwargs | |
| ): | |
| self.tokenizer = tokenizer | |
| self.use_bucketing = use_bucketing | |
| self.frames = [] | |
| self.index = 1 | |
| self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg") | |
| self.n_sample_frames = n_sample_frames | |
| self.frame_step = frame_step | |
| self.single_video_path = single_video_path | |
| self.single_video_prompt = single_video_prompt | |
| self.width = width | |
| self.height = height | |
| def create_video_chunks(self): | |
| vr = decord.VideoReader(self.single_video_path) | |
| vr_range = range(0, len(vr), self.frame_step) | |
| self.frames = list(self.chunk(vr_range, self.n_sample_frames)) | |
| return self.frames | |
| def chunk(self, it, size): | |
| it = iter(it) | |
| return iter(lambda: tuple(islice(it, size)), ()) | |
| def get_frame_batch(self, vr, resize=None): | |
| index = self.index | |
| frames = vr.get_batch(self.frames[self.index]) | |
| if type(frames) == decord.ndarray.NDArray: | |
| frames = torch.from_numpy(frames.asnumpy()) | |
| video = rearrange(frames, "f h w c -> f c h w") | |
| if resize is not None: video = resize(video) | |
| return video | |
| def get_frame_buckets(self, vr): | |
| h, w, c = vr[0].shape | |
| width, height = sensible_buckets(self.width, self.height, w, h) | |
| resize = T.transforms.Resize((height, width), antialias=True) | |
| return resize | |
| def process_video_wrapper(self, vid_path): | |
| video, vr = process_video( | |
| vid_path, | |
| self.use_bucketing, | |
| self.width, | |
| self.height, | |
| self.get_frame_buckets, | |
| self.get_frame_batch | |
| ) | |
| return video, vr | |
| def single_video_batch(self, index): | |
| train_data = self.single_video_path | |
| self.index = index | |
| if train_data.endswith(self.vid_types): | |
| video, _ = self.process_video_wrapper(train_data) | |
| prompt = self.single_video_prompt | |
| prompt_ids = get_prompt_ids(prompt, self.tokenizer) | |
| return video, prompt, prompt_ids | |
| else: | |
| raise ValueError(f"Single video is not a video type. Types: {self.vid_types}") | |
| def __getname__(): return 'single_video' | |
| def __len__(self): | |
| return len(self.create_video_chunks()) | |
| def __getitem__(self, index): | |
| video, prompt, prompt_ids = self.single_video_batch(index) | |
| example = { | |
| "pixel_values": (video / 127.5 - 1.0), | |
| "prompt_ids": prompt_ids[0], | |
| "text_prompt": prompt, | |
| 'dataset': self.__getname__() | |
| } | |
| return example |