ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

第二讲 神经网络优化--SGD

2020-04-21 23:52:04  阅读:311  来源: 互联网

标签:loss plt test 神经网络 train tf total 优化 SGD


#利用鸢尾花数据,实现前向传播、反向传播,可视化loss曲线

#导入所需模块
import tensorflow as tf
from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np
import time


#导入数据,分别为输入特征和标签
x_data = datasets.load_iris().data
y_data = datasets.load_iris().target


#随机打乱数据(因为原始数据是顺序的,顺序不打乱会影响准确率)
#seed:随机数种子,是一个整数,当设置之后,每次生成的随机数都一样
np.random.seed(116) #使用相同的seed,保证书例如特征和标签一一对应
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)


#将打乱后的数据集风格为训练集和测试集,训练集为前120行,测试集为后30行
x_train = x_data[:-30]
y_train = y_data[:-30]
x_test = x_data[-30:]
y_test = y_data[-30:]

#转换x的数据类型,否则后面矩阵相乘时会因为数据类型不一致报错
x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)

#from_tensor_slices函数使输入特征和标签值一一对应。(把数据集分批次,每个批次batch组数据)
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)


#生成神经网络的参数,4个输入特征数,输入层为4个输入节点,因为3分类,故输出成为3个神经元
#用tf.Variable()标记参数可训练
#使用seed使每次生成的随机数相同
w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1, seed=1))
b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1, seed=1))

lr = 0.1 #学习率为0.1
train_loss_results = []  #将每轮的loss记录在此列表中,为后续画曲线提供数据
test_acc = [] #将每轮的acc记录在此列表中,为后续画acc曲线提供数据
epoch = 500 #循环500轮
loss_all = 0 #每轮分为4个step,loss_all 记录4个step生成的4个loss的和


#训练部分

now_time = time.time()
for epoch in range(epoch):#数据集级别的循环,每个epoch循环一次数据集
  for step, (x_train, y_train) in enumerate(train_db): #batch级别的循环,每个step循环一个batch
    with tf.GradientTape() as tape: #with结构记录梯度信息
      y = tf.matmul(x_train, w1) + b1 #神经网络的乘加运算
      y = tf.nn.softmax(y)  #使输出y符合概率分布(此操作后与独热码同量级,可相减求Loss)
      y_ = tf.one_hot(y_train, depth=3) #将标签值转化为独热码格式,方便计算和accuracy
      loss = tf.reduce_mean(tf.square(y_ - y))  #采用均方误差函数mse = mean(sum(y - out)^2)
      loss_all += loss # 将每个step计算出的loss累加,为后续求loss平均值提供数据,这样计算的loss更准确
    #计算loss对各个参数的梯度
    grads = tape.gradient(loss, [w1, b1])

    #实现梯度更新w1 = w1 - lr*w1_grad b = b - lr*b_grad
    w1.assign_sub(lr * grads[0]) #参数w1自更新
    b1.assign_sub(lr * grads[1]) #参数b1自更新

  #每个epoch,打印loss信息
  print("Epoch {}, loss: {}".format(epoch, loss_all/4)) 
  train_loss_results.append(loss_all/4) # 将4个step的loss求平均记录在此变量中
  loss_all = 0 # loss_all归零,为记录下一个epoch的loss做准备

  #测试部分
  #total_correct为预测对的样本个数,total_number为测试的总样本数,将这两个变量都初始化为0
  total_correct, total_number = 0, 0
  for x_test, y_test in test_db:
    #使用更新后的参数进行预测
    y = tf.matmul(x_test, w1) + b1
    y = tf.nn.softmax(y)
    pred = tf.argmax(y, axis=1) #返回y中最大值的索引,即预测的分类
    #将pred转换为y_test的数据类型
    pred = tf.cast(pred, dtype=y_test.dtype)
    #若分类正确,则correct=1, 否则为0, 将bool型的结果转换为int型
    correct = tf.cast(tf.equal(pred, y_test), dtype=tf.int32)
    #将每个batch的correct数加起来
    correct = tf.reduce_sum(correct)
    #将所有batch中的数加起来
    total_correct += int(correct)
    #total-number为测试的样本总数,也就是x_test的行数,shape[0]返回变量的行数
    total_number += x_test.shape[0]
  #总的准确率等于total_correct/total_number
  acc = total_correct / total_number
  test_acc.append(acc)
  print("Test_acc:", acc)
  print("-------------------------------------")
total_time = time.time() - now_time
print("total_time", total_time)


#绘制loss曲线
plt.title("Loss Function Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.plot(train_loss_results, "b-.", label="$Loss$")
plt.legend()
plt.show()


#绘制Accuracy曲线
plt.title("Acc Curve")
plt.xlabel("Epoch")
plt.ylabel("Acc")
plt.plot(test_acc, "b-.", label="$Accuracy$")
plt.legend()
plt.show()

 

标签:loss,plt,test,神经网络,train,tf,total,优化,SGD
来源: https://www.cnblogs.com/wbloger/p/12748978.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有