PyTorch Lightning: 让 PyTorch 更为易用
摘要
今天来介绍一个很好用的深度学习框架,PyTorch Lightning。从名字就可以看出,它是基于 PyTorch 的框架。它的核心思想是,将学术代码(模型定义、前向 / 反向、优化器、验证等)与工程代码(for-loop,保存、tensorboard 日志、训练策略等)解耦开来,使得代码更为简洁清晰。工程代码经常会出现在深度学习代码中,PyTorch Lightning 对这部分逻辑进行了封装,只需要在 Trainer 类中简单设置即可调用,无需重复造轮子。
介绍
近些年来,各种深度学习框架层出不穷,TensorFlow、PyTorch 已经成为深度学习研究人员使用最多的两个框架。其中,PyTorch 的热度连年飙升,大有超越 TensorFlow 之势。 基础的张量运算、计算图等功能已经被这些框架支持的很好了,后面的框架开始着力于解决其他问题,例如:
- transformers: 致力于简单高效地使用预训练模型,支持 PyTorch 和 TensorFlow 作为后端
- Fairseq:提供通用建模序列任务的工具包,包含丰富高效的命令行接口,基于 PyTorch
- Pytorch Lightning:解决深度学习代码中,学术代码、工程代码耦合性高,工程代码需要重复造轮子等问题
我之前已经介绍过 transformers 了,相信做 NLP 的同学都对这个框架很熟悉了。Fairseq 我也接触过一段时间,但是由于其文档不是很详细,大部分问题都要读源码才能找到答案, 入门起来比较痛苦。Pytorch Lightning 就是本文的重点了,下面做详细介绍。
一个例子
引用一个官网的 gif,介绍 Pytorch Lightning 的动机是什么。一段典型的、基于 PyTorch 的深度学习代码可能是下面左侧的这样的, 包含:模型、优化器、数据定义,训练、验证循环逻辑。
可以将其转化为右侧的 Lightning Module,按照以下步骤:
- 将模型定义代码写在
__init__
中 - 定义前向传播逻辑
- 将优化器代码写在
configure_optimizers
钩子中 - 训练代码写在
training_step
钩子中,可以使用self.log
随时记录变量的值,会保存在 tensorboard 中 - 验证代码写在
validation_step
钩子中 - 移除硬件调用
.cuda()
等,PyTorch Lightning 会自动将模型、张量的设备放置在合适的设备;移除.train()
等代码,这也会自动切换 - 根据需要,重写其他钩子函数,例如
validation_epoch_end
,对validation_step
的结果进行汇总;train_dataloader
,定义训练数据的加载逻辑 - 实例化 Lightning Module 和 Trainer 对象,传入数据集
- 定义训练参数和回调函数,例如训练设备、数量、保存策略,Early Stop、半精度等
其中,最为直接的好处是,你无需关注模型和张量的设备,可以省去不计其数的.cuda()
,再也不需要担心 device 报错了(当然,自己新建的张量还是需要调整下设备的)。 此外就是功能强大的 Trainer 和 tensorboard 的集成,可以非常优雅地调用。
进阶
如果想要解锁 PyTorch Lightning 的更多玩法,可以参考官方文档,详细地介绍了各种技巧,辅以代码示例,很容易理解上手。对于核心的 API,Lightning Module 的各种钩子,Trainer 的参数、用法,也做了详细的介绍。文档中还包含各种常见的工作流、上手教程,内容非常的齐全。
例如,configure_optimizers
支持配置多个优化器、搭配学习率衰减策略。
1 | # most cases. no learning rate scheduler |
可以在,epoch_end
系列的钩子中,完成 Epoch-level 的 metric 计算。以 validation_epoch_end
为例,其接收一个参数 validation_step_outputs
,是一个 list,包含了 validation_step
的所有返回结果。
1 | def validation_step(self, batch, batch_idx): |
Trainer 中,常见的回调函数有 ModelCheckpoint
,完成模型的定期保存;EarlyStopping
,定义模型的早停策略。Trainer 定义时,只需要传入 precision=16
,即可实现 PyTorch naive 的混合半精度,如果要制定 apex 后端,也只需要加上一行 amp_backend='apex'
即可。使用 accelerator
可以方便切换各种加速设备,CPU、GPU、TPU、IPU 等等。指定 strategy="ddp"
即可使用数据并行策略。
过多的就不再赘述了,官方文档介绍的还挺详细的,按需引入即可。