ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

『迷你教程』Keras中为截断的BPTT进行序列预测

2022-01-28 09:05:00  阅读:165  来源: 互联网

标签:Keras 神经网络 步长 BPTT 时间 序列 输入 中为


文章目录

什么是BPTT?

循环神经网络能够学习序列预测问题中跨多个时间步长的时间依赖性。现代循环神经网络,如长短期记忆或 LSTM,网络是用反向传播算法的变体进行训练的,称为反向传播时间。该算法已被进一步修改,以提高非常长序列的序列预测问题的效率,称为Truncated Backpropagation Through Time

在这里插入图片描述

使用截断反向传播训练循环神经网络(如 LSTM)时的一个重要配置参数是决定使用多少时间步作为输入。也就是说,如何准确地将很长的输入序列拆分为子序列以获得最佳性能。

我们接下来使用 6 种不同的方法,您可以拆分非常长的输入序列,以使用 Keras 在 Python 中使用截断反向传播有效地训练循环神经网络。

通过时间截断反向传播

反向传播是一种训练算法,用于更新神经网络中的权重,以最小化给定输入的预期输出和预测输出之间的误差。

在这里插入图片描述
对于观察之间存在顺序依赖性的序列预测问题,使用循环神经网络代替经典的前馈神经网络。循环神经网络使用反向传播算法的变体进行训练,该算法称为 Backpropagation Through Time,简称 BPTT。

实际上,BPTT 展开循环神经网络,并在整个输入序列上向后传播错误,一次一个时间步。然后用累积的梯度更新权重。

BPTT 在输入序列很长的问题上训练循环神经网络可能很慢。除了速度之外,在如此多的时间步长上累积梯度可能会导致值缩小到零,或者最终溢出或爆炸的值增长。

BPTT 的一个修改是限制反向传播使用的时间步数,实际上估计用于更新权重的梯度而不是完全计算它。

BPTT训练算法有两个参数:

  • k1:定义在前向传递中向网络显示的时间步数。
  • k2:定义在估计反向传播的梯度时要查看的时间步数。

因此,我们可以在考虑如何配置训练算法时使用符号 BPTT(k1, k2),其中 k1 = k2 = n,其中 n 是经典非截断 BPTT 的输入序列长度。

BPTT 配置对 RNN 序列模型的影响,像 LSTM 这样的现代循环神经网络可以使用它们的内部状态来记住很长的输入序列,例如超过数千个时间步。

这意味着 BPTT 的配置不一定定义正在通过选择时间步数优化的网络的内存。可以选择何时将网络的内部状态与用于更新网络权重的机制分开重置。

相反,BPTT 参数的选择会影响网络如何估计用于更新权重的误差梯度。更一般地说,配置定义了可以考虑网络来模拟序列问题的时间步数。

我们可以将其正式表述为:

yhat(t) = f(X(t), X(t-1), X(t-2), ... X(t-n))

其中 yhat 是特定时间步长的输出,f(…) 是循环神经网络逼近的关系,X(t) 是特定时间步长的观测值。

它与在时间序列问题上训练的多层感知器上的窗口大小或线性时间序列模型(如 ARIMA)的 p 和 q 参数在概念上相似(但在实践中完全不同)。BPTT 定义了模型在训练期间输入序列的范围。

BPTT的Keras实现

Keras 深度学习库提供了用于训练循环神经网络的 BPTT 实现。具体来说 k1 和 k2 值彼此相等且固定。

  • BPTT(k1, k2),其中 k1 = k2

这是通过训练循环神经网络(如长短期记忆网络或 LSTM)所需的固定大小的 3D 输入来实现的。

LSTM 期望输入数据具有以下维度:样本、时间步长和特征。

这是此输入格式的第二个维度,时间步长定义了用于序列预测问题的前向和后向传递的时间步长数。

因此,在为 Keras 中的序列预测问题准备输入数据时,必须仔细选择指定的时间步数。

时间步长的选择将影响两者:

  • 前向传递期间累积的内部状态。
  • 用于更新反向传播权重的梯度估计。

注意,默认情况下,网络的内部状态会在每批后重置,但可以通过使用所谓的有状态 LSTM 并手动调用重置操作来实现对内部状态何时重置的更明确的控制。

在 Keras 中为 BPTT 准备序列数据

分解序列数据的方式将定义 BPTT 向前和向后传递中使用的时间步数。

