Back to Tensorrt

Using PyTorch with TensorRT through ONNX:

quickstart/IntroNotebooks/2. Using PyTorch through ONNX.ipynb

23.0811.4 KB
Original Source

Using PyTorch with TensorRT through ONNX:

TensorRT is a great way to take a trained PyTorch model and optimize it to run more efficiently during inference on an NVIDIA GPU.

One approach to convert a PyTorch model to TensorRT is to export a PyTorch model to ONNX (an open format exchange for deep learning models) and then convert into a TensorRT engine. Essentially, we will follow this path to convert and deploy our model:

Both PyTorch and TensorFlow models can be exported to ONNX, as well as many other frameworks. This allows models created using either framework to flow into common downstream pipelines.

To get started, let's take a well-known computer vision model and follow five key steps to deploy it to the TensorRT Python runtime:

  1. What format should I save my model in?
  2. What batch size(s) am I running inference at?
  3. What precision am I running inference at?
  4. What TensorRT path am I using to convert my model?
  5. What runtime am I targeting?

1. What format should I save my model in?

We are going to use ResNet50, a widely used CNN architecture first described in <a href=https://arxiv.org/abs/1512.03385>this paper</a>.

Let's start by loading dependencies and downloading the model. We will also move our Resnet model onto the GPU and set it to evaluation mode.

python
import torchvision.models as models
import torch
import torch.onnx

# load the pretrained model
resnet50 = models.resnet50(pretrained=True, progress=False).eval()

When saving a model to ONNX, PyTorch requires a test batch in proper shape and format. We pick a batch size:

python
BATCH_SIZE=32

dummy_input=torch.randn(BATCH_SIZE, 3, 224, 224)

Next, we will export the model using the dummy input batch:

python
# export the model to ONNX
torch.onnx.export(resnet50, dummy_input, "resnet50_pytorch.onnx", verbose=False)

Note that we are picking a BATCH_SIZE of 32 in this example.

Now Test with a Real Image:

Let's try a real image batch! For this example, we will simply repeat one open-source dog image from http://www.dog.ceo:

python
from skimage import io
from skimage.transform import resize
from matplotlib import pyplot as plt
import numpy as np

url='https://images.dog.ceo/breeds/retriever-golden/n02099601_3004.jpg'
img = resize(io.imread(url), (224, 224))
img = np.expand_dims(np.array(img, dtype=np.float32), axis=0) # Expand image to have a batch dimension
input_batch = np.array(np.repeat(img, BATCH_SIZE, axis=0), dtype=np.float32) # Repeat across the batch dimension

input_batch.shape
python
plt.imshow(input_batch[0].astype(np.float32))
python
resnet50_gpu = models.resnet50(pretrained=True, progress=False).to("cuda").eval()

We need to move our batch onto GPU and properly format it to shape [32, 3, 224, 224].

python
input_batch_chw = torch.from_numpy(input_batch).transpose(1,3).transpose(2,3)
input_batch_gpu = input_batch_chw.to("cuda")

input_batch_gpu.shape

We can run a prediction on a batch using .forward():

python
with torch.no_grad():
    predictions = np.array(resnet50_gpu(input_batch_gpu).cpu())

predictions.shape

Verify Baseline Model Performance/Accuracy:

For a baseline, lets time our prediction in FP32:

python
%%timeit

with torch.no_grad():
    preds = np.array(resnet50_gpu(input_batch_gpu).cpu())

We can also time FP16 precision performance:

python
resnet50_gpu_half = resnet50_gpu.half()
input_half = input_batch_gpu.half()

with torch.no_grad():
    preds = np.array(resnet50_gpu_half(input_half).cpu()) # Warm Up

preds.shape
python
%%timeit

with torch.no_grad():
    preds = np.array(resnet50_gpu_half(input_half).cpu())

Let's also make sure our results are accurate. We will look at the top 5 accuracy on a single image prediction. The image we are using is of a Golden Retriever, which is class 207 in the ImageNet dataset our model was trained on.

python
indices = (-predictions[0]).argsort()[:5]
print("Class | Likelihood")
list(zip(indices, predictions[0][indices]))

We have a model exported to ONNX and a baseline to compare against! Let's now take our ONNX model and convert it to a TensorRT inference engine.

2. What batch size(s) am I running inference at?

We are going to run with a fixed batch size of 32 for this example. Note that above we set BATCH_SIZE to 32 when saving our model to ONNX. We need to create another dummy batch of the same size (this time it will need to be in our target precision) to test out our engine.

First, as before, we will set our BATCH_SIZE to 32. Note that our trtexec command above includes the '--explicitBatch' flag to signal to TensorRT that we will be using a fixed batch size at runtime.

python
BATCH_SIZE = 32

Importantly, by default TensorRT will use the input precision you give the runtime as the default precision for the rest of the network. So before we create our new dummy batch, we also need to choose a precision as in the next section:

3. What precision am I running inference at?

Remember that lower precisions than FP32 tend to run faster. There are two common reduced precision modes - FP16 and INT8. Graphics cards that are designed to do inference well often have an affinity for one of these two types. This guide was developed on an NVIDIA V100, which favors FP16, so we will use that here by default. INT8 is a more complicated process that requires a calibration step.

NOTE: Make sure you use the same precision (USE_FP16) here you saved your model in above!

python
import numpy as np

USE_FP16 = True
target_dtype = np.float16 if USE_FP16 else np.float32

To create a test batch, we will once again repeat one open-source dog image from http://www.dog.ceo:

python
from skimage import io
from skimage.transform import resize
from matplotlib import pyplot as plt
import numpy as np

