ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

pytorch余弦退火学习率和warmup实现

2022-03-11 12:30:04  阅读:812  来源: 互联网

标签:last warmup 05 退火 self pytorch lr steps


不说废话,直接上代码warmup_lr_scheduler.py:

from torch.optim.lr_scheduler import _LRScheduler
import warnings
import math
class CosineAnnealingLRWarmup(_LRScheduler):
    def __init__(self, optimizer, T_max, eta_min=1.0e-5, last_epoch=-1, verbose=False,
                 warmup_steps=2, warmup_start_lr=1.0e-5):
        self.T_max = T_max
        self.eta_min = eta_min
        super(CosineAnnealingLRWarmup, self).__init__(optimizer, last_epoch, verbose)
        self.warmup_steps=warmup_steps
        self.warmup_start_lr = warmup_start_lr
        if warmup_steps>0:
            self.base_warup_factors = [
                (base_lr/warmup_start_lr)**(1.0/self.warmup_steps)
                for base_lr in self.base_lrs
            ]

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)
        return self._get_closed_form_lr()

    def _get_closed_form_lr(self):
        if hasattr(self,'warmup_steps'):
            if self.last_epoch<self.warmup_steps:
                return [self.warmup_start_lr*(warmup_factor**self.last_epoch)
                        for warmup_factor in self.base_warup_factors]
            else:
                return [self.eta_min + (base_lr - self.eta_min) *
                        (1 + math.cos(math.pi * (self.last_epoch - self.warmup_steps) / (self.T_max - self.warmup_steps)))*0.5
                        for base_lr in self.base_lrs]
        else:
            return [self.eta_min + (base_lr - self.eta_min) *
                    (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
                    for base_lr in self.base_lrs]

使用方式, optimizer中,对每个group设定的初始学习率:

    lr_scheduler_warmup = CosineAnnealingLRWarmup(optimizer,
                                                  T_max=100,
                                                  eta_min=1.0e-4,
                                                  last_epoch=-1,
                                                  warmup_steps=10,
                                                  warmup_start_lr=1.0e-5)
    for i in range(args.epochs):
        lr_scheduler_warmup.step()
        print(i, 'lr: ', lr_scheduler_warmup.get_last_lr())

打印的日志:

0 lr:  [1.6986464646342476e-05, 1.6986464646342476e-05]
1 lr:  [2.8853998118144274e-05, 2.8853998118144274e-05]
2 lr:  [4.9012741893949e-05, 4.9012741893949e-05]
3 lr:  [8.325532074018734e-05, 8.325532074018734e-05]
4 lr:  [0.00014142135623730956, 0.00014142135623730956]
5 lr:  [0.00024022488679628634, 0.00024022488679628634]
6 lr:  [0.00040805715467367407, 0.00040805715467367407]
7 lr:  [0.0006931448431551469, 0.0006931448431551469]
8 lr:  [0.0011774080373049502, 0.0011774080373049502]
9 lr:  [0.002, 0.002]
10 lr:  [0.001999421285668141, 0.001999421285668141]
11 lr:  [0.0019976858477468327, 0.0019976858477468327]
12 lr:  [0.0019947958005998596, 0.0019947958005998596]
13 lr:  [0.0019907546653044916, 0.0019907546653044916]
14 lr:  [0.0019855673653615975, 0.0019855673653615975]
15 lr:  [0.0019792402206971153, 0.0019792402206971153]
16 lr:  [0.0019717809399621964, 0.0019717809399621964]
17 lr:  [0.001963198611141403, 0.001963198611141403]
18 lr:  [0.001953503690480396, 0.001953503690480396]
19 lr:  [0.0019427079897466131, 0.0019427079897466131]
20 lr:  [0.001930824661838448, 0.001930824661838448]
21 lr:  [0.0019178681847604707, 0.0019178681847604707]
22 lr:  [0.0019038543439842087, 0.0019038543439842087]
23 lr:  [0.0018888002132159806, 0.0018888002132159806]
24 lr:  [0.001872724133595217, 0.001872724133595217]
25 lr:  [0.0018556456913486046, 0.0018556456913486046]
26 lr:  [0.0018375856939272896, 0.0018375856939272896]
27 lr:  [0.0018185661446562002, 0.0018185661446562002]
28 lr:  [0.001798610215926386, 0.001798610215926386]
29 lr:  [0.001777742220963029, 0.001777742220963029]
30 lr:  [0.0017559875842035246, 0.0017559875842035246]
31 lr:  [0.0017333728103217185, 0.0017333728103217185]
32 lr:  [0.0017099254519360473, 0.0017099254519360473]
33 lr:  [0.0016856740760409154, 0.0016856740760409154]
34 lr:  [0.0016606482292022124, 0.0016606482292022124]
35 lr:  [0.0016348784015593754, 0.0016348784015593754]
36 lr:  [0.0016083959896778495, 0.0016083959896778495]
37 lr:  [0.0015812332582972094, 0.0015812332582972094]
38 lr:  [0.0015534233010215447, 0.0015534233010215447]
39 lr:  [0.001525, 0.001525]
40 lr:  [0.0014959979846465962, 0.0014959979846465962]
41 lr:  [0.0014664525894496235, 0.0014664525894496235]
42 lr:  [0.0014363998109220102, 0.0014363998109220102]
43 lr:  [0.0014058762637451164, 0.0014058762637451164]
44 lr:  [0.0013749191361593855, 0.0013749191361593855]
45 lr:  [0.0013435661446562, 0.0013435661446562]
46 lr:  [0.0013118554880261492, 0.0013118554880261492]
47 lr:  [0.0012798258008196845, 0.0012798258008196845]
48 lr:  [0.0012475161062768714, 0.0012475161062768714]
49 lr:  [0.001214965768783584, 0.001214965768783584]
50 lr:  [0.0011822144459120625, 0.0011822144459120625]
51 lr:  [0.0011493020401042709, 0.0011493020401042709]
52 lr:  [0.0011162686500569192, 0.0011162686500569192]
53 lr:  [0.0010831545218673762, 0.0010831545218673762]
54 lr:  [0.00105, 0.00105]
55 lr:  [0.0010168454781326244, 0.0010168454781326244]
56 lr:  [0.0009837313499430809, 0.0009837313499430809]
57 lr:  [0.0009506979598957294, 0.0009506979598957294]
58 lr:  [0.0009177855540879379, 0.0009177855540879379]
59 lr:  [0.0008850342312164163, 0.0008850342312164163]
60 lr:  [0.0008524838937231288, 0.0008524838937231288]
61 lr:  [0.0008201741991803156, 0.0008201741991803156]
62 lr:  [0.0007881445119738509, 0.0007881445119738509]
63 lr:  [0.0007564338553438001, 0.0007564338553438001]
64 lr:  [0.0007250808638406148, 0.0007250808638406148]
65 lr:  [0.0006941237362548836, 0.0006941237362548836]
66 lr:  [0.00066360018907799, 0.00066360018907799]
67 lr:  [0.0006335474105503764, 0.0006335474105503764]
68 lr:  [0.0006040020153534041, 0.0006040020153534041]
69 lr:  [0.0005750000000000002, 0.0005750000000000002]
70 lr:  [0.0005465766989784554, 0.0005465766989784554]
71 lr:  [0.0005187667417027907, 0.0005187667417027907]
72 lr:  [0.0004916040103221507, 0.0004916040103221507]
73 lr:  [0.0004651215984406246, 0.0004651215984406246]
74 lr:  [0.0004393517707977876, 0.0004393517707977876]
75 lr:  [0.00041432592395908465, 0.00041432592395908465]
76 lr:  [0.0003900745480639528, 0.0003900745480639528]
77 lr:  [0.00036662718967828134, 0.00036662718967828134]
78 lr:  [0.0003440124157964757, 0.0003440124157964757]
79 lr:  [0.000322257779036971, 0.000322257779036971]
80 lr:  [0.0003013897840736142, 0.0003013897840736142]
81 lr:  [0.0002814338553438, 0.0002814338553438]
82 lr:  [0.00026241430607271046, 0.00026241430607271046]
83 lr:  [0.00024435430865139536, 0.00024435430865139536]
84 lr:  [0.00022727586640478324, 0.00022727586640478324]
85 lr:  [0.00021119978678401958, 0.00021119978678401958]
86 lr:  [0.00019614565601579133, 0.00019614565601579133]
87 lr:  [0.0001821318152395293, 0.0001821318152395293]
88 lr:  [0.00016917533816155207, 0.00016917533816155207]
89 lr:  [0.0001572920102533871, 0.0001572920102533871]
90 lr:  [0.00014649630951960415, 0.00014649630951960415]
91 lr:  [0.00013680138885859727, 0.00013680138885859727]
92 lr:  [0.00012821906003780334, 0.00012821906003780334]
93 lr:  [0.00012075977930288461, 0.00012075977930288461]
94 lr:  [0.00011443263463840238, 0.00011443263463840238]
95 lr:  [0.00010924533469550826, 0.00010924533469550826]
96 lr:  [0.00010520419940014038, 0.00010520419940014038]
97 lr:  [0.00010231415225316702, 0.00010231415225316702]
98 lr:  [0.00010057871433185903, 0.00010057871433185903]
99 lr:  [0.0001, 0.0001]

标签:last,warmup,05,退火,self,pytorch,lr,steps
来源: https://blog.csdn.net/qq_22751305/article/details/123421696

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有