ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

ONNXRuntime学习笔记(三)

2022-05-01 00:35:45  阅读:230  来源: 互联网

标签:ONNXRuntime onnx 笔记 学习 onnxruntime 推理 time data testloader


接上一篇完成的pytorch模型训练结果,模型结构为ResNet18+fc,参数量约为11M,最终测试集Acc达到94.83%。接下来有分两个部分:导出onnx和使用onnxruntime推理。

一、pytorch导出onnx

直接放函数吧,这部分我是直接放在test.py里面的,直接从dataloader中拿到一个batch的数据走一遍推理即可。

def export_onnx(net, testloader, output_file):
    net.eval()
    with torch.no_grad():
        for data in testloader: 
            images, labels = data

            torch.onnx.export(net, 
                            (images), 
                            output_file,
                            training=False,
                            do_constant_folding=True,
                            input_names=["img"], 
                            output_names=["output"],
                  dynamic_axes={"img": {0: "b"},"output": {0: "b"}}
                  )
            print("onnx export done!")
            break

上面函数中几个比较重要的参数:do_constant_folding是常量折叠,建议打开;输入张量通过一个tuple传入,并且最好指定每个输入和输出的名称,此外,为保证使用onnxruntime推理的时候batchsize可变,dynamic_axes的第一维需要像上述一样设置为动态的。如果是全卷积做分割的网络,类似的输入h和w也应该是动态的。

单独运行test.py计算测试集效果和平均相应时间,结果为:

Test Acc is: 94.83%
Average response time cost:  0.10121344916428192

二、使用onnxruntime推理

这里我们使用gpu版本的onnxruntime库进行推理,其python包可直接pip install onnxruntime-gpu安装。onnxruntime推理代码和测试集推理代码很类似,如下:

import numpy as np
import onnxruntime as ort
import argparse, os
from lib import CIFARDataset

def onnxruntime_test(session, testloader):
    print("Start Testing!")
    input_name = session.get_inputs()[0].name
    correct = 0
    total = 0   # 计数归零(初始化)
    for data in testloader:
        images, labels = data
        images, labels = images.numpy(), labels.numpy()
        outputs = session.run(None, {input_name:images})
        predicted = np.argmax(outputs[0], axis=1)  # 取得分最高的那个类
        total += labels.shape[0]                        # 累加样本总数
        correct += (predicted == labels).sum()        # 累加预测正确的样本个数
    acc = correct / total
    print('ONNXRuntime Test Acc is: %.2f%%' % (100*acc))
            
if __name__ == '__main__':
    # 命令行参数解析
    parser = argparse.ArgumentParser("CNN backbone on cifar10")
    parser.add_argument('--onnx', default='./output/test_resnet18_10_autoaug/densenet_best.onnx')
    args = parser.parse_args()

    NUM_CLASS =10
    BATCH_SIZE = 128  # 批处理尺寸(batch_size)

    # 数据集迭代器
    data_path="./data"
    dataset = CIFARDataset(dataset_path=data_path, batchsize=BATCH_SIZE)
    _, testloader = dataset.get_cifar10_dataloader()

    # 构建session
    sess = ort.InferenceSession(args.onnx, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])

    #onnxruntime推理
    import time
    start = time.time()
    onnxruntime_test(sess, testloader)
    end = time.time()
    print("Average response time cost: ", (end-start)/len(testloader))

使用onnxruntime加载导出的onnx模型,计算测试集效果和平均响应时间,结果为:

ONNXRuntime Test Acc is: 94.83%
Average response time cost:  0.07324151147769976

三、小结

分析上面的pytorch和onnxruntime的测试结果可知,最终测试集效果是一致的,Acc均为94.83%,但onnxruntime的效率更高,耗时是pytorch的75%,但比最初目标设定的50ms高,需要进一步优化,两个方向:模型量化或并行化推理。下一篇再分析。

标签:ONNXRuntime,onnx,笔记,学习,onnxruntime,推理,time,data,testloader
来源: https://www.cnblogs.com/lee-zq/p/16211934.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有