ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

MEAL V2: Boosting Vanilla ResNet-50 to 80%+ Top-1 Accuracy on ImageNet without Tricks

2021-05-17 10:34:16  阅读:233  来源: 互联网

标签:right mathbf 训练 Vanilla ResNet 80% theta mathcal left


在这里插入图片描述

无需额外数据、Tricks、架构调整,CMU开源首个将ResNet50精度提升至80%+新方法

代码可用:https://github.com/szq0214/MEAL-V2

摘要:本文介绍了一种简单而有效的方法,该方法可以提高 ImageNet 上的普通 ResNet-50-80%+ Top-1的准确性,而无需任何技巧。我们的方法是基于最近提出的 MEAL ,即通过鉴别器的集成知识蒸馏。我们进一步简化了这一过程:

  1. 仅对最终输出采用相似性损失和判别损失;
  2. 使用来自所有教师集成的平均 softmax 概率作为更强的监督信息进行蒸馏。

作者提出了一个观点:在蒸馏过程中不应使用one-hot方式的标签编码。实验证明:这样一个简单的方案可以在不涉及任何常用技术的情况下获得SOTA的性能。如:(1)架构修改;(2)额外的训练数据;(3)autoaug/randaug;(4)cosine learning rate;(5)mixup/cutmix 数据增强训练;(6)标签平滑。
我们证明,这样一个简单的框架可以实现最佳的结果,而不涉及任何常用的技术,如:1)架构修改;2)超出 ImageNet 的训练数据;3) autoaug / randomaug ;4) cosine 学习率;5) mix / cucumix 训练;6)标签平滑;等等。

在ImageNet数据集上,本文所提方法取得了80.67%的Top1精度(single crop@224),以极大的优势超越其他同架构方案。该方法可以视作采用知识蒸馏对ResNet50涨点的一个新的基准,该文可谓首个在不改变网路架构、无需额外训练数据的前提下将ResNet提升到超过80%Top1精度的方法。

1. 介绍

卷积神经网络(CNNs)[14]已被证明在许多视觉任务中是有用的,例如图像分类[13,8]、目标检测[6,17]、语义分割[15],以及一些特定场景,例如迁移特征表示[23]、从头开始学习检测器[19]等。为了获得尽可能高的准确率,人们提出了许多训练技术和数据扩充方法,如:mixup [25],cutmix [24], autoaug [1], randaug [2], fix resolution discrepancy [22], 等。还有一些工作集中在修改网络结构上,如SENET[11],ResNeST[26]等。本文的目标是获得网络类似的最佳性能,但我们提出的方法与上述技术正交。一般来说,我们的方法只依赖于教师-学生范式,具有强大的教师集成和良好的学生初始化。该算法简单、直观、有效,能够在大规模数据集上实现最先进的性能。我们的方法的优点是:

  1. 不需要修改架构;
  2. 除imagenet外,没有额外的训练数据;
  3. 没有余弦学习率;
  4. 没有额外的数据增强;
  5. 没有标签平滑

在训练过程中也有一些有趣的发现,the one-hot/hard label1is not necessary and could not be used in the distillation process,该发现对于知识蒸馏特别重要。Sec4提供了一些关于这一观点的讨论。以前的一些研究认为,在目标检测、分割等下游任务中,结构可能比预先训练的参数更重要、更关键。作者仍然相信,提高标准和经典网络结构的性能是有趣和有用的,特别是网络已经很小很紧凑,如MobileNet V3,EfficientNetB0,因为所提出的方法可以毫不费力地推广到其他精心设计或搜索到的体系结构中。
关于这一观点的一些讨论将在第4节中提供。也就是说,我们提出的框架是一个通用的设计,字面上易于使用,可以作为一个后处理过程来提取小而紧凑的模型,以进一步提高它们的性能,同时不需要修改。
在这里插入图片描述

2 方法

我们首先介绍了我们提出的框架中的每个组成部分,包括:

  1. 教师集合;
  2. KL散度损失;
  3. 判别器。
    然后,我们介绍我们在蒸馏训练中使用过和没有使用过的训练细节和技术。

Teachers Ensemble

