ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

LSTM VS RNN改进

2021-09-22 16:35:09  阅读:214  来源: 互联网

标签:ft RNN 梯度 ht cell VS LSTM ct


1.rnn常见的图形表示

rnn是一种早期相对比较简单的循环神经网络,其结构图可以用如下表示。
在这里插入图片描述
图片来自网络。

其中,x,y,h分别表示神经元的输入,输出以及隐藏状态。
根据上面的图片不难看出,在时刻t,神经元的输入包括 x t x_t xt​与上一时刻的隐藏状态 h t − 1 h_{t-1} ht−1​,而输出包括当前时刻的隐藏状态 h t h_t ht​与当前时刻的输出 y t y_t yt​。

RNN的输入 x t x_t xt​只包含了t时刻的信息,而不包含顺序信息。而 h t h_t ht​则包含了历史信息与当前输入信息,所以RNN是能用到历史信息的。
h t = σ ( z t ) = σ ( U x t + W h t − 1 + b ) y t = σ ( V h t + c ) h_t = \sigma(z_t) = \sigma(Ux_t + Wh_{t-1} + b) \\ y_t = \sigma(Vh_t + c) ht​=σ(zt​)=σ(Uxt​+Wht−1​+b)yt​=σ(Vht​+c)

2.RNN的问题

RNN最主要的问题是梯度消失与梯度爆炸
具体梯度消失与梯度爆炸的原因,可以查看参考文献1

3.LSTM

LSTM,Long short-term memory,中文直译的话就是长短记忆模型,主要就是为了解决RNN训练中的梯度消失与梯度爆炸问题。
LSTM与RNN的对比,经常用下面一张图来表示。

在这里插入图片描述
LSTM的神经元除了隐状态 h t − 1 h_{t-1} ht−1​与当前输入 x t x_t xt​外,还多了一个细胞状态 c t − 1 c_{t-1} ct−1​ cell。其中,cell更多地与rnn中的h比较像,保存的是历史状态的信息,而LSTM中的h更多的保存上一时刻的输出信息。

LSTM内部的计算,可以分为遗忘门,输入门与输出门。

在这里插入图片描述
遗忘门主要是盘段cell状态 c t − 1 c_{t-1} ct−1​哪些信息被删除。输入的 ht-1 和 xt 经过 sigmoid 激活函数之后得到 ft,ft 中每一个值的范围都是 [0, 1]。ft 中的值越接近 1,表示 cell 状态 ct-1 中对应位置的值更应该记住;ft 中的值越接近 0,表示 cell 状态 ct-1 中对应位置的值更应该忘记。将 ft 与 ct-1 按位相乘,可以得到遗忘无用信息之后的 c’t-1。
f t = σ ( W f ( h t − 1 , x t ) + b f ) c t − 1 ′ = c t − 1 ⊙ f t f_t = \sigma(W_f(h_{t-1}, x_t) + b_f) \\ c'_{t-1} = c_{t-1} \odot f_t ft​=σ(Wf​(ht−1​,xt​)+bf​)ct−1′​=ct−1​⊙ft​

在这里插入图片描述
输入门主要是判断哪些信息需要加入到cell状态 c t − 1 ′ c'_{t-1} ct−1′​中。 h t − 1 h_{t-1} ht−1​与 x t x_t xt​经过tanh激活以后可以得到新的输入信息,但是这些输入信息不需要全部加入,因此需要用 h t − 1 h_{t-1} ht−1​与 x t x_t xt​经过sigmoid激活以后得到it,it表示哪些新信息有用,两向量相乘后的结果加到 c t − 1 ′ c'_{t-1} ct−1′​ 中,即得到 t 时刻的 cell 状态 c t c_t ct​。

在这里插入图片描述
输出门主要用来判断哪些信息到 h t h_t ht​中。cell 状态 ct 经过 tanh 函数得到可以输出的信息,然后 ht-1 和 xt 经过 sigmoid 函数得到一个向量 ot,ot 的每一维的范围都是 [0, 1],表示哪些位置的输出应该去掉,哪些应该保留。两向量相乘后的结果就是最终的 ht。

4.LSTM解决梯度爆炸与梯度消失

根据第二部分参考文献里面的内容,我们可以得知梯度爆炸与梯度消失主要是犹豫连乘项引起的,所以要解决这个问题主要是去掉连乘项。

LSTM 中通过门的作用,可以使连乘项约等于 0 或者 1。首先我们看一下 LSTM 中 ct 与 ht 的计算公式。

c t = c t − 1 ⊙ f t + ( i t ⊙ c t ~ h t = o t ⊙ c t ~ c_t = c_{t-1} \odot f_t + (i_t \odot \tilde{c_t} \\ h_t = o_t \odot \tilde{c_t} ct​=ct−1​⊙ft​+(it​⊙ct​~​ht​=ot​⊙ct​~​

在公式中 ft 与 ot 都是通过 sigmoid 函数得到的,意味着它们的值要么接近 0,要么接近 1。因此在 LSTM 中的连乘项变成:

∂ c t ∂ c t − 1 = f t ∂ t t ∂ t t − 1 = o t \frac{\partial c_t }{\partial c_{t-1}} = f_t \\ \frac{\partial t_t }{\partial t_{t-1}} = o_t ∂ct−1​∂ct​​=ft​∂tt−1​∂tt​​=ot​

因此当门的梯度接近1时,连乘项能够保证梯度很好地在 LSTM 中传递,避免梯度消失的情况发生。

而当门的梯度接近 0 时,意味着上一时刻的信息对当前时刻并没有作用,此时没有必要把梯度回传。

参考文献

1.https://zhuanlan.zhihu.com/p/28687529
2.https://juejin.cn/post/6949159845731762184

标签:ft,RNN,梯度,ht,cell,VS,LSTM,ct
来源: https://blog.csdn.net/bitcarmanlee/article/details/120415395

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有