ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

【机器学习】使用pyplot绘制MNIST数据集中的手写数字

2021-07-02 09:34:40  阅读:154  来源: 互联网

标签:字节 pyplot labels ax images path 手写 MNIST 标注


MNIST数据集是人工智能大佬Yann LeCun给出的一套手写数字的数据集,训练集包含60,000个样本和标注,测试集包含10,000个样本和标注。可以给新手用来练手用。

数据集表示

  1. 标注:数字分为0-9,总共10个数字,标注也是从0-9,分别对应0-910个数字;
  2. 图片:将每张图片切分成2828的矩阵,矩阵的每个元素使用灰度值来表示,所以总共使用一个2828的矩阵来表示图片;

下载数据集

数据集下载地址:http://yann.lecun.com/exdb/mnist/
整个数据集分为四个部分:

  • train-images-idx3-ubyte.gz: 训练集图片 (9912422 bytes)
  • train-labels-idx1-ubyte.gz: 训练集标注 (28881 bytes)
  • t10k-images-idx3-ubyte.gz: 测试集图片 (1648877 bytes)
  • t10k-labels-idx1-ubyte.gz: 测试集标注 (4542 bytes)

解析数据集文件

在LeCun的网站上,给出了数据集的格式,需要关注的点有:

  1. 存储二进制数据;
  2. 使用大端法存储;
  3. 标注集(包括训练标注集和测试标注集):第一个字节为魔数(Magic Number),第二个字节为标注总个数(训练集-60,000,测试集10,000),后续每个字节为对应标注数值;
  4. 图片集(包括训练图片集和测试图片集):第一个字节为魔数(Magic Number),第二个字节为图片总个数(训练集-60,000,测试集10,000),第四个字节为每张图片表示矩阵的rows,第五个字节为每张图片表示矩阵的cols

绘制图片

使用pyplot来绘制0-9这10个数字总的来说有以下几个步骤:

  1. 加载图片和标注数据;
  2. 使用subplots方法创建一张2*5的画布;
  3. 将画布展开,并把0-9是个数字使用imshow方法绘制进去;
  4. plt.show();

具体代码

  1. 加载数据(load_data.py)
import os
import struct
import numpy as np


def load_mnist(path, kind='train'):
    """拼接路径"""
    labels_path = os.path.join(path, '%s-labels-idx1-ubyte' % kind)
    images_path = os.path.join(path, '%s-images-idx3-ubyte' % kind)

    with open(labels_path, 'rb') as lbpath:
    	"""使用大端法读取2个字节,第一个是魔数,第二个是个数"""
        magic, n = struct.unpack('>II',
                                 lbpath.read(8))
        """依次读取标注值"""
        labels = np.fromfile(lbpath,
                             dtype=np.uint8)

    with open(images_path, 'rb') as imgpath:
        """使用大端法读取4个字节,第一个是魔数,第二个是个数,三四分别是rows、cols"""
        magic, num, rows, cols = struct.unpack('>IIII',
                                               imgpath.read(16))
        """依次读取值,并reshape为length*784的矩阵"""
        images = np.fromfile(imgpath,
                             dtype=np.uint8).reshape(len(labels), 784)

    return images, labels

  1. 绘制图片(main.py)
import load_data
import matplotlib.pyplot as plt

"""加载数据"""
images, labels = load_data.load_mnist('/Users/wowo/Documents/0-Tensorflow')

"""创建画布"""
fig, ax = plt.subplots(
    nrows=2,
    ncols=5,
    sharex=True,
    sharey=True,
)

"""平铺画布"""
ax = ax.flatten()
for i in range(10):
    """获取数据集中第一次出现的0-9数字,并reshape到28*28的矩阵"""
    img = images[labels == i][0].reshape(28, 28)
    """绘制数字"""
    ax[i].imshow(img, cmap='Greys', interpolation='nearest')

"""隐藏横纵坐标"""
ax[0].set_xticks([])
ax[0].set_yticks([])
"""美化画布,使之更紧凑"""
plt.tight_layout()
"""绘制画布"""
plt.show()

参开文献

  1. http://yann.lecun.com/exdb/mnist/
  2. https://www.cnblogs.com/xianhan/p/9145966.html

标签:字节,pyplot,labels,ax,images,path,手写,MNIST,标注
来源: https://blog.csdn.net/hfut_wowo/article/details/118399195

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有