Spaces:
Configuration error
Configuration error
| """ | |
| Batch generation for sequnce of images. This script accept a jsonl file | |
| as input. Each line of the jsonl file representing a dictionary. Each line | |
| represents one example in the evaluation set. The dictionary should have two key: | |
| input: a list of paths to the input images as context to the model. | |
| output: a string representing the path to the output of generation to be saved. | |
| Ths script runs the mode to generate the output images, and concatenate the | |
| input and output images together and save them to the output path. | |
| """ | |
| import os | |
| import json | |
| from PIL import Image | |
| import numpy as np | |
| import mlxu | |
| from tqdm import tqdm, trange | |
| from multiprocessing import Pool | |
| import einops | |
| import torch | |
| from .inference import MultiProcessInferenceModel | |
| from .utils import read_image_to_tensor, MultiProcessImageSaver | |
| FLAGS, _ = mlxu.define_flags_with_default( | |
| input_file='', | |
| checkpoint='', | |
| input_base_dir='', | |
| output_base_dir='', | |
| evaluate_mse=False, | |
| json_input_key='input', | |
| json_output_key='output', | |
| json_target_key='target', | |
| n_new_frames=1, | |
| n_candidates=2, | |
| context_frames=16, | |
| temperature=1.0, | |
| top_p=1.0, | |
| n_workers=8, | |
| dtype='float16', | |
| torch_devices='', | |
| batch_size_factor=4, | |
| max_examples=0, | |
| resize_output='', | |
| include_input=False, | |
| ) | |
| # create this according to the json file. | |
| class MultiFrameDataset(torch.utils.data.Dataset): | |
| def __init__(self, input_files, output_files, target_files=None): | |
| assert len(input_files) | |
| self.input_files = input_files | |
| self.output_files = output_files | |
| self.target_files = target_files | |
| def __len__(self): | |
| return len(self.input_files) | |
| def __getitem__(self, idx): | |
| original_size = Image.open(self.input_files[idx][-1]).size | |
| input_images = np.stack( | |
| [read_image_to_tensor(f) for f in self.input_files[idx]], | |
| axis=0 | |
| ) | |
| if self.target_files is not None: | |
| target_images = np.stack( | |
| [read_image_to_tensor(f) for f in self.target_files[idx]], | |
| axis=0 | |
| ) | |
| else: | |
| target_images = None | |
| return input_images, target_images, self.output_files[idx], np.array(original_size) | |
| def main(_): | |
| assert FLAGS.checkpoint != '' | |
| print(f'Loading checkpoint from {FLAGS.checkpoint}') | |
| print(f'Evaluating input file from {FLAGS.input_file}') | |
| # build a model. | |
| model = MultiProcessInferenceModel( | |
| checkpoint=FLAGS.checkpoint, | |
| torch_devices=FLAGS.torch_devices, | |
| dtype=FLAGS.dtype, | |
| context_frames=FLAGS.context_frames, | |
| use_lock=True, | |
| ) | |
| # input_files: the json file that needs to be generated by the other file. | |
| input_files = [] | |
| output_files = [] | |
| if FLAGS.evaluate_mse: | |
| target_files = [] | |
| else: | |
| target_files = None | |
| with mlxu.open_file(FLAGS.input_file, 'r') as f: | |
| for line in f: | |
| record = json.loads(line) | |
| input_files.append(record[FLAGS.json_input_key]) | |
| output_files.append(record[FLAGS.json_output_key]) | |
| if FLAGS.evaluate_mse: | |
| target_files.append(record[FLAGS.json_target_key]) | |
| if FLAGS.max_examples > 0: | |
| input_files = input_files[:FLAGS.max_examples] | |
| output_files = output_files[:FLAGS.max_examples] | |
| if FLAGS.evaluate_mse: | |
| target_files = target_files[:FLAGS.max_examples] | |
| if FLAGS.input_base_dir != '': | |
| input_files = [ | |
| [os.path.join(FLAGS.input_base_dir, x) for x in y] | |
| for y in input_files | |
| ] | |
| if FLAGS.evaluate_mse: | |
| target_files = [ | |
| [os.path.join(FLAGS.input_base_dir, x) for x in y] | |
| for y in target_files | |
| ] | |
| if FLAGS.output_base_dir != '': | |
| os.makedirs(FLAGS.output_base_dir, exist_ok=True) | |
| output_files = [ | |
| os.path.join(FLAGS.output_base_dir, x) | |
| for x in output_files | |
| ] | |
| dataset = MultiFrameDataset(input_files, output_files, target_files) | |
| data_loader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=FLAGS.batch_size_factor * model.n_processes, | |
| shuffle=False, | |
| num_workers=FLAGS.n_workers, | |
| ) | |
| image_saver = MultiProcessImageSaver(FLAGS.n_workers) | |
| mses = [] | |
| for batch_images, batch_targets, batch_output_files, batch_sizes in tqdm(data_loader, ncols=0): | |
| # batch_images is input. | |
| batch_images = batch_images.numpy() | |
| # | |
| context_length = batch_images.shape[1] | |
| generated_images = model( | |
| batch_images, | |
| FLAGS.n_new_frames, | |
| FLAGS.n_candidates, | |
| temperature=FLAGS.temperature, | |
| top_p=FLAGS.top_p | |
| ) | |
| repeated_batch = einops.repeat( | |
| batch_images, | |
| 'b s h w c -> b n s h w c', | |
| n=FLAGS.n_candidates, | |
| ) | |
| generated_images = np.array(generated_images) | |
| if FLAGS.evaluate_mse: | |
| batch_targets = einops.repeat( | |
| batch_targets.numpy(), | |
| 'b s h w c -> b n s h w c', # batch, candidate, s | |
| n=FLAGS.n_candidates, | |
| ) | |
| channels = batch_targets.shape[-1] | |
| # calculate mse loss. | |
| mse = np.mean((generated_images - batch_targets) ** 2, axis=(1, 2, 3, 4, 5)) | |
| mses.append(mse * channels) | |
| if FLAGS.include_input: | |
| combined = einops.rearrange( | |
| np.concatenate([repeated_batch, generated_images], axis=2), | |
| 'b n s h w c -> b (n h) (s w) c' | |
| ) | |
| else: | |
| combined = einops.rearrange( | |
| generated_images, | |
| 'b n s h w c -> b (n h) (s w) c' | |
| ) | |
| combined = (combined * 255).astype(np.uint8) | |
| n_frames = FLAGS.n_new_frames | |
| if FLAGS.include_input: | |
| n_frames += context_length | |
| if FLAGS.resize_output == '': | |
| resizes = None | |
| elif FLAGS.resize_output == 'original': | |
| resizes = batch_sizes.numpy() | |
| resizes = resizes * np.array([[n_frames, FLAGS.n_candidates]]) | |
| else: | |
| resize = tuple(int(x) for x in FLAGS.resize_output.split(',')) | |
| resizes = np.array([resize] * len(batch_sizes)) | |
| resizes = resizes * np.array([[n_frames, FLAGS.n_candidates]]) | |
| image_saver(combined, batch_output_files, resizes) | |
| if FLAGS.evaluate_mse: | |
| mses = np.concatenate(mses, axis=0) | |
| print(f'MSE: {np.mean(mses)}') | |
| image_saver.close() | |
| if __name__ == "__main__": | |
| mlxu.run(main) |