ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

(原)pytorch中使用TensorRT

2019-09-18 20:08:34  阅读:310  来源: 互联网

标签:engine stream mem onnx host TensorRT pytorch 使用 device


原文链接:https://www.cnblogs.com/darkknightzh/p/11332155.html

 

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

本文目录

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

转载请注明出处:

https://www.cnblogs.com/darkknightzh/p/11332155.html

代码网址:

https://github.com/darkknightzh/TensorRT_pytorch

参考网址:

tensorrt安装包的sample/python目录

https://github.com/pytorch/examples/tree/master/mnist

此处代码使用的是tensorrt5.1.5

 

在安装完tensorrt之后,使用tensorrt主要包括下面几段代码:

回到顶部(go to top)

1. 初始化

import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit    # 此句代码中未使用,但是必须有。this is useful, otherwise stream = cuda.Stream() will cause 'explicit_context_dependent failed: invalid device context - no currently active context?'

如注解所示,import pycuda.autoinit这句话程序中未使用,但是必须包含,否则程序运行会出错。

回到顶部(go to top)

2. 保存onnx模型

def saveONNX(model, filepath, c, h, w):
    model = model.cuda()
    dummy_input = torch.randn(1, c, h, w, device='cuda')
    torch.onnx.export(model, dummy_input, filepath, verbose=True)

回到顶部(go to top)

3. 创建tensorrt引擎

复制代码

def build_engine(onnx_file_path):
    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)   # INFO
    # For more information on TRT basics, refer to the introductory samples.
    with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
        if builder.platform_has_fast_fp16:
            print('this card support fp16')
        if builder.platform_has_fast_int8:
            print('this card support int8')

        builder.max_workspace_size = 1 << 30
        with open(onnx_file_path, 'rb') as model:
           parser.parse(model.read())
        return builder.build_cuda_engine(network)

# This function builds an engine from a Caffe model.
def build_engine_int8(onnx_file_path, calib):
    TRT_LOGGER = trt.Logger()
    with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
        # We set the builder batch size to be the same as the calibrator's, as we use the same batches
        # during inference. Note that this is not required in general, and inference batch size is
        # independent of calibration batch size.
        builder.max_batch_size = 1  # calib.get_batch_size()
        builder.max_workspace_size = 1 << 30
        builder.int8_mode = True
        builder.int8_calibrator = calib
        with open(onnx_file_path, 'rb') as model:
           parser.parse(model.read())   # , dtype=trt.float32
        return builder.build_cuda_engine(network)

复制代码

回到顶部(go to top)

4. 保存及载入引擎

复制代码

def save_engine(engine, engine_dest_path):
    buf = engine.serialize()
    with open(engine_dest_path, 'wb') as f:
        f.write(buf)

def load_engine(engine_path):
    TRT_LOGGER = trt.Logger(trt.Logger.WARNING)  # INFO
    with open(engine_path, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime:
        return runtime.deserialize_cuda_engine(f.read())

复制代码

回到顶部(go to top)

5. 分配缓冲区

复制代码

class HostDeviceMem(object):
    def __init__(self, host_mem, device_mem):
        self.host = host_mem
        self.device = device_mem

    def __str__(self):
        return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)

    def __repr__(self):
        return self.__str__()

def allocate_buffers(engine):
    inputs = []
    outputs = []
    bindings = []
    stream = cuda.Stream()
    for binding in engine:
        dtype = trt.nptype(engine.get_binding_dtype(binding))
        # Allocate host and device buffers
        host_mem = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size, dtype)
        device_mem = cuda.mem_alloc(host_mem.nbytes)
        # Append the device buffer to device bindings.
        bindings.append(int(device_mem))
        # Append to the appropriate list.
        if engine.binding_is_input(binding):
            inputs.append(HostDeviceMem(host_mem, device_mem))
        else:
            outputs.append(HostDeviceMem(host_mem, device_mem))
    return inputs, outputs, bindings, stream

复制代码

回到顶部(go to top)

6. 前向推断

复制代码

def do_inference(context, bindings, inputs, outputs, stream, batch_size=1):
    # Transfer input data to the GPU.
    [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
    # Run inference.
    context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle)
    # Transfer predictions back from the GPU.
    [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
    # Synchronize the stream
    stream.synchronize()
    # Return only the host outputs.
    return [out.host for out in outputs]

复制代码

回到顶部(go to top)

7. 矫正(Calibrator)

使用tensorrt的int8时,需要矫正。具体可参见test_onnx_int8及calibrator.py。

回到顶部(go to top)

8. 具体的推断代码

img_numpy = img.ravel().astype(np.float32)
np.copyto(inputs[0].host, img_numpy)
output = do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
output = [np.reshape(stage_i, (10)) for stage_i in output]  # 有多个输出时遍历

回到顶部(go to top)

9. 代码分析

程序中主要包括下面6个函数。

复制代码

test_pytorch()            # 测试pytorch模型的代码
export_onnx()             # 导出pytorch模型到onnx模型
test_onnx_fp32()          # 测试tensorrt的fp32模型(有保存引擎的代码)
test_onnx_fp32_engine()   # 测试tensorrt的fp32引擎的代码
test_onnx_int8()          # 测试tensorrt的int8模型(有保存引擎的代码)
test_onnx_int8_engine()   # 测试tensorrt的int8引擎的代码

复制代码

回到顶部(go to top)

10. 说明

9的部分函数中,最开始有一句:

torch.load('mnist_cnn_3.pth')    # 如果结果不对,加上这句话

因为有时候会碰到,不使用这句话,直接运行代码时,结果完全不正确;加上这句话之后,结果正确了。

具体原因为找到。。。也就先记在这里吧。

标签:engine,stream,mem,onnx,host,TensorRT,pytorch,使用,device
来源: https://blog.csdn.net/baobei0112/article/details/100995870

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有