ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

Model-Agnostic Meta-Learning (MAML)模型介绍及算法详解(转载)

2020-04-26 18:53:54  阅读:427  来源: 互联网

标签:task set 训练 梯度 模型 Agnostic Meta MAML


转载于:Model-Agnostic Meta-Learning (MAML)模型介绍及算法详解

MAML在学术界已经是非常重要的模型了,论文Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks自2017年发表至今已经收获了400+的引用。由于当前网上关于MAML的中文介绍少之又少,可能很多小伙伴对其还不是特别理解。所以今天我整理了这段时间来的学习心得,与大家分享自己对MAML的认识与理解。MAML可以用于Supervised Regression and Classification以及Reinforcement Learning。由于我对强化学习不是特别了解,因此这篇文章,均是基于MAML在Supervised Regression and Classification中的运用。

一、一些相关概念的介绍

在原论文中,作者直接引用了许多元学习相关的概念,例如 meta-learning, model-agnostic, N-way K-shot, tasks等等,其中有些概念在MAML中还有特殊的含义。在此,我尽量用通俗易懂的方式对这些概念为大家做一个介绍。

(1) meta-learning

meta-learning即元学习,也可以称为“learning to learn”。常见的深度学习模型,目的是学习一个用于预测的数学模型。而元学习面向的不是学习的结果,而是学习的过程。其学习的不是一个直接用于预测的数学模型,而是学习“如何更快更好地学习一个数学模型”。

举一个现实生活的例子。我们教小朋友读英语时,可以直接让他们模仿apple、banana的发音。但是他们很快又会遇到新的单词,例如strawberry,这是小朋友就需要重新听你的发音,才能正确地读出这个新单词。我们换一种方式,这一次我们不教每个单词的发音,而是教音标的发音。从此小朋友再遇见新单词,他们只要根据音标,就可以正确地读出这个单词。学习音标的过程,正是一个元学习的过程。

在深度学习中,已经被提出的元学习模型有很多,大致上可以分类为learning good weight initializations,meta-models that generate the parameters of other models 以及learning transferable optimizers。其中MAML属于第一类。MAML学习一个好的初始化权重,从而在新任务上实现fast adaptation,即在小规模的训练样本上迅速收敛并完成fine-tune。

(2) model-agnostic

model-agnostic即模型无关。MAML与其说是一个深度学习模型,倒不如说是一个框架,提供一个meta-learner用于训练base-learner。这里的meta-learner即MAML的精髓所在,用于learning to learn;而base-learner则是在目标数据集上被训练,并实际用于预测任务的真正的数学模型。绝大多数深度学习模型都可以作为base-learner无缝嵌入MAML中,而MAML甚至可以用于强化学习中,这就是MAML中model-agnostic的含义。

(3) N-way K-shot

N-way K-shot是few-shot learning中常见的实验设置。few-shot learning指利用很少的被标记数据训练数学模型的过程,这也正是MAML擅长解决的问题之一。N-way指训练数据中有N个类别,K-shot指每个类别下有K个被标记数据。

(4) task

MAML的论文中多次出现名词task,模型的训练过程都是围绕task展开的,而作者并没有给它下一个明确的定义。要正确地理解task,我们需要了解的相关概念包括[公式][公式], support set, query set, meta-train classes, meta-test classes等等。是不是有点眼花缭乱?不要着急,举个简单的例子,大家就可以很轻松地掌握这些概念。

我们假设这样一个场景:我们需要利用MAML训练一个数学模型模型 [公式] ,目的是对未知标签的图片做分类,类别包括 [公式] (每类5个已标注样本用于训练。另外每类有15个已标注样本用于测试)。我们的训练数据除了 [公式] 中已标注的样本外,还包括另外10个类别的图片 [公式] (每类30个已标注样本),用于帮助训练元学习模型 [公式] 。我们的实验设置为5-way 5-shot。

关于具体的训练过程,会在下一节MAML算法详解中介绍。这里我们只需要有一个大概的了解:MAML首先利用 [公式] 的数据集训练元模型[公式],再在[公式]的数据集上精调(fine-tune)得到最终的模型 [公式] 。

此时,[公式]即meta-train classes,[公式]包含的共计300个样本,即 [公式] ,是用于训练 [公式] 的数据集。与之相对的,[公式]即meta-test classes,[公式]包含的共计100个样本,即 [公式] ,是用于训练和测试 [公式]的数据集。

根据5-way 5-shot的实验设置,我们在训练 [公式] 阶段,从 [公式] 中随机取5个类别,每个类别再随机取20个已标注样本,组成一个task [公式] 。其中的5个已标注样本称为 [公式] 的support set,另外15个样本称为 [公式] 的query set。这个task [公式] , 就相当于普通深度学习模型训练过程中的一条训练数据。那我们肯定要组成一个batch,才能做随机梯度下降SGD对不对?所以我们反复在训练数据分布中抽取若干个这样的task [公式] ,组成一个batch。在训练 [公式] 阶段,task、support set、query set的含义与训练 [公式] 阶段均相同。

