Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| """ | |
| A script to benchmark builtin models. | |
| Note: this script has an extra dependency of psutil. | |
| """ | |
| import itertools | |
| import logging | |
| import psutil | |
| import torch | |
| import tqdm | |
| from fvcore.common.timer import Timer | |
| from torch.nn.parallel import DistributedDataParallel | |
| from detectron2.checkpoint import DetectionCheckpointer | |
| from detectron2.config import LazyConfig, get_cfg, instantiate | |
| from detectron2.data import ( | |
| DatasetFromList, | |
| build_detection_test_loader, | |
| build_detection_train_loader, | |
| ) | |
| from detectron2.data.benchmark import DataLoaderBenchmark | |
| from detectron2.engine import AMPTrainer, SimpleTrainer, default_argument_parser, hooks, launch | |
| from detectron2.modeling import build_model | |
| from detectron2.solver import build_optimizer | |
| from detectron2.utils import comm | |
| from detectron2.utils.collect_env import collect_env_info | |
| from detectron2.utils.events import CommonMetricPrinter | |
| from detectron2.utils.logger import setup_logger | |
| logger = logging.getLogger("detectron2") | |
| def setup(args): | |
| if args.config_file.endswith(".yaml"): | |
| cfg = get_cfg() | |
| cfg.merge_from_file(args.config_file) | |
| cfg.SOLVER.BASE_LR = 0.001 # Avoid NaNs. Not useful in this script anyway. | |
| cfg.merge_from_list(args.opts) | |
| cfg.freeze() | |
| else: | |
| cfg = LazyConfig.load(args.config_file) | |
| cfg = LazyConfig.apply_overrides(cfg, args.opts) | |
| setup_logger(distributed_rank=comm.get_rank()) | |
| return cfg | |
| def create_data_benchmark(cfg, args): | |
| if args.config_file.endswith(".py"): | |
| dl_cfg = cfg.dataloader.train | |
| dl_cfg._target_ = DataLoaderBenchmark | |
| return instantiate(dl_cfg) | |
| else: | |
| kwargs = build_detection_train_loader.from_config(cfg) | |
| kwargs.pop("aspect_ratio_grouping", None) | |
| kwargs["_target_"] = DataLoaderBenchmark | |
| return instantiate(kwargs) | |
| def RAM_msg(): | |
| vram = psutil.virtual_memory() | |
| return "RAM Usage: {:.2f}/{:.2f} GB".format( | |
| (vram.total - vram.available) / 1024**3, vram.total / 1024**3 | |
| ) | |
| def benchmark_data(args): | |
| cfg = setup(args) | |
| logger.info("After spawning " + RAM_msg()) | |
| benchmark = create_data_benchmark(cfg, args) | |
| benchmark.benchmark_distributed(250, 10) | |
| # test for a few more rounds | |
| for k in range(10): | |
| logger.info(f"Iteration {k} " + RAM_msg()) | |
| benchmark.benchmark_distributed(250, 1) | |
| def benchmark_data_advanced(args): | |
| # benchmark dataloader with more details to help analyze performance bottleneck | |
| cfg = setup(args) | |
| benchmark = create_data_benchmark(cfg, args) | |
| if comm.get_rank() == 0: | |
| benchmark.benchmark_dataset(100) | |
| benchmark.benchmark_mapper(100) | |
| benchmark.benchmark_workers(100, warmup=10) | |
| benchmark.benchmark_IPC(100, warmup=10) | |
| if comm.get_world_size() > 1: | |
| benchmark.benchmark_distributed(100) | |
| logger.info("Rerun ...") | |
| benchmark.benchmark_distributed(100) | |
| def benchmark_train(args): | |
| cfg = setup(args) | |
| model = build_model(cfg) | |
| logger.info("Model:\n{}".format(model)) | |
| if comm.get_world_size() > 1: | |
| model = DistributedDataParallel( | |
| model, device_ids=[comm.get_local_rank()], broadcast_buffers=False | |
| ) | |
| optimizer = build_optimizer(cfg, model) | |
| checkpointer = DetectionCheckpointer(model, optimizer=optimizer) | |
| checkpointer.load(cfg.MODEL.WEIGHTS) | |
| cfg.defrost() | |
| cfg.DATALOADER.NUM_WORKERS = 2 | |
| data_loader = build_detection_train_loader(cfg) | |
| dummy_data = list(itertools.islice(data_loader, 100)) | |
| def f(): | |
| data = DatasetFromList(dummy_data, copy=False, serialize=False) | |
| while True: | |
| yield from data | |
| max_iter = 400 | |
| trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(model, f(), optimizer) | |
| trainer.register_hooks( | |
| [ | |
| hooks.IterationTimer(), | |
| hooks.PeriodicWriter([CommonMetricPrinter(max_iter)]), | |
| hooks.TorchProfiler( | |
| lambda trainer: trainer.iter == max_iter - 1, | |
| cfg.OUTPUT_DIR, | |
| save_tensorboard=True, | |
| ), | |
| ] | |
| ) | |
| trainer.train(1, max_iter) | |
| def benchmark_eval(args): | |
| cfg = setup(args) | |
| if args.config_file.endswith(".yaml"): | |
| model = build_model(cfg) | |
| DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS) | |
| cfg.defrost() | |
| cfg.DATALOADER.NUM_WORKERS = 0 | |
| data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) | |
| else: | |
| model = instantiate(cfg.model) | |
| model.to(cfg.train.device) | |
| DetectionCheckpointer(model).load(cfg.train.init_checkpoint) | |
| cfg.dataloader.num_workers = 0 | |
| data_loader = instantiate(cfg.dataloader.test) | |
| model.eval() | |
| logger.info("Model:\n{}".format(model)) | |
| dummy_data = DatasetFromList(list(itertools.islice(data_loader, 100)), copy=False) | |
| def f(): | |
| while True: | |
| yield from dummy_data | |
| for k in range(5): # warmup | |
| model(dummy_data[k]) | |
| max_iter = 300 | |
| timer = Timer() | |
| with tqdm.tqdm(total=max_iter) as pbar: | |
| for idx, d in enumerate(f()): | |
| if idx == max_iter: | |
| break | |
| model(d) | |
| pbar.update() | |
| logger.info("{} iters in {} seconds.".format(max_iter, timer.seconds())) | |
| def main() -> None: | |
| parser = default_argument_parser() | |
| parser.add_argument("--task", choices=["train", "eval", "data", "data_advanced"], required=True) | |
| args = parser.parse_args() | |
| assert not args.eval_only | |
| logger.info("Environment info:\n" + collect_env_info()) | |
| if "data" in args.task: | |
| print("Initial " + RAM_msg()) | |
| if args.task == "data": | |
| f = benchmark_data | |
| if args.task == "data_advanced": | |
| f = benchmark_data_advanced | |
| elif args.task == "train": | |
| """ | |
| Note: training speed may not be representative. | |
| The training cost of a R-CNN model varies with the content of the data | |
| and the quality of the model. | |
| """ | |
| f = benchmark_train | |
| elif args.task == "eval": | |
| f = benchmark_eval | |
| # only benchmark single-GPU inference. | |
| assert args.num_gpus == 1 and args.num_machines == 1 | |
| launch( | |
| f, | |
| args.num_gpus, | |
| args.num_machines, | |
| args.machine_rank, | |
| args.dist_url, | |
| args=(args,), | |
| ) | |
| if __name__ == "__main__": | |
| main() # pragma: no cover | |