Spaces:
Running
Running
Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
| import os | |
| import cv2 | |
| import json | |
| import torch | |
| import random | |
| import logging | |
| import tempfile | |
| import numpy as np | |
| from copy import copy | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from utils.registry_class import DATASETS | |
| class VideoDataset(Dataset): | |
| def __init__(self, | |
| data_list, | |
| data_dir_list, | |
| max_words=1000, | |
| resolution=(384, 256), | |
| vit_resolution=(224, 224), | |
| max_frames=16, | |
| sample_fps=8, | |
| transforms=None, | |
| vit_transforms=None, | |
| get_first_frame=False, | |
| **kwargs): | |
| self.max_words = max_words | |
| self.max_frames = max_frames | |
| self.resolution = resolution | |
| self.vit_resolution = vit_resolution | |
| self.sample_fps = sample_fps | |
| self.transforms = transforms | |
| self.vit_transforms = vit_transforms | |
| self.get_first_frame = get_first_frame | |
| image_list = [] | |
| for item_path, data_dir in zip(data_list, data_dir_list): | |
| lines = open(item_path, 'r').readlines() | |
| lines = [[data_dir, item] for item in lines] | |
| image_list.extend(lines) | |
| self.image_list = image_list | |
| def __getitem__(self, index): | |
| data_dir, file_path = self.image_list[index] | |
| video_key = file_path.split('|||')[0] | |
| try: | |
| ref_frame, vit_frame, video_data, caption = self._get_video_data(data_dir, file_path) | |
| except Exception as e: | |
| logging.info('{} get frames failed... with error: {}'.format(video_key, e)) | |
| caption = '' | |
| video_key = '' | |
| ref_frame = torch.zeros(3, self.resolution[1], self.resolution[0]) | |
| vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) | |
| video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0]) | |
| return ref_frame, vit_frame, video_data, caption, video_key | |
| def _get_video_data(self, data_dir, file_path): | |
| video_key, caption = file_path.split('|||') | |
| file_path = os.path.join(data_dir, video_key) | |
| for _ in range(5): | |
| try: | |
| capture = cv2.VideoCapture(file_path) | |
| _fps = capture.get(cv2.CAP_PROP_FPS) | |
| _total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT) | |
| stride = round(_fps / self.sample_fps) | |
| cover_frame_num = (stride * self.max_frames) | |
| if _total_frame_num < cover_frame_num + 5: | |
| start_frame = 0 | |
| end_frame = _total_frame_num | |
| else: | |
| start_frame = random.randint(0, _total_frame_num-cover_frame_num-5) | |
| end_frame = start_frame + cover_frame_num | |
| pointer, frame_list = 0, [] | |
| while(True): | |
| ret, frame = capture.read() | |
| pointer +=1 | |
| if (not ret) or (frame is None): break | |
| if pointer < start_frame: continue | |
| if pointer >= end_frame - 1: break | |
| if (pointer - start_frame) % stride == 0: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frame = Image.fromarray(frame) | |
| frame_list.append(frame) | |
| break | |
| except Exception as e: | |
| logging.info('{} read video frame failed with error: {}'.format(video_key, e)) | |
| continue | |
| video_data = torch.zeros(self.max_frames, 3, self.resolution[1], self.resolution[0]) | |
| if self.get_first_frame: | |
| ref_idx = 0 | |
| else: | |
| ref_idx = int(len(frame_list)/2) | |
| try: | |
| if len(frame_list)>0: | |
| mid_frame = copy(frame_list[ref_idx]) | |
| vit_frame = self.vit_transforms(mid_frame) | |
| frames = self.transforms(frame_list) | |
| video_data[:len(frame_list), ...] = frames | |
| else: | |
| vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) | |
| except: | |
| vit_frame = torch.zeros(3, self.vit_resolution[1], self.vit_resolution[0]) | |
| ref_frame = copy(frames[ref_idx]) | |
| return ref_frame, vit_frame, video_data, caption | |
| def __len__(self): | |
| return len(self.image_list) | |