核心要点

  • 能说清标准模式:继承 nn.Module,init 里建参数/子模块并先调 super().init(),forward 写计算逻辑

  • 能区分可学习权重用 nn.Parameter(自动进 parameters() 参与优化),非训练状态用 register_buffer(随 state_dict 保存但不更新)

  • 能说明把子模块或 Parameter 赋给 self.xxx 会自动注册,但放进普通 list/dict 不会,需用 nn.ModuleList / nn.ParameterList

  • 能指出参数应在 init 创建,forward 里 new 层会每步重建、丢失梯度状态

标准回答

PyTorch 实现自定义层的标准模式:

python
class LinearCustom(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        self.bias = nn.Parameter(torch.zeros(out_features))

    def forward(self, x):
        return F.linear(x, self.weight, self.bias)

class MyNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.register_buffer('eps', torch.tensor(eps))  # 非参数状态
        self.gamma = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps)

要点

  • nn.Parameter:自动加入 parameters(),参与优化
  • register_buffer:持久化状态(如 running mean)但不训练
  • super().init() 必须调用
  • 子模块赋值给 self.xxx 自动注册

复杂层(Multi-Head Attention、自定义卷积)均遵循此模式。也可用 torch.autograd.Function 写自定义 autograd。详见 深度学习基础

常见误区

⚠️ 常见踩坑

把若干子层放进 Python list(self.layers = [nn.Linear(...), ...])而非 nn.ModuleList,导致这些层不出现在 parameters() 里、优化器根本不更新它们、.to(device) 也搬不过去;另一个是用普通 torch.tensor 当权重而非 nn.Parameter,前向能跑但 backward 后该张量永远不被优化。

追问

追问 1Parameter 和 register_parameter 区别?

本质相同——self.w = nn.Parameter(...) 内部就是调 register_parameter("w", ...)。区别在 register_parameter 接受字符串名、可显式传 None 占位、参数名含特殊字符或动态生成时更方便;日常直接赋值 nn.Parameter 更简洁。

追问 2何时用 autograd.Function?

当你要自定义反向传播公式时——比如实现数值更稳定的梯度、对接 C++/CUDA 自定义算子、或做梯度截断/STE(直通估计器)。它要求实现 forward 和 backward 两个静态方法。普通层只组合已有可微算子时无需用它,Autograd 会自动求导。

追问 3自定义层如何初始化权重?

init 末尾显式初始化,常用 torch.nn.init,如 nn.init.kaiming_uniform_(self.weight)(配 ReLU)或 xavier_uniform_(配 tanh/sigmoid),bias 通常置 0。也可写 reset_parameters() 方法集中管理。默认 randn 方差不当会导致梯度爆炸或消失。

延伸学习

与本题相关的知识库文章、术语、工具与行业资讯。

🛠️ AI 工具

  • Pytorch

    Meta 开源的深度学习框架,100K+ stars。以动态计算图和 Pythonic 风格著称,在学术界和工业界都有广泛应用,支持分布式训练、移动端部署和 ONNX 导出

  • Tensorflow

    全球最流行的机器学习框架之一,195K+ stars。Google 开源的端到端 ML 平台,支持 TensorFlow、Keras 等多种 API,覆盖深度学习、强化学习、移动端部署等全场景,是 AI 工程师的必备工具