Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| import argparse | |
| import os | |
| from typing import Dict, List, Tuple | |
| import torch | |
| from torch import Tensor, nn | |
| import detectron2.data.transforms as T | |
| from detectron2.checkpoint import DetectionCheckpointer | |
| from detectron2.config import get_cfg | |
| from detectron2.data import build_detection_test_loader, detection_utils | |
| from detectron2.evaluation import COCOEvaluator, inference_on_dataset, print_csv_format | |
| from detectron2.export import ( | |
| STABLE_ONNX_OPSET_VERSION, | |
| TracingAdapter, | |
| dump_torchscript_IR, | |
| scripting_with_instances, | |
| ) | |
| from detectron2.modeling import GeneralizedRCNN, RetinaNet, build_model | |
| from detectron2.modeling.postprocessing import detector_postprocess | |
| from detectron2.projects.point_rend import add_pointrend_config | |
| from detectron2.structures import Boxes | |
| from detectron2.utils.env import TORCH_VERSION | |
| from detectron2.utils.file_io import PathManager | |
| from detectron2.utils.logger import setup_logger | |
| def setup_cfg(args): | |
| cfg = get_cfg() | |
| # cuda context is initialized before creating dataloader, so we don't fork anymore | |
| cfg.DATALOADER.NUM_WORKERS = 0 | |
| add_pointrend_config(cfg) | |
| cfg.merge_from_file(args.config_file) | |
| cfg.merge_from_list(args.opts) | |
| cfg.freeze() | |
| return cfg | |
| def export_caffe2_tracing(cfg, torch_model, inputs): | |
| from detectron2.export import Caffe2Tracer | |
| tracer = Caffe2Tracer(cfg, torch_model, inputs) | |
| if args.format == "caffe2": | |
| caffe2_model = tracer.export_caffe2() | |
| caffe2_model.save_protobuf(args.output) | |
| # draw the caffe2 graph | |
| caffe2_model.save_graph(os.path.join(args.output, "model.svg"), inputs=inputs) | |
| return caffe2_model | |
| elif args.format == "onnx": | |
| import onnx | |
| onnx_model = tracer.export_onnx() | |
| onnx.save(onnx_model, os.path.join(args.output, "model.onnx")) | |
| elif args.format == "torchscript": | |
| ts_model = tracer.export_torchscript() | |
| with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f: | |
| torch.jit.save(ts_model, f) | |
| dump_torchscript_IR(ts_model, args.output) | |
| # experimental. API not yet final | |
| def export_scripting(torch_model): | |
| assert TORCH_VERSION >= (1, 8) | |
| fields = { | |
| "proposal_boxes": Boxes, | |
| "objectness_logits": Tensor, | |
| "pred_boxes": Boxes, | |
| "scores": Tensor, | |
| "pred_classes": Tensor, | |
| "pred_masks": Tensor, | |
| "pred_keypoints": torch.Tensor, | |
| "pred_keypoint_heatmaps": torch.Tensor, | |
| } | |
| assert args.format == "torchscript", "Scripting only supports torchscript format." | |
| class ScriptableAdapterBase(nn.Module): | |
| # Use this adapter to workaround https://github.com/pytorch/pytorch/issues/46944 | |
| # by not retuning instances but dicts. Otherwise the exported model is not deployable | |
| def __init__(self): | |
| super().__init__() | |
| self.model = torch_model | |
| self.eval() | |
| if isinstance(torch_model, GeneralizedRCNN): | |
| class ScriptableAdapter(ScriptableAdapterBase): | |
| def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]: | |
| instances = self.model.inference(inputs, do_postprocess=False) | |
| return [i.get_fields() for i in instances] | |
| else: | |
| class ScriptableAdapter(ScriptableAdapterBase): | |
| def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]: | |
| instances = self.model(inputs) | |
| return [i.get_fields() for i in instances] | |
| ts_model = scripting_with_instances(ScriptableAdapter(), fields) | |
| with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f: | |
| torch.jit.save(ts_model, f) | |
| dump_torchscript_IR(ts_model, args.output) | |
| # TODO inference in Python now missing postprocessing glue code | |
| return None | |
| # experimental. API not yet final | |
| def export_tracing(torch_model, inputs): | |
| assert TORCH_VERSION >= (1, 8) | |
| image = inputs[0]["image"] | |
| inputs = [{"image": image}] # remove other unused keys | |
| if isinstance(torch_model, GeneralizedRCNN): | |
| def inference(model, inputs): | |
| # use do_postprocess=False so it returns ROI mask | |
| inst = model.inference(inputs, do_postprocess=False)[0] | |
| return [{"instances": inst}] | |
| else: | |
| inference = None # assume that we just call the model directly | |
| traceable_model = TracingAdapter(torch_model, inputs, inference) | |
| if args.format == "torchscript": | |
| ts_model = torch.jit.trace(traceable_model, (image,)) | |
| with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f: | |
| torch.jit.save(ts_model, f) | |
| dump_torchscript_IR(ts_model, args.output) | |
| elif args.format == "onnx": | |
| with PathManager.open(os.path.join(args.output, "model.onnx"), "wb") as f: | |
| torch.onnx.export(traceable_model, (image,), f, opset_version=STABLE_ONNX_OPSET_VERSION) | |
| logger.info("Inputs schema: " + str(traceable_model.inputs_schema)) | |
| logger.info("Outputs schema: " + str(traceable_model.outputs_schema)) | |
| if args.format != "torchscript": | |
| return None | |
| if not isinstance(torch_model, (GeneralizedRCNN, RetinaNet)): | |
| return None | |
| def eval_wrapper(inputs): | |
| """ | |
| The exported model does not contain the final resize step, which is typically | |
| unused in deployment but needed for evaluation. We add it manually here. | |
| """ | |
| input = inputs[0] | |
| instances = traceable_model.outputs_schema(ts_model(input["image"]))[0]["instances"] | |
| postprocessed = detector_postprocess(instances, input["height"], input["width"]) | |
| return [{"instances": postprocessed}] | |
| return eval_wrapper | |
| def get_sample_inputs(args): | |
| if args.sample_image is None: | |
| # get a first batch from dataset | |
| data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) | |
| first_batch = next(iter(data_loader)) | |
| return first_batch | |
| else: | |
| # get a sample data | |
| original_image = detection_utils.read_image(args.sample_image, format=cfg.INPUT.FORMAT) | |
| # Do same preprocessing as DefaultPredictor | |
| aug = T.ResizeShortestEdge( | |
| [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST | |
| ) | |
| height, width = original_image.shape[:2] | |
| image = aug.get_transform(original_image).apply_image(original_image) | |
| image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) | |
| inputs = {"image": image, "height": height, "width": width} | |
| # Sample ready | |
| sample_inputs = [inputs] | |
| return sample_inputs | |
| def main() -> None: | |
| global logger, cfg, args | |
| parser = argparse.ArgumentParser(description="Export a model for deployment.") | |
| parser.add_argument( | |
| "--format", | |
| choices=["caffe2", "onnx", "torchscript"], | |
| help="output format", | |
| default="torchscript", | |
| ) | |
| parser.add_argument( | |
| "--export-method", | |
| choices=["caffe2_tracing", "tracing", "scripting"], | |
| help="Method to export models", | |
| default="tracing", | |
| ) | |
| parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") | |
| parser.add_argument("--sample-image", default=None, type=str, help="sample image for input") | |
| parser.add_argument("--run-eval", action="store_true") | |
| parser.add_argument("--output", help="output directory for the converted model") | |
| parser.add_argument( | |
| "opts", | |
| help="Modify config options using the command-line", | |
| default=None, | |
| nargs=argparse.REMAINDER, | |
| ) | |
| args = parser.parse_args() | |
| logger = setup_logger() | |
| logger.info("Command line arguments: " + str(args)) | |
| PathManager.mkdirs(args.output) | |
| # Disable re-specialization on new shapes. Otherwise --run-eval will be slow | |
| torch._C._jit_set_bailout_depth(1) | |
| cfg = setup_cfg(args) | |
| # create a torch model | |
| torch_model = build_model(cfg) | |
| DetectionCheckpointer(torch_model).resume_or_load(cfg.MODEL.WEIGHTS) | |
| torch_model.eval() | |
| # convert and save model | |
| if args.export_method == "caffe2_tracing": | |
| sample_inputs = get_sample_inputs(args) | |
| exported_model = export_caffe2_tracing(cfg, torch_model, sample_inputs) | |
| elif args.export_method == "scripting": | |
| exported_model = export_scripting(torch_model) | |
| elif args.export_method == "tracing": | |
| sample_inputs = get_sample_inputs(args) | |
| exported_model = export_tracing(torch_model, sample_inputs) | |
| # run evaluation with the converted model | |
| if args.run_eval: | |
| assert exported_model is not None, ( | |
| "Python inference is not yet implemented for " | |
| f"export_method={args.export_method}, format={args.format}." | |
| ) | |
| logger.info("Running evaluation ... this takes a long time if you export to CPU.") | |
| dataset = cfg.DATASETS.TEST[0] | |
| data_loader = build_detection_test_loader(cfg, dataset) | |
| # NOTE: hard-coded evaluator. change to the evaluator for your dataset | |
| evaluator = COCOEvaluator(dataset, output_dir=args.output) | |
| metrics = inference_on_dataset(exported_model, data_loader, evaluator) | |
| print_csv_format(metrics) | |
| logger.info("Success.") | |
| if __name__ == "__main__": | |
| main() # pragma: no cover | |