咱们好,我是半虹,这篇文章来讲长短期回忆网络 (Long Short-Term Memory, LSTM)

文章行文思路如下:

  1. 首要经过循环神经网络引出为啥需求长短期回忆网络
  2. 然后介绍长短期回忆网络的中心思维与运作方法
  3. 最终经过简略的代码深化了解长短期回忆网络的运作方法

长短期回忆网络能够看作是循环神经网络的改善版本,想要了解长短期回忆网络,首要要了解循环神经网络

由于咱们之前已具体介绍过循环神经网络,所以这儿咱们只会做一个简略的回顾,想看具体的阐明请戳这儿


对比前馈神经网络,循环神经网络经过添加隐状况实现对躲藏层信息的传递,以此到达记住前史输入的意图

网络在每个时刻步里读取上一躲藏层输出作为当时躲藏层输入,并保存当时躲藏层输出作为下一躲藏层输入

其结构简图如下:

NLP学习笔记(二) LSTM基本介绍

其间 XX 是输入 ,HH 是躲藏层的输出,图中的每个矩形都表明同一个循环神经网络躲藏层

下面咱们把躲藏层中的细节也画出来,方便后边与长短期回忆网络来对比

NLP学习笔记(二) LSTM基本介绍

其间 XX 是输入 ,HH 是躲藏层的输出,图中的灰色矩形同样代表躲藏层,\sigma 表明一个带激活函数的线性层

对应的公式表达如下:

Ht=(XtWxh+Ht−1Whh+bh)H_{t} = \alpha(X_{t} W_{xh} + H_{t-1} W_{hh} + b_{h})

其间 XtX_{t} 是当时输入,HtH_{t} 是当时躲藏层输出,Ht−1H_{t-1} 是早年躲藏层输出,WxhW_{xh}WhhW_{hh}bhb_{h} 都是网络参数


理论上,上述介绍的循环神经网络能处理任意长的序列,但实际上却并非如此

在实际使用循环神经网络处理长序列时一般会呈现梯度爆破或梯度消失的状况,导致网络难以捕捉长时刻依靠

这是为什么呢?经过简略分析一下梯度核算公式就能发现端倪

为了论述方便,咱们暂时假定一切的参数都是一维的,用字母 \theta 表明,对参数求导:

dHtd=∂Ht∂+∂Ht∂Ht−1dHt−1d\frac{d H_{t}}{d \theta} = \frac{\partial H_{t}}{\partial \theta} + \frac{\partial H_{t}}{\partial H_{t-1}} \frac{d H_{t-1}}{d \theta}

按时刻展开:

dHtd=∂Ht∂+∂Ht∂Ht−1∂Ht−1∂+∂Ht∂Ht−1∂Ht−1∂Ht−2dHt−2d+⋯\frac{d H_{t}}{d \theta} = \frac{\partial H_{t}}{\partial \theta} + \frac{\partial H_{t}}{\partial H_{t-1}} \frac{\partial H_{t-1}}{\partial \theta} + \frac{\partial H_{t}}{\partial H_{t-1}} \frac{\partial H_{t-1}}{\partial H_{t-2}} \frac{d H_{t-2}}{d \theta} + \cdots

不难发现,当时梯度 dHtd\frac{d H_{t}}{d \theta} 由当时梯度值 ∂Ht∂\frac{\partial H_{t}}{\partial \theta} 以及早年梯度 dHt−1d\frac{d H_{t-1}}{d \theta} 决定,关于早年梯度权重 ∂Ht∂Ht−1\frac{\partial H_{t}}{\partial H_{t-1}}

  • ∣∂Ht∂Ht−1∣<1|\frac{\partial H_{t}}{\partial H_{t-1}}| < 1 时,表明前史的梯度信息是逐步减弱的,随着时刻步不断添加,很可能会呈现梯度消失
  • ∣∂Ht∂Ht−1∣>1|\frac{\partial H_{t}}{\partial H_{t-1}}| > 1 时,表明前史的梯度信息是逐步增强的,随着时刻步不断添加,很可能会呈现梯度爆破

由推导式能够看出,梯度爆破和梯度消失更简单呈现在与当时时刻步间隔更远的梯度

这是由于这些梯度的权重连乘项更多,举例来说,关于时刻步 tt,其梯度 dHtd\frac{d H_{t}}{d \theta} 由以下梯度相加组成

  • 时刻步 t−1t – 1 的梯度 dHt−1d\frac{d H_{t-1}}{d \theta},与时刻步 tt 的间隔为 11,其权重为 ∂Ht∂Ht−1\frac{\partial H_{t}}{\partial H_{t-1}}
  • 时刻步 t−2t – 2 的梯度 dHt−2d\frac{d H_{t-2}}{d \theta},与时刻步 tt 的间隔为 22,其权重为 ∂Ht∂Ht−1∂Ht−1∂Ht−2\frac{\partial H_{t}}{\partial H_{t-1}} \frac{\partial H_{t-1}}{\partial H_{t-2}}
  • 时刻步 t−3t – 3 的梯度 dHt−2d\frac{d H_{t-2}}{d \theta},与时刻步 tt 的间隔为 33,其权重为 ∂Ht∂Ht−1∂Ht−1∂Ht−2∂Ht−3∂Ht−3\frac{\partial H_{t}}{\partial H_{t-1}} \frac{\partial H_{t-1}}{\partial H_{t-2}} \frac{\partial H_{t-3}}{\partial H_{t-3}}
  • ……

