变分自编码器 VAE

简介

今天来回顾一下变分自编码器(Variational Autoencoder,VAE),这是 2013 年提出的一种生成模型,时至今日,它的各类变体还活跃在各类会议上。之前我读过它的离散变体 VQ-VAE,这里再回顾一下原本的 VAE。

数学知识

理解 VAE 需要一些信息论和概率论的知识,这里总结一下。

概率统计

数值计算 vs 采样计算

对于一个随机变量 X,如果我们想知道 X 的期望 \(E(X)\)。如果我们已知 X 的分布函数,很容易可以计算出准确的期望 \(E(X)=\sum p(x)x\)(连续型变量替换为积分即可),这当然是最好的。然而很多情况下,我们无法得知准确的分布函数,那么我们可以采用统计量进行估计,对于 n 个随机样本 \(x_1,x_2,\dots,x_n\)\(\overline X=\frac{1}{n}\sum x_i\) 就是期望 \(E(X)\) 的无偏估计。

信息论

信息熵

在信息论中,信息熵衡量了信息的不确定性,公式为 \(H(X)=-\sum_{x\in X}p(x)logp(x)\)。以单个事件 x 为例,概率越小的事件的信息熵越大。当一个事件必定会发生时(\(p(x)=1\)),其信息熵为 0,没有任何不确定性。对随机变量 X 而说,其信息熵就是 \(-logp(x)\) 的期望,熵越大代表随机变量越不确定,很自然可以想到,分布越均匀,变量的状态越不容易确定,其熵越大。

在通信领域,信息熵可以看作对随机变量 X 进行编码所需的最短期望位数,这也被称为编码定理。在通信编码问题中,将随机变量 X 的每个值编码为一个二进制序列,使得序列长度期望最短。同时为了避免混乱,一个序列不能是其他序列的延申。这时编码位数的最短期望位数就是信息熵,有兴趣的同学可以去看看证明。

交叉熵

对于随机变量 \(X\) 的真实分布 \(p(x)\),有时是未知的,我们只有它的近似分布 \(q(x)\),如果按照 \(q(x)\) 对变量 X 进行编码,得到的编码长度的期望称为交叉熵,记为 \(H(p,q)=-\sum_x p(x)logq(x)\)。容易知道交叉熵是大于等于信息熵的,因为信息熵是最短编码长度。

相对熵(KL 散度)

对于真实分布 \(p\) 和近似分布 \(q\),相对熵为使用近似分布编码得到的编码长度与最短编码长度的差,即交叉熵与信息熵的差,定义为 \(D(p||q)=H(p,q)-H(p)\)。KL 散度衡量了两个分布之间的差异,两个分布差异越大,KL 散度越大。不过 KL 散度并不是距离,因为它不是对称的。因此,KL 散度可以用于分类任务中计算真实概率分布与预测的概率分布之间的差异。事实上,由于这时信息熵为常数,往往将其略去使用交叉熵作为损失函数

变分自编码器

对于数据集 \(D={x_1,x_2,\dots,x_n}\)(假设是很多张猫的图片,每个样本 \(x\) 是像素矩阵)。所有可能的猫图构成数据总体 \(X\),即 \(x_i\in D\subseteq X\)\(X\) 上存在一个概率分布 \(p(x)\),当我们随机采样一张猫的图片时,这张图片有 \(p(x)\) 的概率被采样到。对于非猫图 \(y,\ p(y)=0\)。我们希望的是能够找到 \(p(x)=p(x^{(1)},x^{(2)},\dots)\) 的准确数学形式,其中 \(x^{(1)}\) 代表 x 一维展开后的第 1 个像素,依次类推。如果这个目标能够实现,我们就能分析出 \(p(x)\) 这个概率函数,对哪些输入 \(x\) 能够取到非 0 概率值(即猫图),进而能够随机采样猫图和非猫图。

如果上述描述还不好理解的话,可以想象一个二维坐标系,以原点为圆心的单位圆上的均匀分布。所有猫图都满足 \(x^2+y^2\le1\),在这个单位圆内外分别随机采样,即可得到猫图和非猫图。

