ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

从TensorFlow的mnist数据集导出手写体数字图片

2019-09-03 15:02:22  阅读:228  来源: 互联网

标签:labels images shape train 手写体 TensorFlow mnist 图片


在TensorFlow的官方入门课程中,多次用到mnist数据集。

mnist数据集是一个数字手写体图片库,但它的存储格式并非常见的图片格式,所有的图片都集中保存在四个扩展名为idx3-ubyte的二进制文件。

如果我们想要知道大名鼎鼎的mnist手写体数字都长什么样子,就需要从mnist数据集中导出手写体数字图片。了解这些手写体的总体形状,也有助于加深我们对TensorFlow入门课程的理解。

下面先给出通过TensorFlow api接口导出mnist手写体数字图片的python代码,再对代码进行分析。代码在win7下测试通过,linux环境也可以参考本处代码。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
from PIL import Image

# 声明图片宽高
rows = 28
cols = 28

# 要提取的图片数量
images_to_extract = 8000

# 当前路径下的保存目录
save_dir = "./mnist_digits_images"

# 读入mnist数据
mnist = input_data.read_data_sets("C:\\Users\\Administrator\\Desktop\\Tensorflow\\数据集\\mnist\\", one_hot=False)

# 创建会话
sess = tf.Session()

# 获取图片总数
shape = sess.run(tf.shape(mnist.train.images))
images_count = shape[0]
pixels_per_image = shape[1]

# 获取标签总数
shape = sess.run(tf.shape(mnist.train.labels))
labels_count = shape[0]

# mnist.train.labels是一个二维张量,为便于后续生成数字图片目录名,有必要一维化(后来发现只要把数据集的one_hot属性设为False,mnist.train.labels本身就是一维)
# labels = sess.run(tf.argmax(mnist.train.labels, 1))
labels = mnist.train.labels

# 检查数据集是否符合预期格式
if (images_count == labels_count) and (shape.size == 1):
    print("数据集总共包含 %s 张图片,和 %s 个标签" % (images_count, labels_count))
    print("每张图片包含 %s 个像素" % (pixels_per_image))
    print("数据类型:%s" % (mnist.train.images.dtype))

    # mnist图像数据的数值范围是[0,1],需要扩展到[0,255],以便于人眼观看
    if mnist.train.images.dtype == "float32":
        print("准备将数据类型从[0,1]转为binary[0,255]...")
        for i in range(0, images_to_extract):
            for n in range(pixels_per_image):
                if mnist.train.images[i][n] != 0:
                    mnist.train.images[i][n] = 255
            # 由于数据集图片数量庞大,转换可能要花不少时间,有必要打印转换进度
            if ((i + 1) % 50) == 0:
                print("图像浮点数值扩展进度:已转换 %s 张,共需转换 %s 张" % (i + 1, images_to_extract))

    # 创建数字图片的保存目录
    for i in range(10):
        dir = "%s/%s/" % (save_dir, i)
        if not os.path.exists(dir):
            print("目录 ""%s"" 不存在!自动创建该目录..." % dir)
            os.makedirs(dir)

    # 通过python图片处理库,生成图片
    indices = [0 for x in range(0, 10)]
    for i in range(0, images_to_extract):
        img = Image.new("L", (cols, rows))
        for m in range(rows):
            for n in range(cols):
                img.putpixel((n, m), int(mnist.train.images[i][n + m * cols]))
        # 根据图片所代表的数字label生成对应的保存路径
        digit = labels[i]
        path = "%s/%s/%s.bmp" % (save_dir, labels[i], indices[digit])
        indices[digit] += 1
        img.save(path)
        # 由于数据集图片数量庞大,保存过程可能要花不少时间,有必要打印保存进度
        if ((i + 1) % 50) == 0:
            print("图片保存进度:已保存 %s 张,共需保存 %s 张" % (i + 1, images_to_extract))
else:
    print("图片数量和标签数量不一致!")

上述代码的实现思路如下:

1.读入mnist手写体数据;

2.把数据的值从[0,1]浮点范围转化为黑白格式(背景为0-黑色,前景为255-白色);

3.根据mnist.train.labels的内容,生成数字索引,也就是建立每一张图片和其所代表数字的关联,由此创建对应的保存目录;

4.循环遍历mnist.train.images,把每张图片的像素数据赋值给python图片处理库PIL的Image类实例,再调用Image类的save方法把图片保存在第3步骤中创建的对应目录。

标签:labels,images,shape,train,手写体,TensorFlow,mnist,图片
来源: https://www.cnblogs.com/answerThe/p/11453041.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有