EA-VQ-VAE 代码学习(1)

简介

之前学习 EA-VQ-VAE 的时候发现只读论文本身还是有很多细节问题不太懂,而 EA-VQ-VAE 的代码开源在 github 上。今天正好通过学习代码更深层地理解一下这个模型以及基础的 VQ-VAE 模型。

目录结构

代码的目录结构如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
│  README.md
| LICENSE
├─data
│ get_atomic_data.sh
│ get_event2mind_data.sh
│ preprocess-atomic.py
│ preprocess-event2mind.py
├─estimator
│ model.py
│ run.py
├─generator
│ beam.py
│ model.py
│ run.py
└─vq-vae
gpt2.py
model.py
run.py

可以看到,整个代码目录结构还是比较清晰的四部分:

  • data/:用以数据的获取和预处理
  • estimator/:估计先验分布的模型
  • generator/:推理阶段生成推理文本(光束搜索等)
  • vq-vae/:vq-vae 的模型定义:包含 codebook、编码器、解码器等

这次先介绍最为核心的 vq-vae 模型,处在 vq-vae/model.py。剩下的部分后续有时间再进行分享。

VQ-VAE

model.py

首先是 CodeBook。codebook 在 EA-VQ-VAE 充当了隐变量表的角色,保存了一张由 K 个 D 维隐变量组成的 \(R^{K*D}\)。CodeBook 类代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class CodeBook(nn.Module):
def __init__(self, num_embeddings, embedding_dim, commitment_cost):
super(CodeBook, self).__init__()
self._embedding_dim = embedding_dim
self._num_embeddings = num_embeddings
self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
self._commitment_cost = commitment_cost

def forward(self, inputs):
# Calculate distances
distances = (torch.sum(inputs**2, dim=1, keepdim=True)
+ torch.sum(self._embedding.weight**2, dim=1)
- 2 * torch.matmul(inputs, self._embedding.weight.t()))

# Encoding
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings).cuda()
encodings.scatter_(1, encoding_indices, 1) # 离散隐变量索引 [batch_size,num_embeddings]

# Quantize and unflatten
quantized = torch.matmul(encodings, self._embedding.weight) ## 乘法获得隐变量
# 整个隐变量的获取方法有点复杂了,argmin之后直接查询embedding即可,无需手动操作。这里这样处理是为了后续
# 还要计算perplexity

# Loss
# detach()从计算图中脱离,达到stop gradient的目的
e_latent_loss = torch.mean((quantized.detach() - inputs)**2)
q_latent_loss = torch.mean((quantized - inputs.detach())**2)
loss = q_latent_loss + self._commitment_cost * e_latent_loss

quantized = inputs + (quantized - inputs).detach()
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

# convert quantized from BHWC -> BCHW
return loss, quantized, perplexity, encodings

整个代码是比较清晰的。在初始化中根据传入参数初始化嵌入空间,并保存了 commitment cost。commitment cost 指的是 VQ-VAE 损失函数的第三项的权重 \(\beta\)。由论文可知,CodeBook 的前向过程应该是输入编码器输出 \(z_e(x)\),输出最近的隐变量 \(z\)。那么代码中 inputs 的 shape 应该为 [batch_size,embedding_dim],进而距离的计算过程就很自然了。其他见代码的注释部分。

接下来是 seq2seq 模型:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class Model(nn.Module):
"""
Build Seqence-to-Sequence.

Parameters:

* `encoder`- encoder of seq2seq model. e.g. 2-layer transformer
* `decoder`- decoder of seq2seq model. e.g. GPT2
* `config`- configuration of encoder model.
* `args`- arguments.
"""
def __init__(self, encoder,decoder,config,args):
super(Model, self).__init__()
self.encoder = encoder
self.decoder=decoder
self.config=config
self.args=args

self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.codebook = CodeBook(args.z_size, config.n_embd,0.25)
self.codebook._embedding.weight.data.normal_(mean=0,std=0.1)
self.lsm = nn.LogSoftmax(dim=-1)
self.lm_head.weight=self.decoder.wte.weight

def forward(self, event_ids,target_ids):
"""
Forward the VQ-VAE model.
Parameters:
* `event_ids`- event ids of examples
* `target_ids`- target ids of examples
"""
input_ids=torch.cat((event_ids,target_ids),-1)
#obtain hidden of event+target by encoder
hidden_xy=self.encoder(input_ids,special=True)[0][:,-1]

#obtain latent variable z by coodebook
vae_loss, z, perplexity, encoding=self.codebook(hidden_xy)

#obtain hiddens of target
transformer_outputs=self.decoder(input_ids,z=z)
hidden_states = transformer_outputs[0][:,-target_ids.size(1):]

#calculate loss
lm_logits = self.lm_head(hidden_states+z[:,None,:])
# Shift so that tokens < n predict n
active_loss = target_ids[..., 1:].ne(0).view(-1) == 1 # 将推理文本展平并得到非0位置的索引,用以计算loss
shift_logits = lm_logits[..., :-1, :].contiguous() # 去除末尾的EOS
shift_labels = target_ids[..., 1:].contiguous() #
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))[active_loss],
shift_labels.view(-1)[active_loss])

outputs = (loss,vae_loss,perplexity),loss*active_loss.sum(),active_loss.sum(),encoding
return outputs

init 方法比较简单,只是保存参数和新建 codebook。前向过程也比较简单:训练阶段,seq2seq 的输入是事件和推理文本的拼接,然后进行编码和解码(这里编码器为 2 层 Transformer,解码器为预训练的 GPT 模型)。