Teachers Ensemble用于生成更准确的预测来指导学生训练。与MEAL[18]在每次迭代训练中通过教师集成模块选择一名教师不同,我们采用多个预训练教师的softmax概率的平均值作为集合。令 T θ \mathcal{T}_{\theta} Tθ​为学生网络,输出的集成概率 p ^ e T θ \hat{\mathbf{p}}_{\mathbf{e}}^{\mathcal{T}_{\theta}} p^​eTθ​​可描述为:
p ^ e T θ ( X ) = 1 K ∑ t = 1 K p t T T θ ( X ) \hat{\mathbf{p}}_{e}^{\mathcal{T}_{\theta}}(X)=\frac{1}{K} \sum_{\mathbf{t}=1}^{K} \mathbf{p}_{\mathbf{t}}{\mathcal{T}}^{\mathcal{T}_{\theta}}(X) p^​eTθ​​(X)=K1​t=1∑K​pt​TTθ​(X)其中 p ^ T θ \hat{\mathbf{p}}_{\mathbf{}}^{\mathcal{T}_{\theta}} p^​Tθ​​是第 t t t个教师网络的softmax输出的预测概率。 X X X是输入,K是老师模型的个数。

KL-divergence

KL散度是知识蒸馏领域最常用的一种损失,它用度量两个概率分布之间的相似性。在该文中,KL散度用于度量学生模型的预测概率与前述老师模型的平均预测概率之间的相似性。KL散度损失函数可以描述如下:
L K L ( S θ ) = − 1 N ∑ i = 1 N p ^ e T θ ( x i ) log ⁡ ( p S θ ( x i ) p ^ e T θ ( x i ) ) = − 1 N ∑ i = 1 N p ^ e T θ ( x i ) log ⁡ p S θ ( x i ) + 1 N ∑ i = 1 N p ^ e T θ ( x i ) log ⁡ p ^ e T θ ( x i ) \begin{aligned} \mathcal{L}_{K L}\left(\mathcal{S}_{\theta}\right)=-\frac{1}{N} & \sum_{i=1}^{N} \hat{\mathbf{p}}_{e}^{\mathcal{T}_{\theta}}\left(x_{i}\right) \log \left(\frac{\mathbf{p}^{S_{\theta}}\left(x_{i}\right)}{\hat{\mathbf{p}}_{e}^{\mathcal{T}_{\theta}}\left(x_{i}\right)}\right) \\ =&-\frac{1}{N} \sum_{i=1}^{N} \hat{\mathbf{p}}_{e}^{\mathcal{T}_{\theta}}\left(x_{i}\right) \log \mathbf{p}^{\mathcal{S}_{\theta}}\left(x_{i}\right) \\ &+\frac{1}{N} \sum_{i=1}^{N} \hat{\mathbf{p}}_{e}^{\mathcal{T}_{\theta}}\left(x_{i}\right) \log \hat{\mathbf{p}}_{e}^{\mathcal{T}_{\theta}}\left(x_{i}\right) \end{aligned} LKL​(Sθ​)=−N1​=​i=1∑N​p^​eTθ​​(xi​)log(p^​eTθ​​(xi​)pSθ​(xi​)​)−N1​i=1∑N​p^​eTθ​​(xi​)logpSθ​(xi​)+N1​i=1∑N​p^​eTθ​​(xi​)logp^​eTθ​​(xi​)​
当然,各位同学不用花费精力去研究上述公式,目前各大深度学习框架中均有该损失函数的实现,直接调用就好。除了KL散度损失外,另一个常用的损失函数就是交叉熵损失,定义如下:
L C E ( S θ ) = − 1 N ∑ i = 1 N p ^ e T θ ( x i ) log ⁡ p S θ ( x i ) \mathcal{L}_{C E}\left(\mathcal{S}_{\theta}\right)=-\frac{1}{N} \sum_{i=1}^{N} \hat{\mathbf{p}}_{e}^{\mathcal{T}_{\theta}}\left(x_{i}\right) \log \mathbf{p}^{\mathcal{S}_{\theta}}\left(x_{i}\right) LCE​(Sθ​)=−N1​i=1∑N​p^​eTθ​​(xi​)logpSθ​(xi​)

Discriminator

