Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (C) 2022-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| # | |
| # -------------------------------------------------------- | |
| # Extracting crops for pre-training | |
| # -------------------------------------------------------- | |
| import os | |
| import argparse | |
| from tqdm import tqdm | |
| from PIL import Image | |
| import functools | |
| from multiprocessing import Pool | |
| import math | |
| def arg_parser(): | |
| parser = argparse.ArgumentParser( | |
| "Generate cropped image pairs from image crop list" | |
| ) | |
| parser.add_argument("--crops", type=str, required=True, help="crop file") | |
| parser.add_argument("--root-dir", type=str, required=True, help="root directory") | |
| parser.add_argument( | |
| "--output-dir", type=str, required=True, help="output directory" | |
| ) | |
| parser.add_argument("--imsize", type=int, default=256, help="size of the crops") | |
| parser.add_argument( | |
| "--nthread", type=int, required=True, help="number of simultaneous threads" | |
| ) | |
| parser.add_argument( | |
| "--max-subdir-levels", | |
| type=int, | |
| default=5, | |
| help="maximum number of subdirectories", | |
| ) | |
| parser.add_argument( | |
| "--ideal-number-pairs-in-dir", | |
| type=int, | |
| default=500, | |
| help="number of pairs stored in a dir", | |
| ) | |
| return parser | |
| def main(args): | |
| listing_path = os.path.join(args.output_dir, "listing.txt") | |
| print(f"Loading list of crops ... ({args.nthread} threads)") | |
| crops, num_crops_to_generate = load_crop_file(args.crops) | |
| print(f"Preparing jobs ({len(crops)} candidate image pairs)...") | |
| num_levels = min( | |
| math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)), | |
| args.max_subdir_levels, | |
| ) | |
| num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1 / num_levels)) | |
| jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir) | |
| del crops | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map | |
| call = functools.partial(save_image_crops, args) | |
| print(f"Generating cropped images to {args.output_dir} ...") | |
| with open(listing_path, "w") as listing: | |
| listing.write("# pair_path\n") | |
| for results in tqdm(mmap(call, jobs), total=len(jobs)): | |
| for path in results: | |
| listing.write(f"{path}\n") | |
| print("Finished writing listing to", listing_path) | |
| def load_crop_file(path): | |
| data = open(path).read().splitlines() | |
| pairs = [] | |
| num_crops_to_generate = 0 | |
| for line in tqdm(data): | |
| if line.startswith("#"): | |
| continue | |
| line = line.split(", ") | |
| if len(line) < 8: | |
| img1, img2, rotation = line | |
| pairs.append((img1, img2, int(rotation), [])) | |
| else: | |
| l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line) | |
| rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2) | |
| pairs[-1][-1].append((rect1, rect2)) | |
| num_crops_to_generate += 1 | |
| return pairs, num_crops_to_generate | |
| def prepare_jobs(pairs, num_levels, num_pairs_in_dir): | |
| jobs = [] | |
| powers = [num_pairs_in_dir**level for level in reversed(range(num_levels))] | |
| def get_path(idx): | |
| idx_array = [] | |
| d = idx | |
| for level in range(num_levels - 1): | |
| idx_array.append(idx // powers[level]) | |
| idx = idx % powers[level] | |
| idx_array.append(d) | |
| return "/".join(map(lambda x: hex(x)[2:], idx_array)) | |
| idx = 0 | |
| for pair_data in tqdm(pairs): | |
| img1, img2, rotation, crops = pair_data | |
| if -60 <= rotation and rotation <= 60: | |
| rotation = 0 # most likely not a true rotation | |
| paths = [get_path(idx + k) for k in range(len(crops))] | |
| idx += len(crops) | |
| jobs.append(((img1, img2), rotation, crops, paths)) | |
| return jobs | |
| def load_image(path): | |
| try: | |
| return Image.open(path).convert("RGB") | |
| except Exception as e: | |
| print("skipping", path, e) | |
| raise OSError() | |
| def save_image_crops(args, data): | |
| # load images | |
| img_pair, rot, crops, paths = data | |
| try: | |
| img1, img2 = [ | |
| load_image(os.path.join(args.root_dir, impath)) for impath in img_pair | |
| ] | |
| except OSError as e: | |
| return [] | |
| def area(sz): | |
| return sz[0] * sz[1] | |
| tgt_size = (args.imsize, args.imsize) | |
| def prepare_crop(img, rect, rot=0): | |
| # actual crop | |
| img = img.crop(rect) | |
| # resize to desired size | |
| interp = ( | |
| Image.Resampling.LANCZOS | |
| if area(img.size) > 4 * area(tgt_size) | |
| else Image.Resampling.BICUBIC | |
| ) | |
| img = img.resize(tgt_size, resample=interp) | |
| # rotate the image | |
| rot90 = (round(rot / 90) % 4) * 90 | |
| if rot90 == 90: | |
| img = img.transpose(Image.Transpose.ROTATE_90) | |
| elif rot90 == 180: | |
| img = img.transpose(Image.Transpose.ROTATE_180) | |
| elif rot90 == 270: | |
| img = img.transpose(Image.Transpose.ROTATE_270) | |
| return img | |
| results = [] | |
| for (rect1, rect2), path in zip(crops, paths): | |
| crop1 = prepare_crop(img1, rect1) | |
| crop2 = prepare_crop(img2, rect2, rot) | |
| fullpath1 = os.path.join(args.output_dir, path + "_1.jpg") | |
| fullpath2 = os.path.join(args.output_dir, path + "_2.jpg") | |
| os.makedirs(os.path.dirname(fullpath1), exist_ok=True) | |
| assert not os.path.isfile(fullpath1), fullpath1 | |
| assert not os.path.isfile(fullpath2), fullpath2 | |
| crop1.save(fullpath1) | |
| crop2.save(fullpath2) | |
| results.append(path) | |
| return results | |
| if __name__ == "__main__": | |
| args = arg_parser().parse_args() | |
| main(args) | |