Skip to content

FlashAttention: 优化计算过程中的临时显存占用

FlashAttention 的主要目标是加速 Attention 计算减少计算过程中的显存占用。它的优化对象是 Activation,更具体地说是 Attention 计算中产生的中间结果

1. 作用于哪个组件?

Activation (中间激活值)

具体来说,在标准的 Attention 实现中,计算 Softmax(QKᵀ/sqrt(d_k))V 需要一个非常耗费显存的中间步骤:

  1. 计算 P = QKᵀ,这里 Q 的维度是 (N, d_k)Kᵀ 的维度是 (d_k, N),所以会生成一个巨大的中间矩阵 P,维度为 (N, N),其中 N 是序列长度。
  2. 这个 (N, N) 的矩阵必须完整地存储在 GPU 的高带宽显存 (HBM) 中,然后才能进行 Softmax 等后续计算。

这个 (N, N) 的矩阵就是 FlashAttention 优化的目标,它既不是模型权重,也不是 KV Cache。

2. 大概能优化多少显存占用?

FlashAttention 将这部分中间结果的显存复杂度从 O(N²) 降低到了 O(N)

这是一个量级的降低,效果非常巨大。

  • 举例说明
    • 假设序列长度 N = 8192 (8k),数据类型为 FP16 (2字节)。
    • 标准 Attention 需要的中间显存 = 8192 * 8192 * 2 bytes128 MB
    • 这仅仅是一个注意力头在一个样本中的消耗。如果 Batch Size 是 8,模型有 32 个头,那么总消耗就是 128 MB * 8 * 3232 GB!这足以让高端 GPU 直接显存溢出 (Out of Memory)。
    • FlashAttention 通过分块计算(Tiling)技术,避免在 HBM 中生成这个完整的 (N, N) 矩阵。它将 Q, K, V 矩阵切分成小块,然后将这些小块加载到速度极快但容量很小的 SRAM (GPU 上的片上缓存) 中进行计算。每次只计算出一个小块的最终结果,然后写回 HBM。
    • 这样,它只需要存储与块大小相关的少量中间状态,其显存占用与 N 呈线性关系,而不是平方关系。对于上面的例子,FlashAttention 几乎将这 32 GB 的临时显存占用降低到接近于零

结论:FlashAttention 的主要贡献是让长序列(long context)训练和推理成为可能。如果没有它,处理几万甚至几十万长度的序列是完全不可想象的。

PagedAttention: 优化 KV Cache 的管理和利用率

PagedAttention 的主要目标是解决 KV Cache 的显存浪费问题,从而提高 GPU 的吞吐量(即同时处理更多请求)。

1. 作用于哪个组件?

KV Cache

KV Cache 是在模型进行自回归生成(autoregressive generation)时,为了避免重复计算而存储的过去所有 token 的 Key 和 Value 向量。随着生成序列的变长,KV Cache 会越来越大,成为推理过程中最主要的显存消耗来源。

传统的 KV Cache 管理方式存在严重的浪费:

  1. 内部碎片(Internal Fragmentation): 系统需要为每个请求预留一个连续的、能容纳其最大可能长度的显存空间。例如,即使一个请求当前只有 10 个 token,系统也可能为它预留了 2048 个 token 的空间。这中间 2048 - 10 = 2038 个 token 的空间就被浪费了。
  2. 过度预留(Over-reservation): 对于一批请求,系统必须按照这批请求中“最长”的那个来为“所有”请求分配空间,导致短请求占用了大量不必要的显存。

2. 大概能优化多少显存占用?

PagedAttention 通过借鉴操作系统中“虚拟内存”和“分页”的思想,极大地提高了 KV Cache 的显存利用率。它能将 KV Cache 的内存浪费(碎片)从 60%-80% 降低到 4% 以下

  • 工作原理

    1. 它将 KV Cache 空间分割成许多个固定大小的、不连续的物理块 (Block)
    2. 每个序列的 KV Cache 在逻辑上是连续的,但它在物理上可以存储在任意的、不相邻的块中。
    3. 系统通过一个“块表”(类似页表)来管理逻辑 token 位置到物理块位置的映射。
    4. 当序列需要生成新 token 时,系统就按需分配一个新的块,而不需要预先分配一大片连续空间。
  • 优化效果

    • 几乎没有内存浪费:按需分配块,使得显存利用率非常高。根据 vLLM 论文的数据,PagedAttention 可以达到 96% 的内存使用效率。
    • 提升吞吐量:由于显存利用率的大幅提升,同样一张 GPU 可以容纳更多的并发请求(更大的总 Batch Size)。这使得推理服务的吞吐量可以提升 2-4倍
    • 更灵活的采样:对于复杂的采样算法(如 beam search),多个候选序列可以高效地共享前面相同的 token 所对应的 KV Cache 块,实现了类似“Copy-on-Write”的机制,极大地节省了显存。

结论:PagedAttention 优化的不是单次计算的峰值显存,而是整个服务运行期间的平均显存占用管理效率。它的主要价值体现在高吞吐量的推理服务场景。

总结

  • FlashAttention 作用于 Activation,通过改变计算方式,解决了计算过程中 O(N²)临时显存问题,让长序列处理成为可能,主要优势体现在训练处理长文本的推理。
  • PagedAttention 作用于 KV Cache,通过改变存储管理方式,解决了显存碎片和浪费问题,极大地提高了显存利用率,从而将推理服务的吞吐量提升数倍,主要优势体现在高并发的线上服务

两者从不同角度优化了 Transformer 的显存问题,并且可以同时使用,共同构成当前大模型高效推理的基础。