ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

Three-layer network on MNIST (Grokking 第八章学习笔记)

2020-09-27 23:34:28  阅读:256  来源: 互联网

标签:layer network labels Grokking np hot weights test


之前在MNIST上做的分类器都是调用现有的包,这次我们尝试手写构建一个含有隐藏层的神经网络来预测手写数字

因为构建的神经网络只有三层所以代码非常简单(但是Grokking上的代码,有几个非常明显的错误),效果也算过的去,具体就看代码吧,注释的很详细了。

# 调用包的版本判断
import sys
assert sys.version_info >= (3, 5)


import sklearn
assert sklearn.__version__ >= "0.20"


import numpy as np
import os

# 图像生成的相关参数
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)

# 图像保存位置
PROJECT_ROOT_DIR = "."
CHAPTER_ID = "classification"
IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID)
os.makedirs(IMAGES_PATH, exist_ok=True)

def save_fig(fig_id, tight_layout=True, fig_extension="png", resolution=300):
    path = os.path.join(IMAGES_PATH, fig_id + "." + fig_extension)
    print("get image:", fig_id)
    if tight_layout:
        plt.tight_layout()
    plt.savefig(path, format=fig_extension, dpi=resolution)

from sklearn.datasets import fetch_openml

#使用sklearn自带的 fatch_openml加载数据集
#加载的数据集是一个key-value的字典结构
#DESCR 用于描述数据集;data 实例为行 特征为列
mnist = fetch_openml('mnist_784', version=1)        
mnist.keys()

#划分训练集与测试集
X, y = mnist["data"], mnist["target"]
x_train, x_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

images, labels = (x_train[0:60000].reshape(60000,28*28)/255, y_train[0:60000])

one_hot_labels = np.zeros((len(labels),10))                   #将图像的标签用0-1串表示例如'1'的one_hot_labels应该是[0,1,0,0,0,0,0,0,0,0]
for i in range(len(labels)):
    one_hot_labels[i][int(labels[i])] = 1
labels = one_hot_labels

test_images = x_test.reshape(len(x_test),28*28)/255          #像素为闭区间[0,255]的整数,这里将像素值压缩至闭区间[0,1]之间
test_one_hot_labels = np.zeros((len(y_test),10))
for j in range(len(test_one_hot_labels)):
    test_one_hot_labels[j][int(y_test[j])] = 1
test_labels = test_one_hot_labels

#设置激活函数、学习率等参数
np.random.seed(1)
relu = lambda x:(x>=0) * x 
relu2deriv = lambda x: x>=0 

#依次设置:学习率、迭代次数、隐藏层神经元、像素数、输出层神经元数目
alpha, iterations, hidden_size, pixels_per_image, num_labels = (0.005, 350, 40, 784, 10)

#构建神经网络的权重
weights_0_1 = 0.2*np.random.random((pixels_per_image,hidden_size)) - 0.1    #输入层-》隐藏层   大小为[784 x 40]
weights_1_2 = 0.2*np.random.random((hidden_size,num_labels)) - 0.1        #隐藏层-》输出层   大小为[40 x 10]
print(weights_0_1.shape)
print(weights_1_2.shape)
#print(weights_0_1)
#print(weights_1_2)
for j in range(iterations):
    error, correct_cnt = (0.0, 0)    
    for i in range(len(images)):
        error = 0.0
        layer_0 = images[i:i+1]
        layer_1 = relu(np.dot(layer_0,weights_0_1))   #将输入层的784个神经元与权重点乘后映射到隐藏层神经元上
        layer_2 = np.dot(layer_1,weights_1_2)         #将隐藏层40个神经原与权重点乘后映射到输出层神经元上
        error += np.sum((labels[i:i+1] - layer_2) ** 2)        #计算预测结果与实际结果之间的误差
        correct_cnt += int(np.argmax(layer_2) == np.argmax(labels[i:i+1]))    #若预测结果与实际标签结果相同,则正确数加一
        layer_2_delta = (labels[i:i+1] - layer_2)                             #隐藏层到输出层之间的误差
        layer_1_delta = layer_2_delta.dot(weights_1_2.T) * relu2deriv(layer_1)#输入层到隐藏层之间的误差
        
        #调整权重
        weights_1_2 += (alpha * layer_1.T.dot(layer_2_delta))                 
        weights_0_1 += (alpha * layer_0.T.dot(layer_1_delta)) 
    print(correct_cnt)

#使用test数据集进行检测
for i in range(100):
    layer_0 = test_images[i]
    layer_1 = relu(np.dot(layer_0,weights_0_1))   #将输入层的784个神经元与权重点乘后映射到隐藏层神经元上
    layer_2 = np.dot(layer_1,weights_1_2)
    print(np.argmax(layer_2), y_test[i])

 

夜里睡觉时脑子跟电影院一样全是过去的画面,

想了四天,觉得导致现在这个局面的主要原因还是自己太菜,以后要更加努力才行

另一方面,我觉得不是因为不爱才导致某些问题,而是因为某些问题才导致不爱了。我还在反思,但我一定要去着手解决这些问题了。

不管怎样,把自己变得更好才是正路。Good Night

 

标签:layer,network,labels,Grokking,np,hot,weights,test
来源: https://www.cnblogs.com/alan-W/p/13742501.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有