Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import cv2 | |
| import os, argparse, json | |
| from os.path import join | |
| from glob import glob | |
| import torch | |
| import torch.nn.functional as F | |
| from model.model import ResHalf | |
| from model.model import Quantize | |
| from model.loss import l1_loss | |
| from utils import util | |
| from utils.dct import DCT_Lowfrequency | |
| from utils.filters_tensor import bgr2gray | |
| from collections import OrderedDict | |
| class Inferencer: | |
| def __init__(self, checkpoint_path, model, use_cuda=True, multi_gpu=True): | |
| self.checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) | |
| self.use_cuda = use_cuda | |
| self.model = model.eval() | |
| if multi_gpu: | |
| self.model = torch.nn.DataParallel(self.model) | |
| state_dict = self.checkpoint['state_dict'] | |
| else: | |
| ## remove keyword "module" in the state_dict | |
| state_dict = OrderedDict() | |
| for k, v in self.checkpoint['state_dict'].items(): | |
| name = k[7:] | |
| state_dict[name] = v | |
| if self.use_cuda: | |
| self.model = self.model.cuda() | |
| self.model.load_state_dict(state_dict) | |
| def __call__(self, input_img, decoding_only=False): | |
| with torch.no_grad(): | |
| scale = 8 | |
| _, _, H, W = input_img.shape | |
| if H % scale != 0 or W % scale != 0: | |
| input_img = F.pad(input_img, [0, scale - W % scale, 0, scale - H % scale], mode='reflect') | |
| if self.use_cuda: | |
| input_img = input_img.cuda() | |
| if decoding_only: | |
| resColor = self.model(input_img, decoding_only) | |
| if H % scale != 0 or W % scale != 0: | |
| resColor = resColor[:, :, :H, :W] | |
| return resColor | |
| else: | |
| resHalftone, resColor = self.model(input_img, decoding_only) | |
| resHalftone = Quantize.apply((resHalftone + 1.0) * 0.5) * 2.0 - 1. | |
| if H % scale != 0 or W % scale != 0: | |
| resHalftone = resHalftone[:, :, :H, :W] | |
| resColor = resColor[:, :, :H, :W] | |
| return resHalftone, resColor | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='invHalf') | |
| parser.add_argument('--model', default=None, type=str, | |
| help='model weight file path') | |
| parser.add_argument('--decoding', action='store_true', default=False, help='restoration from halftone input') | |
| parser.add_argument('--data_dir', default=None, type=str, | |
| help='where to load input data (RGB images)') | |
| parser.add_argument('--save_dir', default=None, type=str, | |
| help='where to save the result') | |
| args = parser.parse_args() | |
| invhalfer = Inferencer( | |
| checkpoint_path=args.model, | |
| model=ResHalf(train=False) | |
| ) | |
| save_dir = os.path.join(args.save_dir) | |
| util.ensure_dir(save_dir) | |
| test_imgs = glob(join(args.data_dir, '*.*g')) | |
| print('------loaded %d images.' % len(test_imgs) ) | |
| for img in test_imgs: | |
| print('[*] processing %s ...' % img) | |
| if args.decoding: | |
| input_img = cv2.imread(img, flags=cv2.IMREAD_GRAYSCALE) / 127.5 - 1. | |
| c = invhalfer(util.img2tensor(input_img), decoding_only=True) | |
| c = util.tensor2img(c / 2. + 0.5) * 255. | |
| cv2.imwrite(join(save_dir, 'restored_' + img.split('/')[-1].split('.')[0] + '.png'), c) | |
| else: | |
| input_img = cv2.imread(img, flags=cv2.IMREAD_COLOR) / 127.5 - 1. | |
| h, c = invhalfer(util.img2tensor(input_img), decoding_only=False) | |
| h = util.tensor2img(h / 2. + 0.5) * 255. | |
| c = util.tensor2img(c / 2. + 0.5) * 255. | |
| cv2.imwrite(join(save_dir, 'halftone_' + img.split('/')[-1].split('.')[0] + '.png'), h) | |
| cv2.imwrite(join(save_dir, 'restored_' + img.split('/')[-1].split('.')[0] + '.png'), c) | |