EA-VQ-VAE 代码学习(1)
简介
之前学习 EA-VQ-VAE 的时候发现只读论文本身还是有很多细节问题不太懂,而 EA-VQ-VAE 的代码开源在 github 上。今天正好通过学习代码更深层地理解一下这个模型以及基础的 VQ-VAE 模型。
目录结构
代码的目录结构如下:
1 | │ README.md |
可以看到,整个代码目录结构还是比较清晰的四部分:
- data/:用以数据的获取和预处理
- estimator/:估计先验分布的模型
- generator/:推理阶段生成推理文本(光束搜索等)
- vq-vae/:vq-vae 的模型定义:包含 codebook、编码器、解码器等
这次先介绍最为核心的 vq-vae 模型,处在 vq-vae/model.py。剩下的部分后续有时间再进行分享。
VQ-VAE
model.py
首先是 CodeBook。codebook 在 EA-VQ-VAE 充当了隐变量表的角色,保存了一张由 K 个 D 维隐变量组成的 \(R^{K*D}\)。CodeBook 类代码如下:
1 | class CodeBook(nn.Module): |
整个代码是比较清晰的。在初始化中根据传入参数初始化嵌入空间,并保存了 commitment cost。commitment cost 指的是 VQ-VAE 损失函数的第三项的权重 \(\beta\)。由论文可知,CodeBook 的前向过程应该是输入编码器输出 \(z_e(x)\),输出最近的隐变量 \(z\)。那么代码中 inputs 的 shape 应该为 [batch_size,embedding_dim],进而距离的计算过程就很自然了。其他见代码的注释部分。
接下来是 seq2seq 模型:
1 | class Model(nn.Module): |
init 方法比较简单,只是保存参数和新建 codebook。前向过程也比较简单:训练阶段,seq2seq 的输入是事件和推理文本的拼接,然后进行编码和解码(这里编码器为 2 层 Transformer,解码器为预训练的 GPT 模型)。