核心要点

  • 由输入分别线性映射出 Q、K、V,三者形状均为 (L, d)。

  • 注意力分数 = Q·Kᵀ / sqrt(d),除以 sqrt(d) 防止点积过大导致 softmax 梯度消失

  • 对分数最后一维做 softmax 得权重,再与 V 相乘聚合:out = softmax @ V。

  • 易错点:softmax 需减去每行最大值保证数值稳定;缩放因子是 sqrt(d_k) 而非 d_k。

标准回答

Self-Attention(自注意力)让序列中每个位置都能关注到其他所有位置。先把输入投影成查询 Q、键 K、值 V,用 Q 与 K 的点积衡量两两相关性,除以 sqrt(d) 进行缩放避免点积数值过大使 softmax 进入饱和区。随后对每一行做 softmax 得到归一化权重,最后用权重对 V 加权求和得到输出。softmax 实现时务必先减去行最大值再取指数,保证数值稳定。下面给出纯 NumPy 实现:

python
import numpy as np

def softmax(x, axis=-1):
    # 减去最大值保证数值稳定,避免 exp 溢出
    x = x - np.max(x, axis=axis, keepdims=True)
    e = np.exp(x)
    return e / np.sum(e, axis=axis, keepdims=True)

def self_attention(X, Wq, Wk, Wv):
    # X: (L, d_model) 输入序列;W*: (d_model, d_k) 投影矩阵
    Q = X @ Wq              # (L, d_k)
    K = X @ Wk              # (L, d_k)
    V = X @ Wv              # (L, d_v)
    d_k = Q.shape[-1]
    scores = Q @ K.T / np.sqrt(d_k)   # (L, L) 缩放点积
    attn = softmax(scores, axis=-1)   # 每行归一化为注意力权重
    out = attn @ V                     # (L, d_v) 加权聚合
    return out, attn

if __name__ == '__main__':
    np.random.seed(0)
    L, d_model, d_k = 4, 8, 8
    X = np.random.randn(L, d_model)
    Wq = np.random.randn(d_model, d_k)
    Wk = np.random.randn(d_model, d_k)
    Wv = np.random.randn(d_model, d_k)
    out, attn = self_attention(X, Wq, Wk, Wv)
    print(out.shape, attn.sum(axis=-1))  # (4, 8) 权重每行和为 1

常见误区

⚠️ 常见踩坑

忘记除以 sqrt(d_k)(仅除以 d_k 或不缩放)会让点积方差过大,softmax 饱和、梯度消失;softmax 不减最大值在大数值时会上溢出 NaN。

追问

追问 1复杂度是多少?如何优化?

序列长度为 L 时,QKᵀ 与 attn@V 各为 O(L²·d),时间与显存均随 L 平方增长。优化手段包括 FlashAttention分块计算、不显式存储 L×L 矩阵,把显存降到 O(L))、稀疏/线性注意力,以及推理时的 KV Cache 复用历史键值。

追问 2如何加入因果掩码实现解码器自注意力?

在 softmax 之前,把上三角(未来位置)的 scores 置为 -inf(或一个很大的负数),这样 softmax 后这些位置权重为 0,保证位置 i 只能看到 i 及之前的 token,符合自回归生成的因果约束。

延伸学习

与本题相关的知识库文章、术语、工具与行业资讯。