PyTorch Lightning: 让 PyTorch 更为易用

摘要

今天来介绍一个很好用的深度学习框架,PyTorch Lightning。从名字就可以看出,它是基于 PyTorch 的框架。它的核心思想是,将学术代码(模型定义、前向 / 反向、优化器、验证等)与工程代码(for-loop,保存、tensorboard 日志、训练策略等)解耦开来,使得代码更为简洁清晰。工程代码经常会出现在深度学习代码中,PyTorch Lightning 对这部分逻辑进行了封装,只需要在 Trainer 类中简单设置即可调用,无需重复造轮子。

介绍

近些年来,各种深度学习框架层出不穷,TensorFlow、PyTorch 已经成为深度学习研究人员使用最多的两个框架。其中,PyTorch 的热度连年飙升,大有超越 TensorFlow 之势。 基础的张量运算、计算图等功能已经被这些框架支持的很好了,后面的框架开始着力于解决其他问题,例如:

  • transformers: 致力于简单高效地使用预训练模型,支持 PyTorch 和 TensorFlow 作为后端
  • Fairseq:提供通用建模序列任务的工具包,包含丰富高效的命令行接口,基于 PyTorch
  • Pytorch Lightning:解决深度学习代码中,学术代码、工程代码耦合性高,工程代码需要重复造轮子等问题

我之前已经介绍过 transformers 了,相信做 NLP 的同学都对这个框架很熟悉了。Fairseq 我也接触过一段时间,但是由于其文档不是很详细,大部分问题都要读源码才能找到答案, 入门起来比较痛苦。Pytorch Lightning 就是本文的重点了,下面做详细介绍。

一个例子

引用一个官网的 gif,介绍 Pytorch Lightning 的动机是什么。一段典型的、基于 PyTorch 的深度学习代码可能是下面左侧的这样的, 包含:模型、优化器、数据定义,训练、验证循环逻辑。

可以将其转化为右侧的 Lightning Module,按照以下步骤:

  • 将模型定义代码写在__init__
  • 定义前向传播逻辑
  • 将优化器代码写在 configure_optimizers 钩子中
  • 训练代码写在 training_step 钩子中,可以使用 self.log 随时记录变量的值,会保存在 tensorboard 中
  • 验证代码写在 validation_step 钩子中
  • 移除硬件调用.cuda() 等,PyTorch Lightning 会自动将模型、张量的设备放置在合适的设备;移除.train() 等代码,这也会自动切换
  • 根据需要,重写其他钩子函数,例如 validation_epoch_end,对 validation_step 的结果进行汇总;train_dataloader,定义训练数据的加载逻辑
  • 实例化 Lightning Module 和 Trainer 对象,传入数据集
  • 定义训练参数和回调函数,例如训练设备、数量、保存策略,Early Stop、半精度等

其中,最为直接的好处是,你无需关注模型和张量的设备,可以省去不计其数的.cuda(),再也不需要担心 device 报错了(当然,自己新建的张量还是需要调整下设备的)。 此外就是功能强大的 Trainer 和 tensorboard 的集成,可以非常优雅地调用。

进阶

如果想要解锁 PyTorch Lightning 的更多玩法,可以参考官方文档,详细地介绍了各种技巧,辅以代码示例,很容易理解上手。对于核心的 API,Lightning Module 的各种钩子,Trainer 的参数、用法,也做了详细的介绍。文档中还包含各种常见的工作流、上手教程,内容非常的齐全。

例如,configure_optimizers 支持配置多个优化器、搭配学习率衰减策略。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# most cases. no learning rate scheduler
def configure_optimizers(self):
return Adam(self.parameters(), lr=1e-3)

# multiple optimizer case (e.g.: GAN)
def configure_optimizers(self):
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
return gen_opt, dis_opt

# example with learning rate schedulers
def configure_optimizers(self):
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
dis_sch = CosineAnnealing(dis_opt, T_max=10)
return [gen_opt, dis_opt], [dis_sch]

# example with step-based learning rate schedulers
# each optimizer has its own scheduler
def configure_optimizers(self):
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
gen_sch = {
'scheduler': ExponentialLR(gen_opt, 0.99),
'interval': 'step' # called after each training step
}
dis_sch = CosineAnnealing(dis_opt, T_max=10) # called every epoch
return [gen_opt, dis_opt], [gen_sch, dis_sch]

# example with optimizer frequencies
# see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1
# https://arxiv.org/abs/1704.00028
def configure_optimizers(self):
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
dis_opt = Adam(self.model_dis.parameters(), lr=0.02)
n_critic = 5
return (
{'optimizer': dis_opt, 'frequency': n_critic},
{'optimizer': gen_opt, 'frequency': 1}
)

可以在,epoch_end 系列的钩子中,完成 Epoch-level 的 metric 计算。以 validation_epoch_end 为例,其接收一个参数 validation_step_outputs,是一个 list,包含了 validation_step 的所有返回结果。

1
2
3
4
5
6
7
8
9
10
11
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
pred = ...
return pred


def validation_epoch_end(self, validation_step_outputs):
all_preds = torch.stack(validation_step_outputs)
...

Trainer 中,常见的回调函数有 ModelCheckpoint,完成模型的定期保存;EarlyStopping,定义模型的早停策略。Trainer 定义时,只需要传入 precision=16,即可实现 PyTorch naive 的混合半精度,如果要制定 apex 后端,也只需要加上一行 amp_backend='apex' 即可。使用 accelerator 可以方便切换各种加速设备,CPU、GPU、TPU、IPU 等等。指定 strategy="ddp" 即可使用数据并行策略。

过多的就不再赘述了,官方文档介绍的还挺详细的,按需引入即可。

参考