url='https://images.dog.ceo/breeds/retriever-golden/n02099601_3004.jpg'
img = resize(io.imread(url), (224, 224))
input_batch = np.array(np.repeat(np.expand_dims(np.array(img, dtype=np.float32), axis=0), BATCH_SIZE, axis=0), dtype=np.float32)

input_batch.shape
python
plt.imshow(input_batch[0].astype(np.float32))

Preprocess Images:

PyTorch has a normalization that it applies by default in all of its pretrained vision models - we can preprocess our images to match this normalization by the following, making sure our final result is in FP16 precision:

python
import torch
from torchvision.transforms import Normalize

def preprocess_image(img):
    norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    result = norm(torch.from_numpy(img).transpose(0,2).transpose(1,2))
    return np.array(result, dtype=np.float16)

preprocessed_images = np.array([preprocess_image(image) for image in input_batch])

4. What TensorRT path am I using to convert my model?

We can use trtexec, a command line tool for working with TensorRT, in order to convert an ONNX model originally from PyTorch to an engine file.

Let's make sure we have TensorRT installed (this comes with trtexec):

python
import tensorrt

To convert the model we saved in the previous step, we need to point to the ONNX file, give trtexec a name to save the engine as, and last specify that we want to use a fixed batch size instead of a dynamic one.

python
# step out of Python for a moment to convert the ONNX model to a TRT engine using trtexec
if USE_FP16:
    !trtexec --onnx=resnet50_pytorch.onnx --saveEngine=resnet_engine_pytorch.trt   --inputIOFormats=fp16:chw --outputIOFormats=fp16:chw --fp16
else:
    !trtexec --onnx=resnet50_pytorch.onnx --saveEngine=resnet_engine_pytorch.trt

This will save our model as 'resnet_engine.trt'.

5. What TensorRT runtime am I targeting?

Now, we have a converted our model to a TensorRT engine. Great! That means we are ready to load it into the native Python TensorRT runtime. This runtime strikes a balance between the ease of use of the high level Python APIs used in frameworks and the fast, low level C++ runtimes available in TensorRT.

python
%%time

import tensorrt as trt
from cuda.bindings import runtime as cudart
import numpy as np

err, = cudart.cudaSetDevice(0)
assert err == cudart.cudaError_t.cudaSuccess

f = open("resnet_engine_pytorch.trt", "rb")
runtime = trt.Runtime(trt.Logger(trt.Logger.WARNING))

engine = runtime.deserialize_cuda_engine(f.read())
context = engine.create_execution_context()

Now allocate input and output memory, give TRT pointers (bindings) to it:

python
import numpy as np

# need to set input and output precisions to FP16 to fully enable it
output = np.empty([BATCH_SIZE, 1000], dtype = target_dtype)

# allocate device memory
err, d_input = cudart.cudaMalloc(input_batch.nbytes)
assert err == cudart.cudaError_t.cudaSuccess
err, d_output = cudart.cudaMalloc(output.nbytes)
assert err == cudart.cudaError_t.cudaSuccess

tensor_names = [engine.get_tensor_name(i) for i in range(engine.num_io_tensors)]
assert(len(tensor_names) == 2)

context.set_tensor_address(tensor_names[0], d_input)
context.set_tensor_address(tensor_names[1], d_output)

err, stream = cudart.cudaStreamCreate()
assert err == cudart.cudaError_t.cudaSuccess

Next, set up the prediction function.

This involves a copy from CPU RAM to GPU VRAM, executing the model, then copying the results back from GPU VRAM to CPU RAM:

python
def predict(batch): # result gets copied into output
    # transfer input data to device
    err, = cudart.cudaMemcpyAsync(d_input, batch.ctypes.data, batch.nbytes,
                                  cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)
    assert err == cudart.cudaError_t.cudaSuccess

    # execute model
    context.execute_async_v3(stream)

    # transfer predictions back
    err, = cudart.cudaMemcpyAsync(output.ctypes.data, d_output, output.nbytes,
                                  cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, stream)
    assert err == cudart.cudaError_t.cudaSuccess

    # synchronize stream
    err, = cudart.cudaStreamSynchronize(stream)
    assert err == cudart.cudaError_t.cudaSuccess

    return output

Let's time the function!

python
print("Warming up...")

pred = predict(preprocessed_images)

print("Done warming up!")
python
%%timeit

pred = predict(preprocessed_images)

Finally we should verify our TensorRT output is still accurate.

python
indices = (-pred[0]).argsort()[:5]
print("Class | Probability (out of 1)")
list(zip(indices, pred[0][indices]))
python
err, = cudart.cudaStreamDestroy(stream)
assert err == cudart.cudaError_t.cudaSuccess
err, = cudart.cudaFree(d_input)
assert err == cudart.cudaError_t.cudaSuccess
err, = cudart.cudaFree(d_output)
assert err == cudart.cudaError_t.cudaSuccess

Look for ImageNet indices 150-275 above, where 207 is the ground truth correct class (Golden Retriever). Compare with the results of the original unoptimized model in the first section!

Next Steps:

<h4> TRT Dev Docs </h4>

Main documentation page for the ONNX, layer builder, C++, and legacy APIs

You can find it here: https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html

<h4> TRT OSS GitHub </h4>

Contains OSS TRT components, sample applications, and plugin examples

You can find it here: https://github.com/NVIDIA/TensorRT

TRT Supported Layers:

https://docs.nvidia.com/deeplearning/tensorrt/support-matrix/index.html#layers-precision-matrix

TRT ONNX Plugin Example:

https://github.com/NVIDIA/TensorRT/tree/main/samples/sampleOnnxMnistCoordConvAC