ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

将mnist训练的caffemodel生成动态链接库DLL

2021-10-03 21:35:38  阅读:220  来源: 互联网

标签:string res namespace DLL 动态链接库 using caffemodel include class


在项目程序中经常看到动态链接库,非常好奇,想自己实现一下,于是乎尝试一波。就因为这种好奇,每天都被bug所困扰。。。

1. 训练caffemodel

在windows环境下搭建caffe无果,转投Ubuntu。。。

用的caffe--example--mnist中的文件,新建文件夹的话注意改路径,下面为train.sh

#!/usr/bin/env sh
set -e

/home/fish/caffe/build/tools/caffe train --solver=/home/fish/STUDY/lenet_solver.prototxt

 训练好后把lenet_train_test.prototxt和训练好的模型lenet_iter_10000.caffemodel拿出来。

 

 

 2. 使用cv::dnn里的API加载model,输入图片,进行测试(可跳过)

根据文章https://blog.csdn.net/sushiqian/article/details/78555891,修改模型文件。若图片为白底黑字,bitwise_not一下。

#include 
#include <opencv2/opencv.hpp>
#include <opencv2/dnn.hpp>

using namespace std;
using namespace cv;
using namespace cv::dnn;

/* Find best class for the blob (i. e. class with maximal probability) */
static void getMaxClass(const Mat& probBlob, int* classId, double* classProb)
{
    Mat probMat = probBlob.reshape(1, 1);
    Point classNumber;
    minMaxLoc(probMat, NULL, classProb, NULL, &classNumber);
    *classId = classNumber.x;
}

int main(int argc, char* argv[])
{
    string modelTxt = "C:\\Users\\ATWER\\Desktop\\lenet_train_test.prototxt";
    string modelBin = "C:\\Users\\ATWER\\Desktop\\lenet_iter_10000.caffemodel";
    string imgFileName = "C:\\Users\\ATWER\\Desktop\\9.png";
    //read image
    Mat imgSrc = imread(imgFileName);
    if (imgSrc.empty()) {
        cout << "Failed to read image " << imgFileName << endl;
        exit(-1);
    }
    Mat img;
    cvtColor(imgSrc, img, COLOR_BGR2GRAY);
    //LeNet accepts 28*28 gray image
    resize(img, img, Size(28, 28));
    bitwise_not(img, img);
    img /= 255;

    //transfer image(1*28*28) to blob data with 4 dimensions(1*1*28*28) 
    Mat inputBlob = dnn::blobFromImage(img);
    dnn::Net net;
    try {
        net = dnn::readNetFromCaffe(modelTxt, modelBin);
    }
    catch (cv::Exception& ee) {
        cerr << "Exception: " << ee.what() << endl;
        if (net.empty()) {
            cout << "Can't load the network by using the flowing files:" << endl;
            cout << "modelTxt: " << modelTxt << endl;
            cout << "modelBin: " << modelBin << endl; exit(-1);
        }
    }
    Mat pred;
    net.setInput(inputBlob, "data");//set the network input, "data" is the name of the input layer 
    pred = net.forward("prob");//compute output, "prob" is the name of the output layer 
    cout << pred << endl; int classId; double classProb; getMaxClass(pred, &classId, &classProb);
    cout << "Best Class: " << classId << endl;
    cout << "Probability: " << classProb * 100 << "%" << endl;
}

 

3. 创建动态链接库

参考https://blog.csdn.net/qq_30139555/article/details/103621955

class.h

#include 
#include <opencv2/opencv.hpp>
#include <opencv2/dnn/dnn.hpp>


using namespace std;
using namespace cv;
using namespace cv::dnn;


extern "C" _declspec(dllexport) void Classfication(char* imgpath, char* result);

在此处卡的最久,原本我写的是Classfication(string imgpath, string result),生成dll时没问题,调用时总是System.AccessViolationException: 尝试读取或写入受保护的内存。后来发现要写成指针的形式。

class.cpp

#include 
#include <opencv2/opencv.hpp>
#include <opencv2/dnn/dnn.hpp>
#include "class.h"
using namespace std;
using namespace cv;
using namespace cv::dnn;

/* Find best class for the blob (i. e. class with maximal probability) */
static void getMaxClass(const Mat& probBlob, int* classId, double* classProb)
{
    Mat probMat = probBlob.reshape(1, 1);
    Point classNumber;
    minMaxLoc(probMat, NULL, classProb, NULL, &classNumber);
    *classId = classNumber.x;
}

