Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # -*- coding:utf-8 -*- | |
| # Copyright (c) Megvii, Inc. and its affiliates. | |
| import argparse | |
| import megengine as mge | |
| import numpy as np | |
| from megengine import jit | |
| from build import build_and_load | |
| def make_parser(): | |
| parser = argparse.ArgumentParser("YOLOX Demo Dump") | |
| parser.add_argument("-n", "--name", type=str, default="yolox-s", help="model name") | |
| parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval") | |
| parser.add_argument( | |
| "--dump_path", default="model.mge", help="path to save the dumped model" | |
| ) | |
| return parser | |
| def dump_static_graph(model, graph_name="model.mge"): | |
| model.eval() | |
| model.head.decode_in_inference = False | |
| data = mge.Tensor(np.random.random((1, 3, 640, 640))) | |
| def pred_func(data): | |
| outputs = model(data) | |
| return outputs | |
| pred_func(data) | |
| pred_func.dump( | |
| graph_name, | |
| arg_names=["data"], | |
| optimize_for_inference=True, | |
| enable_fuse_conv_bias_nonlinearity=True, | |
| ) | |
| def main(args): | |
| model = build_and_load(args.ckpt, name=args.name) | |
| dump_static_graph(model, args.dump_path) | |
| if __name__ == "__main__": | |
| args = make_parser().parse_args() | |
| main(args) | |