核心要点
将 d_model 切成 H 个头,每头维度 d_h = d_model / H,让模型在不同子空间学习不同关系。
Q/K/V 投影后 reshape 为 (B, H, L, d_h),对每个头独立做缩放点积注意力。
各头输出在最后一维 concat 回 (B, L, d_model),再经输出投影 Wo 融合。
易错点:reshape 后要 transpose 把 H 维提到前面;缩放因子用每头维度 d_h 的 sqrt。
标准回答
多头注意力把模型维度切分成 H 个并行的头,每个头在更低维的子空间中独立计算缩放点积注意力,从而能同时捕捉不同类型的依赖(如语法、指代等)。实现上先用线性层得到 Q、K、V,reshape 成 (B, H, L, d_h) 并把头维度 transpose 到前面,对每个头并行做注意力;最后把所有头的输出拼接回 d_model 维,再经过输出投影矩阵 Wo 融合信息。下面用 PyTorch 实现:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.h = n_heads
self.d_h = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model) # 一次投影出 Q,K,V
self.wo = nn.Linear(d_model, d_model) # 输出投影
def forward(self, x, mask=None):
B, L, d = x.shape
qkv = self.qkv(x) # (B, L, 3d)
q, k, v = qkv.chunk(3, dim=-1)
# reshape 成 (B, H, L, d_h)
def split(t):
return t.view(B, L, self.h, self.d_h).transpose(1, 2)
q, k, v = split(q), split(k), split(v)
scores = q @ k.transpose(-2, -1) / (self.d_h ** 0.5) # (B,H,L,L)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
out = attn @ v # (B,H,L,d_h)
# 合并多头:transpose 回来再 reshape
out = out.transpose(1, 2).contiguous().view(B, L, d)
return self.wo(out)
if __name__ == '__main__':
mha = MultiHeadAttention(d_model=64, n_heads=8)
x = torch.randn(2, 10, 64)
print(mha(x).shape) # torch.Size([2, 10, 64])常见误区
⚠️ 常见踩坑
reshape 后忘记 transpose(1,2) 会让头维与序列维错位;合并时若不加 .contiguous() 直接 view 会因内存不连续报错。
追问
追问 1:复杂度是多少?多头相比单头开销如何?
每头注意力为 O(L²·d_h),H 个头合计仍是 O(L²·d_model),与同维度的单头基本持平——多头是把维度切分而非叠加,几乎不增加计算量却提升表达力。瓶颈仍是 L² 项,可用 FlashAttention 缓解。
延伸学习
与本题相关的知识库文章、术语、工具与行业资讯。
🛠️ AI 工具
- Pytorch
Meta 开源的深度学习框架,100K+ stars。以动态计算图和 Pythonic 风格著称,在学术界和工业界都有广泛应用,支持分布式训练、移动端部署和 ONNX 导出