ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

Pytorch学习率更新

2021-05-07 23:53:05  阅读:254  来源: 互联网

标签:optimizer torch list epoch 更新 学习 Pytorch lr scheduler


缘由

自己在尝试了官方的代码后就想提高训练的精度就想到了调整学习率,但固定的学习率肯定不适合训练就尝试了几个更改学习率的方法,但没想到居然更差!可能有几个学习率没怎么尝试吧!

更新方法

直接修改optimizer中的lr参数;

  • 定义一个简单的神经网络模型:y=Wx+b
import torchimport matplotlib.pyplot as plt%matplotlib inlinefrom torch.optim import *import torch.nn as nnclass net(nn.Module):
    def __init__(self):
        super(net,self).__init__()
        self.fc = nn.Linear(1,10)
    def forward(self,x):
        return self.fc(x)

  • 直接更改lr的值
model = net()LR = 0.01optimizer = Adam(model.parameters(),lr = LR)lr_list = []for epoch in range(100):
    if epoch % 5 == 0:
        for p in optimizer.param_groups:
            p['lr'] *= 0.9
    lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])plt.plot(range(100),lr_list,color = 'r')

关键是如下两行能达到手动阶梯式更改,自己也可按需求来更改变换函数

for p in optimizer.param_groups:
	p['lr'] *= 0.9

利用lr_scheduler()提供的几种衰减函数

  • torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)
参数含义
lr_lambda会接收到一个int参数:epoch,然后根据epoch计算出对应的lr。如果设置多个lambda函数的话,会分别作用于Optimizer中的不同的params_group
import numpy as np 
lr_list = []model = net()LR = 0.01optimizer = Adam(model.parameters(),lr = LR)lambda1 = lambda epoch:np.sin(epoch) / epoch
scheduler = lr_scheduler.LambdaLR(optimizer,lr_lambda = lambda1)for epoch in range(100):
    scheduler.step()
    lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])plt.plot(range(100),lr_list,color = 'r')

  • torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=0, last_epoch=-1)
参数含义
T_max对应1/2个cos周期所对应的epoch数值
eta_min最小的lr值,默认为0
lr_list = []model = net()LR = 0.01optimizer = Adam(model.parameters(),lr = LR)scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max = 20)for epoch in range(100):
    scheduler.step()
    lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])plt.plot(range(100),lr_list,color = 'r')

  • torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=‘min’, factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode=‘rel’, cooldown=0, min_lr=0, eps=1e-08)

在发现loss不再降低或者acc不再提高之后,降低学习率。各参数意义如下:

参数含义
mode'min’模式检测metric是否不再减小,'max’模式检测metric是否不再增大;
factor触发条件后lr*=factor;
patience不再减小(或增大)的累计次数;
verbose触发条件后print;
threshold只关注超过阈值的显著变化;
threshold_mode有rel和abs两种阈值计算模式,rel规则:max模式下如果超过best(1+threshold)为显著,min模式下如果低于best(1-threshold)为显著;abs规则:max模式下如果超过best+threshold为显著,min模式下如果低于best-threshold为显著;
cooldown触发一次条件后,等待一定epoch再进行检测,避免lr下降过速;
min_lr最小的允许lr;
eps如果新旧lr之间的差异小与1e-8,则忽略此次更新。

如需了解其它学习率更新方法请访问: https://www.emperinter.info/2020/08/01/learning-rate-in-pytorch/

示例

使用的更新方法

  • 代码中可选的选项有:余弦方式(默认方式,其他两种注释了)、e^-x的方式以及按loss是否不在降低来判断的三种方式,其他就自己测试吧!

  • 训练截图(第一个图为trainingg_loss,第二个为学习率变化曲线)

完整代码

import torchimport torchvisionimport torchvision.transforms as transformsimport matplotlib.pyplot as pltimport numpy as npimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimfrom datetime import datetimefrom torch.utils.tensorboard import SummaryWriterfrom torch.optim import *transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=0)testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=0)classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')#如需了解示例完整代码及其后续内容请访问: [https://www.emperinter.info/2020/08/01/learning-rate-in-pytorch/](https://www.emperinter.info/2020/08/01/learning-rate-in-pytorch/)


标签:optimizer,torch,list,epoch,更新,学习,Pytorch,lr,scheduler
来源: https://blog.51cto.com/u_15193557/2760326

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有