核心要点

  • 常规反向传播需保存前向所有层的激活以便求梯度,激活显存随网络深度线性增长

  • 梯度检查点只保存少量关键层的激活(检查点),其余激活在反向时按需重新前向计算

  • 本质是「时间换空间」:激活显存可从 O(L) 降到约 O(√L),代价是多一次前向、训练算力增加约 20~30%

  • PyTorch 用 torch.utils.checkpointDeepSpeed 提供 activation checkpointing,是训练大模型扩 batch/扩深度的常用手段

标准回答

问题来源

反向传播求每层梯度时需要该层前向的激活值,因此标准训练会缓存所有层的激活。对深层大模型,激活显存往往超过参数本身,成为显存瓶颈。

核心思想

梯度检查点在前向时只保留若干检查点处的激活,丢弃中间激活;反向传播到某段时,从最近的检查点重新前向计算出所需激活,再求梯度。

收益与代价

把网络切成约 √L 段、每段存一个检查点,可使激活显存从 O(L) 降到 O(√L);代价是反向阶段额外做一次前向,整体训练算力增加约 20~30%。

工程使用

PyTorch 通过 torch.utils.checkpoint 包裹子模块即可启用;DeepSpeed/Megatron 提供 activation checkpointing,常与混合精度、ZeRO、并行策略组合,用于在有限显存下训练更大模型或更大 batch。

常见误区

⚠️ 常见踩坑

误以为它能省参数或优化器显存——它只针对激活;忽视重算带来的训练变慢,在显存充足时盲目开启反而得不偿失。

追问

追问 1梯度检查点和混合精度、ZeRO 能一起用吗?

可以且常组合。混合精度降低激活与参数的存储位宽,ZeRO 切分优化器状态/梯度/参数,梯度检查点削减激活显存,三者作用于不同显存组成部分,叠加后可在同等硬件上训练更大模型。

追问 2检查点放在哪些位置比较合理?

通常按层或 Transformer block 的边界切分,使每段重算代价均衡,激活显存与算力开销取得平衡。等间隔切成约 √L 段是理论上的最优折中,实践中常以 block 为粒度。

延伸学习

与本题相关的知识库文章、术语、工具与行业资讯。

🛠️ AI 工具

  • Pytorch

    Meta 开源的深度学习框架,100K+ stars。以动态计算图和 Pythonic 风格著称,在学术界和工业界都有广泛应用,支持分布式训练、移动端部署和 ONNX 导出

  • DeepSpeed

    深度学习训练优化库,42,156+ stars。微软开发的开源深度学习优化库,提供 ZeRO 内存优化、3D 并行等核心技术,大幅降低大模型训练成本