Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import argparse | |
| import os.path as osp | |
| import numpy as np | |
| from mmengine.fileio import dump, load | |
| from mmengine.utils import mkdir_or_exist, track_parallel_progress | |
| prog_description = '''K-Fold coco split. | |
| To split coco data for semi-supervised object detection: | |
| python tools/misc/split_coco.py | |
| ''' | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| '--data-root', | |
| type=str, | |
| help='The data root of coco dataset.', | |
| default='./data/coco/') | |
| parser.add_argument( | |
| '--out-dir', | |
| type=str, | |
| help='The output directory of coco semi-supervised annotations.', | |
| default='./data/coco/semi_anns/') | |
| parser.add_argument( | |
| '--labeled-percent', | |
| type=float, | |
| nargs='+', | |
| help='The percentage of labeled data in the training set.', | |
| default=[1, 2, 5, 10]) | |
| parser.add_argument( | |
| '--fold', | |
| type=int, | |
| help='K-fold cross validation for semi-supervised object detection.', | |
| default=5) | |
| args = parser.parse_args() | |
| return args | |
| def split_coco(data_root, out_dir, percent, fold): | |
| """Split COCO data for Semi-supervised object detection. | |
| Args: | |
| data_root (str): The data root of coco dataset. | |
| out_dir (str): The output directory of coco semi-supervised | |
| annotations. | |
| percent (float): The percentage of labeled data in the training set. | |
| fold (int): The fold of dataset and set as random seed for data split. | |
| """ | |
| def save_anns(name, images, annotations): | |
| sub_anns = dict() | |
| sub_anns['images'] = images | |
| sub_anns['annotations'] = annotations | |
| sub_anns['licenses'] = anns['licenses'] | |
| sub_anns['categories'] = anns['categories'] | |
| sub_anns['info'] = anns['info'] | |
| mkdir_or_exist(out_dir) | |
| dump(sub_anns, f'{out_dir}/{name}.json') | |
| # set random seed with the fold | |
| np.random.seed(fold) | |
| ann_file = osp.join(data_root, 'annotations/instances_train2017.json') | |
| anns = load(ann_file) | |
| image_list = anns['images'] | |
| labeled_total = int(percent / 100. * len(image_list)) | |
| labeled_inds = set( | |
| np.random.choice(range(len(image_list)), size=labeled_total)) | |
| labeled_ids, labeled_images, unlabeled_images = [], [], [] | |
| for i in range(len(image_list)): | |
| if i in labeled_inds: | |
| labeled_images.append(image_list[i]) | |
| labeled_ids.append(image_list[i]['id']) | |
| else: | |
| unlabeled_images.append(image_list[i]) | |
| # get all annotations of labeled images | |
| labeled_ids = set(labeled_ids) | |
| labeled_annotations, unlabeled_annotations = [], [] | |
| for ann in anns['annotations']: | |
| if ann['image_id'] in labeled_ids: | |
| labeled_annotations.append(ann) | |
| else: | |
| unlabeled_annotations.append(ann) | |
| # save labeled and unlabeled | |
| labeled_name = f'instances_train2017.{fold}@{percent}' | |
| unlabeled_name = f'instances_train2017.{fold}@{percent}-unlabeled' | |
| save_anns(labeled_name, labeled_images, labeled_annotations) | |
| save_anns(unlabeled_name, unlabeled_images, unlabeled_annotations) | |
| def multi_wrapper(args): | |
| return split_coco(*args) | |
| if __name__ == '__main__': | |
| args = parse_args() | |
| arguments_list = [(args.data_root, args.out_dir, p, f) | |
| for f in range(1, args.fold + 1) | |
| for p in args.labeled_percent] | |
| track_parallel_progress(multi_wrapper, arguments_list, args.fold) | |