标准回答
两层 MLP 的反向传播是链式法则的逐层应用。前向缓存每层的线性输出与激活;反向从损失对输出 logits 的梯度 dz2=p−y 出发,先得到 W2、b2 的梯度,再把梯度乘以 W2ᵀ 传回第一层,经过 ReLU 掩码后得到 dz1,进而求 W1、b1 的梯度。每个权重梯度都是「上游梯度」与「本层输入激活」的外积式矩阵乘。完整实现如下:
python
import numpy as np
def softmax(z):
z = z - z.max(axis=1, keepdims=True)
ez = np.exp(z)
return ez / ez.sum(axis=1, keepdims=True)
class TwoLayerMLP:
def __init__(self, d_in, d_hidden, d_out, seed=0):
rng = np.random.default_rng(seed)
# He 初始化(配合 ReLU)
self.W1 = rng.normal(0, np.sqrt(2 / d_in), (d_in, d_hidden))
self.b1 = np.zeros(d_hidden)
self.W2 = rng.normal(0, np.sqrt(2 / d_hidden), (d_hidden, d_out))
self.b2 = np.zeros(d_out)
def forward(self, X):
self.X = X
self.z1 = X @ self.W1 + self.b1
self.a1 = np.maximum(0, self.z1) # ReLU
self.z2 = self.a1 @ self.W2 + self.b2
self.p = softmax(self.z2)
return self.p
def backward(self, y):
N = self.X.shape[0]
# 输出层:Softmax+交叉熵合并梯度 dz2 = p - y
dz2 = self.p.copy()
dz2[np.arange(N), y] -= 1.0
dz2 /= N
dW2 = self.a1.T @ dz2 # (H, C)
db2 = dz2.sum(axis=0)
# 回传到隐藏层,过 ReLU 掩码
da1 = dz2 @ self.W2.T
dz1 = da1 * (self.z1 > 0) # ReLU 导数:z>0 处为 1
dW1 = self.X.T @ dz1 # (D, H)
db1 = dz1.sum(axis=0)
return {'W1': dW1, 'b1': db1, 'W2': dW2, 'b2': db2}
def step(self, grads, lr):
for k in ('W1', 'b1', 'W2', 'b2'):
setattr(self, k, getattr(self, k) - lr * grads[k])
if __name__ == '__main__':
rng = np.random.default_rng(1)
X = rng.normal(0, 1, (200, 4))
y = (X[:, 0] + X[:, 1] > 0).astype(int) # 二分类标签
net = TwoLayerMLP(4, 16, 2)
for epoch in range(300):
p = net.forward(X)
loss = -np.mean(np.log(p[np.arange(len(y)), y] + 1e-12))
net.step(net.backward(y), lr=0.5)
acc = (net.forward(X).argmax(1) == y).mean()
print('loss=', round(float(loss), 4), 'acc=', round(float(acc), 3))常见误区
⚠️ 常见踩坑
ReLU 反向要用前向缓存的 z1(或 a1)的掩码,而非对输出重新算;权重梯度的矩阵乘顺序易写反(应是输入激活的转置 @ 上游梯度)。还有人忘了在 dz2 里除以 N,导致梯度尺度随 batch 变化、学习率难调。
追问
追问 1:复杂度是多少?如何验证梯度正确?
前向与反向都由矩阵乘主导,单层 O(N·d_in·d_out),整体随层规模线性扩展。验证用数值梯度检查:对每个参数做 (L(θ+ε)−L(θ−ε))/(2ε) 与解析梯度比对,相对误差应在 1e−5 量级。这是手写反向传播最可靠的自检手段。
追问 2:为什么用 He 初始化而不是全零或 Xavier?
全零会让同层神经元对称、梯度相同无法区分;Xavier 假设激活近似线性,而 ReLU 会丢掉一半信号,He 初始化方差取 2/d_in 正好补偿这一点,保持前向激活与反向梯度的方差稳定,避免深层时的梯度消失或爆炸。
延伸学习
与本题相关的知识库文章、术语、工具与行业资讯。