长短时记忆网络(LSTM)是一类特殊的循环神经网络,具有学习长时依赖关系的能力。
长时依赖问题
长时依赖是指,当预测点与相关依赖信息距离比较远的时候,就难以学到该依赖信息。例如在句子“我出生在法国,……,我会说法语”中,若要预测末尾“法语”,需要用到上下文“法国”,如果中间隔着很多句子,就有可能学不到首句的“法国”这个信息,导致不能准确预测。
理论上,循环神经网络能够处理这种问题,但实际上,常规的循环神经网络由于梯度消失或爆炸问题,并不能很好地解决长时依赖问题。为解决这个问题,Hochreiter和Schmidhuber两位科学家发明出了LSTM(长短时记忆网络)。
LSTM 神经网络
长短时记忆网络(LSTM / Long Short Term Mermory network)是一种特殊的循环神经网络(RNN),可以很好地解决长时依赖问题,是最常用的循环神经网络之一。
如前所诉,循环神经网络的结构:
循环神经网络是由重复的神经网络模块构成的一条链,可以看到它的处理层非常简单,通常是一个单tanh层,通过当前输入及上一时刻的输出来得到当前输出。与前馈神经网络相比,经过简单地改造,它可以利用上一时刻学习到的信息进行当前时刻的学习了。
LSTM的结构与上面相似,不同的是它的重复模块会比较复杂一点。原始循环神经网络的隐藏层只有一个状态,即h,它对于短期的输入非常敏感。LSTM增加了一个状态c,用于处理长期依赖,它的重复模块如下图所示,有四层结构:
其中,处理层出现的符号及意义如下:
LSTM原理
LSTM的关键,就是怎样控制长期状态c。
LSTM的重复单元,即下面的矩形方框,被称为记忆块(memory block),主要包含了三个门(遗忘门、输入门、输出门)与一个记忆单元(cell)。方框内上方的那条水平线,被称为单元状态(cell state),它就像一个传送带,控制给下一时刻的信息传递。
在这里,LSTM的思路是使用三个门(门类似开关)控制长期状态c。第一个门,负责控制继续保存长期状态c;第二个门,负责控制把即时状态输入到长期状态c;第三个门,负责控制是否把长期状态c作为当前的LSTM的输出。
这个矩形方框还可以表示为:
这两个图可以对应起来看,下图中心的c_t即记忆单元(cell),从下方输入(h_{t−1},x_t)到输出h_t的一条线,即为单元状态(cell state),f_t,i_t,o_t分别为遗忘门、输入门、输出门。上图中的两个tanh层则分别对应cell的输入与输出。
LSTM可以通过门控单元对记忆单元(cell)添加和删除信息。门实际上就是一个全连接层,它的输入是一个向量,输出是一个0到1之间的实数向量,门由一个sigmoid层和一个乘法操作组成,如下所示:
门的输出是一个介于0到1的数,表示允许通过多少信息,0 表示完全不允许通过,1表示允许完全通过。
LSTM前向计算
LSTM前向计算的步骤如下。
第一步 遗忘阶段。这个阶段主要是对上一个时刻的长期状态值进行选择性忘记。简单来说就是会 “忘记不重要的,记住重要的”。
具体来说,就是通过遗忘门控制上一个节点传进来的C_{t−1}通过或部分通过。 遗忘门的值f_t,在0到1之间,根据上一时刻的输出h_{t−1}和当前输入x_t产生。计算过程如下所示:
例如,在学习一段文本时,在之前的句子中学到了很多东西,其中一些东西对当前来讲是没用的,就可以进行过滤。
第二步 选择记忆阶段。这个阶段将这个阶段的输入有选择性地进行“记忆”。哪些重要则着重记录下来,哪些不重要,则少记一些。
具体来说,就是通过输入门控制\tilde{C_t}值通过或部分通过。其中:
- 输入门的值i_t,在0到1之间,根据上一时刻的输出h_{t−1}和当前输入x_t产生。
- \tilde{C_t}值,根据上一时刻的输出h_{t−1}和当前输入x_t,经过一个tanh层产生。
然后,将上面两步得到的结果相加,即可得到传输给下一个状态的C_t值。一二步结合起来就是丢掉不需要的信息,添加新信息的过程:
例如,在前面的句子中,我们保存的是张三的信息,现在有了新的李四信息,我们需要把张三的信息丢弃掉,然后把李四的信息保存下来。
最后一步 输出阶段。 这个阶段将决定哪些将会被当成当前状态的输出。
具体来说,就是通过输出门控制C_t值通过或部分通过。其中:
- C_t值首先会经过tanh处理,缩放到-1到1间,再与输出门得到的输出逐对相乘,最终得到模型的输出。
- 输出门的值O_t,在0到1之间,根据上一时刻的输出h_{t−1}和当前输入x_t产生。
LSTM训练
LSTM的训练算法仍然是反向传播算法,主要有下面三个步骤:
- 前向计算每个神经元的输出值,对于LSTM来说,即f_t、i_t、c_t、o_t、h_t五个向量的值。计算方法已经在上一节中描述过了。
- 反向计算每个神经元的误差项δ值。与循环神经网络一样,LSTM误差项的反向传播也是包括两个方向:一个是沿时间的反向传播,即从当前t时刻开始,计算每个时刻的误差项;一个是将误差项向上一层传播。
- 根据相应的误差项,计算每个权重的梯度。