核心要点
先定位是哪一步出现 NaN:是 loss、某层激活,还是梯度先变 NaN,用 anomaly detection 抓首个出错算子
查数值不稳定运算:log(0)、除 0、sqrt(负数)、exp 溢出,给 log 加 eps、用 logits 版损失避免手写 softmax
查梯度爆炸:学习率过大或缺梯度裁剪,加 clip_grad_norm 并降 LR / 加 warmup
查混合精度与脏数据:FP16 需 loss scaling,输入本身含 NaN/Inf 也会传播,先过滤数据
标准回答
先确定 NaN 最先出现在哪里,再对症下药(独占一行)
打开 torch.autograd.set_detect_anomaly(True) 或逐步打印,判断是输入数据、前向某层激活、loss,还是反向梯度第一个变成 NaN/Inf,缩小范围后再排查具体原因。
数值不稳定的运算
最常见是 log(0)、除以 0、sqrt(负数)、softmax/exp 溢出。修法:log 加小 eps、用框架自带的「带 logits」损失(如 CrossEntropy 直接吃 logits)避免手写 exp/log、对分母加 eps,并做数值范围 clamp。
梯度爆炸
学习率过大会让梯度发散为 Inf。对策:降低学习率、加 warmup、加梯度裁剪 clip_grad_norm;观察梯度范数是否在某步骤突然飙升。
混合精度与脏数据
FP16 混合精度 动态范围小,需配 loss scaling(GradScaler),或改用动态范围更大的 BF16;同时检查输入/标签里是否本就含 NaN、Inf 或异常大值,脏数据会直接污染前向。
常见误区
⚠️ 常见踩坑
直接把所有数值 clamp 或调小学习率「压住」NaN,却不定位首次出现的算子与根因;以及忘了 FP16 需要 loss scaling,误以为是模型本身问题。
追问
追问 1:为什么 FP16 容易出 NaN,而 BF16 相对稳定?
FP16 指数位少、动态范围窄,小梯度容易下溢为 0、大值容易上溢为 Inf,所以需要 loss scaling 把梯度放大到可表示范围。BF16 指数位与 FP32 相同、动态范围大得多(只是精度低),不易溢出,因此训练大模型时更常用 BF16。
追问 2:anomaly detection 开销大,生产长训练怎么排查 NaN?
平时关闭 anomaly detection,只在检测到 loss 为 NaN 时回退到最近 checkpoint,开启 detect_anomaly 重放那几步定位;同时常驻监控梯度范数与 loss,设阈值告警并自动跳过/裁剪异常 batch,避免整次训练作废。
延伸学习
与本题相关的知识库文章、术语、工具与行业资讯。