核心要点

  • 机制:把大 batch 拆成 N 个 micro-batch,逐个 forward/backward,梯度累加在 .grad 上,累满 N 次再 optimizer.step() 并清零

  • 等效性:在不依赖跨样本统计的层上,累积 N 步等效于一个 N 倍大的 batch,但显存只需装下一个 micro-batch

  • 关键坑:损失要按累积步数缩放(或对 loss 取均值),否则相当于把学习率放大了 N 倍

  • BN 例外:BatchNorm 的统计量只在单个 micro-batch 内计算,无法等效大 batch,需用 LayerNorm/GroupNorm 或 SyncBN 规避

标准回答

核心思想

显存装不下大 batch 时,把目标大 batch 拆成 N 个小 micro-batch,每个都做一次前向和反向,梯度自然累加到参数的 .grad 上;累满 N 步后再调用一次 optimizer.step() 更新参数并清空梯度。

为什么等效

梯度对样本是线性可加的:N 个 micro-batch 累加的梯度,等于把它们拼成一个大 batch 算出的梯度(在不依赖跨样本统计的层上)。因此能用单卡小显存逼近大 batch 的训练效果,代价是耗时变为 N 倍(顺序计算)。

实现要点

loss 需除以 N(或对 batch 内求均值),保证累积后的梯度量级与真实大 batch 一致,否则等效于学习率被放大 N 倍。step 完成后再 zero_grad。

常见误区

⚠️ 常见踩坑

忘记按累积步数缩放损失,会让有效学习率被放大 N 倍;以及误以为它能让 BatchNorm 等效大 batch——BN 统计仍只在单个 micro-batch 内计算。

追问

追问 1梯度累积和数据并行(DDP)的区别?

数据并行用多张卡同时算不同 micro-batch、再 all-reduce 求平均,是空间换吞吐、速度快;梯度累积在单卡上顺序算多个 micro-batch,是时间换显存、不增速。两者可叠加:每卡先累积再跨卡 all-reduce。DDP 下累积阶段应配合 no_sync() 跳过中间通信,只在最后一步同步以省带宽。

追问 2增大有效 batch 后学习率要怎么调?

经验上 batch 扩大需相应增大学习率以维持每步更新幅度,常用线性缩放法则(batch 翻倍学习率翻倍),并配合 warmup 防止初期发散。但过大 batch 会损害泛化(sharp minima),收益边际递减,需结合验证集调整。

延伸学习

与本题相关的知识库文章、术语、工具与行业资讯。