核心要点
最直接:减小 batch size,再用梯度累积(gradient accumulation)等效大 batch 不增显存
用混合精度(FP16/BF16)省一半激活显存;用梯度检查点(activation checkpointing)以算换存
推理用 no_grad/释放无用张量;超大模型上 ZeRO/张量并行/CPU offload 分摊显存
标准回答
先搞清显存被谁吃掉,再分层释放(独占一行)
训练显存主要由四部分构成:模型参数、优化器状态(如 Adam 的一阶/二阶矩约为参数 2 倍)、前向激活、以及临时 buffer。先用显存分析工具看占比,针对最大项优化,避免盲目。
减 batch 与梯度累积
最直接是减小 batch size;为保持等效大 batch 的稳定性,用梯度累积:多个小 batch 累加梯度后再 step,不额外占显存。
混合精度与梯度检查点
混合精度 用 FP16/BF16 存激活,省近一半显存与带宽;梯度检查点(activation checkpointing)只保留部分激活、反向时重算其余,用计算时间换显存,对深层网络很有效。
推理与超大模型
推理时用 torch.no_grad()/inference_mode 不存计算图,及时 del 无用张量并 empty_cache;序列过长则截断或分块。模型本身放不下时用 ZeRO 切分优化器状态/梯度、张量并行或 CPU/NVMe offload(见 模型量化压缩 也可进一步降存)。
常见误区
⚠️ 常见踩坑
一遇 OOM 就只会调小 batch,忽略梯度累积保稳定、混合精度与梯度检查点等更省的手段;以及推理时忘了关梯度(no_grad),白白存了整张计算图。
追问
追问 1:梯度检查点为什么能省显存,代价是什么?
常规前向会缓存每层激活供反向用,显存随深度线性增长。梯度检查点只保存少数「检查点」激活,反向时从最近检查点重新前向算出中间激活,于是激活显存大幅下降,代价是多一次前向计算,训练时间通常增加 20%~30%。
追问 2:为什么有时减小 batch 后显存峰值反而没降多少?
若瓶颈不是激活而是模型参数+优化器状态(与 batch 无关)或某个超长序列、超大中间张量,减 batch 收效就有限。此时应改用混合精度、ZeRO 切分优化器状态、缩短序列长度或并行/offload,而不是继续压 batch。
延伸学习
与本题相关的知识库文章、术语、工具与行业资讯。