但是事实上,这样的目标是很难实现的。在数理统计里,这是个典型的非参数估计问题,在未知分布形式的情况下,没有办法对分布形式和分布参数进行估计。而且很容易想到,这也绝对是一个非常复杂的概率分布。因此直接对 \(p(x)\) 建模是不现实的。我们可以曲线救国。假设存在一个隐变量 \(z\) 控制着数据 \(x\) 的生成。那么根据 \(p(x)=\int p(z)p(x|z)dz\) 可以计算得到 \(p(x)\)。然而这个边界似然是不可解的,每一项都不知道具体的数学形式,更不要说还要积分。

那么求其次,我们可以用一个分布 \(q(x,z)\) 近似联合概率分布 \(p(x,z)\),那么我们的优化目标就是 \(KL(p||q)\) 最小。 \[ \begin{align} KL(p||q)&=\int\int p(x,z)log\frac{p(x,z)}{q(x,z)}dzdx \\ &=\int p(x)\int p(z|x)log\frac{p(x,z)}{q(x,z)}dzdx \\ &=E_{x\sim p(x)}\int p(z|x)log\frac{p(x)p(z|x)}{q(x,z)}dz \\ &=E_{x\sim p(x)}\int p(z|x)(log\frac{p(z|x)}{q(x,z)}+logp(x))dz \end{align} \]\[ \begin{align} E_{x\sim p(x)}\int q(z|x)logp(x)dz&=E_{x\sim p(x)}logp(x)\int q(z|x)dz\\ &=E_{x\sim p(x)}logp(x) \end{align} \] 为一个常数,可以略去。令 \[ \begin{align} \mathcal{L}&=E_{x\sim p(x)}\int p(z|x)log\frac{p(z|x)}{q(x,z)}dz\\ &=E_{x\sim p(x)}\int p(z|x)log\frac{p(z|x)}{q(z)q(x|z)}dz\\ &=E_{x\sim p(x)}[\int -p(z|x)logq(x|z) dz+ \int p(z|x)log\frac{p(z|x)}{q(z)}]dz\\ &=E_{x\sim p(x)}[E_{z\sim p(z|x)}(-log(q(x|z)))+KL(p(z|x)||q(z))] \end{align} \] 最小化 KL 与最小化 \(\mathcal{L}\) 等价。进而得到了 VAE 的损失函数。只不过与原论文中的符号有些出入,将最后的 KL 项的 p 与 q 调换,即得到了论文中 VAE 的损失函数: \[ \mathcal{L}=E_{x\sim p(x)}[E_{z\sim p(z|x)}(-log(q(x|z)))+KL(q(z|x)||p(z))] \] 符号的差别是由于论文直接引入的 \(q(z|x)\),而这里引入的是联合概率分布 \(q(x,z)\)

注意上面的 \(\mathcal L\) 是损失函数,而不是变分下界 ELBO,VAE 的 ELBO 是损失函数的相反数,引用 VAE 原文中的公式: \[ \mathcal{L}(\theta,\phi;x^{(i)})=-D_{KL}(q_\phi(z|x^{(i)})||p(z))+\mathbb E_{q_\phi(z|x^{(i)})}[-log(p_\theta(x^{(i)}|z))] \]

ELBO 是要最大化的,而损失函数是要最小化的。上面的 ELBO 是单个样本的公式,没有对所有样本计算期望,所以形式上有所差异。

说明

值得注意的是,VAE 的两个损失函数项并不是割裂的。换而言之,VAE 并不是独立地优化每项损失。这很容易理解,如果 VAE 独立优化第二项损失至最小,\(p(z)=q(z|x)\),那么 \(q(z|x)\) 将不具备任何 \(x\) 的信息,这显然会使得第一项重构损失很大。同理,如果第一项重构损失很小,就意味着 \(q(z|x)\) 包含了过多 \(x\) 的信息,与 \(p(z)\) 的差异(即 KL 项)就会很大。因此,VAE 是在两项损失的相互作用下,取得一个最优解。

参考