核心要点

  • 先定位显存去向:模型参数、优化器状态、激活、batch/序列长度,哪块最大就先压哪块

  • 最直接:减小 batch size,再用梯度累积(gradient accumulation)等效大 batch 不增显存

  • 用混合精度(FP16/BF16)省一半激活显存;用梯度检查点(activation checkpointing)以算换存

  • 推理用 no_grad/释放无用张量;超大模型上 ZeRO/张量并行/CPU offload 分摊显存

标准回答

先搞清显存被谁吃掉,再分层释放(独占一行)

训练显存主要由四部分构成:模型参数、优化器状态(如 Adam 的一阶/二阶矩约为参数 2 倍)、前向激活、以及临时 buffer。先用显存分析工具看占比,针对最大项优化,避免盲目。

减 batch 与梯度累积

最直接是减小 batch size;为保持等效大 batch 的稳定性,用梯度累积:多个小 batch 累加梯度后再 step,不额外占显存。

混合精度与梯度检查点

混合精度 用 FP16/BF16 存激活,省近一半显存与带宽;梯度检查点(activation checkpointing)只保留部分激活、反向时重算其余,用计算时间换显存,对深层网络很有效。

推理与超大模型

推理时用 torch.no_grad()/inference_mode 不存计算图,及时 del 无用张量并 empty_cache;序列过长则截断或分块。模型本身放不下时用 ZeRO 切分优化器状态/梯度、张量并行或 CPU/NVMe offload(见 模型量化压缩 也可进一步降存)。

常见误区

⚠️ 常见踩坑

一遇 OOM 就只会调小 batch,忽略梯度累积保稳定、混合精度与梯度检查点等更省的手段;以及推理时忘了关梯度(no_grad),白白存了整张计算图。

追问

追问 1梯度检查点为什么能省显存,代价是什么?

常规前向会缓存每层激活供反向用,显存随深度线性增长。梯度检查点只保存少数「检查点」激活,反向时从最近检查点重新前向算出中间激活,于是激活显存大幅下降,代价是多一次前向计算,训练时间通常增加 20%~30%。

追问 2为什么有时减小 batch 后显存峰值反而没降多少?

若瓶颈不是激活而是模型参数+优化器状态(与 batch 无关)或某个超长序列、超大中间张量,减 batch 收效就有限。此时应改用混合精度、ZeRO 切分优化器状态、缩短序列长度或并行/offload,而不是继续压 batch。

延伸学习

与本题相关的知识库文章、术语、工具与行业资讯。