ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

CenterNet pytorch 转 libtorch模型并使用

2020-01-21 17:38:02  阅读:939  来源: 互联网

标签:std temp CenterNet col libtorch pytorch PlateDangerousDetection include data0


使用原版的github上的centerNet 生成模型,这部分参考我的另外一篇博文:

https://blog.csdn.net/qq_31610789/article/details/99938631

 

c++后,需要用到libtorch库,按照官方教程编译即可,CMakeList.txt如下:

cmake_minimum_required(VERSION 3.13)
project(CenterNetCppPro)

set(CMAKE_CXX_STANDARD 14)


include_directories(${CMAKE_SOURCE_DIR}/3rdParty/opencv-3.2.0/include)
include_directories(${CMAKE_SOURCE_DIR}/3rdParty/cuda/include)


include_directories(${CMAKE_SOURCE_DIR}/3rdParty/pytorch_libs/include)
include_directories(${CMAKE_SOURCE_DIR}/3rdParty/pytorch_libs/include/torch/csrc/api/include)

link_directories(${CMAKE_SOURCE_DIR}/3rdParty/opencv-3.2.0/lib)
link_directories(${CMAKE_SOURCE_DIR}/3rdParty/cuda/lib)
link_directories(${CMAKE_SOURCE_DIR}/3rdParty/pytorch_libs/lib)

#add_library()

add_executable(CenterNetCppPro main.cpp plateDangerousDetection/plateDangerousDetection.cpp plateDangerousDetection/include/plateDangerousDetection.hpp)

target_link_libraries(CenterNetCppPro opencv_core opencv_imgproc opencv_imgcodecs opencv_highgui caffe2 caffe2_gpu c10_cuda c10 torch cuda /usr/local/cuda/lib64/libnvrtc.so)

这里基本使用,

主函数:

处理函数,包括保存热力图:并且验证了是否和pytorch版本一致的代码。

