ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

基于seq2seq文本生成的解码/采样策略

2021-04-22 19:05:21  阅读:354  来源: 互联网

标签:采样 temperature probability 解码 seq2seq Sampling 生成 decoding


基于seq2seq文本生成的解码/采样策略

目录
基于Seq2Seq模型的文本生成有各种不同的decoding strategy。文本生成中的decoding strategy主要可以分为两大类:

  • Argmax Decoding: 主要包括beam search, class-factored softmax等
  • Stochastic Decoding: 主要包括temperature sampling, top-k sampling等。

在Seq2Seq模型中,RNN Encoder对输入句子进行编码,生成一个大小固定的hidden state \(h_c\);基于输入句子的hidden state \(h_c\) 和先前生成的第1到t-1个词\(x_{1:t-1}\),RNN Decoder会生成当前第t个词的hidden state \(h_t\) ,最后通过softmax函数得到第t个词 \(x_t\) 的vocabulary probability distribution \(P(x|x_{1:t-1})\)。

两类decoding strategy的主要区别就在于,如何从vocabulary probability distribution \(P(x|x_{1:t-1})\)中选取一个词 \(x_t\) :

  • Argmax Decoding的做法是选择词表中probability最大的词,即\(x_t=argmax\quad P(x|x_{1:t-1})\) ;
  • Stochastic Decoding则是基于概率分布\(P(x|x_{1:t-1})\) 随机sample一个词 \(x_t\),即 \(x_t \sim P(x|x_{1:t-1})\) 。

在做seq predcition时,需要根据假设模型每个时刻softmax的输出概率来sample单词,合适的sample方法可能会获得更有效的结果。

1. 贪婪采样

  1. Greedy Search

    核心思想:每一步取当前最大可能性的结果,作为最终结果。

    具体方法:获得新生成的词是vocab中各个词的概率,取argmax作为需要生成的词向量索引,继而生成后一个词。

  2. Beam Search

    核心思想: beam search尝试在广度优先基础上进行进行搜索空间的优化(类似于剪枝)达到减少内存消耗的目的。

    具体方法:在decoding的每个步骤,我们都保留着 top K 个可能的候选单词,然后到了下一个步骤的时候,我们对这 K 个单词都做下一步 decoding,分别选出 top K,然后对这 K^2 个候选句子再挑选出 top K 个句子。以此类推一直到 decoding 结束为止。当然 Beam Search 本质上也是一个 greedy decoding 的方法,所以我们无法保证自己一定可以得到最好的 decoding 结果。

Greedy Search和Beam Search存在的问题:

  1. 容易出现重复的、可预测的词;
  2. 句子/语言的连贯性差。

2. 随机采样

核心思想: 根据单词的概率分布随机采样。

  1. Temperature Sampling:

    具体方法:在softmax中引入一个temperature来改变vocabulary probability distribution,使其更偏向high probability words:

    \[P(x|x_{1:t-1})=\frac{exp(u_t/temperature)}{\sum_{t'}exp(u_{t'}/temperature)},temperature\in[0,1) \]

    另一种表示:假设\(p(x)\)为模型输出的原始分布,给定一个 temperature 值,将按照下列方法对原始概率分布(即模型的 softmax 输出) 进行重新加权,计算得到一个新的概率分布。

    \[\pi(x_{k})=\frac{e^{log(p(x_k))/temperature}} {\sum_{i=1}^{n}e^{log(p(x_i))/temperature}},temperature\in[0,1) \]

    当\(temperature \to 0\),就变成greedy search;当\(temperature \to \infty\),就变成均匀采样(uniform sampling)。详见论文:The Curious Case of Neural Text Degeneration

  2. Top-k Sampling:

    可以缓解生成罕见单词的问题。比如说,我们可以每次只在概率最高的50个单词中按照概率分布做采样。我只保留top-k个probability的单词,然后在这些单词中根据概率做sampling。

    核心思想:对概率进行降序排序,然后对第k个位置之后的概率转换为0。

    具体方法:在decoding过程中,从 \(P(x|x_{1:t-1})\) 中选取probability最高的前k个tokens,把它们的probability加总得到 \(p'=\sum P(x|x_{1:t-1})\) ,然后将 \(P(x|x_{1:t-1})\) 调整为 \(P'(x|x_{1:t-1})=P(x|x_{1:t-1})/p'\) ,其中 \(x\in V^{(k)}\)! ,最后从 \(P'(x|x_{1:t-1})\) 中sample一个token作为output token。详见论文:Hierarchical Neural Story Generation

    但Top-k Sampling存在的问题是,常数k是提前给定的值,对于长短大小不一,语境不同的句子,我们可能有时需要比k更多的tokens。

  3. Top-p Sampling (Nucleus Sampling ):

    核心思想:通过对概率分布进行累加,然后当累加的值超过设定的阈值p,则对之后的概率进行置0。

    具体方法:提出了Top-p Sampling来解决Top-k Sampling的问题,基于Top-k Sampling,它将 \(p'=\sum P(x|x_{1:t-1})\) 设为一个提前定义好的常数\(p'\in(0,1)\) ,而selected tokens根据句子history distribution的变化而有所不同。详见论文:The Curious Case of Neural Text Degeneration

    本质上Top-p Sampling和Top-k Sampling都是从truncated vocabulary distribution中sample token,区别在于置信区间的选择。

随机采样存在的问题:

  1. 生成的句子容易不连贯,上下文比较矛盾。
  2. 容易生成奇怪的句子,出现罕见词。

3. 参考

LSTM文本生成:《Python深度学习》第8章第1节:8.1 使用LSTM生成文本P228-P234。

标签:采样,temperature,probability,解码,seq2seq,Sampling,生成,decoding
来源: https://www.cnblogs.com/dangui/p/14690958.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有