|
- #https://mxnet.apache.org/versions/1.9.1/api/python/docs/tutorials/deploy/export/onnx.html
-
- import mxnet as mx
- import argparse
- import numpy as np
-
- parser = argparse.ArgumentParser(description='MxNet ONNX Example')
-
- parser.add_argument('--model',
- type=str,
- help='path to training/inference dataset folder'
- )
- parser.add_argument('--n',
- type=int,
- default=64,
- help='batch size for input shape type'
- )
- parser.add_argument('--c',
- type=int,
- default=1,
- help='channel for input shape type'
- )
- parser.add_argument('--h',
- type=int,
- default=28,
- help='height for input shape type'
- )
- parser.add_argument('--w',
- type=int,
- default=28,
- help='width for input shape type'
- )
-
-
- if __name__ == "__main__":
- args = parser.parse_args()
- print('args:')
- print(args)
-
- #args.model = 'resnet-18'
- #sym = 'model/resnet-18-symbol.json'
- #params = 'model/resnet-18-0000.params'
-
- # 注意mxnet模型一定要用HybridSequential()组网才能保存模型结构文件 'model-0000.json'
- model_path = '/model/' + args.model.split('.')[0]
- input_shape = (args.n, args.c, args.h, args.w)
- sym = model_path + '-symbol.json'
- params = model_path + '-0000.params'
- output_file = model_path + '.onnx'
-
- # 调用导出模型API。它返回转换后的onnx模型的路径
- converted_model_path = mx.onnx.export_model(sym, params, [input_shape], np.float32, output_file)
|