因此,必须仔细考虑如何准备训练数据,列出了可以考虑的 6 种方法。

1. 按原样使用数据

如果每个序列中的时间步数不多,例如几十或几百个时间步长,可以按原样使用的输入序列。

已经建议了大约 200 到 400 个时间步长的 BPTT 的实际限制。如果序列数据小于或等于此范围,可以将序列观测值重塑为输入数据的时间步长。

例如,如果有 25 个时间步长的 100 个单变量序列的集合,则可以将其重构为 100 个样本、25 个时间步长和 1 个特征或 [100, 25, 1]。

2. 朴素的数据拆分

如果有很长的输入序列,例如数千个时间步长,可能需要将长输入序列分解为多个连续的子序列。
在这里插入图片描述

这将需要在 Keras 中使用有状态的 LSTM,以便在子序列的输入中保留内部状态,并且仅在真正更完整的输入序列的末尾重置。

例如,如果有 50,000 个时间步长的 100 个输入序列,那么每个输入序列可以分为 500 个时间步长的 100 个子序列。一个输入序列将变成 100 个样本,因此 100 个原始样本将变成 10,000 个。Keras 的输入维度为 10,000 个样本、500 个时间步长和 1 个特征或 [10000, 500, 1]。需要注意保存每 100 个子序列的状态,并在每 100 个样本后明确地或使用 100 的批量大小重置内部状态。

将整个序列整齐地划分为固定大小的子序列的拆分是首选。全序列的因子(子序列长度)的选择是任意的,因此得名“naive data split”。

将序列拆分为子序列并没有考虑有关合适数量的时间步长的域信息来估计用于更新权重的误差梯度。

3. 特定领域的数据拆分

很难知道提供有用的误差梯度估计所需的正确时间步数。
在这里插入图片描述

我们可以使用朴素的方法快速得到一个模型,但模型可能远未优化。或者我们可以使用特定领域的信息来估计在学习问题时与模型相关的时间步数。

例如如果序列问题是回归时间序列,也许对自相关和偏自相关图的回顾可以告知时间步数的选择。

如果序列问题是自然语言处理问题,也许可以将输入序列按句子分割然后填充到固定长度,或者根据域中的平均句子长度进行分割。

广泛思考并考虑可以使用哪些特定于的领域的知识将序列分成有意义的块。

4. 系统数据拆分(例如网格搜索)

可以针对序列预测问题系统地评估一组不同的子序列长度,而不是猜测合适的时间步数。
在这里插入图片描述

可以对每个子序列长度执行网格搜索,并采用导致平均性能最佳模型的配置。

如果正在考虑使用这种方法,需要注意从作为完整序列长度因子的子序列长度开始。如果探索不是完整序列长度的因素的子序列长度,请使用填充和可能的掩码。
考虑使用比解决问题所需的稍微过度规定的网络(更多的记忆单元和更多的训练时期),以帮助排除网络容量限制的实验。

取每种不同配置多次运行(例如 30 次)的平均性能。

如果计算资源不是限制,则建议对不同数量的时间步进行系统调查。

5. 使用 BPTT(1, 1) 严重依赖内部状态

可以将序列预测问题重新表述为每个时间步长一个输入和一个输出。

例如,如果有 50 个时间步长的 100 个序列,则每个时间步长都将成为一个新样本。100 个样本将变成 5,000 个。三维输入将变为 5,000 个样本、1 个时间步长和 1 个特征,或 [5000, 1, 1]。

同样,这将需要在序列的每个时间步长内保留内部状态,并在每个实际序列(50 个样本)结束时重置。

这会将学习序列预测问题的负担放在循环神经网络的内部状态上。根据问题的类型,它可能超出网络的处理能力,预测问题可能无法学习。

个人经验表明,这种公式可能适用于需要对序列进行记忆的预测问题,但当结果是过去观察的复杂函数时表现不佳。

6.前向和后向序列长度

Keras 深度学习库用于支持通过时间截断反向传播的前向和后向传递的解偶数量的时间步长。
在这里插入图片描述

本质上,k1 参数可以由输入序列的时间步数指定,k2 参数可以由 LSTM 层上的 truncate_gradient 参数指定。

标签:Keras,神经网络,步长,BPTT,时间,序列,输入,中为
来源: https://blog.csdn.net/qq_20288327/article/details/122726434

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有