这阐明了什么?这阐明了关于当时输入,距其更远的输入的梯度更简单呈现梯度爆破或梯度消失

从而导致长间隔的梯度反馈失效,这便是循环神经网络难以捕捉长时刻依靠的实际含义


最终提示咱们留意一个细节,关于时刻步 tt 的梯度 dHtd\frac{d H_{t}}{d \theta}

  • 假定有且仅有最终一项梯度爆破,那么就会导致整个梯度爆破,由于 dHt−1d+⋯+NaN=NaN\frac{d H_{t-1}}{d \theta} + \cdots + NaN = NaN
  • 假定有且仅有最终一项梯度消失,这并不会导致整个梯度消失,由于 dHt−1d+⋯+0≠0\frac{d H_{t-1}}{d \theta} + \cdots + 0 \neq 0

总结一下,梯度反向传达时产生的反常,首要能够分为两种,一是梯度爆破,二是梯度消失

梯度爆破比较简单处理,一个简略但有效的做法是设置一个梯度阈值,当梯度超过这个阈值时直接切断

梯度消失更难处理一些,而现在流行的做法正是将循环神经网络替换生长短期回忆网络

留意,长短期回忆网络能缓解梯度消失的问题,但并不能缓解梯度爆破的问题


上面咱们从反向传达的视点解说了什么是梯度消失

假如咱们早年向核算的视点来看,则梯度消失能够了解成隐状况对短期回忆敏感,对长时刻回忆作用有限

为了保持长时刻回忆,长短期回忆网络引进回忆元存放长时刻回忆,并经过门机制操控回忆元中的信息活动

从直觉上来说,早年重要的回忆会保存在回忆元,不重要的回忆会被过滤,以此来到达长时刻回忆的意图


这儿有两个概念需求解说,一是回忆元,二是门机制,这两个便是长短期回忆网络的中心

先说回忆元,能够了解成另一种隐状况,都是用来记载附加信息的,简称为单元,英文为 Cell\text{Cell}

再说门机制,这是用来操控回忆元中信息活动的机制,具体来说包括三个操控门:

  • 输入门:操控是否将信息写入回忆元,英文为 InputGate\text{Input Gate}
  • 忘记门:操控是否从回忆元丢弃信息,英文为 ForgetGate\text{Forget Gate}
  • 输出门:操控是否从回忆元读出信息,英文为 OutputGate\text{Output Gate}

本质上来说,上述三个操控门都是由一个线性层加一个激活函数组成的,这儿激活函数用的是 sigmoid\text{sigmoid}

由于这样能将输出限制在零到一之间,以表明门的打开程度,操控信息活动的程度


相比循环神经网络只要一个传输状况,即隐状况,长短期回忆网络有两个传输状况,即隐状况和回忆元

二者的输入输出对比图如下:

NLP学习笔记(二) LSTM基本介绍

其间 HH 表明隐状况,CC 表明回忆元,知道输入输出后,下面开端介绍长短期回忆网络的内部作业原理

首要,根据当时输入 XtX_{t} 和早年隐状况 Ht−1H_{t-1},核算得到输入门 ItI_t、忘记门 FtF_t、输出门 OtO_t

其间,WxiW_{xi}WhiW_{hi}bib_{i}WxfW_{xf}WhfW_{hf}bfb_{f}WxoW_{xo}WhoW_{ho}bob_{o} 都是网络参数,\sigmasigmoid\text{sigmoid} 激活函数

It=(XtWxi+Ht−1Whi+bi)I_{t} = \sigma (X_{t} W_{xi} + H_{t-1} W_{hi} + b_{i})
Ft=(XtWxf+Ht−1Whf+bf)F_{t} = \sigma (X_{t} W_{xf} + H_{t-1} W_{hf} + b_{f})
Ot=(XtWxo+Ht−1Who+bo)O_{t} = \sigma (X_{t} W_{xo} + H_{t-1} W_{ho} + b_{o})

然后,根据当时输入 XtX_{t} 和早年隐状况 Ht−1H_{t-1},核算得到候选回忆元 C~t\widetilde{C}_{t}

其间,WxcW_{xc}WhcW_{hc}bcb_{c} 都是网络参数,tanh⁡\tanhtanh⁡\tanh 激活函数

