循环神经网络 RNN 及其变体 GRU、LSTM

序言

同样,借着复习面试,把 RNN 家族再梳理回顾一下,包含 RNN、GRU、LSTM。

循环神经网络 RNN

模型结构

RNN 的结构如上图所示,其核心思想是使用同一套参数来更新状态 \(s\) 与计算输出 \(o\),箭头右侧是按时序展开的模型结构图。可以看到,RNN 仅使用了一个状态 \(s\)​来保存序列信息,共有三个参数矩阵。这一部分公式化描述如下: \[ s_t=f(Ux_t+Ws_{t-1}) \]

\[ o_t=g(Vs_t) \]

其中,\(f\) \(g\)​均为激活函数,激活函数可选的有 sigmoid,tanh,relu 等(下面会分析)。

RNN 有以下缺陷:

  • 容易发生梯度消失和梯度爆炸现象(由于导数连乘)。
  • 难以捕捉长距离的依赖。

在其中,梯度消失相对于梯度爆炸要更为严重,因为梯度爆炸是可以观测到的(NAN),梯度消失则难以直接观测。梯度爆炸问题很容易解决,可以通过梯度裁剪的方法进行解决。

梯度消失和梯度爆炸

梯度消失和爆炸的解决方法:

  • 梯度的剪切以及正则化(常见的是 l1 正则和 l2 正则)。
  • relu、elu 等激活函数。(梯度消失)
  • 批标准化(Batch Normalization)。
  • 残差结构(将映射 F (x) 改为 F (x)+x,使用 relu 激活函数的 F 在 x<0 时能够无损传播梯度,保证了深层网络的性能)。
  • LSTM、GRU 等结构。

批标准化 Batch Normalization

Batch Normalization 是一种常用于 CNN 的正则化方法,可以分为两个步骤:

(1)标准化:对 batch 的数据求均值与标准差,将数据标准化到标准正态分布

(2)进行放缩与平移

整个过程类似于 VAE 的重参数化,先获得一个正态分布的变量,再进行放缩平移,达到从任意正态分布中取样的效果。

也就是说,batch normlization 假设每个 batch 的数据服从一个正态分布(参数 γ 和 β 学习得来,即通过 batch 数据计算得来),先将数据标准化,再放缩与平移,使得数据 “看起来” 是从这个正态分布中取样而来的。

在预测阶段,所有参数的取值是固定的,对 BN 层而言,意味着 μ、σ、γ、β 都是固定值。γ 和 β 比较好理解,随着训练结束,两者最终收敛,预测阶段使用训练结束时的值即可。

对于 μ 和 σ,在训练阶段,它们为当前 mini batch 的统计量,随着输入 batch 的不同,μ 和 σ 一直在变化。在预测阶段,输入数据可能只有 1 条,该使用哪个 μ 和 σ,或者说,每个 BN 层的 μ 和 σ 该如何取值?可以采用训练收敛最后几批 mini batch 的 μ 和 σ 的期望,作为预测阶段的 μ 和 σ。

层标准化 Layer Normalization

Batch Normalization 是在 Batch 的方向上进行 Normalization。这种方法在 NLP 中不是很适合。由于文本序列的长度可变性,一个 batch 中的数据往往长度不同,进而对每个位置进行标准化不是很合适。

而 Layer Normalization 则是在序列的方向上进行 Normalization。这使得它可以处理变长序列。

激活函数

对于激活函数而言,sigmoid 的最大梯度为 0.25,因此很容易发生梯度消失现象,而 tanh 虽然最大梯度为 1,但也只有 0 处取得,也容易 发生梯度消失。因此 RNN 常使用 relu 作为激活函数。relu 的梯度非 0 即 1,这能够缓解梯度消失现象,但也有一定的问题:1. 容易发生梯度爆炸。(梯度恒为 1 时)2. 负数部分梯度恒为 0,部分神经元无法激活。elu 能够缓解 relu 的 0 梯度的问题,但是由于加入了幂运算,会更慢一点。

门控循环单元 GRU

