ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

学习率及其指数衰减

2020-01-31 09:06:24  阅读:847  来源: 互联网

标签:sess val 及其 global 学习 step tf 衰减


根据上一篇给出的梯度下降的公式,很明显,参数的变化量有两个数决定,一个是偏导,另一个是学习率.首先来说一下学习率对于学习过程又怎样的影响.
1.首先是学习率过小,这种情况其实只是将来损失函数收敛的会比较慢,比较费时间.
2.然后是学习率过大,你可以这么理解,在你的脚下是一个坑,学习率可以暂且理解为你的
步伐大小,如果你迈的步子大,一下就到坑的那边,就避免踩到坑里了,但对于机器学习它就
是要踩到坑里,所以学习率过大会导致最终损失函数在一个比较大的值就收敛,甚至可能
发散.

那么我们该怎样才可以既使得耗时短,又取到较好的拟合效果呢?
有一种方法是学习率的指数衰减,就是每经过固定的轮数,学习率就会乘一个固定的衰减率.

import tensorflow.compat.v1 as tf#设置为v1版本
tf.disable_v2_behavior()#禁用v2版本

LEARNING_RATE_BASE=0.2#初始学习率
LEARNING_RATE_DECAY=0.95#学习率衰减率
LEARNING_RATE_STEP=1#喂入多少BATCH_SIZE之后,更新一轮学习率,一般为:样本总数/BATCH-SIZE

#运行了几轮BATCH-SIZE的计数器,初始值为0,设为不训练
global_step=tf.Variable(0,trainable=False)
#定义值数下降学习率:
learning_rate=tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,LEARNING_RATE_STEP,LEARNING_RATE_DECAY,staircase=True)
#定义待优化参数,初始值为5
w=tf.Variable(tf.constant(5,tf.float32))
#定义损失函数loss:
loss=tf.square(w+1)
#定义方向传播方法:
train_step=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step)
#生成会话,训练40轮
with tf.Session() as sess:
    init_op=tf.global_variables_initializer()
    sess.run(init_op)
    for i in range(100):
        sess.run(train_step)
        learning_rate_val=sess.run(learning_rate)
        global_step_val=sess.run(global_step)
        W_val=sess.run(w)
        loss_val=sess.run(loss)
        print("运行了轮",i,",global_step是",global_step_val,'\n',",learning_rate是",learning_rate_val,",损失函数是",loss_val,
              "参数是",W_val)
          

但是值数衰减并不是万能的,如果衰减率比较小,训练的轮数有多,最后学习率就会非常小,这时电脑会把它当成零来判断,然后参数就不变了,损失函数也会相对过快的收敛,那就得不偿失了.

oahuyil 发布了8 篇原创文章 · 获赞 3 · 访问量 4228 私信 关注

标签:sess,val,及其,global,学习,step,tf,衰减
来源: https://blog.csdn.net/realliyuhao/article/details/104121206

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有