核心要点

  • 能写出并讲清公式 Attention(Q,K,V) = softmax(QK^T / √d_k) V:Q 找信息、K 被匹配、V 是被加权聚合的内容。

  • 说清缩放因子 √d_k 的必要性:d_k 大时点积方差变大,不缩放会让 Softmax 进饱和区、梯度趋零。

  • 讲清多头作用:在 h 个子空间并行注意,不同 head 学语法、共指、位置等不同关系,拼接后再投影。

  • 点出复杂度瓶颈:序列长 n 时为 O(n²·d),是长上下文显存/算力压力的根源,引出 Flash/稀疏/线性 Attention。

简要回答

Self-Attention 将每个 token 投影为 Q/K/V,用 Q 与所有 K 点积得分数,缩放后 Softmax 得权重,再对 V 加权求和,得到融合全局上下文的新表示。

标准回答

计算步骤:给定输入序列 X,先线性投影得到 Q=XW_Q、K=XW_K、V=XW_V,再计算缩放点积注意力:

Attention(Q,K,V) = softmax(QK^T / √d_k) · V

直觉:每个 token 的 Query 与所有 Key 算相似度,Softmax 归一化为权重,对 Value 加权求和,得到融合全局上下文的新表示。

多头注意力(MHA):在 h 个独立子空间并行计算 Attention,拼接后再线性变换。不同 head 可学习语法、共指、位置等不同依赖关系。

缩放因子 √d_k:防止 d_k 较大时点积过大,Softmax 进入饱和区导致梯度消失

详见 Transformer 原理 与术语 注意力机制。

常见误区

⚠️ 常见踩坑

常把 √d_k 说成「为了归一化概率」——归一化是 Softmax 做的,缩放是为了控制点积方差、防梯度消失。另一个常见错误是认为多头是把同一注意力算 h 遍,实际是各 head 在不同低维子空间投影、互不相同,拼接后才还原维度。

追问

追问 1Self-Attention 的时间/空间复杂度是多少?

标准 Self-Attention 对序列长度 n 为 O(n²·d),显存与计算都随上下文平方增长。Flash Attention 通过分块计算降低显存;线性 Attention、稀疏 Attention 可降至近似 O(n)。

追问 2Causal Mask 在 Decoder 中的作用?

自回归解码器中,掩码阻止位置 i 注意到未来 token j>i,保证训练时不会「偷看」答案,与逐步生成时的信息可见性一致,避免标签泄漏。

追问 3Flash Attention 优化了什么?

通过分块计算避免物化完整 n×n 注意力矩阵,减少 HBM 读写,在相同数学结果下显著降显存、提速度。

延伸学习

与本题相关的知识库文章、术语、工具与行业资讯。