量子变分自编码器 VQ-VAE

简介

VQ-VAE(Vector Quantised - Variational AutoEncoder,量子变分自编码器)出自 2017 年 Google 团队的论文 Neural Discrete Representation Learning。顾名思义,VQ-VAE 是 VAE( Variational AutoEncoder,变分自编码器)的变种。主要是为了解决 VAE 所存在的” 后验坍塌 “问题。VQ-VAE 与 VAE 的主要区别在于:

  • 隐变量是离散的,而非连续的
  • 先验分布是学习得来的,而非固定不变的

研究动机与背景

离散型隐变量

离散型隐变量对于某些任务是更为自然与恰当的,例如语言是由离散的字符组成的,图像的像素是 0-255 的自然数。然而,离散 VAE 往往难以训练,现有的训练方法无法弥补其与连续型 VAE 存在的性能上的差距。尽管连续型 VAE 会存在后验坍塌问题,但是由于从高斯分布中使用重参数化技巧采样隐变量,连续型 VAE 中能够获得方差更小,即更稳定的参数梯度。

自回归模型

自回归模型(Autoregressive model)是一种处理时间序列的方法,使用 \(x_1,x_2,\dots,x_{t-1}\) 来预测 \(x_t\),并假设它们是线性关系。由于其使用 \(x\) 本身来预测 \(x\),因而得名为自回归模型。形式化来讲,自回归模型定义如下: \[ X_t=c+\sum_{i=1}^p\phi_iX_{t-i}+\epsilon_t \] 其中,\(c\) 是常数项,\(\epsilon_t\) 假设为一个均值为 0,标准差为 \(\sigma\) 的随机误差。

典型的自回归模型有循环神经网络(Recurrent Neural Network, RNN),PixelCNN 等。下面以文中提到的 PixelCNN 为例进行介绍。

PixelCNN 是虽然是 CNN,但它与传统的 CNN 不同,而是参考了 RNN 的思路,将图片扁平化为一维后,将其看成时间序列进行逐像素的生成。即: \[ \begin{align} p(x)&=p(x_1,x_2,\dots,x_t)\\ &=p(x_1)p(x_2|x_1)\dots p(x_t|x_1,x_2,\dots,x_{t-1}) \end{align} \] 可以看到,符合上述的自回归模型的定义(令 \(X_t=p(x_1,x_2,\dots,x_t)\))。

变分自编码器

变分自编码器(Variational AutoEncoder,VAE)是一类重要的生成模型。由于篇幅原因这里只做简单介绍,后面可能会单独出一篇博客介绍。VAE 假设存在一个无法观测的隐变量 \(z\) 控制数据 \(x\) 的生成,它主要由以下几部分组成:

  • 编码网络,拟合后验分布 \(q(z|x)\) ,将数据 \(x\) 映射到连续隐变量 \(z\)
  • 生成网络,拟合分布 \(p(x|z)\)
  • 隐变量的先验分布 \(p(z)\)

在训练过程中,从 \(q(z|x)\) 中采样隐变量 \(z\) 来重构数据。在推理过程中,从 \(p(z)\) 中采样隐变量来生成数据。

模型细节

整体结构如下图所示:

离散隐变量

模型定义了一个 \(K*D\) 的隐变量嵌入空间,其中 \(K\) 为空间大小,\(D\) 为隐变量向量的维度。在得到编码网络的输出 \(z_e(x)\) 后,通过最近邻算法将其映射为隐变量嵌入空间中的某个隐变量 \(e_k\)(简记为 \(z\)),投喂到解码器。后验分布 \(q(z|x)\) 定义为如下的独热分布: \[ q(z=k|x) = \begin{cases} 1 &if\ k=\arg\min_j||z_e(x)-e_j|| , \\ 0 & otherwise. \end{cases} \] 进而: \[ z_q(x)=e_k, where\ k=\arg\min_j||z_e(x)-e_j|| \]

梯度计算

注意到上述公式中的 \(\arg\min\) 操作是无法求梯度的,这使得模型无法进行反向传播。VQ-VAE 采取直通估计(straight-through estimator )来解决这个问题。原论文中具体做法描述为 ” 将解码器输入 \(z_q(x)\) 的梯度复制到解码器的输出 \(z_e(x)\)。对应上述结构图中的红线。

损失函数

损失函数表示如下: \[ L=logp(x|z_q(x))+||sg[z_e(x)]-e||_2^2+\beta||z_e(x)-sg[e]||^2_2 \] 其中,\(sg\) 代表停止梯度,即反向传播时不再向前计算梯度。这个符号的含义我个人感觉论文解释的有点不清楚,可能需要对照代码进一步看一下。我目前的理解是,在前向传播的时候,sg 是恒等式,即被忽略掉了,此时计算得到的 loss 是真正的 loss。在反向传播时,sg 部分的计算图相当于断开了,以 \(||sg[z_e(x)]-e||_2^2\) 为例,前项传播时等价于 \(||z_e(x)-e||_2^2\)。反向传播时等价于 \(||const-e||_2^2\),即将 \(z_e(x)\) 看做常数,不对其进行优化。

损失函数的各项含义解释如下:

  • 第一项为重构损失,用以训练编码器和解码器,个人感觉这里是不是少了个负号,这一部分是似然函数,按理说应该是要最大化的。
  • 第二项为 L2 范数损失函数。通过矢量量化(Vector Quantisation,VQ)学习嵌入空间的字典,即希望编码器的输出 \(z_e(x)\) 与最近邻算法得到的 \(e\) 距离越近越好,用以优化嵌入空间。
  • 第三项为 L2 范数损失函数。与第二项的区别在于优化的是编码器。原论文中的说法是,由于嵌入空间是无量纲的,当仅存在第二项时,若 \(e\) 的参数训练速度慢于编码器参数,会使得 \(e\) 的参数向任意方向增长。

第二项和第三项本质上都是希望编码器的输出 \(z_e(x)\) 与离散化隐变量 \(e\) 相互接近,相较于 \(||z_e(x)-e||_2^2\),个人理解这里的设计是为了控制二者的优化速度。如果希望编码器输出相对稳定,则调小 \(\beta\),让嵌入空间更多地靠近编码器的输出,也可以反之。

论文中实验发现 \(\beta\) 从 0.1-2.0 都是非常鲁棒的,实验设置 \(\beta=0.25\),可能意味着二者靠近的速度影响不大(这也更符合直观认知)。

先验分布 \(p(z)\)

先验分布 \(p(z)\) 是个分类分布(categotical distribution),在训练过程中保持不变。在训练结束后,在隐变量 \(z\) 上拟合一个自回归分布,即 \(p(z)\),进而通过祖先采样(ancestral sampling)来生成 \(x\)

参考