ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

基于TensorFlow简单实现手写体数字识别

2019-06-28 17:55:18  阅读:395  来源: 互联网

标签:... batch tf cost 手写体 TensorFlow 识别 data mnist


本案例采用的是MNIST数据集[1],是一个入门级的计算机视觉数据集。

MNIST数据集已经被嵌入到TensorFlow中,可以直接下载和安装。

1 from tensorflow.examples.tutorials.mnist import input_data
2 mnist=input_data.read_data_sets("MNIST_data/",one_hot=True)

此时,文件名为MNIST_data的数据集就下载下来了,其中one_hot=True为将样本标签转化为one_hot编码。

接下来将MNIST的信息打印出来。

3 print('输入数据:',mnist.train.images)
4 print('输入数据的尺寸:',mnist.train.images.shape)
5 import pylab
6 im=mnist.train.images[0]  #第一张图片
7 im=im.reshape(-1,28)
8 pylab.imshow(im)
9 pylab.show()

输出为:

输入数据: [[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
输入数据的尺寸: (55000, 784)

MNIST的图片尺寸为28*28,数据集的存储把所有的图片存在一个矩阵中,将一张图片铺开存为一个行向量,从输出信息我们可以知道训练集包含55000张图片。

MNIST中还包括测试集和验证集,大小分别为10000和5000。

10 print("测试集大小:",mnist.test.images.shape)
11 print("验证集大小:",mnist.validation.images.shape)

测试集用于训练过程中评估模型的准确度,验证集用于最终评估模型的准确度。

接下来就可以进行识别了,采用最简单的单层神经网络的方法,大致顺序就是定义输入-学习参数-学习参数和输入计算-计算损失-定义优化函数-迭代优化

 1 import tensorflow as tf
 2 mnist=input_data.read_data_sets("MNIST_data/",one_hot=True)
 3 tf.reset_default_graph()   #清除默认图形堆栈并重置全局默认图形
 4 #定义占位符
 5 x=tf.placeholder(tf.float32,[None,784])   #图像28*28=784
 6 y=tf.placeholder(tf.float32,[None,10])    #标签10类
 7 #定义学习参数
 8 w=tf.Variable(tf.random_normal([784,10])) #权值,初始化为正太随机值
 9 b=tf.Variable(tf.zeros([10]))             #偏置,初始化为0
10 #定义输出
11 pred=tf.nn.softmax(tf.matmul(x,w)+b)      #相当于单层神经网络,激活函数为softmax
12 #损失函数
13 cost=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1))  #reduction_indices指定计算维度
14 #优化函数
15 optimizer=tf.train.GradientDescentOptimizer(0.01).minimize(cost)   #梯度下降优化器,学习率为0.01
16 #定义训练参数
17 training_epochs=25   #训练次数
18 batch_size=100       #每次训练图像数量
19 display_step=1       #打印训练信息周期
20 
21 #开始训练
22 with tf.Session() as sess :
23     sess.run(tf.global_variables_initializer())   #初始化所有参数
24     for epoch in range(training_epochs) :
25         avg_cost=0.                               #平均损失
26         total_batch=int(mnist.train.num_examples/batch_size)   #计算总的训练批次
27         for i in range(total_batch) :
28             batch_xs, batch_ys=mnist.train.next_batch(batch_size)  #抽取数据
29             _, c=sess.run([optimizer,cost], feed_dict={x:batch_xs, y:batch_ys})  #运行
30             avg_cost+=c/total_batch
31         if (epoch+1) % display_step == 0 :
32             print("Epoch:",'%04d'%(epoch+1),"cost=","{:.9f}".format(avg_cost))
33     print("Finished!")

运行得到的结果:

Epoch: 0001 cost= 7.973125283
...
Epoch: 0025 cost= 0.898346810
Finished!

可以看出,损失降低了很多,得到的结果还不错,这只是简单的模型,使用复杂的模型可以得到更好的结果,将在以后给出。

[1] http://yann.lecun.com/exdb/mnist/

标签:...,batch,tf,cost,手写体,TensorFlow,识别,data,mnist
来源: https://www.cnblogs.com/xbyfight/p/11103979.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有