ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

在C++上利用onnxruntime (CUDA)和 opencv 部署模型onnx

2021-10-30 16:02:24  阅读:92  来源: 互联网

标签:imgSource onnxruntime onnx cols opencv int input include cv


概述

将得到的模型转化为onnx模型,加载到c++中运行,来完成模型的部署,下载并安装onnxruntime;

CMakeLists.txt:

cmake_minimum_required(VERSION 2.8)
project(test)

#使用clang++编译器
set(CMAKE_CXX_COMPILER clang++)
set(CMAKE_BUILD_TYPE "Release")
set(CMAKE_INCLUDE_CURRENT_DIR ON)

#find the opencv and the qt5
find_package(OpenCV 4.5.1  REQUIRED)
#onnxruntime
set(ONNXRUNTIME_ROOT_PATH /home/zyl/ubuntu/tensorRT/onnxruntime-master)
set(ONNXRUNTIME_INCLUDE_DIRS ${ONNXRUNTIME_ROOT_PATH}/include/onnxruntime
                             ${ONNXRUNTIME_ROOT_PATH}/onnxruntime
                             ${ONNXRUNTIME_ROOT_PATH}/include/onnxruntime/core/session/)
set(ONNXRUNTIME_LIB ${ONNXRUNTIME_ROOT_PATH}/build/Linux/Release/libonnxruntime.so)
include_directories(${ONNXRUNTIME_INCLUDE_DIRS})


add_executable(${PROJECT_NAME} "main.cpp")
target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS} ${ONNXRUNTIME_LIB})

C++源码:

#include <iostream>
#include <opencv2/highgui.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/core.hpp>
#include <opencv2/imgcodecs.hpp>
#include <chrono>
#include <string>

//onnxruntime
#include <core/session/onnxruntime_cxx_api.h>
#include <core/providers/cuda/cuda_provider_factory.h>
#include <core/session/onnxruntime_c_api.h>
#include <core/providers/tensorrt/tensorrt_provider_factory.h>

using namespace std;

