核心要点

  • 将 d_model 切成 H 个头,每头维度 d_h = d_model / H,让模型在不同子空间学习不同关系。

  • Q/K/V 投影后 reshape 为 (B, H, L, d_h),对每个头独立做缩放点积注意力

  • 各头输出在最后一维 concat 回 (B, L, d_model),再经输出投影 Wo 融合。

  • 易错点:reshape 后要 transpose 把 H 维提到前面;缩放因子用每头维度 d_h 的 sqrt。

标准回答

多头注意力把模型维度切分成 H 个并行的头,每个头在更低维的子空间中独立计算缩放点积注意力,从而能同时捕捉不同类型的依赖(如语法、指代等)。实现上先用线性层得到 Q、K、V,reshape 成 (B, H, L, d_h) 并把头维度 transpose 到前面,对每个头并行做注意力;最后把所有头的输出拼接回 d_model 维,再经过输出投影矩阵 Wo 融合信息。下面用 PyTorch 实现:

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.h = n_heads
        self.d_h = d_model // n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model)  # 一次投影出 Q,K,V
        self.wo = nn.Linear(d_model, d_model)       # 输出投影

    def forward(self, x, mask=None):
        B, L, d = x.shape
        qkv = self.qkv(x)                            # (B, L, 3d)
        q, k, v = qkv.chunk(3, dim=-1)
        # reshape 成 (B, H, L, d_h)
        def split(t):
            return t.view(B, L, self.h, self.d_h).transpose(1, 2)
        q, k, v = split(q), split(k), split(v)
        scores = q @ k.transpose(-2, -1) / (self.d_h ** 0.5)  # (B,H,L,L)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        out = attn @ v                               # (B,H,L,d_h)
        # 合并多头:transpose 回来再 reshape
        out = out.transpose(1, 2).contiguous().view(B, L, d)
        return self.wo(out)

if __name__ == '__main__':
    mha = MultiHeadAttention(d_model=64, n_heads=8)
    x = torch.randn(2, 10, 64)
    print(mha(x).shape)  # torch.Size([2, 10, 64])

常见误区

⚠️ 常见踩坑

reshape 后忘记 transpose(1,2) 会让头维与序列维错位;合并时若不加 .contiguous() 直接 view 会因内存不连续报错。

追问

追问 1复杂度是多少?多头相比单头开销如何?

每头注意力为 O(L²·d_h),H 个头合计仍是 O(L²·d_model),与同维度的单头基本持平——多头是把维度切分而非叠加,几乎不增加计算量却提升表达力。瓶颈仍是 L² 项,可用 FlashAttention 缓解。

追问 2什么是 MQA / GQA,为什么能加速推理?

MQA(多查询注意力)让所有头共享同一组 K、V;GQA(分组查询注意力)让若干头共享一组 K、V,介于 MHA 与 MQA 之间。它们大幅减小 KV Cache 体积,降低解码时的显存带宽压力,从而提升长序列推理吞吐,是 Llama 2/3 等模型的常用设计。

延伸学习

与本题相关的知识库文章、术语、工具与行业资讯。