Created
October 28, 2021 08:56
-
-
Save MrPanch/05260d72fe33a526cf96a50a05996aad to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import timm | |
| import torch | |
| import numpy as np | |
| import onnx | |
| import pycuda.driver as cuda | |
| import pycuda.autoinit | |
| import tensorrt as trt | |
| def torch2onnx(model_name, onnx_file_path, batch_size, img_size=224): | |
| model = timm.create_model(model_name, pretrained=True) | |
| model.cuda().eval() | |
| model.half() | |
| img = np.random.random((batch_size, 3, img_size, img_size)) | |
| img = torch.from_numpy(img).cuda().half() | |
| torch.onnx.export(model, img, onnx_file_path, input_names=['input'], | |
| output_names=['output'], export_params=True) | |
| onnx_model = onnx.load(onnx_file_path) | |
| onnx.checker.check_model(onnx_model) | |
| def build_engine(onnx_file_path): | |
| # initialize TensorRT engine and parse ONNX model | |
| TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) | |
| builder = trt.Builder(TRT_LOGGER) | |
| network = builder.create_network() | |
| parser = trt.OnnxParser(network, TRT_LOGGER) | |
| # parse ONNX | |
| with open(onnx_file_path, 'rb') as model: | |
| print('Beginning ONNX file parsing') | |
| parser.parse(model.read()) | |
| print('Completed parsing of ONNX file') | |
| builder.max_workspace_size = 1 << 30 | |
| # we have only one image in batch | |
| builder.max_batch_size = 8 | |
| builder.fp16_mode = True | |
| # generate TensorRT engine optimized for the target platform | |
| print('Building an engine...') | |
| engine = builder.build_cuda_engine(network) | |
| context = engine.create_execution_context() | |
| print("Completed creating Engine") | |
| return engine, context | |
| # ['efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'resnet34', 'resnet50'] | |
| if __name__ == '__main__': | |
| # logger to capture errors, warnings, and other information during the build and inference phases | |
| ONNX_FILE_PATH = 'effnet0.onnx' | |
| engine, context = build_engine(ONNX_FILE_PATH) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment