Back to Tensorrtx

UNet

unet/README.md

latest1.9 KB
Original Source

UNet

Pytorch model from Pytorch-UNet.

Contributors

<a href="https://github.com/YuzhouPeng"></a> <a href="https://github.com/East-Face"></a> <a href="https://github.com/irvingzhang0512"></a> <a href="https://github.com/wang-xinyu"></a> <a href="https://github.com/nengwp"></a>

Requirements

Now TensorRT 8.x is supported and you can use it. The key cause of the previous bug is the pooling layer Stride setting problem.

Build and Run

  1. Generate .wts
cp {path-of-tensorrtx}/unet/gen_wts.py Pytorch-UNet/
cd Pytorch-UNet/
wget https://github.com/milesial/Pytorch-UNet/releases/download/v3.0/unet_carvana_scale0.5_epoch2.pth
python gen_wts.py unet_carvana_scale0.5_epoch2.pth
  1. Generate TensorRT engine
cd tensorrtx/unet/
mkdir build
cd build
cmake ..
make
cp {path-of-Pytorch-UNet}/unet.wts .
./unet -s
  1. Run inference
wget https://raw.githubusercontent.com/wang-xinyu/tensorrtx/f60dcc7bec28846cd973fc95ac829c4e57a11395/unet/samples/0cdf5b5d0ce1_01.jpg
./unet -d 0cdf5b5d0ce1_01.jpg
  1. Check result.jpg
<p align="center"> </p>

Benchmark

PytorchTensorRT FP32TensorRT FP16
816x672816x672816x672
58ms43ms (batchsize 8)14ms (batchsize 8)

More Information

See the readme in home page.