核心要点
forward 描述输入张量如何流经各子层得到输出,本质是定义这张前向计算图
调用要写 model(x) 而非 model.forward(x):前者经 call 触发 pre/forward hook,后者会跳过
init 负责声明子模块与参数,forward 只负责连线计算,二者职责分开
forward 里可以写分支、循环、残差 x+block(x) 等任意 Python 控制流,这正是动态图的灵活之处
简要回答
在 PyTorch 中,继承 nn.Module的类必须实现forward(self, args),定义前向传播*:输入张量如何经各子层得到输出
标准回答
在 PyTorch 中,继承 nn.Module的类必须实现forward(self, args),定义前向传播*:输入张量如何经各子层得到输出。
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 64, 3)
self.fc = nn.Linear(64, 10)
def forward(self, x):
x = F.relu(self.conv(x))
x = F.adaptive_avg_pool2d(x, 1).flatten(1)
return self.fc(x)Autograd 在 forward 时建图,backward 时求导。详见 深度学习基础。
常见误区
⚠️ 常见踩坑
到处写 model.forward(x) 跳过 hooks;在 forward 里创建 nn.Linear(应放 init);忘记 return 导致输出 None。
追问
追问 1:self.training 这个标志为什么不在 __init__ 里固定,而要在 forward 里读?
self.training 是 nn.Module 的内置状态,由 model.train()/model.eval() 在运行期切换,所以应在 forward 里读取它来决定 Dropout、BN 等行为,而不能在 init 里写死。init 只在建模型时跑一次,写死会导致 eval 时仍走训练分支。
追问 2:forward 能有多输入多输出吗?
可以。forward 是普通 Python 方法,签名随意:def forward(self, x, mask, state=None) 支持多输入,return 一个 tuple/dict 支持多输出,调用时 model(x, mask) 即可。Transformer、检测头等都靠这个特性传 attention mask、返回多分支结果。
追问 3:torch.jit.script 对 forward 有何要求?
script 会把 forward 当作静态类型的 TorchScript 编译,要求代码是其支持的 Python 子集:变量类型尽量可推断、避免不支持的动态特性和任意第三方库调用,张量与控制流要类型一致。含复杂动态分支时可改用 torch.jit.trace(但 trace 不记录数据相关分支)或新的 torch.export。
延伸学习
与本题相关的知识库文章、术语、工具与行业资讯。
📰 AI 资讯
🛠️ AI 工具
- Pytorch
Meta 开源的深度学习框架,100K+ stars。以动态计算图和 Pythonic 风格著称,在学术界和工业界都有广泛应用,支持分布式训练、移动端部署和 ONNX 导出
- Tensorflow
全球最流行的机器学习框架之一,195K+ stars。Google 开源的端到端 ML 平台,支持 TensorFlow、Keras 等多种 API,覆盖深度学习、强化学习、移动端部署等全场景,是 AI 工程师的必备工具