核心要点

  • forward 描述输入张量如何流经各子层得到输出,本质是定义这张前向计算图

  • 调用要写 model(x) 而非 model.forward(x):前者经 call 触发 pre/forward hook,后者会跳过

  • init 负责声明子模块与参数,forward 只负责连线计算,二者职责分开

  • forward 里可以写分支、循环、残差 x+block(x) 等任意 Python 控制流,这正是动态图的灵活之处

简要回答

PyTorch 中,继承 nn.Module的类必须实现forward(self, args),定义前向传播*:输入张量如何经各子层得到输出

标准回答

在 PyTorch 中,继承 nn.Module的类必须实现forward(self, args),定义前向传播*:输入张量如何经各子层得到输出。

python
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 64, 3)
        self.fc = nn.Linear(64, 10)
    def forward(self, x):
        x = F.relu(self.conv(x))
        x = F.adaptive_avg_pool2d(x, 1).flatten(1)
        return self.fc(x)
调用约定:使用model(x) 而非model.forward(x)。nn.Module.__call__ 会在 forward 前后执行注册钩子(forward_pre_hook、forward_hook),直接调 forward 会跳过它们。职责划分:__init__ 声明子模块与参数;forward描述计算图连接。复杂模型可在 forward 中做分支、多输入、残差x + self.block(x)。

Autograd 在 forward 时建图,backward 时求导。详见 深度学习基础

常见误区

⚠️ 常见踩坑

到处写 model.forward(x) 跳过 hooks;在 forward 里创建 nn.Linear(应放 init);忘记 return 导致输出 None。

追问

追问 1self.training 这个标志为什么不在 __init__ 里固定,而要在 forward 里读?

self.training 是 nn.Module 的内置状态,由 model.train()/model.eval() 在运行期切换,所以应在 forward 里读取它来决定 Dropout、BN 等行为,而不能在 init 里写死。init 只在建模型时跑一次,写死会导致 eval 时仍走训练分支。

追问 2forward 能有多输入多输出吗?

可以。forward 是普通 Python 方法,签名随意:def forward(self, x, mask, state=None) 支持多输入,return 一个 tuple/dict 支持多输出,调用时 model(x, mask) 即可。Transformer、检测头等都靠这个特性传 attention mask、返回多分支结果。

追问 3torch.jit.script 对 forward 有何要求?

script 会把 forward 当作静态类型的 TorchScript 编译,要求代码是其支持的 Python 子集:变量类型尽量可推断、避免不支持的动态特性和任意第三方库调用,张量与控制流要类型一致。含复杂动态分支时可改用 torch.jit.trace(但 trace 不记录数据相关分支)或新的 torch.export。

延伸学习

与本题相关的知识库文章、术语、工具与行业资讯。

🛠️ AI 工具

  • Pytorch

    Meta 开源的深度学习框架,100K+ stars。以动态计算图和 Pythonic 风格著称,在学术界和工业界都有广泛应用,支持分布式训练、移动端部署和 ONNX 导出

  • Tensorflow

    全球最流行的机器学习框架之一,195K+ stars。Google 开源的端到端 ML 平台,支持 TensorFlow、Keras 等多种 API,覆盖深度学习、强化学习、移动端部署等全场景,是 AI 工程师的必备工具