Spaces:
Runtime error
Runtime error
| from argparse import ArgumentParser | |
| import os | |
| import json | |
| import sys | |
| from tqdm import tqdm | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader | |
| import torchvision.transforms as transforms | |
| sys.path.append(".") | |
| sys.path.append("..") | |
| from criteria.lpips.lpips import LPIPS | |
| from datasets.gt_res_dataset import GTResDataset | |
| def parse_args(): | |
| parser = ArgumentParser(add_help=False) | |
| parser.add_argument('--mode', type=str, default='lpips', choices=['lpips', 'l2']) | |
| parser.add_argument('--data_path', type=str, default='results') | |
| parser.add_argument('--gt_path', type=str, default='gt_images') | |
| parser.add_argument('--workers', type=int, default=4) | |
| parser.add_argument('--batch_size', type=int, default=4) | |
| parser.add_argument('--is_cars', action='store_true') | |
| args = parser.parse_args() | |
| return args | |
| def run(args): | |
| resize_dims = (256, 256) | |
| if args.is_cars: | |
| resize_dims = (192, 256) | |
| transform = transforms.Compose([transforms.Resize(resize_dims), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) | |
| print('Loading dataset') | |
| dataset = GTResDataset(root_path=args.data_path, | |
| gt_dir=args.gt_path, | |
| transform=transform) | |
| dataloader = DataLoader(dataset, | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| num_workers=int(args.workers), | |
| drop_last=True) | |
| if args.mode == 'lpips': | |
| loss_func = LPIPS(net_type='alex') | |
| elif args.mode == 'l2': | |
| loss_func = torch.nn.MSELoss() | |
| else: | |
| raise Exception('Not a valid mode!') | |
| loss_func.cuda() | |
| global_i = 0 | |
| scores_dict = {} | |
| all_scores = [] | |
| for result_batch, gt_batch in tqdm(dataloader): | |
| for i in range(args.batch_size): | |
| loss = float(loss_func(result_batch[i:i + 1].cuda(), gt_batch[i:i + 1].cuda())) | |
| all_scores.append(loss) | |
| im_path = dataset.pairs[global_i][0] | |
| scores_dict[os.path.basename(im_path)] = loss | |
| global_i += 1 | |
| all_scores = list(scores_dict.values()) | |
| mean = np.mean(all_scores) | |
| std = np.std(all_scores) | |
| result_str = 'Average loss is {:.2f}+-{:.2f}'.format(mean, std) | |
| print('Finished with ', args.data_path) | |
| print(result_str) | |
| out_path = os.path.join(os.path.dirname(args.data_path), 'inference_metrics') | |
| if not os.path.exists(out_path): | |
| os.makedirs(out_path) | |
| with open(os.path.join(out_path, 'stat_{}.txt'.format(args.mode)), 'w') as f: | |
| f.write(result_str) | |
| with open(os.path.join(out_path, 'scores_{}.json'.format(args.mode)), 'w') as f: | |
| json.dump(scores_dict, f) | |
| if __name__ == '__main__': | |
| args = parse_args() | |
| run(args) | |