标准回答
问题来源
反向传播求每层梯度时需要该层前向的激活值,因此标准训练会缓存所有层的激活。对深层大模型,激活显存往往超过参数本身,成为显存瓶颈。
核心思想
梯度检查点在前向时只保留若干检查点处的激活,丢弃中间激活;反向传播到某段时,从最近的检查点重新前向计算出所需激活,再求梯度。
收益与代价
把网络切成约 √L 段、每段存一个检查点,可使激活显存从 O(L) 降到 O(√L);代价是反向阶段额外做一次前向,整体训练算力增加约 20~30%。
工程使用
PyTorch 通过 torch.utils.checkpoint 包裹子模块即可启用;DeepSpeed/Megatron 提供 activation checkpointing,常与混合精度、ZeRO、并行策略组合,用于在有限显存下训练更大模型或更大 batch。
常见误区
⚠️ 常见踩坑
误以为它能省参数或优化器显存——它只针对激活;忽视重算带来的训练变慢,在显存充足时盲目开启反而得不偿失。
追问
追问 1:梯度检查点和混合精度、ZeRO 能一起用吗?
可以且常组合。混合精度降低激活与参数的存储位宽,ZeRO 切分优化器状态/梯度/参数,梯度检查点削减激活显存,三者作用于不同显存组成部分,叠加后可在同等硬件上训练更大模型。
追问 2:检查点放在哪些位置比较合理?
通常按层或 Transformer block 的边界切分,使每段重算代价均衡,激活显存与算力开销取得平衡。等间隔切成约 √L 段是理论上的最优折中,实践中常以 block 为粒度。
延伸学习
与本题相关的知识库文章、术语、工具与行业资讯。
📖 术语表
🛠️ AI 工具