二、MAML算法详解

作者在论文中给出的算法流程如下:

MAML算法

该算法实质上是MAML预训练阶段的算法,目的是得到模型 [公式] 。不要被这些数学符号吓到喔,这个算法的思路其实很简单。接下来,我们来一行一行地分析这个算法。

首先来看两个Require。

第一个Require指的是在 [公式] 中task的分布。结合我们在上一小节举的例子,这里即反复随机抽取task [公式] ,形成一个由若干个(e.g., 1000个) [公式] 组成的task池,作为MAML的训练集。有的小伙伴可能要纳闷了,训练样本就这么多,要组合形成那么多的task,岂不是不同task之间会存在样本的重复?或者某些task的query set会成为其他task的support set?没错!就是这样!我们要记住,MAML的目的,在于fast adaptation,即通过对大量task的学习,获得足够强的泛化能力,从而面对新的、从未见过的task时,通过fine-tune就可以快速拟合。task之间,只要存在一定的差异即可。再强调一下,MAML的训练是基于task的,而这里的每个task就相当于普通深度学习模型训练过程中的一条训练数据。

第二个Require就很好理解啦。step size其实就是学习率,读过MAML论文的小伙伴一定会对gradient by gradient这个词有印象。MAML是基于二重梯度的,每次迭代包括两次参数更新的过程,所以有两个学习率可以调整。

接下来,就是激动人心的算法流程。

步骤1,随机初始化模型的参数,没什么好说的,任何模型训练前都有这一步。

步骤2,是一个循环,可以理解为一轮迭代过程或一个epoch,当然啦预训练的过程是可以有多个epoch的。

步骤3,相当于pytorch中的DataLoader,即随机对若干个(e.g., 4个)task进行采样,形成一个batch。

步骤4~步骤7,是第一次梯度更新的过程。注意这里我们可以理解为copy了一个原模型,计算出新的参数,用在第二轮梯度的计算过程中。我们说过,MAML是gradient by gradient的,有两次梯度更新的过程。步骤4~7中,利用batch中的每一个task,我们分别对模型的参数进行更新(4个task即更新4次)。注意这一个过程在算法中是可以反复执行多次的,伪代码没有体现这一层循环,但是作者再分析的部分明确提到" using multiple gradient updates is a straightforward extension"。

步骤5,即对利用batch中的某一个task中的support set,计算每个参数的梯度。在N-way K-shot的设置下,这里的support set应该有NK个。作者在算法中写with respect to K examples,默认对每一个class下的K个样本做计算。实际上参与计算的总计有NK个样本。这里的loss计算方法,在回归问题中,就是MSE;在分类问题中,就是cross-entropy。

步骤6,即第一次梯度的更新。

步骤4~步骤7,结束后,MAML完成了第一次梯度更新。接下来我们要做的,是根据第一次梯度更新得到的参数,通过gradient by gradient,计算第二次梯度更新。第二次梯度更新时计算出的梯度,直接通过SGD作用于原模型上,也就是我们的模型真正用于更新其参数的梯度。

步骤8即对应第二次梯度更新的过程。这里的loss计算方法,大致与步骤5相同,但是不同点有两处。一处是我们不再是分别利用每个task的loss更新梯度,而是像常见的模型训练过程一样,计算一个batch的loss总和,对梯度进行随机梯度下降SGD。另一处是这里参与计算的样本,是task中的query set,在我们的例子中,即5-way*15=75个样本,目的是增强模型在task上的泛化能力,避免过拟合support set。步骤8结束后,模型结束在该batch中的训练,开始回到步骤3,继续采样下一个batch。

以上即时MAML预训练得到 [公式] 的全部过程,是不是很简单呢?事实上,MAML正是因为其简单的思想与惊人的表现,在元学习领域迅速流行了起来。接下来,应该是面对新的task,在 [公式] 的基础上,精调得到 [公式] 的方法。原文中没有介绍fine-tune的过程,这里我向小伙伴们简单介绍一下。

fine-tune的过程与预训练的过程大致相同,不同的地方主要在于以下几点:

  • 步骤1中,fine-tune不用再随机初始化参数,而是利用训练好的 [公式] 初始化参数。
  • 步骤3中,fine-tune只需要抽取一个task进行学习,自然也不用形成batch。fine-tune利用这个task的support set训练模型,利用query set测试模型。实际操作中,我们会在 [公式]上随机抽取许多个task(e.g., 500个),分别微调模型 [公式] ,并对最后的测试结果进行平均,从而避免极端情况。
  • fine-tune没有步骤8,因为task的query set是用来测试模型的,标签对模型是未知的。因此fine-tune过程没有第二次梯度更新,而是直接利用第一次梯度计算的结果更新参数。

 


参考资料:

编辑于 2019-03-29

标签:task,set,训练,梯度,模型,Agnostic,Meta,MAML
来源: https://www.cnblogs.com/king-lps/p/12781604.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有