ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

bp神经网络---MATLAB

2021-12-26 14:32:43  阅读:229  来源: 互联网

标签:10 matrix 输出 photo --- num bp MATLAB 神经元


文章目录

前言

这个程序可以识别1-9这几个数字,其中训练集一共270张图片用作训练,也就是每个数字取30张用作训练。最后才测试集中,每个数字找了10张,也就是90张,用于是否能够准确识别。注本文转载于:(BP神经网络识别手写数字项目解析及代码),原有代码无法直接运行,故写下本文,在其基础上提供了数据集和测试集。

1:数据集来源

准备工作,将训练集和测试集准备好。使用的库来自于http://www.ee.surrey.ac.uk/CVSSP/demos/chars74k/#download
使用其中的EnglishHnd.tgz库,对每个数字选出前30张作为训练,再另外取10张作为测试。
在这里插入图片描述

图1

在这里插入图片描述

图2

在这里插入图片描述

图3

在这里插入图片描述

图4

在这里插入图片描述

图5

2:原理

(1) 大白话讲解传统神经网络
首先,我们看一下神经网络的基本单元——单个的神经元:
在这里插入图片描述

图中圆形表示一个神经元,我们知道,一个神经元接收相邻的神经元传来的刺激,神经元对这些刺激以不同的权重进行积累,到一定的时候产生自己的刺激将其传递给一些与它相邻的神经元。这样工作的无数个神经元便构成了人脑对外界的感知。而人脑对世界的学习的机制就是通过调节这些相邻连接的神经元刺激的权重。
在图中,周围神经元传过来的刺激表示为Y,权重表示为W,圆形表示的神经元得到的刺激是所有刺激按照权重累加起来,即在这里插入图片描述
同时这个神经元作为网络的一份子,也像其他神经元一样需要向外传播刺激信号,但是不是直接把s传播,而是传播一个f(s)出去,为什么呢?其实无关大局,我们后面分析。其中f(s)学名称为“激活函数”,常用的函数如下:
在这里插入图片描述
好了,理解到这里如果没什么问题的话,恭喜你你已经入门了,现在我们把这一个个的基本单元连接起来,就构成我们最终的神经网络了。传统的神经网络结构如下图所示
在这里插入图片描述
是不是觉得很乱?不着急我们一点一点看,由整体到细微来解剖它。首先整体上,它的结构分为三部分,输入层,隐藏层和输出层,一般输入层和输出层各一个,隐藏层若干个,图中画出了一个。细微处,连接结构上,后一层的每个神经元都由前一层的所有神经元连接进来。
手写数字识别实验使用的是三层的神经网络结构,即只有一个隐藏层,下面以此说明。
下面说明一下各层的表示和各层的关系:
输入层:X=(x1,x2,x3…xn)
隐藏层:Y=(y1,y2,y3…ym)
输出层:O=(o1,o2,o3…or)

两个权重:
输入层到隐藏层的权重:V=(V1,V2,V3…Vm),Vj是一个列向量,表示输入层所有神经元通过Vj加权,得到隐藏层的第j个神经元
隐藏层到输出层的权重:W=(W1,W2,W3…Wr),Wk是一个列向量,表示隐藏层的所有神经元通过Wk加权,得到输出层的第k个神经元

根据我们上面说到的单个神经元的刺激传入和刺激传出,相信到这里很多人应该已经得出下面的各层之间的关系了:
在这里插入图片描述
到这里,神经网络的工作过程就清楚一些了。实例说明一下,假设输入是一张图像, 16x16大小,转换为一个二维的灰度值矩阵,再把每一行拼接在上一行的末尾,拼接成一个1x256的行向量,作为输入层的输入,即X,接下来按照公式2就可以计算出隐藏层,然后根据公式1又可以计算出输出层,输出层的输出就得到了。在这个手写数字识别的项目中,我使用的图片输入正是16x16,所以输入层有256个神经元,隐藏层的神经元我取了64个,最后的输出层神经元我取的是10个,为什么是10个呢?因为数字0到9一共10个,期望上是,比如输入一张写着数字1的图像,在输出端得到的输出是{1 0 0 0 0 0 0 0 0 0},输入的图像为2时,输出{ 0 1 0 0 0 0 0 0 0 0},以此类推,实际上输出的未必就是刚好1和刚好0,经过调参和训练,基本是输出0.9多和正负的0.0多,不过也足够了,仅仅用判断最大值所在位置的方式就可以识别到图像上的数字。

