Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # -*- coding:utf-8 -*- | |
| import argparse | |
| from collections import OrderedDict | |
| import megengine as mge | |
| import torch | |
| def make_parser(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("-w", "--weights", type=str, help="path of weight file") | |
| parser.add_argument( | |
| "-o", | |
| "--output", | |
| default="weight_mge.pkl", | |
| type=str, | |
| help="path of weight file", | |
| ) | |
| return parser | |
| def numpy_weights(weight_file): | |
| torch_weights = torch.load(weight_file, map_location="cpu") | |
| if "model" in torch_weights: | |
| torch_weights = torch_weights["model"] | |
| new_dict = OrderedDict() | |
| for k, v in torch_weights.items(): | |
| new_dict[k] = v.cpu().numpy() | |
| return new_dict | |
| def map_weights(weight_file, output_file): | |
| torch_weights = numpy_weights(weight_file) | |
| new_dict = OrderedDict() | |
| for k, v in torch_weights.items(): | |
| if "num_batches_tracked" in k: | |
| print("drop: {}".format(k)) | |
| continue | |
| if k.endswith("bias"): | |
| print("bias key: {}".format(k)) | |
| v = v.reshape(1, -1, 1, 1) | |
| new_dict[k] = v | |
| elif "dconv" in k and "conv.weight" in k: | |
| print("depthwise conv key: {}".format(k)) | |
| cout, cin, k1, k2 = v.shape | |
| v = v.reshape(cout, 1, cin, k1, k2) | |
| new_dict[k] = v | |
| else: | |
| new_dict[k] = v | |
| mge.save(new_dict, output_file) | |
| print("save weights to {}".format(output_file)) | |
| def main(): | |
| parser = make_parser() | |
| args = parser.parse_args() | |
| map_weights(args.weights, args.output) | |
| if __name__ == "__main__": | |
| main() | |