C~t=tanh⁡(XtWxc+Ht−1Whc+bc)\widetilde{C}_{t} = \tanh (X_{t} W_{xc} + H_{t-1} W_{hc} + b_{c})

接着,输入门 ItI_t 操控选用多少来自 C~t\widetilde{C}_{t} 的新信息,忘记门 FtF_t 操控保存多少来自 Ct−1C_{t-1} 的旧信息,核算得 CtC_t

其间,⊙\odot 表明按元素乘法,当 It=0I_{t} = 0Ft=1F_{t} = 1 时,则曩昔回忆元被保存并传递到当时时刻步

Ct=Ft⊙Ct−1+It⊙C~tC_{t} = F_{t} \odot C_{t-1} + I_{t} \odot \widetilde{C}_{t}

最终,输出门 OtO_t 操控选用多少来自 CtC_{t} 的长回忆,核算得 HtH_{t}

其间,⊙\odot 表明按元素乘法,tanh⁡\tanh 表明 tanh⁡\tanh 激活函数,当 OtO_{t} 接近 11 时,就能够将长时刻回忆传递给隐状况

Ht=Ot⊙tanh⁡(Ct)H_{t} = O_{t} \odot \tanh (C_{t})

上述核算进程对应的核算图如下所示:

NLP学习笔记(二) LSTM基本介绍

为了协助咱们进一步了解长短期回忆网络的作业方法,下面咱们举一个例子来说,并给出关键代码

假定咱们用长短期回忆网络对下面这个语句进行编码:我在画画

import torch
import torch.nn as nn
# 界说输入数据
# 关于输入语句我在画画,首要用独热编码得到其向量表明
x1 = torch.tensor([1, 0, 0]).float() # 我
x2 = torch.tensor([0, 1, 0]).float() # 在
x3 = torch.tensor([0, 0, 1]).float() # 画
x4 = torch.tensor([0, 0, 1]).float() # 画
h0 = torch.zeros(5) # 初始化隐状况
c0 = torch.zeros(5) # 初始化回忆元
# 界说模型参数
# 模型的输入是三维向量,这儿界说模型的输出是五维向量
W_xi = nn.Parameter(torch.randn(3, 5), requires_grad = True)
W_hi = nn.Parameter(torch.randn(5, 5), requires_grad = True)
b_i  = nn.Parameter(torch.randn(5)   , requires_grad = True)
W_xf = nn.Parameter(torch.randn(3, 5), requires_grad = True)
W_hf = nn.Parameter(torch.randn(5, 5), requires_grad = True)
b_f  = nn.Parameter(torch.randn(5)   , requires_grad = True)
W_xo = nn.Parameter(torch.randn(3, 5), requires_grad = True)
W_ho = nn.Parameter(torch.randn(5, 5), requires_grad = True)
b_o  = nn.Parameter(torch.randn(5)   , requires_grad = True)
W_xc = nn.Parameter(torch.randn(3, 5), requires_grad = True)
W_hc = nn.Parameter(torch.randn(5, 5), requires_grad = True)
b_c  = nn.Parameter(torch.randn(5)   , requires_grad = True)
# 前向传达
def forward(X, H, C):
    # 核算各种门机制
    I = torch.sigmoid(torch.matmul(X, W_xi) + torch.matmul(H, W_hi) + b_i) # 输入门
    F = torch.sigmoid(torch.matmul(X, W_xf) + torch.matmul(H, W_hf) + b_f) # 忘记门
    O = torch.sigmoid(torch.matmul(X, W_xo) + torch.matmul(H, W_ho) + b_o) # 输出门
    # 核算候选回忆元
    C_tilde = torch.tanh(torch.matmul(X, W_xc) + torch.matmul(H, W_hc) + b_c)
    # 核算当时回忆元
    C = F * C + I * C_tilde
    # 核算当时隐状况
    H = O * C.tanh()
    # 回来成果
    return H, C
h1, c1 = forward(x1, h0, c0)
h2, c2 = forward(x2, h1, c1)
h3, c3 = forward(x3, h2, c2)
h4, c4 = forward(x4, h3, c3)
# 成果输出
print(h3) # tensor([-0.0408,  0.1785,  0.0455,  0.3802,  0.0235])
print(h4) # tensor([-0.0560,  0.1269,  0.0346,  0.3426,  0.0118])

最终提示咱们一点,假如长短期回忆网络后有接其他网络,例如后边接一个线性层做单词预测

那么一般不会用回忆元的输出,而是用躲藏层的输出


至此本文结束,关键总结如下:

  1. 循环神经网络在处理长序列时很简单会呈现梯度爆破和梯度消失的状况,导致网络难以捕捉长时刻依靠

    关于梯度爆破,一般能够选用梯度裁剪解决,关于梯度消失,能够选用长短期回忆网络缓解

  2. 除了有隐状况,长短期回忆网络还添加回忆元存放长时刻回忆,并经过门机制操控回忆元中的信息活动