至此,我们已经了解了整个网络的结构和正向工作的具体流程。可以说我们已经对神经网络理解了有50%了。为什么才50%呢?仔细想想相信你会发现,我们还不知道两个在网络中很重要的量,就是权重矩阵W和V。

如何求得W和V呢,这里要用到一种算法,就是误差反向传播算法(Error Back Propagation Algorithm) ,简称BP 算法。说的很晦涩,我们来翻译成人话。先来看看算法的工作过程,首先随机地初始化W和V的值,然后代入一些图片进行计算,得到一个输出,当然由于W和V参数不会刚好很完美,输出自然不会是像上文说的,刚好就是{1 0 0 0 0 0 0 0 0 0}这一类,所以存在误差,根据这个误差就可以反过来修正W和V的值,修正后的W和V可以使输出更加的靠近于理想的输出,这就是所谓的“误差反向传播”的意思,修正一次之后,再代入其他一些图片,输出离理想输出又靠近了一点,我们又继续计算误差,然后修正W和V的值,就这样经过很多次的迭代计算,最终多次修正得到了比较完美的W和V矩阵,它可以使得输出非常靠近于理想的输出,至此我们的工作完成度才是100%了。这种在输出端计算误差,根据误差来作调节的思想,学自动化的或者接触过飞思卡尔一类的智能车比赛的同学体会应该是比较深的,跟PID自控算法有很大相似性。

下面是数学推导,关于实际输出和理想输出之间的误差如何具体来调节W和V的值,调节多少的问题。上面说过了,暂时不理解的话可以先跳过推导,看最后的结论就好,自己最后跟着代码实践一遍,有了更深的体会,慢慢会理解的。

(2)逆向传播算法的数学推导
输出层的理想输出:d=(d1,d2,d3…dr),例如{1 0 0 0 0 0 0 0 0 0}和{0 1 0 0 0 0 0 0 0 0}等
假设实际输出和理想输出之间的差距是E,明显W是一个关于输入X,权重W和V,输出O的函数。要修正W,则需要知道具体的修正增量ΔW,离散情况下,表征微分增量,可以得到:
在这里插入图片描述
在这里插入图片描述
这样,改变η的大小即可改变每一次调节的幅度,η大的话调节更快,小则调节慢,但是过大容易导致振荡,这一点也跟PID中的比例系数P是一样的。一般η的大小需要经过多次尝试来找到合适值。

好了,到这里神经网络就讲解完毕,下面是一个较次要的内容,我们上面说了,通过不断迭代来调整权重W和V,那么如何衡量迭代是否可以停止了呢。一个自然的想法是判断每次的输出和理想输出是否足够接近,所以我们可以用算向量距离的方法,跟均方差是一个 道理,如下:
在这里插入图片描述
这样,主要s足够小,迭代就可以结束了。

3:代码部分

train.m的代码
V=double(rand(256,64));
W=double(rand(64,10));
delta_V=double(rand(256,64));
delta_W=double(rand(64,10));

yita=0.2;%缩放系数,有的文章称学习率
yita1=0.05;%我自己加的参数,缩放激活函数的自变量防止输入过大进入函数的饱和区,可以去掉体会一下变化
train_number=9;%训练样本中,有多少个数字,一共9个,没有0
train_num=30;%训练样本中,每种数字多少张图,一共100张
x=double(zeros(1,256));%输入层
y=double(zeros(1,64));%中间层,也是隐藏层
output=double(zeros(1,10));%输出层
tar_output=double(zeros(1,10));%目标输出,即理想输出
delta=double(zeros(1,10));%一个中间变量,可以不管

