Spaces:
Runtime error
Runtime error
| # https://github.com/xinntao/facexlib/blob/master/inference/inference_matting.py | |
| from tqdm import tqdm, trange | |
| import argparse | |
| from pathlib import Path | |
| import cv2 | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from torchvision.transforms.functional import normalize | |
| from facexlib.matting import init_matting_model | |
| from facexlib.utils import img2tensor | |
| def matt_single(args): | |
| modnet = init_matting_model() | |
| # read image | |
| img = cv2.imread(args.img_path) / 255. | |
| # unify image channels to 3 | |
| if len(img.shape) == 2: | |
| img = img[:, :, None] | |
| if img.shape[2] == 1: | |
| img = np.repeat(img, 3, axis=2) | |
| elif img.shape[2] == 4: | |
| img = img[:, :, 0:3] | |
| img_t = img2tensor(img, bgr2rgb=True, float32=True) | |
| normalize(img_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) | |
| img_t = img_t.unsqueeze(0).cuda() | |
| # resize image for input | |
| _, _, im_h, im_w = img_t.shape | |
| ref_size = 512 | |
| if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size: | |
| if im_w >= im_h: | |
| im_rh = ref_size | |
| im_rw = int(im_w / im_h * ref_size) | |
| elif im_w < im_h: | |
| im_rw = ref_size | |
| im_rh = int(im_h / im_w * ref_size) | |
| else: | |
| im_rh = im_h | |
| im_rw = im_w | |
| im_rw = im_rw - im_rw % 32 | |
| im_rh = im_rh - im_rh % 32 | |
| img_t = F.interpolate(img_t, size=(im_rh, im_rw), mode='area') | |
| # inference | |
| _, _, matte = modnet(img_t, True) | |
| # resize and save matte | |
| matte = F.interpolate(matte, size=(im_h, im_w), mode='area') | |
| matte = matte[0][0].data.cpu().numpy() | |
| cv2.imwrite(args.save_path, (matte * 255).astype('uint8')) | |
| # get foreground | |
| matte = matte[:, :, None] | |
| foreground = img * matte + np.full(img.shape, 1) * (1 - matte) | |
| cv2.imwrite(args.save_path.replace('.png', '_fg.png'), foreground * 255) | |
| def matt_directory(args): # for extracting ffhq imgs foreground | |
| modnet = init_matting_model() | |
| all_imgs = list(Path(args.img_dir_path).rglob('*.png')) | |
| print('all imgs: ', len(all_imgs)) | |
| tgt_dir_path = '/mnt/lustre/share/yslan/ffhq/unzipped_ffhq_matte/' | |
| # tgt_img_path = '/mnt/lustre/share/yslan/ffhq/unzipped_ffhq_matting/' | |
| for img_path in tqdm(all_imgs): | |
| # read image | |
| # img = cv2.imread(args.img_path) / 255. | |
| img = cv2.imread(str(img_path)) / 255. | |
| relative_img_path = Path(img_path).relative_to('/mnt/lustre/share/yslan/ffhq/unzipped_ffhq_512/') | |
| tgt_save_path = tgt_dir_path / relative_img_path | |
| (tgt_save_path.parent).mkdir(parents=True, exist_ok=True) | |
| # unify image channels to 3 | |
| if len(img.shape) == 2: | |
| img = img[:, :, None] | |
| if img.shape[2] == 1: | |
| img = np.repeat(img, 3, axis=2) | |
| elif img.shape[2] == 4: | |
| img = img[:, :, 0:3] | |
| img_t = img2tensor(img, bgr2rgb=True, float32=True) | |
| normalize(img_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) | |
| img_t = img_t.unsqueeze(0).cuda() | |
| # resize image for input | |
| _, _, im_h, im_w = img_t.shape | |
| ref_size = 512 | |
| if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size: | |
| if im_w >= im_h: | |
| im_rh = ref_size | |
| im_rw = int(im_w / im_h * ref_size) | |
| elif im_w < im_h: | |
| im_rw = ref_size | |
| im_rh = int(im_h / im_w * ref_size) | |
| else: | |
| im_rh = im_h | |
| im_rw = im_w | |
| im_rw = im_rw - im_rw % 32 | |
| im_rh = im_rh - im_rh % 32 | |
| img_t = F.interpolate(img_t, size=(im_rh, im_rw), mode='area') | |
| # inference | |
| _, _, matte = modnet(img_t, True) | |
| # resize and save matte | |
| matte = F.interpolate(matte, size=(im_h, im_w), mode='area') | |
| matte = matte[0][0].data.cpu().numpy() | |
| # cv2.imwrite(args.save_path, (matte * 255).astype('uint8')) | |
| cv2.imwrite(str(tgt_save_path), (matte * 255).astype('uint8')) | |
| assert tgt_save_path.exists() | |
| # get foreground | |
| # matte = matte[:, :, None] | |
| # foreground = img * matte + np.full(img.shape, 1) * (1 - matte) | |
| # cv2.imwrite(args.save_path.replace('.png', '_fg.png'), foreground * 255) | |
| pass | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--img_path', type=str, default='assets/test.jpg') | |
| parser.add_argument('--save_path', type=str, default='test_matting.png') | |
| parser.add_argument('--img_dir_path', type=str, default='assets', required=False) | |
| args = parser.parse_args() | |
| # matt_single(args) | |
| matt_directory(args) |