ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

代码笔记18 pytorch中加载ResNet,导致过拟合或者测试时model.train()高于model.eval()

2022-06-20 01:32:39  阅读:202  来源: 互联网

标签:训练 18 self ResNet freeze dict eval model


问题

  训练网络往往需要加载预训练模型,主流的就是ResNet一类的预训练好的参数
  但我在加载了预训练模型,并冻结与训练参数后,进行训练时,发现了两个问题

1

  在进行test中model.train()的准确率要远高于model.eval()差别大概在7个点左右。
  其中model.eval() 负责改变batchnorm、dropout的工作方式,如在eval()模式下,dropout是不工作的。
  这种问题出现一般网上有几种回答,可以看看这个[1]不过很遗憾,我并不是这里的问题。还有认为是eval时batchsize过小的原因,导致每个mini-batch的数据分布无法符合整个数据集,我的batchsize时32,比我训练是还要大。
  不过我测试了一下,在训练好的模型上,使用train()模式,带有dropout()系数的准确率是低于不使用dropout的,差别在4个点左右。

2

  过拟合,训练集和测试集准确率差了15个点。

解决

  在此非常感谢博客[2],说的太有道理了。
  问题就在于我冻结Resnet参数时,冻结了所有的参数包括BatchNorm2D中的权重与偏差值。这样导致的问题就是,在Resnet中的bn层所学习到的参数是基于Imagenet数据集的数据分布的并且被冻结后不会再学习,当然会导致train()模式下的精度降低,以及过拟合的发生,一切都解释的通了。

代码

把我冻结参数的代码放上来

  def _freeze_parameters(self):
      # cant freeze if not load, because the freeze_dict is null
      if bool(self.freezedict) == False:
          print('freeze params must after loading ResNet!')
          os._exit(0)
      freeze_dict = self.freezedict
      state_dict = self.state_dict(keep_vars=True)
      for k, v in freeze_dict.items():
          if k in state_dict:
              state_dict[k].requires_grad = False
      # freeze parameters in ResNet except BatchNorm!!!!
      # or the batchnorm will not trainable, and keep the mean and var (weights and bias) on Imagenet dataset
      for m in self.modules():
          if isinstance(m,nn.BatchNorm2d):
              m.weight.requires_grad = True
              m.bias.requires_grad = True

Refrences

[1] https://blog.csdn.net/yucong96/article/details/88652964
[2] https://blog.csdn.net/aojue1109/article/details/88181927?spm=1001.2014.3001.5506

标签:训练,18,self,ResNet,freeze,dict,eval,model
来源: https://www.cnblogs.com/HumbleHater/p/16391958.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有