import torch from torch import nn import tensorrt as trt TRT_LOGGER = trt.Logger(trt.Logger.INFO) class Model(nn.Module): def forward(self, x): y = (2 * x)[0:1] return y print('TensorRT version:', trt.__version__) model = Model().eval() dummy_input = torch.randn(4, 4, 4) with torch.no_grad(): torch.onnx.export(model, dummy_input, 'test.onnx', verbose=True) with trt.Builder(TRT_LOGGER) as builder, \ builder.create_network() as network, \ trt.OnnxParser(network, TRT_LOGGER) as parser: with open('test.onnx', 'rb') as model: success = parser.parse(model.read()) assert success, f'{parser.num_errors} detected during parsing'