简要回答
Self-Attention 将每个 token 投影为 Q/K/V,用 Q 与所有 K 点积得分数,缩放后 Softmax 得权重,再对 V 加权求和,得到融合全局上下文的新表示。
标准回答
计算步骤:给定输入序列 X,先线性投影得到 Q=XW_Q、K=XW_K、V=XW_V,再计算缩放点积注意力:
Attention(Q,K,V) = softmax(QK^T / √d_k) · V
直觉:每个 token 的 Query 与所有 Key 算相似度,Softmax 归一化为权重,对 Value 加权求和,得到融合全局上下文的新表示。
多头注意力(MHA):在 h 个独立子空间并行计算 Attention,拼接后再线性变换。不同 head 可学习语法、共指、位置等不同依赖关系。
缩放因子 √d_k:防止 d_k 较大时点积过大,Softmax 进入饱和区导致梯度消失。
详见 Transformer 原理 与术语 注意力机制。
常见误区
⚠️ 常见踩坑
常把 √d_k 说成「为了归一化概率」——归一化是 Softmax 做的,缩放是为了控制点积方差、防梯度消失。另一个常见错误是认为多头是把同一注意力算 h 遍,实际是各 head 在不同低维子空间投影、互不相同,拼接后才还原维度。
追问
追问 1:Self-Attention 的时间/空间复杂度是多少?
标准 Self-Attention 对序列长度 n 为 O(n²·d),显存与计算都随上下文平方增长。Flash Attention 通过分块计算降低显存;线性 Attention、稀疏 Attention 可降至近似 O(n)。
追问 2:Causal Mask 在 Decoder 中的作用?
在自回归解码器中,掩码阻止位置 i 注意到未来 token j>i,保证训练时不会「偷看」答案,与逐步生成时的信息可见性一致,避免标签泄漏。
追问 3:Flash Attention 优化了什么?
通过分块计算避免物化完整 n×n 注意力矩阵,减少 HBM 读写,在相同数学结果下显著降显存、提速度。
延伸学习
与本题相关的知识库文章、术语、工具与行业资讯。
📰 AI 资讯
🛠️ AI 工具
- Pytorch
Meta 开源的深度学习框架,100K+ stars。以动态计算图和 Pythonic 风格著称,在学术界和工业界都有广泛应用,支持分布式训练、移动端部署和 ONNX 导出