核心要点

  • 注意力复杂度 O(L²):注意力矩阵与激活随序列长度平方增长,是显存与算力的主瓶颈

  • 推理阶段 KV Cache 随序列长度线性增长,长上下文时占用巨大

  • FFN 中间层维度为 4×d,其权重与激活也是显存大户

  • 优化手段:FlashAttention、稀疏/线性注意力、KV Cache 量化、梯度检查点

标准回答

回答这道题要把「训练显存」和「推理显存」分开看,并抓住复杂度随序列长度 L 的增长关系。

1. 自注意力是平方级瓶颈(O(L²))
注意力需要计算 Q·Kᵀ,得到一个 L×L 的注意力分数矩阵。无论是矩阵本身还是 softmax 后的激活,显存与计算量都随序列长度 平方增长。多头还要乘以 head 数。当 L 很大(长文档、长上下文)时,这一项迅速主导显存占用与算力开销,是 Transformer 最核心的性能瓶颈。

2. 推理时 KV Cache 随 L 线性增长
自回归生成时,为避免重复计算历史 token,会缓存每层每个 token 的 Key/Value,即 KV Cache。其大小约为 2(K 和 V)× 层数 × L × d_model × 每元素字节数(其中 d_model = 头数 × 每头维度,已包含所有头,不要再额外乘一次头数;GQA/MQA 则把 d_model 换成 KV 头数 × 每头维度),随序列长度 线性增长。长上下文推理时,KV Cache 往往比模型权重还吃显存,是部署长上下文模型的主要约束。

3. FFN 中间层与激活
每层的前馈网络(FFN)通常把维度从 d 扩到 4×d 再降回,参数量大,前向激活也占不少显存;训练时为反向传播保留的激活更是大头。

4. 优化方向

  • FlashAttention:分块计算注意力,不显式存下完整的 L×L 注意力矩阵,把注意力的显存从 O(L²) 降到接近 O(L),同时提升访存效率。
  • 稀疏/线性注意力:用局部窗口、稀疏模式或核技巧把复杂度降到近线性。
  • KV Cache 量化 / 压缩:用 INT8/INT4 存 KV,或用 MQA/GQA 减少 KV 头数。
  • 梯度检查点(activation checkpointing):训练时只保留部分激活,反向时重算,用算力换显存。

常见误区

⚠️ 常见踩坑

误以为「模型权重」永远是显存最大头。在长序列训练中,注意力激活的 O(L²) 增长会迅速超过权重;在长上下文推理中,KV Cache 也常常超过权重。瓶颈到底在哪,取决于序列长度与是训练还是推理。

追问

追问 1FlashAttention 既然不存完整注意力矩阵,为什么还能算出正确结果?

它把 Q、K、V 分块加载到片上 SRAM,逐块计算局部注意力,并用在线 softmax(online softmax)增量地维护归一化分母与累加结果,最终等价于完整 softmax 注意力。整个过程不需要在显存中物化 L×L 矩阵,因此显存接近 O(L),且大幅减少对 HBM 的读写,速度也更快。

追问 2MQA / GQA 是怎么降低 KV Cache 占用的?

标准多头注意力每个头都有独立的 K、V。MQA(Multi-Query Attention)让所有 Query 头共享同一组 K/V,KV Cache 缩小到原来的 1/头数;GQA(Grouped-Query Attention)介于两者之间,让若干头分一组共享 K/V,在质量和显存之间折中。两者都直接减少需要缓存的 K/V 数量。

追问 3梯度检查点为什么能省显存,代价是什么?

常规训练会保留前向过程中的所有中间激活以供反向传播使用。梯度检查点只保留少量关键节点的激活,反向时再从这些节点重新前向计算被丢弃的激活,从而把激活显存从 O(层数) 降到约 O(√层数)。代价是要多做一次前向计算,训练时间通常增加约 20%–30%。

🔗 相似问题

同一考点的不同问法,面试官可能换着问,一起刷更稳

没找到想看的面试题?把你想看的告诉我们 →

延伸学习

按主题分类的相关资源,便于系统复习