ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

文献阅读:A New Meta-Baseline for Few-Shot Learning

2021-06-14 20:02:13  阅读:315  来源: 互联网

标签:分类器 Shot Baseline 训练 样本 学习 Learning 基类


原文链接:https://arxiv.org/abs/2003.04390

源代码链接:https://github.com/yinboc/few-shot-meta-baseline

背景知识:

  1. meta-learning(元学习)
    本质是一种“learning to learn”的学习过程,不同于常用的深度学习模型(依据数据集去学习如何预测或者分类),meta-learning是学习“如何更快学习一个模型”的过程

  2. MAML算法:Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks 
    模型无关框架,定义一种架构,使用meta-learning去训练base-learner。

  3. base-learner(基学习器):从足够样本的公共类(基类)中训练过的的深度学习模型

问题定义:

小样本分类:给定一个基类Cbase的标记数据集,每个类中有大量的图像,目标是学习新类Cnovel中每个类中有几个样本的概念,其中Cbase∩Cnovel=∅。(N-way指训练数据中有N个类别,K-shot指每个类别下有K个被标记数据)在一个N类K样本的小样本任务中,支持集(基类)包含每个类中具有K样本的N个类,查询集(新类)包含相同的N个类,每个类中有Q个样本,其目标是将查询集中的N个×Q未标记的样本正确分类为N个类。

文章成果:

  1.        定义了一个Classifier-Baseline(基线分类器),在基类上预先训练一个分类器来学习视觉表示,删除最后一个全连接(FC)层,得到特征映射。然后在少量样本的新类上进行训练,计算新类的样本平均特征,利用特征空间中的余弦距离,用最近质心对查询样本(验证集)进行分类,即为余弦最近质心分类法。只是用于估计新类的最后FC权重,不需要重新训练所有模型参数。
  2.        在1的基础上,用元学习改进Classifier-Baseline,提出Meta-Beaseline。在Meta-Baseline中,使用预先训练的Classifier-Baseline初始化模型,并使用余弦最近轮廓度量执行元学习,这是Clssififier-Baseline中的评估度量 。(一种基于度量的元学习)
  3.        文章观察到训练过程中测试性能下降,在元学习的背景下对两种泛化类型进行了评估,基类泛化和新类泛化。
  4.        文章还对数据集因素对元学习的影响进行了研究。

 

文章具体研究内容:

  1. Classifier-Baseline

    分类器-基线是指在所有基类上训练一个具有分类功能的分类器,并使用余弦最近质心方法执行小样本任务。具体来说,我们在所有具有标准交叉熵损失的基类上训练一个分类器,然后删除它的最后一个FC层,得到编码器fθ,将输入映射到特征空间。给定一个具有支持集S的小样本任务,让Sc表示c类中的小样本,我们计算平均特征Wc作为c类的质心:

     

     

    然后,对于小样本任务中的查询样本x,我们预测样本x属于c类的概率作为样本x的特征向量与c类质心之间的余弦相似度:

    其中<.,.>表示两个向量的余弦相似度。 请注意,Wc也可以看作是新FC层对新概念的预测权重。

  2. Meta-Baseline

    一般来说,Meta-Baseline包含两个训练阶段。

    第一阶段是预训练阶段,即训练Classifier-Baseline(即在所有基类上训练分类器,并删除其最后一个FC层以获得fθ)。

    第二阶段是元学习阶段,在元学习阶段同样使用基类中的数据分成多个task,在每个task中对support-set用fθ编码,然后用(1)式求每个类的平均特征表示。同时对query-set也进行编码操作,利用(2)式余弦相似度求query-set和support set之间的距离,使用softmax进行分类。

    我们在分类器-基线评估算法上对模型进行优化。具体来说,给定预先训练的特征编码器fθ,我们从基类中的训练数据中采样N类K样本任务(具有N×Q查询样本)

    为了计算每个任务的损失,在支持集中,我们在方程1中计算定义的N个类的质心,然后用于方程2中计算定义的查询集中每个样本的预测概率分布。损失是由p和查询集中样本的标签计算的交叉熵损失。 请注意,我们将每个任务视为训练中的数据点,每批可能包含多个任务,并计算平均损失。

未完待续

 

标签:分类器,Shot,Baseline,训练,样本,学习,Learning,基类
来源: https://www.cnblogs.com/xihongqing/p/14883169.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有