模型结构

GRU 的思想是在 RNN 的基础上,引入门控信号来缓解 RNN 存在的梯度消失问题。模型结构如下:

公式化描述如下(公式中的 \(\odot\) 代表哈达玛积,即同型矩阵间逐元素乘法):

首先根据输入 \(x_t\) 与上一时刻隐藏状态 \(h_{t-1}\) 计算得到两个门控状态 \(z_t\) \(r_t\)​,假设 \(h_t\in \mathbb R^H\)\[ z_t=sigmoid(W_zx_t+U_zh_{t-1})\in \mathbb R^{H} \]

\[ r_t=sigmoid(W_rx_t+U_rh_{t-1})\in \mathbb R^{H} \]

之后,使用重置门计算得到一个新的隐藏状态(即图中的 \(h’\)): \[ \tilde h_t=tanh(Wx_t+U(r_t\odot h_{t-1}))\in \mathbb R^{H} \] 再使用更新门 \(z_t\) 更新隐藏状态: \[ h_t=(1-z)\odot h_{t-1}+z\odot \tilde h_t\in \mathbb R^{H} \]

长短期记忆网络 LSTM

模型结构

LSTM 的思想是在 RNN 的基础上,加入一个不易被改变的新状态 \(c_t\)​​,代表的是 0-t 时刻的全局信息。而 \(h_t\)​代表的是在 0~t-1 时刻全局信息的影响下,\(t\) 时刻的信息。换而言之,\(c_t\) 变化的很慢,而 \(h_t\) 变化的很快。

公式化描述如下:

首先计算得到三个门控状态(分别对应图中的 \(z^i,z^f,z^o\)): \[ i_t=sigmoid(W_ix_t+U_ih_{t-1})\in \mathbb R^{H} \]

\[ f_t=sigmoid(W_fx_t+U_fh_{t-1})\in \mathbb R^{H} \]

\[ o_t=sigmoid(W_ox_t+U_oh_{t-1})\in \mathbb R^{H} \]

以及一个与当前输入密切相关的向量(对应图中的 \(z\)\[ \tilde c_t=tanh(W_zx_t+U_zh_{t-1}) \] 接着,更新两种状态: \[ c_t=f_t\odot c_{t-1}+i_t\odot \tilde c_t \]

\[ h_t=o_t\odot tanh(c_t) \]

其中,\(i_t.f_t,o_t\) 分别代表信息、遗忘、输出门控。信息和遗忘门控负责 cell state 的更新,输出门控负责 hidden state 的更新。具体而言,LSTM 可以简单分为以下三个阶段:

  • 遗忘阶段,根据遗忘门控,忘记上一个 cell state 的部分信息。
  • 记忆阶段,根据信息门控,将输入信息进行选择记忆。
  • 输出阶段,根据输出门控,输出最终的状态。

LSTM VS GRU

本质上,LSTM 和 GRU 都是通过引入门控信号来解决 RNN 的梯度消失问题。在实现方法上,GRU 相对于 LSTM 要更为简单。GRU 抛弃了 LSTM 中的 hidden state(GRU 中的 hidden state 实际上是 LSTM 中的 cell state),因为 LSTM 中的 \(h_t\) 只是想保存当前时刻的信息,这一部分已经包含到 GRU 中的 \(\tilde h_t\) 中了。cell state 中的之前的全局信息与当前时刻的信息应当是一个此消彼长的状态,GRU 因此直接使用一个门控信号 \(z_t\) 同时控制了遗忘和更新。

在参数上,GRU 有着比 LSTM 更少的参数,收敛速度更快,并且与 LSTM 有着差不多的性能表现,因此实际工程中多使用 GRU。

参考

深度学习之 3—— 梯度爆炸与梯度消失 - 知乎 (zhihu.com)

人人都能看懂的 GRU - 知乎 (zhihu.com)

人人都能看懂的 LSTM - 知乎 (zhihu.com)

RNN vs LSTM vs GRU -- 该选哪个? - 知乎 (zhihu.com)