Skip to content

Instantly share code, notes, and snippets.

@dsuess
Last active March 28, 2019 23:55
Show Gist options
  • Select an option

  • Save dsuess/bd4f3385451241a48338c0e01f74d4fc to your computer and use it in GitHub Desktop.

Select an option

Save dsuess/bd4f3385451241a48338c0e01f74d4fc to your computer and use it in GitHub Desktop.

Revisions

  1. dsuess revised this gist Mar 28, 2019. 1 changed file with 0 additions and 1 deletion.
    1 change: 0 additions & 1 deletion onnx_tensorrt.py
    Original file line number Diff line number Diff line change
    @@ -1,7 +1,6 @@
    import torch
    from torch import nn
    import tensorrt as trt
    import click

    TRT_LOGGER = trt.Logger(trt.Logger.INFO)

  2. dsuess renamed this gist Mar 28, 2019. 1 changed file with 0 additions and 0 deletions.
    File renamed without changes.
  3. dsuess created this gist Mar 28, 2019.
    31 changes: 31 additions & 0 deletions gistfile1.txt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,31 @@
    import torch
    from torch import nn
    import tensorrt as trt
    import click

    TRT_LOGGER = trt.Logger(trt.Logger.INFO)

    class Model(nn.Module):

    def forward(self, x):
    y = (2 * x)[0:1]
    return y


    print('TensorRT version:', trt.__version__)


    model = Model().eval()
    dummy_input = torch.randn(4, 4, 4)

    with torch.no_grad():
    torch.onnx.export(model, dummy_input, 'test.onnx', verbose=True)


    with trt.Builder(TRT_LOGGER) as builder, \
    builder.create_network() as network, \
    trt.OnnxParser(network, TRT_LOGGER) as parser:

    with open('test.onnx', 'rb') as model:
    success = parser.parse(model.read())
    assert success, f'{parser.num_errors} detected during parsing'