int PlateDangerousDetection::process(cv::Mat& img, std::vector<PlateDangerousOutput>& output) {
    std::vector<torch::jit::IValue> inputs;  //def an input
    cv::Mat img1 = cv::imread(
            "/data_1/vir/weixianpinche/train_ws/CenterNet/c++/sample/130642000000025553+冀FK8196+134414+130642000000025553+01+0+@1873077@@@2019-12-18#23#29#50+a1+0_0.jpg");
    cv::Mat image, float_image;
    std::cout << std::to_string(img1.at<cv::Vec3b>(0, 0)[0]) << std::endl;
    std::cout << std::to_string(img1.at<cv::Vec3b>(0, 0)[1]) << std::endl;
    std::cout << std::to_string(img1.at<cv::Vec3b>(0, 0)[2]) << std::endl;

    resize(img1, image, cv::Size(imageHeight_, imageWidth_), cv::INTER_LINEAR);  // resize 图像
    std::cout << std::to_string(image.at<cv::Vec3b>(0, 0)[0]) << std::endl;
    std::cout << std::to_string(image.at<cv::Vec3b>(0, 0)[1]) << std::endl;
    std::cout << std::to_string(image.at<cv::Vec3b>(0, 0)[2]) << std::endl;
//    cvtColor(image, image, CV_BGR2RGB);  // bgr -> rgb
    image.convertTo(float_image, CV_32F, 1.0 / 255);   //归一化到[0,1]区间 TODO
    float *point_img;
    std::cout << float_image.at<cv::Vec3f>(0, 0)[0] << std::endl;
    std::cout << float_image.at<cv::Vec3f>(0, 0)[1] << std::endl;
    std::cout << float_image.at<cv::Vec3f>(0, 0)[2] << std::endl;

//    point_img = float_image.ptr(32);
//    std::cout << *(float_image.data) << std::endl;  //输出一个像素点点值
    auto img_tensor = torch::CPU(torch::kFloat32).tensorFromBlob(float_image.data, {1, imageHeight_, imageWidth_,
                                                                                    3});   //将cv::Mat转成tensor,大小为1,224,224,3
    img_tensor = img_tensor.permute({0, 3, 1, 2});  //调换顺序变为torch输入的格式 1,3,224,224
    //img_tensor[0][0] = img_tensor[0][0].sub_(0.485).div_(0.229);  //减去均值,除以标准差
    //img_tensor[0][1] = img_tensor[0][1].sub_(0.456).div_(0.224);
    //img_tensor[0][2] = img_tensor[0][2].sub_(0.406).div_(0.225);
    auto img_var = torch::autograd::make_variable(img_tensor, false);  //不需要梯度
    inputs.emplace_back(img_var.to(at::kCUDA));  // 把预处理后的图像放入gpu
    torch::Tensor result = centerNet->forward(inputs).toTensor();  //前向传播获取结果
    inputs.pop_back();
    std::cout << "result.sizes() = " << result.sizes() << std::endl;
    std::cout << "Forward over!!!" << std::endl;
    for (int i = 0; i < 8; ++i) {
        std::cout << "result:" + std::to_string(i) + " " << result[0][i][0][0] << std::endl;
    }
    // result (1, 8, 128, 128)   (0, 0-3, 128, 128) pythorch-'hm'(but hasn't been sigmoided.) /(0, 4-5, 128, 128) pythorch-'wh' / (0, 6-7, 128, 128) pythorch-'reg'

    torch::Tensor hm = at::select(result, 1, 2);
    std::cout << "hm.sizes() = " << hm.sizes() << std::endl;
    int splitList[] = {4, 2, 2};
    std::vector<torch::Tensor> splitResult = torch::split_with_sizes(result, {4, 2, 2}, 1);
    for (auto tt:splitResult) {
        std::cout << "tt.sizes() = " << tt.sizes() << std::endl;
    }

    for (int i = 0; i < 8; ++i) {
        std::cout << "result:" + std::to_string(i) + " " << result[0][i][0][0] << std::endl;
    }


    std::vector<torch::Tensor> splitHeatMap = torch::split_with_sizes(splitResult[0], {3, 1}, 1);
    torch::Tensor heatMapTensor012 = splitHeatMap[0];
    heatMapTensor012 = heatMapTensor012.reshape({3, 128, 128});
    heatMapTensor012 = heatMapTensor012.squeeze().detach().permute({1, 2, 0});
    std::cout << "heatMapTensor012.sizes() = " << heatMapTensor012.sizes() << std::endl;
    torch::Tensor heatMapTensor012Img = heatMapTensor012.add(5).mul(255).clamp(0, 255).to(torch::kU8);
//    heatMapTensor012.sum(1)
    for (int i = 0; i < 3; ++i) {
        std::cout << "heatMapTensor012Img:" + std::to_string(i) + " " << heatMapTensor012Img[i][0][0] << std::endl;
    }
    heatMapTensor012Img = heatMapTensor012Img.to(torch::kCPU);
    cv::Mat resultImg(128, 128, CV_8UC3);
    std::memcpy((void *) resultImg.data, heatMapTensor012Img.data_ptr(),
                sizeof(torch::kU8) * heatMapTensor012Img.numel());
    cv::Mat resizedResultImg;
    cv::resize(resultImg, resizedResultImg, img1.size());
    int r, g, b;
    for (int row = 0; row < resizedResultImg.rows; ++row) {
        uchar *data0 = img1.ptr<uchar>(row);
        uchar *data1 = resizedResultImg.ptr<uchar>(row);
        for (int col = 0; col < resizedResultImg.cols; ++col) {
            // ---------【开始处理每个像素】-------------
            data0[col * 3] = data0[col * 3] * 0.3;
            data0[col * 3+1] = data0[col * 3+1] * 0.3;
            data0[col * 3+2] = data0[col * 3+2] * 0.3;
            int temp = 0;
            temp = data0[col * 3] + (data1[col * 3] + data0[col * 3 + 1] + data0[col * 3 + 2])/3.1;
            data0[col*3] = temp > 255? 255: temp;
            temp = data0[col * 3+1] + (data1[col * 3] + data0[col * 3 + 1] + data0[col * 3 + 2])/3.1;
            data0[col*3+1] = temp > 255? 255: temp;
            temp = data0[col*3+2] + (data1[col * 3] + data0[col * 3 + 1] + data0[col * 3 + 2]);
            data0[col*3+2] = temp > 255? 255: temp;

        }
    }




    cv::imwrite("/data_1/vir/weixianpinche/train_ws/CenterNet/c++/sample/centerNet.jpg", img1);
    cv::imshow("heapmap012", img1);
    cv::waitKey();
    bool centerNetDecoder(torch::Tensor &OutPutTensor, std::vector<PlateDangerousOutput> &output);
}
// 初始化等函数
PlateDangerousDetection::PlateDangerousDetection() {

}

PlateDangerousDetection::~PlateDangerousDetection()
{

}

PlateDangerousDetection &PlateDangerousDetection::ins() {
    static thread_local PlateDangerousDetection obj;
    return obj;
}

int PlateDangerousDetection::init(const std::string& configPath) {
    torch::NoGradGuard no_grad;
    centerNet = torch::jit::load(configPath+"/plateDangerousDetection/torch_model.pt");
    centerNet->to(at::kCUDA);
    assert(centerNet != nullptr);
    std::cout << "[INFO] init model done...\n";
    return 0;
}

主函数如下:一些非关键代码没有贴上来。

int main() {

    std::cout << "Hello, World!" << std::endl;
    int flag = PlateDangerousDetection::ins().init("/data_1/vir/weixianpinche/train_ws/CenterNet/c++");
    if (flag != 0) {
        std::cout << "VIRPlateRecognition init faild" << std::endl;
        return flag;
    }

    std::vector<std::vector<PlateDangerousOutput>> plateDangerousOutputs;
    std::vector<PlateDangerousOutput> plateDangerousOutput;
    cv::Mat carRect = cv::imread("/data_1/vir/weixianpinche/train_ws/CenterNet/c++/sample/130642000000025553+冀FK8196+134414+130642000000025553+01+0+@1873077@@@2019-12-18#23#29#50+a1+0_0.jpg");
    PlateDangerousDetection::ins().process(carRect, plateDangerousOutput);

    return 0;
}

下面是一些后处理(事实上后处理是比较复杂的部分,等待移植更新,可以使用c++版本numpy移植)

AaronJiang395 发布了38 篇原创文章 · 获赞 8 · 访问量 1万+ 私信 关注

标签:std,temp,CenterNet,col,libtorch,pytorch,PlateDangerousDetection,include,data0
来源: https://blog.csdn.net/qq_31610789/article/details/104063943

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有