int main(int argc,char ** argv)
{
    //输入网络的维度
    static constexpr const int width = 600;
    static constexpr const int height = 600;
    static constexpr const int channel = 3;
    std::array<int64_t, 4> input_shape_{ 1,height, width,channel};

    //-------------------------------------------------------------onnxruntime-------------------------------------------------
    //图片和模型位置
    string img_path = argv[1];
    string model_path = "model.onnx";
    cv::Mat imgSource = cv::imread(img_path);

    Ort::Env env(OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, "Detection");
    Ort::SessionOptions session_options;
    //CUDA加速开启
    OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0);
    session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);

    Ort::AllocatorWithDefaultOptions allocator;
    //加载ONNX模型
    Ort::Session session(env, model_path.c_str(), session_options);
    //获取输入输出的维度
    std::vector<int64_t> input_dims = session.GetInputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
    std::vector<int64_t> output_dims = session.GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();

    /*
    session.GetOutputName(1, allocator);
    session.GetInputName(1, allocator);
    //输出模型输入节点的数量
    size_t num_input_nodes = session.GetInputCount();
    size_t num_output_nodes = session.GetOutputCount();
    */

    std::vector<const char*> input_node_names = {"image_tensor:0"};
    std::vector<const char*> output_node_names = {"detection_boxes:0","detection_scores:0","detection_classes:0","num_detections:0" };
    input_dims[0] = output_dims[0] = 1;//batch size = 1

    std::vector<Ort::Value> input_tensors;
    auto memory_info = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);

    //将图像存储到uchar数组中,BGR--->RGB
    std::array<uchar, width * height *channel> input_image_{};
    uchar* input =  input_image_.data();
    for (int i = 0; i < imgSource.rows; i++) {
        for (int j = 0; j < imgSource.cols; j++) {
            for (int c = 0; c < 3; c++)
            {
                //NHWC 格式
                if(c==0)
                    input[i*imgSource.cols*3+j*3+c] = imgSource.ptr<uchar>(i)[j*3+2];
                if(c==1)
                    input[i*imgSource.cols*3+j*3+c] = imgSource.ptr<uchar>(i)[j*3+1];
                if(c==2)
                    input[i*imgSource.cols*3+j*3+c] = imgSource.ptr<uchar>(i)[j*3+0];
                //NCHW 格式
//                if (c == 0)
//                     input[c*imgSource.rows*imgSource.cols + i * imgSource.cols + j] = imgSource.ptr<uchar>(i)[j * 3 + c];
//                if (c == 1)
//                     input[c*imgSource.rows*imgSource.cols + i * imgSource.cols + j] = imgSource.ptr<uchar>(i)[j * 3 + c];
//                if (c == 2)
//                     input[c*imgSource.rows*imgSource.cols + i * imgSource.cols + j] = imgSource.ptr<uchar>(i)[j * 3 + c];


            }
        }
    }

    input_tensors.push_back(Ort::Value::CreateTensor<uchar>(
            memory_info, input, input_image_.size(), input_shape_.data(), input_shape_.size()));
    //不知道输入维度时
    //input_tensors.push_back(Ort::Value::CreateTensor<uchar>(
    //        memory_info, input, input_image_.size(), input_dims.data(), input_dims.size()));

    chrono::steady_clock::time_point t1 = chrono::steady_clock::now();

    std::vector<Ort::Value> output_tensors;
    for(int i=0; i<100;i++)
         output_tensors = session.Run(Ort::RunOptions { nullptr },
                                                            input_node_names.data(), //输入节点名
                                                            input_tensors.data(),     //input tensors
                                                            input_tensors.size(),     //1
                                                            output_node_names.data(), //输出节点名
                                                            output_node_names.size()); //4
    chrono::steady_clock::time_point t2 = chrono::steady_clock::now();

    chrono::duration<double> delay_time = chrono::duration_cast<chrono::duration<double>>(t2 - t1); //milliseconds 毫秒

    cout<<"前向传播平均耗时:"<<delay_time.count()*1000/100.0<<"ms"<<endl;

    float* boxes_ = output_tensors[0].GetTensorMutableData<float>();
    float* scores_ = output_tensors[1].GetTensorMutableData<float>();
    float* class_ = output_tensors[2].GetTensorMutableData<float>();
    float* num_detection = output_tensors[3].GetTensorMutableData<float>();

    //-------------------------------------------------------------onnxruntime-------------------------------------------------


    //------------------循环遍历显示检测框--------------------------
    cv::Mat frame(imgSource.clone());
        for (int i = 0; i < num_detection[0]; i++)
            {
                float confidence = scores_[i];
                size_t objectClass = (size_t)class_[i];

                if (confidence >= 0.8)
                {
                    int xLeftBottom = static_cast<int>(boxes_[i*4 + 1] * frame.cols);
                    int yLeftBottom = static_cast<int>(boxes_[i*4 + 0] * frame.rows);
                    int xRightTop = static_cast<int>(boxes_[i*4 + 3] * frame.cols);
                    int yRightTop = static_cast<int>(boxes_[i*4 + 2]* frame.rows);

                    //显示检测框
                    cv::Rect object((int)xLeftBottom, (int)yLeftBottom,
                        (int)(xRightTop - xLeftBottom),
                        (int)(yRightTop - yLeftBottom));

                    cv::rectangle(frame, object, cv::Scalar(0,0,255), 2);
                    cv::String label = cv::String("confidence :") +to_string(confidence);
                    int baseLine = 0;
                    cv::Size labelSize = cv::getTextSize(label, cv::FONT_HERSHEY_SIMPLEX, 0.3, 1, &baseLine);
                    cv::rectangle(frame, cv::Rect(cv::Point(xLeftBottom, yLeftBottom - labelSize.height),
                        cv::Size(labelSize.width, labelSize.height + baseLine)),
                        cv::Scalar(255, 255, 0), cv::FILLED);
                    cv::putText(frame, label, cv::Point(xLeftBottom, yLeftBottom),
                        cv::FONT_HERSHEY_SIMPLEX, 0.3, cv::Scalar(0, 0, 0));
                }
            }
    cv::imshow("frame",frame);
    cv::waitKey(0);
    return 0;

}

 
 
参考链接:

  1. https://codechina.csdn.net/mirrors/tenglike1997/onnxruntime-projects/-/blob/master/Ubuntu/onnx_mobilenet
  2. https://blog.csdn.net/mightbxg/article/details/119237326

标签:imgSource,onnxruntime,onnx,cols,opencv,int,input,include,cv
来源: https://blog.csdn.net/qq_42995327/article/details/121051991

专注分享技术,共同学习,共同进步。侵权联系[admin#icode9.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有