核心要点

  • 只测确定性逻辑:数据处理、特征工程、张量形状、loss 计算、梯度数值

  • 用极小的固定输入 + 固定随机种子,让测试稳定可复现

  • 断言不变量而非精度:输出范围、softmax 和为 1、形状一致、非 NaN

  • 过拟合一小批数据(overfit one batch)作为训练管线的 sanity check

标准回答

测确定性部分,不断言精度数值

模型训练本身带随机性,直接断言「准确率 > 0.9」既不稳定也不是单元测试该做的事。单元测试应锁定可确定验证的环节。

1. 数据与特征

  • 归一化、分桶、缺失值填充等转换:给定输入,断言输出精确等于预期。
  • 边界用例:空输入、单样本、含异常值。

2. 形状与数值不变量

  • 前向输出形状是否符合 (batch, num_classes)
  • 概率输出范围在 [0,1]、softmax 每行和为 1、无 NaN/Inf。

3. loss 与梯度

  • 用已知输入手算 loss 对比;调用一次 backward 后断言关键参数梯度非 0、量级合理。
python
def test_forward_shape():
    torch.manual_seed(0)
    x = torch.randn(4, 10)
    out = model(x)
    assert out.shape == (4, 3)
    assert torch.allclose(out.softmax(-1).sum(-1), torch.ones(4))

4. Sanity check:让模型在一小批样本上过拟合,loss 应能降到接近 0;降不下去说明管线有 bug。

常见误区

⚠️ 常见踩坑

把对精度/指标的断言当单元测试(不稳定且慢),或忘记固定随机种子导致测试时灵时不灵。

追问

追问 1为什么不直接在单元测试里断言模型准确率?

准确率依赖随机初始化、数据顺序、硬件,结果不确定,会造成测试 flaky;且训练慢,不适合频繁运行。精度应放在独立的评测/回归流程里跑,单元测试只覆盖确定性逻辑。

追问 2「过拟合一小批」具体怎么做、能发现什么问题?

取 1~2 个 batch 反复训练若干步,正常情况下 loss 应快速逼近 0。若降不下来,说明前向/反向、loss 定义、标签对齐学习率或数据管线存在 bug,是廉价又有效的冒烟测试。

延伸学习

与本题相关的知识库文章、术语、工具与行业资讯。