标准回答
Self-Attention(自注意力)让序列中每个位置都能关注到其他所有位置。先把输入投影成查询 Q、键 K、值 V,用 Q 与 K 的点积衡量两两相关性,除以 sqrt(d) 进行缩放避免点积数值过大使 softmax 进入饱和区。随后对每一行做 softmax 得到归一化权重,最后用权重对 V 加权求和得到输出。softmax 实现时务必先减去行最大值再取指数,保证数值稳定。下面给出纯 NumPy 实现:
python
import numpy as np
def softmax(x, axis=-1):
# 减去最大值保证数值稳定,避免 exp 溢出
x = x - np.max(x, axis=axis, keepdims=True)
e = np.exp(x)
return e / np.sum(e, axis=axis, keepdims=True)
def self_attention(X, Wq, Wk, Wv):
# X: (L, d_model) 输入序列;W*: (d_model, d_k) 投影矩阵
Q = X @ Wq # (L, d_k)
K = X @ Wk # (L, d_k)
V = X @ Wv # (L, d_v)
d_k = Q.shape[-1]
scores = Q @ K.T / np.sqrt(d_k) # (L, L) 缩放点积
attn = softmax(scores, axis=-1) # 每行归一化为注意力权重
out = attn @ V # (L, d_v) 加权聚合
return out, attn
if __name__ == '__main__':
np.random.seed(0)
L, d_model, d_k = 4, 8, 8
X = np.random.randn(L, d_model)
Wq = np.random.randn(d_model, d_k)
Wk = np.random.randn(d_model, d_k)
Wv = np.random.randn(d_model, d_k)
out, attn = self_attention(X, Wq, Wk, Wv)
print(out.shape, attn.sum(axis=-1)) # (4, 8) 权重每行和为 1常见误区
⚠️ 常见踩坑
忘记除以 sqrt(d_k)(仅除以 d_k 或不缩放)会让点积方差过大,softmax 饱和、梯度消失;softmax 不减最大值在大数值时会上溢出 NaN。
追问
追问 1:复杂度是多少?如何优化?
序列长度为 L 时,QKᵀ 与 attn@V 各为 O(L²·d),时间与显存均随 L 平方增长。优化手段包括 FlashAttention(分块计算、不显式存储 L×L 矩阵,把显存降到 O(L))、稀疏/线性注意力,以及推理时的 KV Cache 复用历史键值。
追问 2:如何加入因果掩码实现解码器自注意力?
在 softmax 之前,把上三角(未来位置)的 scores 置为 -inf(或一个很大的负数),这样 softmax 后这些位置权重为 0,保证位置 i 只能看到 i 及之前的 token,符合自回归生成的因果约束。
延伸学习
与本题相关的知识库文章、术语、工具与行业资讯。