void Classfication(char* imgpath, char* result)
{
    string res = "";
    string modelTxt = "C:\\Users\\ATWER\\Desktop\\lenet_train_test.prototxt";
    string modelBin = "C:\\Users\\ATWER\\Desktop\\lenet_iter_10000.caffemodel";
    //string imgFileName = "C:\\Users\\ATWER\\Desktop\\9.png";
    string imgFileName = imgpath;
    //read image
    Mat imgSrc = imread(imgFileName);
    if (imgSrc.empty()) {
        cout << "Failed to read image " << imgFileName << endl;
        exit(-1);
    }
    Mat img;
    cvtColor(imgSrc, img, COLOR_BGR2GRAY);
    //LeNet accepts 28*28 gray image
    resize(img, img, Size(28, 28));
    bitwise_not(img, img);
    img /= 255;

    //transfer image(1*28*28) to blob data with 4 dimensions(1*1*28*28) 
    Mat inputBlob = dnn::blobFromImage(img);
    dnn::Net net;
    try {
        net = dnn::readNetFromCaffe(modelTxt, modelBin);
    }
    catch (cv::Exception& ee) {
        cerr << "Exception: " << ee.what() << endl;
        if (net.empty()) {
            cout << "Can't load the network by using the flowing files:" << endl;
            cout << "modelTxt: " << modelTxt << endl;
            cout << "modelBin: " << modelBin << endl; exit(-1);
        }
    }
    Mat pred;
    net.setInput(inputBlob, "data");//set the network input, "data" is the name of the input layer 
    pred = net.forward("prob");//compute output, "prob" is the name of the output layer 
    int classId; 
   double classProb;
   getMaxClass(pred, &classId, &classProb); res += to_string(classId); res += '|'; res += to_string(classProb); strcpy_s(result, 15, res.c_str()); }

4. 调用动态链接库

根据数据的长度申请非托管空间参考:https://blog.csdn.net/xiaoyong_net/article/details/50178021

文中说:“一定要加1,否则后面是乱码,原因未找到 ”,应该是打印字符串时会打印到“\n”为止,没有遇到\n会一直打印下去。.Length方法没有计算"\n",+1的空间用于存放“\n”。

using System;
using System.Runtime.InteropServices;

namespace Test
{
    class Program
    {
        [DllImport("E:/c++project/caffedll/x64/Debug/caffedll.dll", EntryPoint = "Classfication")]

        unsafe private static extern void Classfication(IntPtr imgpath, IntPtr result);
        private static IntPtr mallocIntptr(string strData)
        {
            //先将字符串转化成字节方式   
            Byte[] btData = System.Text.Encoding.Default.GetBytes(strData);
            //申请非拖管空间   
            IntPtr m_ptr = Marshal.AllocHGlobal(btData.Length);
            //给非拖管空间清0    
            Byte[] btZero = new Byte[btData.Length + 1]; //一定要加1,否则后面是乱码,原因未找到   
            Marshal.Copy(btZero, 0, m_ptr, btZero.Length);
            //给指针指向的空间赋值   
            Marshal.Copy(btData, 0, m_ptr, btData.Length);
            return m_ptr;
        }
        private static IntPtr mallocIntptr(int length)
        {
            //申请非拖管空间   
            IntPtr m_ptr = Marshal.AllocHGlobal(length);
            //给非拖管空间清0    
            Byte[] btZero = new Byte[length]; 
            Marshal.Copy(btZero, 0, m_ptr, btZero.Length);
            //给指针指向的空间赋值   
            Marshal.Copy(btZero, 0, m_ptr, length);
            return m_ptr;
        }
        static void Main(string[] args)
        {
            string s = "C:\\Users\\ATWER\\Desktop\\9.png";
            IntPtr ptrFileName;
            IntPtr res;
            //根据数据的长度申请非托管空间   
            ptrFileName = mallocIntptr(s);
            res = mallocIntptr(50);
            Classfication(ptrFileName, res);
            string result = Marshal.PtrToStringAnsi(res);
            string[] a = result.Split('|');
            Console.WriteLine("class:"+a[0]+"\n"+"score:"+a[1]);
            Marshal.FreeHGlobal(res);
        }
    }
}

 

标签:string,res,namespace,DLL,动态链接库,using,caffemodel,include,class
来源: https://www.cnblogs.com/Fish0403/p/15365048.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有