ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

在C++上利用onnxruntime (CUDA)和 opencv 部署模型onnx

2021-10-30 16:02:24  阅读:1581  来源: 互联网

标签:imgSource onnxruntime onnx cols opencv int input include cv


概述

将得到的模型转化为onnx模型,加载到c++中运行,来完成模型的部署,下载并安装onnxruntime;

CMakeLists.txt:

cmake_minimum_required(VERSION 2.8)
project(test)

#使用clang++编译器
set(CMAKE_CXX_COMPILER clang++)
set(CMAKE_BUILD_TYPE "Release")
set(CMAKE_INCLUDE_CURRENT_DIR ON)

#find the opencv and the qt5
find_package(OpenCV 4.5.1  REQUIRED)
#onnxruntime
set(ONNXRUNTIME_ROOT_PATH /home/zyl/ubuntu/tensorRT/onnxruntime-master)
set(ONNXRUNTIME_INCLUDE_DIRS ${ONNXRUNTIME_ROOT_PATH}/include/onnxruntime
                             ${ONNXRUNTIME_ROOT_PATH}/onnxruntime
                             ${ONNXRUNTIME_ROOT_PATH}/include/onnxruntime/core/session/)
set(ONNXRUNTIME_LIB ${ONNXRUNTIME_ROOT_PATH}/build/Linux/Release/libonnxruntime.so)
include_directories(${ONNXRUNTIME_INCLUDE_DIRS})


add_executable(${PROJECT_NAME} "main.cpp")
target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS} ${ONNXRUNTIME_LIB})

C++源码:

#include <iostream>
#include <opencv2/highgui.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/core.hpp>
#include <opencv2/imgcodecs.hpp>
#include <chrono>
#include <string>

//onnxruntime
#include <core/session/onnxruntime_cxx_api.h>
#include <core/providers/cuda/cuda_provider_factory.h>
#include <core/session/onnxruntime_c_api.h>
#include <core/providers/tensorrt/tensorrt_provider_factory.h>

using namespace std;