%记录总的均方差便于画图
s_record=1:400;
tic %计时
for train_control_num=1:400   %训练次数控制,在调参的最后发现1000次其实有多了,大概400次完全够了
    s=0;
    %读图,输入网络
    for number=2:(1+train_number) %train_number=10 由于图片的文件名是从img002开始的,所以这里是2
        ReadDir=['C:\Users\Lenovo\Desktop\BP\train_picture\'];%读取样本的路径
        for num=1:train_num  %控制多少张  train_num=30
            if number~=10  %当图片的文件名到10的时候名字会变化,故这里分情况讨论
                photo_name=['img00',num2str(number) ,'-',num2str(num,'%03d'),'.png'];%图片名是拼接而成的
                photo_index=[ReadDir,photo_name];%路径加图片名得到总的图片索引
                photo_matrix=imread(photo_index);%使用imread得到图像矩阵
                photo_matrix=rgb2gray(photo_matrix);
                photo_matrix=imresize(photo_matrix,[16,16]);
                photo_matrix=uint8(photo_matrix<=230);%二值化,黑色是1
                
                tmp=photo_matrix';
                tmp=tmp(:);%以上两步完成了图像二维矩阵转变为列向量,256维,作为输入
                %计算输入层输入
                x=double(tmp');%转化为行向量因为输入层X是行向量,并且化为浮点数
                %得到隐层输入
                y0=x*V;
                %激活
                y=1./(1+exp(-y0*yita1));
                %得到输出层输入
                output0=y*W;
                % lf=lf+1;之前用来看错误在哪一步的参考数据
                output=1./(1+exp(-output0*yita1));
                %计算预期输出
                tar_output=double(zeros(1,10));
                tar_output(number)=1.0;
                %计算误差
                %按照公式计算W和V的调整,为了避免使用for循环比较耗费时间,下面采用了直接矩阵乘法,更高效
                delta=(tar_output-output).*output.*(1-output);
                delta_W=yita*repmat(y',1,10).*repmat(delta,64,1);
                tmp=sum((W.*repmat(delta,64,1))');
                tmp=tmp.*y.*(1-y);
                delta_V=yita*repmat(x',1,64).*repmat(tmp,256,1);
                %计算均方差
                s=s+sum((tar_output-output).*(tar_output-output))/10;
                %更新权值
                W=W+delta_W;
                V=V+delta_V;
            else
                photo_name=['img0',num2str(number) ,'-',num2str(num,'%03d'),'.png'];%图片名
                photo_index=[ReadDir,photo_name];%路径加图片名得到总的图片索引
                photo_matrix=imread(photo_index);%使用imread得到图像矩阵
                photo_matrix=rgb2gray(photo_matrix);
                photo_matrix=imresize(photo_matrix,[16,16]);
                photo_matrix=uint8(photo_matrix<=230);%二值化,黑色是1
                
                tmp=photo_matrix';
                tmp=tmp(:);%以上两步完成了图像二维矩阵转变为列向量,256维,作为输入
                %计算输入层输入
                x=double(tmp');%转化为行向量因为输入层X是行向量,并且化为浮点数
                %得到隐层输入
                y0=x*V;
                %激活
                y=1./(1+exp(-y0*yita1));
                %得到输出层输入
                output0=y*W;
                % lf=lf+1;之前用来看错误在哪一步的参考数据
                output=1./(1+exp(-output0*yita1));
                %计算预期输出
                tar_output=double(zeros(1,10));
                tar_output(number)=1.0;
                %计算误差
                %按照公式计算W和V的调整,为了避免使用for循环比较耗费时间,下面采用了直接矩阵乘法,更高效
                delta=(tar_output-output).*output.*(1-output);
                delta_W=yita*repmat(y',1,10).*repmat(delta,64,1);
                tmp=sum((W.*repmat(delta,64,1))');
                tmp=tmp.*y.*(1-y);
                delta_V=yita*repmat(x',1,64).*repmat(tmp,256,1);
                %计算均方差
                s=s+sum((tar_output-output).*(tar_output-output))/10;
                %更新权值
                W=W+delta_W;
                V=V+delta_V;
            end
        end
    end
    s=s/train_number/train_num;  %不加分号,随时输出误差观看收敛情况
    train_control_num           %不加分号,随时输出迭代次数观看运行状态
    s_record(train_control_num)=s;%记录
end
toc %计时结束
plot(1:400,s_record);
save result W V yita1; %保存W V yita1的结果,将其命名为result2 
test.m的代码
correct_num=0;%记录正确的数量
incorrect_num=0;%记录错误数量
test_number=9;%测试集中,一共多少数字,9个,没有0
test_num=10;%测试集中,每个数字多少个,最大100个
%load W;%%之前训练得到的W保存了,可以直接加载进来
%load V;
%load yita1;

%记录时间
tic %计时开始
for number=2:(1+test_number)
    ReadDir=['C:\Users\Lenovo\Desktop\BP\test_picture\'];
    for num=31:(30+test_num)  %控制多少张
        if number~=10
            photo_name=['img00',num2str(number) ,'-',num2str(num,'%03d'),'.png'];%图片名
            photo_index=[ReadDir,photo_name];%路径加图片名得到总的图片索引
            photo_matrix=imread(photo_index);%使用imread得到图像矩阵
            photo_matrix=rgb2gray(photo_matrix);
            photo_matrix=imresize(photo_matrix,[16,16]);%大小改变
            photo_matrix=uint8(photo_matrix<=230);%二值化,黑色是1
            %行向量
            tmp=photo_matrix';
            tmp=tmp(:);
            %计算输入层输入
            x=double(tmp');
            %得到隐层输入
            y0=x*V;
            %激活
            y=1./(1+exp(-y0*yita1));
            %得到输出层输入
            o0=y*W;
            o=1./(1+exp(-o0*yita1));
            %最大的输出即是识别到的数字
            [o,index]=sort(o);
            if index(10)==number
                correct_num=correct_num+1
            else
                incorrect_num=incorrect_num+1;
                %显示不成功的数字,显示会比较花时间
%                      figure(incorrect_num)
%                      imshow((1-photo_matrix)*255);
%                      title(num2str(number));
            end
        else
            photo_name=['img0',num2str(number) ,'-',num2str(num,'%03d'),'.png'];%图片名
            photo_index=[ReadDir,photo_name];%路径加图片名得到总的图片索引
            photo_matrix=imread(photo_index);%使用imread得到图像矩阵
            photo_matrix=rgb2gray(photo_matrix);
            photo_matrix=imresize(photo_matrix,[16,16]);%大小改变
            photo_matrix=uint8(photo_matrix<=230);%二值化,黑色是1
            %行向量
            tmp=photo_matrix';
            tmp=tmp(:);
            %计算输入层输入
            x=double(tmp');
            %得到隐层输入
            y0=x*V;
            %激活
            y=1./(1+exp(-y0*yita1));
            %得到输出层输入
            o0=y*W;
            o=1./(1+exp(-o0*yita1));
            %最大的输出即是识别到的数字
            [o,index]=sort(o);
            if index(10)==number
                correct_num=correct_num+1
            else
                incorrect_num=incorrect_num+1;
                %显示不成功的数字,显示会比较花时间
                %     figure(incorrect_num)
                %     imshow((1-photo_matrix)*255);
                %     title(num2str(number));
            end
        end
    end
end
correct_rate=correct_num/test_number/test_num
toc %计时结束
save result2 correct_rate; %保存识别率的结果,将其命名为result2 

补充,在test.m中,有以下几行代码,这里对其进行解释:

o=1./(1+exp(-o0*yita1));
            %最大的输出即是识别到的数字
            [o,index]=sort(o);
            if index(10)==number
                correct_num=correct_num+1
            else

看图:在这里插入图片描述

图6

而在o中,输出有10个,我们要预测的是1-9这9个数(暂且别管0),那么o中就会给出这几个数的概率,我们取概率最大的就是其预测的结果。 [o,index]=sort(o);表示对o中的概率进行排序,index为排序后的索引,
if index(10)==number表示,如果index中排序最大的数的索引正好等于此时输入的数,就表示预测正确了。说起来有点难理解,看下图吧。
在这里插入图片描述

图7

标签:10,matrix,输出,photo,---,num,bp,MATLAB,神经元
来源: https://blog.csdn.net/qq_40077565/article/details/122154534

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有