核心要点

  • 能点出本质:FlashAttention 是 IO-aware 的精确算法,结果与标准注意力逐位相同,省的是 HBM 访存而非浮点运算

  • 能讲清手段:把 Q/K/V 切成块在 SRAM 内计算,永不把完整的 N×N 注意力矩阵物化到 HBM,显存从 O(N²) 降到 O(N)

  • 能解释 online softmax分块累加时用 running max 和 running sum 增量校正,无需先看到整行就能算出正确 softmax

  • 能区分瓶颈:标准注意力是访存受限(memory-bound),FlashAttention 把它推向计算受限,从而打满 GPU

标准回答

问题根源

标准注意力先算 S=QKᵀ(N×N),softmax 后再乘 V。这个 N×N 矩阵要写回 HBM 再读回来,序列越长访存量越大。注意力本身算力不高,真正的瓶颈是 HBM 读写——它是访存受限算子。

FlashAttention 做了什么

它不改数学,只改访存路径。把 Q、K、V 分成小块(tiling)载入片上 SRAM,在 SRAM 内完成 QKᵀ、softmax、乘 V 的整条链路,只把最终输出写回 HBM,绝不物化完整注意力矩阵。这样显存占用从 O(N²) 降到 O(N),HBM 读写量大幅下降。

online softmax 是关键

softmax 通常要先知道整行最大值才能稳定计算,但分块时只能看到局部。FlashAttention 用 online softmax:维护 running max 与 running sum,每来一个新块就按新最大值重新缩放已累积的结果,最终结果与一次性 softmax 完全等价。这保证了「精确」而非近似。

结果是同样输出、更少 HBM 访问、更省显存,长序列收益尤其明显。详见 LLM 推理加速实战

常见误区

⚠️ 常见踩坑

误以为 FlashAttention 是近似算法或减少了计算量——它输出与标准注意力逐位相同,省的是 HBM 访存;也不要把它和稀疏注意力混为一谈,后者才真正改变了计算的 token 对。

追问

追问 1为什么说注意力是 memory-bound 而 GEMM 是 compute-bound?

衡量标准是算术强度(FLOPs/字节)。注意力每读一个元素只做很少运算、要反复读写大矩阵,访存主导;大矩阵乘法计算密度高,数据复用充分,受算力上限约束。FlashAttention 通过提高数据复用把注意力从前者推向后者。

追问 2FlashAttention-2 相比 v1 改进在哪?

主要优化并行与工作划分:减少非矩阵乘的冗余缩放运算,把并行维度从 batch/head 扩展到序列长度,并改善 warp 间的工作分配,使 GPU 占用率更高,前向后向都更快。

追问 3它对反向传播有什么影响?

反向不保存 N×N 注意力矩阵,而是用前向存下的 softmax 归一化统计量(max 与 sum)按需重算,以少量重计算换大量显存,使长序列训练可行。

延伸学习

与本题相关的知识库文章、术语、工具与行业资讯。

🛠️ AI 工具

  • vLLM

    高吞吐 LLM 推理引擎,77,418+ stars。采用 PagedAttention 显存优化技术,吞吐量比 HuggingFace Transformers 高 24 倍,是生产环境部署大模型推理的首选方案,支持 OpenAI 兼容 API

  • Pytorch

    Meta 开源的深度学习框架,100K+ stars。以动态计算图和 Pythonic 风格著称,在学术界和工业界都有广泛应用,支持分布式训练、移动端部署和 ONNX 导出