int main(int argc,char ** argv)
{
    //输入网络的维度
    static constexpr const int width = 600;
    static constexpr const int height = 600;
    static constexpr const int channel = 3;
    std::array<int64_t, 4> input_shape_{ 1,height, width,channel};

    //-------------------------------------------------------------onnxruntime-------------------------------------------------
    //图片和模型位置
    string img_path = argv[1];
    string model_path = "model.onnx";
    cv::Mat imgSource = cv::imread(img_path);

    Ort::Env env(OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, "Detection");
    Ort::SessionOptions session_options;
    //CUDA加速开启
    OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0);
    session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);

    Ort::AllocatorWithDefaultOptions allocator;
    //加载ONNX模型
    Ort::Session session(env, model_path.c_str(), session_options);
    //获取输入输出的维度
    std::vector<int64_t> input_dims = session.GetInputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
    std::vector<int64_t> output_dims = session.GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();

    /*
    session.GetOutputName(1, allocator);
    session.GetInputName(1, allocator);
    //输出模型输入节点的数量
    size_t num_input_nodes = session.GetInputCount();
    size_t num_output_nodes = session.GetOutputCount();
    */

    std::vector<const char*> input_node_names = {"image_tensor:0"};
    std::vector<const char*> output_node_names = {"detection_boxes:0","detection_scores:0","detection_classes:0","num_detections:0" };
    input_dims[0] = output_dims[0] = 1;//batch size = 1

    std::vector<Ort::Value> input_tensors;
    auto memory_info = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);

    //将图像存储到uchar数组中,BGR--->RGB
    std::array<uchar, width * height *channel> input_image_{};
    uchar* input =  input_image_.data();
    for (int i = 0; i < imgSource.rows; i++) {
        for (int j = 0; j < imgSource.cols; j++) {
            for (int c = 0; c < 3; c++)
            {
                //NHWC 格式
                if(c==0)
                    input[i*imgSource.cols*3+j*3+c] = imgSource.ptr<uchar>(i)[j*3+2];
                if(c==1)
                    input[i*imgSource.cols*3+j*3+c] = imgSource.ptr<uchar>(i)[j*3+1];
                if(c==2)
                    input[i*imgSource.cols*3+j*3+c] = imgSource.ptr<uchar>(i)[j*3+0];
                //NCHW 格式
//                if (c == 0)
//                     input[c*imgSource.rows*imgSource.cols + i * imgSource.cols + j] = imgSource.ptr<uchar>(i)[j * 3 + c];
//                if (c == 1)
//                     input[c*imgSource.rows*imgSource.cols + i * imgSource.cols + j] = imgSource.ptr<uchar>(i)[j * 3 + c];
//                if (c == 2)
//                     input[c*imgSource.rows*imgSource.cols + i * imgSource.cols + j] = imgSource.ptr<uchar>(i)[j * 3 + c];


            }
        }
    }

    input_tensors.push_back(Ort::Value::CreateTensor<uchar>(
            memory_info, input, input_image_.size(), input_shape_.data(), input_shape_.size()));
    //不知道输入维度时
    //input_tensors.push_back(Ort::Value::CreateTensor<uchar>(
    //        memory_info, input, input_image_.size(), input_dims.data(), input_dims.size()));

    chrono::steady_clock::time_point t1 = chrono::steady_clock::now();

    std::vector<Ort::Value> output_tensors;
    for(int i=0; i<100;i++)
         output_tensors = session.Run(Ort::RunOptions { nullptr },
                                                            input_node_names.data(), //输入节点名
                                                            input_tensors.data(),     //input tensors
                                                            input_tensors.size(),     //1
                                                            output_node_names.data(), //输出节点名
                                                            output_node_names.size()); //4
    chrono::steady_clock::time_point t2 = chrono::steady_clock::now();

    chrono::duration<double> delay_time = chrono::duration_cast<chrono::duration<double>>(t2 - t1); //milliseconds 毫秒

    cout<<"前向传播平均耗时:"<<delay_time.count()*1000/100.0<<"ms"<<endl;

    float* boxes_ = output_tensors[0].GetTensorMutableData<float>();
    float* scores_ = output_tensors[1].GetTensorMutableData<float>();
    float* class_ = output_tensors[2].GetTensorMutableData<float>();
    float* num_detection = output_tensors[3].GetTensorMutableData<float>();

    //-------------------------------------------------------------onnxruntime-------------------------------------------------


    //------------------循环遍历显示检测框--------------------------
    cv::Mat frame(imgSource.clone());
        for (int i = 0; i < num_detection[0]; i++)
            {
                float confidence = scores_[i];
                size_t objectClass = (size_t)class_[i];

                if (confidence >= 0.8)
                {
                    int xLeftBottom = static_cast<int>(boxes_[i*4 + 1] * frame.cols);
                    int yLeftBottom = static_cast<int>(boxes_[i*4 + 0] * frame.rows);
                    int xRightTop = static_cast<int>(boxes_[i*4 + 3] * frame.cols);
                    int yRightTop = static_cast<int>(boxes_[i*4 + 2]* frame.rows);

                    //显示检测框
                    cv::Rect object((int)xLeftBottom, (int)yLeftBottom,
                        (int)(xRightTop - xLeftBottom),
                        (int)(yRightTop - yLeftBottom));

                    cv::rectangle(frame, object, cv::Scalar(0,0,255), 2);
                    cv::String label = cv::String("confidence :") +to_string(confidence);
                    int baseLine = 0;
                    cv::Size labelSize = cv::getTextSize(label, cv::FONT_HERSHEY_SIMPLEX, 0.3, 1, &baseLine);
                    cv::rectangle(frame, cv::Rect(cv::Point(xLeftBottom, yLeftBottom - labelSize.height),
                        cv::Size(labelSize.width, labelSize.height + baseLine)),
                        cv::Scalar(255, 255, 0), cv::FILLED);
                    cv::putText(frame, label, cv::Point(xLeftBottom, yLeftBottom),
                        cv::FONT_HERSHEY_SIMPLEX, 0.3, cv::Scalar(0, 0, 0));
                }
            }
    cv::imshow("frame",frame);
    cv::waitKey(0);
    return 0;

}

 
 
参考链接:

  1. https://codechina.csdn.net/mirrors/tenglike1997/onnxruntime-projects/-/blob/master/Ubuntu/onnx_mobilenet
  2. https://blog.csdn.net/mightbxg/article/details/119237326

标签:imgSource,onnxruntime,onnx,cols,opencv,int,input,include,cv
来源: https://blog.csdn.net/qq_42995327/article/details/121051991

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有