标准回答
问题根源
标准注意力先算 S=QKᵀ(N×N),softmax 后再乘 V。这个 N×N 矩阵要写回 HBM 再读回来,序列越长访存量越大。注意力本身算力不高,真正的瓶颈是 HBM 读写——它是访存受限算子。
FlashAttention 做了什么
它不改数学,只改访存路径。把 Q、K、V 分成小块(tiling)载入片上 SRAM,在 SRAM 内完成 QKᵀ、softmax、乘 V 的整条链路,只把最终输出写回 HBM,绝不物化完整注意力矩阵。这样显存占用从 O(N²) 降到 O(N),HBM 读写量大幅下降。
online softmax 是关键
softmax 通常要先知道整行最大值才能稳定计算,但分块时只能看到局部。FlashAttention 用 online softmax:维护 running max 与 running sum,每来一个新块就按新最大值重新缩放已累积的结果,最终结果与一次性 softmax 完全等价。这保证了「精确」而非近似。
结果是同样输出、更少 HBM 访问、更省显存,长序列收益尤其明显。详见 LLM 推理加速实战。
常见误区
⚠️ 常见踩坑
误以为 FlashAttention 是近似算法或减少了计算量——它输出与标准注意力逐位相同,省的是 HBM 访存;也不要把它和稀疏注意力混为一谈,后者才真正改变了计算的 token 对。
追问
追问 1:为什么说注意力是 memory-bound 而 GEMM 是 compute-bound?
衡量标准是算术强度(FLOPs/字节)。注意力每读一个元素只做很少运算、要反复读写大矩阵,访存主导;大矩阵乘法计算密度高,数据复用充分,受算力上限约束。FlashAttention 通过提高数据复用把注意力从前者推向后者。
追问 2:FlashAttention-2 相比 v1 改进在哪?
主要优化并行与工作划分:减少非矩阵乘的冗余缩放运算,把并行维度从 batch/head 扩展到序列长度,并改善 warp 间的工作分配,使 GPU 占用率更高,前向后向都更快。
追问 3:它对反向传播有什么影响?
反向不保存 N×N 注意力矩阵,而是用前向存下的 softmax 归一化统计量(max 与 sum)按需重算,以少量重计算换大量显存,使长序列训练可行。
延伸学习
与本题相关的知识库文章、术语、工具与行业资讯。
📚 知识库
🛠️ AI 工具