Spaces:
Configuration error
Configuration error
| import os | |
| import glob | |
| from functools import partial | |
| from tqdm import tqdm, trange | |
| from multiprocessing import Pool | |
| from PIL import Image | |
| import cv2 | |
| import mlxu | |
| from natsort import natsorted | |
| import numpy as np | |
| import einops | |
| import torch | |
| from vqlm_demo.inference import MultiProcessInferenceModel | |
| from vqlm_demo.utils import ( | |
| is_video, random_square_crop, | |
| read_frames_from_dir, read_frames_from_video | |
| ) | |
| FLAGS, _ = mlxu.define_flags_with_default( | |
| checkpoint='', | |
| input_files='', | |
| frame_input=False, | |
| read_file_list='', | |
| center_crop=1.0, | |
| n_context_frames=15, | |
| n_target_frames=1, | |
| n_workers=8, | |
| stride=8, | |
| batch_size=2, | |
| torch_devices='', | |
| shuffle=False, | |
| random_start=True, | |
| max_examples=0, | |
| ) | |
| class VideoDataset(torch.utils.data.Dataset): | |
| def __init__(self, videos, frame_input=False, n_context_frames=15, | |
| n_target_frames=1, stride=1): | |
| self.videos = videos | |
| self.frame_input = frame_input | |
| self.n_context_frames = n_context_frames | |
| self.n_target_frames = n_target_frames | |
| self.stride = stride | |
| def __getitem__(self, index): | |
| if self.frame_input: | |
| frames = read_frames_from_dir( | |
| self.videos[index], | |
| self.n_context_frames + self.n_target_frames, | |
| self.stride, | |
| center_crop=FLAGS.center_crop, | |
| random_start=FLAGS.random_start, | |
| ) | |
| else: | |
| frames = read_frames_from_video( | |
| self.videos[index], | |
| self.n_context_frames + self.n_target_frames, | |
| self.stride, | |
| center_crop=FLAGS.center_crop, | |
| random_start=FLAGS.random_start, | |
| ) | |
| if frames is None: | |
| return self[np.random.randint(0, len(self))] | |
| return frames[:self.n_context_frames], frames[self.n_context_frames:] | |
| def __len__(self): | |
| return len(self.videos) | |
| def main(_): | |
| assert FLAGS.checkpoint != '' | |
| assert FLAGS.read_file_list != '' or FLAGS.input_files != '' | |
| model = MultiProcessInferenceModel( | |
| checkpoint=FLAGS.checkpoint, | |
| torch_devices=FLAGS.torch_devices, | |
| perplexity_batch_size=FLAGS.batch_size, | |
| ) | |
| if FLAGS.read_file_list != '': | |
| with open(FLAGS.read_file_list, 'r') as f: | |
| videos = [x.strip() for x in f.readlines()] | |
| else: | |
| videos = glob.glob(FLAGS.input_files) | |
| if FLAGS.frame_input: | |
| videos = [x for x in videos if os.path.isdir(x)] | |
| else: | |
| videos = [x for x in videos if is_video(x)] | |
| if FLAGS.shuffle: | |
| np.random.shuffle(videos) | |
| if FLAGS.max_examples > 0: | |
| videos = videos[:FLAGS.max_examples] | |
| dataset = VideoDataset( | |
| videos, | |
| frame_input=FLAGS.frame_input, | |
| n_context_frames=FLAGS.n_context_frames, | |
| n_target_frames=FLAGS.n_target_frames, | |
| stride=FLAGS.stride | |
| ) | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=FLAGS.batch_size * model.n_processes * 4, | |
| shuffle=False, | |
| num_workers=FLAGS.n_workers, | |
| prefetch_factor=4, | |
| drop_last=True, | |
| ) | |
| perplexities = [] | |
| for batch_context_frames, batch_taret_frames in tqdm(dataloader, ncols=0): | |
| batch_context_frames = batch_context_frames.numpy() | |
| batch_taret_frames = batch_taret_frames.numpy() | |
| perplexity = model.compute_perplexity( | |
| batch_context_frames, batch_taret_frames | |
| ) | |
| perplexities.append(perplexity) | |
| perplexities = np.concatenate(perplexities, axis=0) | |
| print(f'Perplexity: {np.mean(perplexities)}') | |
| if __name__ == '__main__': | |
| mlxu.run(main) |