|
- import torchvision
- import torch
- import argparse
- from torch.autograd import Variable
- import onnx
-
- import pycuda.autoinit
- import numpy as np
- import pycuda.driver as cuda
- import tensorrt as trt
- import os
- import time
-
-
- print(torch.__version__)
-
- parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
-
- parser.add_argument('--model',
- type=str,
- help='path to training/inference dataset folder'
- )
- parser.add_argument('--n',
- type=int,
- default=256,
- help='batch size for input shape type'
- )
- parser.add_argument('--c',
- type=int,
- default=1,
- help='channel for input shape type'
- )
- parser.add_argument('--h',
- type=int,
- default=28,
- help='height for input shape type'
- )
- parser.add_argument('--w',
- type=int,
- default=28,
- help='width for input shape type'
- )
- parser.add_argument('--fp16',
- type=bool,
- default=False,
- help='fp16 for output format'
- )
-
- if __name__ == "__main__":
- args = parser.parse_args()
- print('args:')
- print(args)
-
- model_file = '/tmp/dataset/' + args.model
- print(model_file)
- model = torch.load(model_file)
- print(model)
- print(type(model))
- for k, v in model.named_parameters():
- print("k:",k)
- print("v:",v.shape)
-
- suffix = args.model.rindex(".")
- out_file = '/tmp/output/' + args.model + ".onnx"
- if suffix!=-1 :
- out_file = '/tmp/output/' + args.model[0:suffix] + ".onnx"
- print(out_file)
- input_name = ['input']
- output_name = ['output']
- input = Variable(torch.randn(args.n, args.c, args.h, args.w)).cuda()
- torch.onnx.export(model, input, out_file, input_names=input_name, output_names=output_name, verbose=True)
-
- max_batch_size = 1
-
- TRT_LOGGER = trt.Logger() # This logger is required to build an engine
-
- fp16_mode = args.fp16
- int8_mode = False
- engine_file_path = '/tmp/output/' + args.model[0:suffix] + '_fp16_{}.trt'.format(fp16_mode)
-
- explicit_batch = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
- with trt.Builder(TRT_LOGGER) as builder, builder.create_network(explicit_batch) as network, \
- builder.create_builder_config() as config, trt.OnnxParser(network, TRT_LOGGER) as parser, \
- trt.Runtime(TRT_LOGGER) as runtime:
- config.max_workspace_size = 1 << 28
- if fp16_mode:
- config.set_flag(trt.BuilderFlag.FP16)
- with open(out_file, 'rb') as model:
- if not parser.parse(model.read()):
- raise TypeError("Parser parse failed.")
-
- plan = builder.build_serialized_network(network, config)
- engine = runtime.deserialize_cuda_engine(plan)
-
- with open(engine_file_path, "wb") as f:
- f.write(engine.serialize())
|