LSTM前向传播与反向传播算法推导(非常详细)
内容导读
互联网集市收集整理的这篇技术教程文章主要介绍了LSTM前向传播与反向传播算法推导(非常详细) ,小编现在分享给大家,供广大互联网技能从业者学习和参考。文章包含9366字,纯文字阅读大概需要14分钟 。
内容图文
1.长短期记忆网络LSTM
LSTM(Long short-term memory)通过刻意的设计来避免长期依赖问题,是一种特殊的RNN。长时间记住信息实际上是 LSTM 的默认行为,而不是需要努力学习的东西!
所有递归神经网络都具有神经网络的链式重复模块。在标准的RNN中,这个重复模块具有非常简单的结构,例如只有单个tanh层,如下图所示。
[外链图片转存失败(img-EwKxtSFp-1569051242265)(./images/lstm-rnn.jpg)]
LSTM具有同样的结构,但是重复的模块拥有不同的结构,如下图所示。与RNN的单一神经网络层不同,这里有四个网络层,并且以一种非常特殊的方式进行交互。
1.1 LSTM–遗忘门
LSTM 的第一步要决定从细胞状态中舍弃哪些信息。这一决定由所谓“遗忘门层”的 S 形网络层做出。它接收 h t ? 1 h_{t-1} ht?1? 和 x t x_t xt?,并且对细胞状态 C t ? 1 C_{t?1} Ct?1? 中的每一个数来说输出值都介于 0 和 1 之间。1 表示“完全接受这个”,0 表示“完全忽略这个”。
1.2 LSTM–输入门
下一步就是要确定需要在细胞状态中保存哪些新信息。这里分成两部分。第一部分,一个所谓“输入门层”的 S 形网络层确定哪些信息需要更新。第二部分,一个 tanh 形网络层创建一个新的备选值向量—— C ~ t \tilde{C}_t C~t?,可以用来添加到细胞状态。在下一步中我们将上面的两部分结合起来,产生对状态的更新。
1.3 LSTM–细胞状态更新
现在更新旧的细胞状态 C t ? 1 C_{t?1} Ct?1? 更新到 C t C_t Ct?。先前的步骤已经决定要做什么,我们只需要照做就好。
我们对旧的状态乘以 f t f_t ft?,用来忘记我们决定忘记的事。然后我们加上 i t ⊙ C ~ t i_t\odot\tilde{C}_t it?⊙C~t?,这是新的候选值,根据我们对每个状态决定的更新值按比例进行缩放。
1.4 LSTM–输出门
最后,我们需要确定输出值。输出依赖于我们的细胞状态,但会是一个“过滤的”版本。首先我们运行 S 形网络层,用来确定细胞状态中的哪些部分可以输出。然后,我们把细胞状态输入 tanh(把数值调整到 ?1 和 1 之间)再和 S 形网络层的输出值相乘,部这样我们就可以输出想要输出的分。
1.5 LSTM的变种
目前我所描述的还只是一个相当一般化的 LSTM 网络。但并非所有 LSTM 网络都和之前描述的一样。事实上,几乎所有文章都会改进 LSTM 网络得到一个特定版本。差别是次要的,但有必要认识一下这些变种。
(1) 一个流行的 LSTM 变种由 Gers 和 Schmidhuber 提出,在 LSTM 的基础上添加了一个“窥视孔连接”,这意味着我们可以让门网络层输入细胞状态。
上图中我们为所有门添加窥视孔,但许多论文只为部分门添加.
(2)另一个变种把遗忘和输入门结合起来。同时确定要遗忘的信息和要添加的新信息,而不再是分开确定。当输入的时候才会遗忘,当遗忘旧信息的时候才会输入新数据。
(3)一个更有意思的 LSTM 变种称为 Gated Recurrent Unit(GRU),由 Cho 等人提出。GRU 把遗忘门和输入门合并成为一个“更新门”,把细胞状态和隐含状态合并,还有其他变化。这样做使得 GRU 比标准的 LSTM 模型更简单,因此正在变得流行起来。
2.LSTM前向传播与反向传播
本小节只推导添加“窥视孔连接”的变种LSTM,如下图所示,其它LSTM变种的推导方法与该方法类似,这里不做过多介绍。对反向传播算法了解不够透彻的,请参考https://zhuanlan.zhihu.com/p/79657669 ,这里有详细的推导过程,本文将直接使用https://zhuanlan.zhihu.com/p/79657669 的结论。
为了更直观的推导反向传播算法,将其转化为右图所示形式。
2.1 LSTM前向传播
LSTM在t时刻的前向传播公式为:
{ i t = σ ( i ~ t ) = σ ( W x i x t + W h i h t ? 1 + W c i c t ? 1 + b i ) f t = σ ( f ~ t ) = σ ( W x f x t + W h f h t ? 1 + W c f c t ? 1 + b f ) g t = tanh ? ( g ~ t ) = tanh ? ( W x g x t + W h g h t ? 1 + b g ) o t = σ ( o ~ t ) = σ ( W x o x t + W h o h t ? 1 + W c o c t + b o ) c t = c t ? 1 ⊙ f t + g t ⊙ i t m t = tanh ? ( c t ) h t = o t ⊙ m t y t = W y h h t + b y
\left\{
\begin{array}{l}
{i_t=\sigma(\tilde{i}_t)=\sigma(W_{xi}x_t+W_{hi}h_{t-1}+W_{ci}c_{t-1}+b_i)} \\
{f_t=\sigma(\tilde{f}_t)=\sigma(W_{xf}x_t+W_{hf}h_{t-1}+W_{cf}c_{t-1}+b_f) }\\
{g_t=\tanh(\tilde{g}_t)=\tanh(W_{xg}x_t+W_{hg}h_{t-1}+b_g)} \\
{o_t=\sigma(\tilde{o}_t)=\sigma(W_{xo}x_t+W_{ho}h_{t-1}+W_{co}c_{t}+b_o) }\\
{c_t=c_{t-1}\odot f_t+g_t\odot i_t}\\
{m_t=\tanh(c_t)}\\
{h_t=o_t\odot m_t}\\
{y_t=W_{yh}h_t+b_y}
\end{array}\right. ????????????????????????it?=σ(i~t?)=σ(Wxi?xt?+Whi?ht?1?+Wci?ct?1?+bi?)ft?=σ(f~?t?)=σ(Wxf?xt?+Whf?ht?1?+Wcf?ct?1?+bf?)gt?=tanh(g~?t?)=tanh(Wxg?xt?+Whg?ht?1?+bg?)ot?=σ(o~t?)=σ(Wxo?xt?+Who?ht?1?+Wco?ct?+bo?)ct?=ct?1?⊙ft?+gt?⊙it?mt?=tanh(ct?)ht?=ot?⊙mt?yt?=Wyh?ht?+by??
2.2 LSTM反向传播
已知:? J ? y t , ? J ? c t + 1 , ? J ? o ~ t + 1 , , ? J ? f ~ t + 1 , ? J ? i ~ t + 1 , ? J ? g ~ t + 1 \frac{\partial J}{\partial y_t},\frac{\partial J}{\partial c_{t+1}},\frac{\partial J}{\partial \tilde{o}_{t+1}},,\frac{\partial J}{\partial \tilde{f}_{t+1}},\frac{\partial J}{\partial \tilde{i}_{t+1}},\frac{\partial J}{\partial \tilde{g}_{t+1}} ?yt??J?,?ct+1??J?,?o~t+1??J?,,?f~?t+1??J?,?i~t+1??J?,?g~?t+1??J?,求某个节点梯度时,首先应该找到该节点的输出节点,然后分别计算所有输出节点的梯度乘以输出节点对该节点的梯度,最后相加即可得到该节点的梯度。如计算? J ? h t \frac{\partial J}{\partial h_t} ?ht??J?时,找到h t h_t ht?节点的所有输出节点y t 、 o ~ t + 1 、 f ~ t + 1 、 i ~ t + 1 、 g ~ t + 1 y_t、 \tilde{o}_{t+1}、\tilde{f}_{t+1}、\tilde{i}_{t+1}、\tilde{g}_{t+1} yt?、o~t+1?、f~?t+1?、i~t+1?、g~?t+1?,然后分别计算输出节点的梯度(如? J ? y t \frac{\partial J}{\partial y_t} ?yt??J?)与输出节点对h t h_t ht?的梯度的乘积(如? J ? y t W y h T \frac{\partial J}{\partial y_t}W_{yh}^T ?yt??J?WyhT?),最后相加即可得到节点h t h_t ht?的梯度:
? J ? h t = ? J ? y t W y h T + ? J ? o ~ t + 1 W h o T + ? J ? f ~ t + 1 W h f T + ? J ? i ~ t + 1 W h i T + ? J ? g ~ t + 1 W h g T
\frac{\partial J}{\partial h_t}=\frac{\partial J}{\partial y_t}W_{yh}^T+\frac{\partial J}{\partial \tilde{o}_{t+1}}W_{ho}^T+\frac{\partial J}{\partial \tilde{f}_{t+1}}W_{hf}^T+\frac{\partial J}{\partial \tilde{i}_{t+1}}W_{hi}^T+\frac{\partial J}{\partial \tilde{g}_{t+1}}W_{hg}^T
?ht??J?=?yt??J?WyhT?+?o~t+1??J?WhoT?+?f~?t+1??J?WhfT?+?i~t+1??J?WhiT?+?g~?t+1??J?WhgT?
同理可得t时刻其它节点的梯度:
{ ? J ? h t = ? J ? y t W y h T + ? J ? o ~ t + 1 W h o T + ? J ? f ~ t + 1 W h f T + ? J ? i ~ t + 1 W h i T + ? J ? g ~ t + 1 W h g T ? J ? m t = ? J ? h t ⊙ o t ? J ? c t = ? J ? m t d m t d c t + ? J ? c t + 1 ⊙ f t + 1 + ? J ? f ~ t + 1 W c f T + ? J ? i ~ t + 1 W c i T ? J ? g t = ? J ? c t ⊙ i t ? J ? i t = ? J ? c t ⊙ g t ? J ? f t = ? J ? c t ⊙ c t ? 1 ? J ? o t = ? J ? h t ⊙ m t } ? { ? J ? g ~ t = ? J ? g t ( 1 ? g t 2 ) ? J ? i ~ t = ? J ? i t i t ( 1 ? i t ) ? J ? f ~ t = ? J ? f t f t ( 1 ? f t ) ? J ? o ~ t = ? J ? o t i t ( 1 ? o t ) ? J ? x t = ? J ? o ~ t W x o T + ? J ? f ~ t W x f T + ? J ? i ~ t W x i T + ? J ? g ~ t W x g T
\left \{\begin{array}{l}
\frac{\partial J}{\partial h_t}=\frac{\partial J}{\partial y_t}W_{yh}^T+\frac{\partial J}{\partial \tilde{o}_{t+1}}W_{ho}^T+\frac{\partial J}{\partial \tilde{f}_{t+1}}W_{hf}^T+\frac{\partial J}{\partial \tilde{i}_{t+1}}W_{hi}^T+\frac{\partial J}{\partial \tilde{g}_{t+1}}W_{hg}^T \\ \\
\frac{\partial J}{\partial m_t} = \frac{\partial J}{\partial h_t} \odot o_t \\ \\
\frac{\partial J}{\partial c_t} = \frac{\partial J}{\partial m_t}\frac{dm_t}{dc_t}+ \frac{\partial J}{\partial c_{t+1}}\odot f_{t+1} +\frac{\partial J}{\partial \tilde{f}_{t+1}}W_{cf}^T+\frac{\partial J}{\partial \tilde{i}_{t+1}}W_{ci}^T \\ \\
\left. \begin{array}{l}
\frac{\partial J}{\partial g_t} = \frac{\partial J}{\partial c_t}\odot i_t \\
\frac{\partial J}{\partial i_t} = \frac{\partial J}{\partial c_t} \odot g_t \\
\frac{\partial J}{\partial f_t} = \frac{\partial J}{\partial c_t} \odot c_{t-1} \\
\frac{\partial J}{\partial o_t} = \frac{\partial J}{\partial h_t} \odot m_t
\end{array} \right \} \Rightarrow \left\{ \begin{array}{l}
\frac{\partial J}{\partial \tilde{g}_t} = \frac{\partial J}{\partial g_t}(1-g_t^2) \\
\frac{\partial J}{\partial \tilde{i}_t} = \frac{\partial J}{\partial i_t}i_t(1-i_t) \\
\frac{\partial J}{\partial \tilde{f}_t} = \frac{\partial J}{\partial f_t}f_t(1-f_t) \\
\frac{\partial J}{\partial \tilde{o}_t} = \frac{\partial J}{\partial o_t}i_t(1-o_t) \\
\end{array}\right. \\ \\
\frac{\partial J}{\partial x_t} = \frac{\partial J}{\partial \tilde{o}_t}W_{xo}^T+\frac{\partial J}{\partial \tilde{f}_t}W_{xf}^T+ \frac{\partial J}{\partial \tilde{i}_t}W_{xi}^T+\frac{\partial J}{\partial \tilde{g}_t}W_{xg}^T\\
\end{array}\right.
???????????????????????????????????????????????ht??J?=?yt??J?WyhT?+?o~t+1??J?WhoT?+?f~?t+1??J?WhfT?+?i~t+1??J?WhiT?+?g~?t+1??J?WhgT??mt??J?=?ht??J?⊙ot??ct??J?=?mt??J?dct?dmt??+?ct+1??J?⊙ft+1?+?f~?t+1??J?WcfT?+?i~t+1??J?WciT??gt??J?=?ct??J?⊙it??it??J?=?ct??J?⊙gt??ft??J?=?ct??J?⊙ct?1??ot??J?=?ht??J?⊙mt????????????????????????g~?t??J?=?gt??J?(1?gt2?)?i~t??J?=?it??J?it?(1?it?)?f~?t??J?=?ft??J?ft?(1?ft?)?o~t??J?=?ot??J?it?(1?ot?)??xt??J?=?o~t??J?WxoT?+?f~?t??J?WxfT?+?i~t??J?WxiT?+?g~?t??J?WxgT??
对参数的梯度:
{ ? J ? W h o = h t T ? J ? o ~ t + 1 ? J ? W h f = h t T ? J ? f ~ t + 1 ? J ? W h i = h t T ? J ? i ~ t + 1 ? J ? W h g = h t T ? J ? g ~ t + 1 { ? J ? W y h = h t T ? J ? y t ? J ? W c f = c t T ? J ? f ~ t + 1 ? J ? W c i = c t T ? J ? i ~ t + 1 ? J ? W c o = c t T ? J ? o ~ t { ? J ? W x o = x t T ? J ? o ~ t ? J ? W x f = x t T ? J ? f ~ t ? J ? W x i = x t T ? J ? i ~ t ? J ? W x g = x t T ? J ? g ~ t
\left \{\begin{array}{l}
\frac{\partial J}{\partial W_{ho}} = h_t^T\frac{\partial J}{\partial \tilde{o}_{t+1}} \\
\frac{\partial J}{\partial W_{hf}} = h_t^T\frac{\partial J}{\partial \tilde{f}_{t+1}} \\
\frac{\partial J}{\partial W_{hi}} = h_t^T\frac{\partial J}{\partial \tilde{i}_{t+1}} \\
\frac{\partial J}{\partial W_{hg}} = h_t^T\frac{\partial J}{\partial \tilde{g}_{t+1}}
\end{array} \right.
\left \{\begin{array}{l}
\frac{\partial J}{\partial W_{yh}} = h_t^T\frac{\partial J}{\partial y_t} \\
\frac{\partial J}{\partial W_{cf}} = c_t^T\frac{\partial J}{\partial \tilde{f}_{t+1}} \\
\frac{\partial J}{\partial W_{ci}} = c_t^T\frac{\partial J}{\partial \tilde{i}_{t+1}} \\
\frac{\partial J}{\partial W_{co}} = c_t^T\frac{\partial J}{\partial \tilde{o}_{t}}
\end{array} \right.
\left \{\begin{array}{l}
\frac{\partial J}{\partial W_{xo}} = x_t^T\frac{\partial J}{\partial \tilde{o}_{t}} \\
\frac{\partial J}{\partial W_{xf}} = x_t^T\frac{\partial J}{\partial \tilde{f}_{t}} \\
\frac{\partial J}{\partial W_{xi}} = x_t^T\frac{\partial J}{\partial \tilde{i}_{t}} \\
\frac{\partial J}{\partial W_{xg}} = x_t^T\frac{\partial J}{\partial \tilde{g}_{t}} \\
\end{array} \right.
???????????Who??J?=htT??o~t+1??J??Whf??J?=htT??f~?t+1??J??Whi??J?=htT??i~t+1??J??Whg??J?=htT??g~?t+1??J?????????????Wyh??J?=htT??yt??J??Wcf??J?=ctT??f~?t+1??J??Wci??J?=ctT??i~t+1??J??Wco??J?=ctT??o~t??J?????????????Wxo??J?=xtT??o~t??J??Wxf??J?=xtT??f~?t??J??Wxi??J?=xtT??i~t??J??Wxg??J?=xtT??g~?t??J??
参考资料:https://www.cnblogs.com/xuruilong100/p/8506949.html
内容总结
以上是互联网集市为您收集整理的LSTM前向传播与反向传播算法推导(非常详细) 全部内容,希望文章能够帮你解决LSTM前向传播与反向传播算法推导(非常详细) 所遇到的程序开发问题。
如果觉得互联网集市技术教程内容还不错,欢迎将互联网集市网站推荐给程序员好友。
内容备注
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 gblab@vip.qq.com 举报,一经查实,本站将立刻删除。
内容手机端
来源:【匿名】