判别器是一个二分类器,它用于判别输入特征来自老师模型还是来自学生模型。它由sigmoid与二值交叉熵损失构成,定义如下:
L D = − 1 N ∑ i = 1 N [ y i ⋅ log ⁡ p i D + ( 1 − y i ) ⋅ log ⁡ ( 1 − p i D ) ] \mathcal{L}_{\mathcal{D}}=-\frac{1}{N} \sum_{i=1}^{N}\left[\mathbf{y}_{i} \cdot \log \mathbf{p}_{i}^{\mathcal{D}}+\left(1-\mathbf{y}_{i}\right) \cdot \log \left(1-\mathbf{p}_{i}^{\mathcal{D}}\right)\right] LD​=−N1​i=1∑N​[yi​⋅logpiD​+(1−yi​)⋅log(1−piD​)]
其中 y i y_i yi​是输入特征 x i x_i xi​的两值标签, p i D \mathbf{p}^{\mathcal{D}}_{i} piD​是对应的概率向量。
作者定义了一个sigmoid函数用于模拟老师-学生的概率,定义如下:
p D ( x ; θ ) = σ ( f θ ( { x T , x S } ) ) \mathbf{p}^{\mathcal{D}}(x ; \theta)=\sigma\left(f_{\theta}\left(\left\{x_{\mathcal{T}}, x_{\mathcal{S}}\right\}\right)\right) pD(x;θ)=σ(fθ​({xT​,xS​}))
其中表示一个三层感知器,即三个全连接 f θ f_{\theta} fθ​, σ ( x ) = 1 / ( 1 + exp ⁡ ( − x ) ) \sigma(x)=1 /(1+\exp (-x)) σ(x)=1/(1+exp(−x))表示逻辑函数,该文采用最后未经softmax处理的输出作为该判别器的输入。

3. 实验

3.1 数据集

我们在ILSVRC 2012分类数据集[3]上进行了实验,该数据集包含1,000个类,120万个训练图像和50,000个验证图像。

3.2 实验设定

在训练过程中,作者采用了最基本的数据增广:RandomResizedCrop、RandomHorizontalFlip,在测试阶段采用了CenterCrop。8GPU用训练,batch=512,优化器为SGD,未采用weight decay,StepLR,初始学习率为0.01,合计训练180epoch,在100epoch时学习率x0.1。

当学生模型的输入为224×224时,老师模型为senet154,resnet152_vl;当学生模型的输入为380×380时,老师模型为efficientnet_b4, efficientnet_b4_ns。

在实验方面,作者分别以ResNet50、MobileNetV3为基准进行了实验对比,那么接下来就分别进行相关结果的介绍。

3.3 实验结果

我们在ResNet-50上的结果如表2所示。在224×224个输入大小下,我们的方法达到了80.67%的Top-1准确率,比以前最先进的方法Meal[18] 提高了2.46%。此外,我们的结果甚至比ResNeST-5026更好,后者需要修改网络体系结构,并且需要通过许多训练技巧来学习。在将输入规模扩大到380×380之后,我们的性能进一步提高到81.72%,在输入略小的情况下比FixRes(*)[22]提高了2.62%。
在这里插入图片描述
我们选择已经是紧凑模型的MobileNet V3、Small0.75/1.0/Large-1.0和EfficientNet-B0网络来验证所提方法的有效性。我们的结果如表3所示,在MobileNet V3-Small 0.75和1.0上,我们的方法在没有任何架构修改的情况下将原始模型的准确率分别提高了2.20%和2.25%。如此巨大的增长在模型上是相当令人惊讶的,更重要的是,在不同的阶段,增长是惊人的。在MobileNet V3-Large 1.0和EfficientNet-B0上,虽然改进不像0.75和1.0那么大,但在ImageNet上我们仍然获得了1.72%和1.49%的增长。请注意,对于EfficientNet-B0,77.3/93.5的精确度来自他们的论文[21],76.8/93.2是他们在TIMM中预先训练的模型的精确度。
在这里插入图片描述

3.4 分析

  1. 一般来说,实力较强的教师网络能够始终如一地蒸馏出实力较强的学生。
  2. 教师网络的初始化,对最终的性能影响不大,只是有助于加快训练的收敛。

