Spaces:
Configuration error
Configuration error
| """ | |
| Evaluating the perplexity on few shot tasks. This script accept a jsonl file | |
| as input. Each line of the jsonl file representing a dictionary. Each line | |
| represents one example in the evaluation set. The dictionary should have two key: | |
| input: a list of paths to the input images as context to the model. This | |
| list should include the few shot examples. | |
| target: a list of paths to the target images to evaluate perplexity | |
| Ths script should run the model and compute the average perplexity on the | |
| evaluation set. | |
| """ | |
| import os | |
| import json | |
| from PIL import Image | |
| import numpy as np | |
| import mlxu | |
| from tqdm import tqdm, trange | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import einops | |
| from .inference import MultiProcessInferenceModel | |
| FLAGS, _ = mlxu.define_flags_with_default( | |
| input_file='', | |
| checkpoint='', | |
| input_base_dir='', | |
| batch_size=2, | |
| json_input_key='input', | |
| json_target_key='target', | |
| dtype='float16', | |
| torch_devices='', | |
| n_workers=4, | |
| max_examples=0, | |
| ) | |
| def read_image_to_tensor(path): | |
| pil_im = Image.open(path).convert('RGB') | |
| input_img = pil_im.resize((256, 256)) | |
| input_img = np.array(input_img) / 255.0 | |
| input_img = input_img.astype(np.float32) | |
| return input_img | |
| class MultiFrameDataset(torch.utils.data.Dataset): | |
| def __init__(self, input_files, target_files): | |
| assert len(input_files) == len(target_files) | |
| self.input_files = input_files | |
| self.target_files = target_files | |
| def __len__(self): | |
| return len(self.input_files) | |
| def __getitem__(self, idx): | |
| input_list = np.stack( | |
| [read_image_to_tensor(f) for f in self.input_files[idx]], | |
| axis=0 | |
| ) | |
| target_list = np.stack( | |
| [read_image_to_tensor(f) for f in self.target_files[idx]], | |
| axis=0 | |
| ) | |
| return input_list, target_list | |
| def main(_): | |
| assert FLAGS.checkpoint != '' | |
| print(f'Loading checkpoint from {FLAGS.checkpoint}') | |
| print(f'Evaluating input file from {FLAGS.input_file}') | |
| model = MultiProcessInferenceModel( | |
| checkpoint=FLAGS.checkpoint, | |
| torch_devices=FLAGS.torch_devices, | |
| dtype=FLAGS.dtype, | |
| use_lock=True, | |
| perplexity_batch_size=FLAGS.batch_size, | |
| ) | |
| input_files = [] | |
| target_files = [] | |
| with mlxu.open_file(FLAGS.input_file, 'r') as f: | |
| for line in f: | |
| record = json.loads(line) | |
| input_files.append(record[FLAGS.json_input_key]) | |
| target_files.append(record[FLAGS.json_target_key]) | |
| if FLAGS.input_base_dir != '': | |
| input_files = [ | |
| [os.path.join(FLAGS.input_base_dir, x) for x in y] | |
| for y in input_files | |
| ] | |
| target_files = [ | |
| [os.path.join(FLAGS.input_base_dir, x) for x in y] | |
| for y in target_files | |
| ] | |
| if FLAGS.max_examples > 0: | |
| input_files = input_files[:FLAGS.max_examples] | |
| target_files = target_files[:FLAGS.max_examples] | |
| dataset = MultiFrameDataset(input_files, target_files) | |
| data_loader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=FLAGS.batch_size * model.n_processes, | |
| shuffle=False, | |
| num_workers=FLAGS.n_workers | |
| ) | |
| perplexities = [] | |
| for input_images, target_images in tqdm(data_loader, ncols=0): | |
| perplexity = model.compute_perplexity(input_images, target_images) | |
| perplexities.append(perplexity) | |
| perplexities = np.concatenate(perplexities, axis=0) | |
| print(f'Perplexity: {np.mean(perplexities)}') | |
| if __name__ == "__main__": | |
| mlxu.run(main) |