ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

蒸馏论文二(PAYING MORE ATTENTION TO ATTENTION)

2021-08-02 14:31:21  阅读:318  来源: 互联网

标签:PAYING loss 简单 模型 ATTENTION 归一化 opt 注意力 MORE


本系列文章介绍一些知识蒸馏领域的经典文章。

知识蒸馏:提取复杂模型有用的先验知识,并与简单模型特征结合算出他们的距离,以此来优化简单模型的参数,让简单模型学习复杂模型,从而帮助简单模型提高性能。

1. Attention Transfer原理

论文Paying more attention to attention主要通过提取复杂模型生成的注意力图来指导简单模型,使简单模型生成的注意力图与复杂模型相似。这样,简单模型不仅可以学到特征信息,还能够了解如何提炼特征信息。使得简单模型生成的特征更加灵活,不局限于复杂模型。

在这里插入图片描述
其中,图a是输入,b是相应的空间注意力图,它可以表现出网络为了分类所给图片所需要注意的地方。所谓空间注意力图,其实就是将特征图[C, H , W]通过映射变换成特征[H, W]。作者将每层通道平方相加获得特征图对应的注意力图。

在这里插入图片描述
上图是人脸识别任务中,对不同维度的特征图进行变换求得的注意力图,可以发现高维注意力图会对整个脸作出反应。

在这里插入图片描述

2. 损失函数

论文中,作者将损失分为两部分:在这里插入图片描述
第一部分是分类损失是简单的交叉熵损失函数,作用是实现分类。

第二部分是衡量复杂模型于简单模型注意力图差异的函数。首先注意力图进行归一化,即除以自身的模值,然后计算两个注意力图的p范数。

第三部分也是衡量复杂模型于简单模型注意力图差异的函数,在实际代码实现中使用KL散度实现。KL散度是用来衡量两个概率分布之间的相似性的函数。不了解KL散度的见这里
在这里插入图片描述

3. 代码解读

参考官方代码,以下是使用attention transfer技巧的关键部分代码。

def f(inputs, params, mode):
    '''
    网络结构中使用到的函数
    返回:
        y_t: 学生模型输出结果
        y_t: 老师模型输出结果
        loss: 学生输出和老师输出结果归一化后的欧式距离,即第三部分loss
    '''
    # f_s和f_t分别是定义的学生和老师网络的网络
    y_s, g_s = f_s(inputs, params, mode, 'student.')
    with torch.no_grad():
        y_t, g_t = f_t(inputs, params, False, 'teacher.')
    return y_s, y_t, [utils.at_loss(x, y) for x, y in zip(g_s, g_t)]

def distillation(y, teacher_scores, labels, T, alpha):
    '''
    arguments:
	    y: 学生模型归一化后的输出
	    teacher_scores:老师模型归一化后的输出
	    labels:学生模型的标签
	    T, alpha:超参数
	returns:
		loss: 简单交叉熵函数和KL函数的加权和,即前两部分loss
    '''
    # 学生网络软化后结果
    p = F.log_softmax(y/T, dim=1)

    # 老师网络软化后结果
    q = F.softmax(teacher_scores/T, dim=1)

    # 两个模型之间的距离损失
    l_kl = F.kl_div(p, q, size_average=False) * (T**2) / y.shape[0]

    # 学生模型的分类损失
    l_ce = F.cross_entropy(y, labels)

    return l_kl * alpha + l_ce * (1. - alpha)

def at(x):
    '''归一化'''
    return F.normalize(x.pow(2).mean(1).view(x.size(0), -1))

def at_loss(x, y):
    '''距离函数'''
    return (at(x) - at(y)).pow(2).mean()

def h(sample):
    '''网络结构'''
    inputs = utils.cast(sample[0], opt.dtype).detach()
    targets = utils.cast(sample[1], 'long')
    if opt.teacher_id != '':
    	# 分配到几个GPU上并行训练
        y_s, y_t, loss_groups = utils.data_parallel(f, inputs, params, sample[2], range(opt.ngpu))
        # 欧式距离
        loss_groups = [v.sum() for v in loss_groups]
        [m.add(v.item()) for m, v in zip(meters_at, loss_groups)]

        # 总的loss
        return utils.distillation(y_s, y_t, targets, opt.temperature, opt.alpha) \
               + opt.beta * sum(loss_groups), y_s

engine.train(h, train_loader, opt.epochs, optimizer)

论文理解部分参考文献:
知识蒸馏论文详解之:PAYING MORE ATTENTION TO ATTENTION

标签:PAYING,loss,简单,模型,ATTENTION,归一化,opt,注意力,MORE
来源: https://blog.csdn.net/weixin_44579633/article/details/119321067

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有