我们知道,在知识蒸馏过程中,有很多因素决定和影响着一个学生网络的表现。由于我们在224×224像素的的输入下对所有ResNet-50、MobileNet V3和EfficientNet-B0使用相同的教师集成,结果表明学生的结构或能力本身是一个重要的指标。我们比较了MEAL V1 和V2 可以进一步得出结论,教师的表现,监督的质量是学生网络性能的的另一个因素,一般来说,实力较强的教师网络能够始终如一地蒸馏出实力较强的学生。为了验证学生的初始化是否有很大的影响,我们在TIMM中分别采用TF Efficientnet b0(Top-1/5:76.85%/93.25%)和Efficientnet b0(77.70%/93.53%)作为学生初始化进行了消融研究。它们具有相同的架构,但训练设置和性能不同。有趣的是,对于相同的教师集合和训练超参数,两次初始化分别获得了78.29%和78.23%的TOP-17%。似乎对一个良好的初始化学生网络只有助于加快训练的收敛,但对学生的最终性能表现并没有很大的影响。

4.讨论

为什么在知识蒸馏中不需要使用one-hot标签?

ImageNet中的one-hot标签是人工标注的,因此肯定存在一些不正确或缺失的标注。另外,ImageNet在一张图像中可能包含多个对象,由多个对象之间的注释者确定的one-hot标签不能精确地表示这张图像的完整内容。我们认为,如果教师集合足够强大,可以为输入图像提供高质量的预测,使用这种不准确的硬标签会误导学生网络,并导致较差的表现。

判别器如何帮助优化?

判别器可以用于防止学生网络对训练数据的过拟合。它可以减缓学生模仿教师输出的动作,这可以看作是一种正则化效应。在训练中,采用很小的学习率来调整判别器的权重参数,以保证它不会收敛得太快,这和主干网络的收敛速度不同。在MEAL V2中,教师集成通常都很强大,同时,学生网络结构总是比教师网络结构更小更紧凑,这就意味着学生网络能力也比预训练的教师网络差得多,即使我们强迫学生网络做出与强大教师集成网络相同的预测,学生和教师集成之间的输出仍然存在不可避免的差距,这些差距不能通过KL散度损失来消除。也就是说,判别器很容易分辨出该特征是来自于一个学生或教师的集合,而正则化效应会减弱。然而,在MEAL V2中,我们仍然可以看到使用鉴别器的性能略有改善。

我们的方法对大型学生的泛化能力如何

我们尝试使用一些较大的模型,如ResNeXt-101 32×48d用于教师网络中的学生,这意味着学生网络具有与教师网络相似的能力。正如预期的那样,进步没有小学生那么大,但我们仍然看到一些性能的提升。一般来说,教师集成的软监督要好于人工注释的one-hot标签。尤其是当师生之间的规模和性能差距很大时,这种改善会更加有效和显著。也就是说,在我们的大多数实验案例中,实力较强的教师能够始终如一地训练和蒸馏出实力较强的学生。

原始ResNet-50的性能还有提升空间吗

用更多更强大的网络取代我们使用的教师组合可能会有所帮助,但训练成本将相应增加。此外,一些常见的技巧,如余弦衰减学习率,可能对性能有用,但我们没有足够的资源来测试所有这些技巧。目前的选择只是在训练效率、计算资源等方面的折衷和权衡。本文的目的主要是验证我们提出的观点的有效性,而不是很高的准确性。不过,探索固定结构网络(如ResNet-50)的上界性能将是非常有趣的。

与彩票假说的关系
彩票假设假设我们可以从一个训练过的巨型模型中找到一个子网,重新训练这种初始化的子网可以将其精度缩放到原来的巨型模型。我们还观察到,在我们的框架中,学生的初始化是教师网络在知识提炼过程中发挥作用的关键。然而,与彩票假设的不同之处在于,我们的学生仅仅是被训练过的,而不是被从一个庞大的预训练网络中挑选出来的子网络

5. 结论

我们提出了一种基于教师集合和判别器的知识蒸馏新范式。我们表明,这样一个简单的框架在各种网络结构上,包括极其微小和紧凑的模型上,都可以不费吹灰之力就能取得令人满意的结果。在ImageNet数据集上,我们方法在原始Resnet50使用224×224像素可以获得80.67%的分类准确度。我们的结果表明,现有网络的潜力尚未充分开发,通过我们的框架仍有提升和增强的空间。我们希望所提出的方法能够启发更多的研究沿着这一方向,通过知识蒸馏来推进模型的小型化。

标签:right,mathbf,训练,Vanilla,ResNet,80%,theta,mathcal,left
来源: https://blog.csdn.net/lihuanyu520/article/details/116902980

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有