核心要点
能说清标准模式:继承 nn.Module,init 里建参数/子模块并先调 super().init(),forward 写计算逻辑
能区分可学习权重用 nn.Parameter(自动进 parameters() 参与优化),非训练状态用 register_buffer(随 state_dict 保存但不更新)
能说明把子模块或 Parameter 赋给 self.xxx 会自动注册,但放进普通 list/dict 不会,需用 nn.ModuleList / nn.ParameterList
能指出参数应在 init 创建,forward 里 new 层会每步重建、丢失梯度状态
标准回答
在 PyTorch 实现自定义层的标准模式:
class LinearCustom(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_features, in_features))
self.bias = nn.Parameter(torch.zeros(out_features))
def forward(self, x):
return F.linear(x, self.weight, self.bias)
class MyNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.register_buffer('eps', torch.tensor(eps)) # 非参数状态
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.gamma * (x - mean) / (std + self.eps)要点:
- nn.Parameter:自动加入 parameters(),参与优化
- register_buffer:持久化状态(如 running mean)但不训练
- super().init() 必须调用
- 子模块赋值给 self.xxx 自动注册
复杂层(Multi-Head Attention、自定义卷积)均遵循此模式。也可用 torch.autograd.Function 写自定义 autograd。详见 深度学习基础。
常见误区
⚠️ 常见踩坑
把若干子层放进 Python list(self.layers = [nn.Linear(...), ...])而非 nn.ModuleList,导致这些层不出现在 parameters() 里、优化器根本不更新它们、.to(device) 也搬不过去;另一个是用普通 torch.tensor 当权重而非 nn.Parameter,前向能跑但 backward 后该张量永远不被优化。
追问
追问 1:Parameter 和 register_parameter 区别?
本质相同——self.w = nn.Parameter(...) 内部就是调 register_parameter("w", ...)。区别在 register_parameter 接受字符串名、可显式传 None 占位、参数名含特殊字符或动态生成时更方便;日常直接赋值 nn.Parameter 更简洁。
追问 2:何时用 autograd.Function?
当你要自定义反向传播公式时——比如实现数值更稳定的梯度、对接 C++/CUDA 自定义算子、或做梯度截断/STE(直通估计器)。它要求实现 forward 和 backward 两个静态方法。普通层只组合已有可微算子时无需用它,Autograd 会自动求导。
追问 3:自定义层如何初始化权重?
在 init 末尾显式初始化,常用 torch.nn.init,如 nn.init.kaiming_uniform_(self.weight)(配 ReLU)或 xavier_uniform_(配 tanh/sigmoid),bias 通常置 0。也可写 reset_parameters() 方法集中管理。默认 randn 方差不当会导致梯度爆炸或消失。
延伸学习
与本题相关的知识库文章、术语、工具与行业资讯。
📰 AI 资讯
🛠️ AI 工具
- Pytorch
Meta 开源的深度学习框架,100K+ stars。以动态计算图和 Pythonic 风格著称,在学术界和工业界都有广泛应用,支持分布式训练、移动端部署和 ONNX 导出
- Tensorflow
全球最流行的机器学习框架之一,195K+ stars。Google 开源的端到端 ML 平台,支持 TensorFlow、Keras 等多种 API,覆盖深度学习、强化学习、移动端部署等全场景,是 AI 工程师的必备工具