核心要点

  • 先定位是哪一步出现 NaN:是 loss、某层激活,还是梯度先变 NaN,用 anomaly detection 抓首个出错算子

  • 查数值不稳定运算:log(0)、除 0、sqrt(负数)、exp 溢出,给 log 加 eps、用 logits 版损失避免手写 softmax

  • 查梯度爆炸:学习率过大或缺梯度裁剪,加 clip_grad_norm 并降 LR / 加 warmup

  • 查混合精度与脏数据:FP16 需 loss scaling,输入本身含 NaN/Inf 也会传播,先过滤数据

标准回答

先确定 NaN 最先出现在哪里,再对症下药(独占一行)

打开 torch.autograd.set_detect_anomaly(True) 或逐步打印,判断是输入数据、前向某层激活、loss,还是反向梯度第一个变成 NaN/Inf,缩小范围后再排查具体原因。

数值不稳定的运算

最常见是 log(0)、除以 0、sqrt(负数)、softmax/exp 溢出。修法:log 加小 eps、用框架自带的「带 logits」损失(如 CrossEntropy 直接吃 logits)避免手写 exp/log、对分母加 eps,并做数值范围 clamp。

梯度爆炸

学习率过大会让梯度发散为 Inf。对策:降低学习率、加 warmup、加梯度裁剪 clip_grad_norm;观察梯度范数是否在某步骤突然飙升。

混合精度与脏数据

FP16 混合精度 动态范围小,需配 loss scaling(GradScaler),或改用动态范围更大的 BF16;同时检查输入/标签里是否本就含 NaN、Inf 或异常大值,脏数据会直接污染前向。

常见误区

⚠️ 常见踩坑

直接把所有数值 clamp 或调小学习率「压住」NaN,却不定位首次出现的算子与根因;以及忘了 FP16 需要 loss scaling,误以为是模型本身问题。

追问

追问 1为什么 FP16 容易出 NaN,而 BF16 相对稳定?

FP16 指数位少、动态范围窄,小梯度容易下溢为 0、大值容易上溢为 Inf,所以需要 loss scaling 把梯度放大到可表示范围。BF16 指数位与 FP32 相同、动态范围大得多(只是精度低),不易溢出,因此训练大模型时更常用 BF16。

追问 2anomaly detection 开销大,生产长训练怎么排查 NaN?

平时关闭 anomaly detection,只在检测到 loss 为 NaN 时回退到最近 checkpoint,开启 detect_anomaly 重放那几步定位;同时常驻监控梯度范数与 loss,设阈值告警并自动跳过/裁剪异常 batch,避免整次训练作废。

延伸学习

与本题相关的知识库文章、术语、工具与行业资讯。