Appearance
FlashAttention: 优化计算过程中的临时显存占用
FlashAttention 的主要目标是加速 Attention 计算并减少计算过程中的显存占用。它的优化对象是 Activation,更具体地说是 Attention 计算中产生的中间结果。
1. 作用于哪个组件?
Activation (中间激活值)。
具体来说,在标准的 Attention 实现中,计算 Softmax(QKᵀ/sqrt(d_k))V
需要一个非常耗费显存的中间步骤:
- 计算
P = QKᵀ
,这里Q
的维度是(N, d_k)
,Kᵀ
的维度是(d_k, N)
,所以会生成一个巨大的中间矩阵P
,维度为(N, N)
,其中 N 是序列长度。 - 这个
(N, N)
的矩阵必须完整地存储在 GPU 的高带宽显存 (HBM) 中,然后才能进行 Softmax 等后续计算。
这个 (N, N)
的矩阵就是 FlashAttention 优化的目标,它既不是模型权重,也不是 KV Cache。
2. 大概能优化多少显存占用?
FlashAttention 将这部分中间结果的显存复杂度从 O(N²) 降低到了 O(N)。
这是一个量级的降低,效果非常巨大。
- 举例说明:
- 假设序列长度 N = 8192 (8k),数据类型为 FP16 (2字节)。
- 标准 Attention 需要的中间显存 =
8192 * 8192 * 2 bytes
≈ 128 MB。 - 这仅仅是一个注意力头在一个样本中的消耗。如果 Batch Size 是 8,模型有 32 个头,那么总消耗就是
128 MB * 8 * 32
≈ 32 GB!这足以让高端 GPU 直接显存溢出 (Out of Memory)。 - FlashAttention 通过分块计算(Tiling)技术,避免在 HBM 中生成这个完整的
(N, N)
矩阵。它将 Q, K, V 矩阵切分成小块,然后将这些小块加载到速度极快但容量很小的 SRAM (GPU 上的片上缓存) 中进行计算。每次只计算出一个小块的最终结果,然后写回 HBM。 - 这样,它只需要存储与块大小相关的少量中间状态,其显存占用与 N 呈线性关系,而不是平方关系。对于上面的例子,FlashAttention 几乎将这 32 GB 的临时显存占用降低到接近于零。
结论:FlashAttention 的主要贡献是让长序列(long context)训练和推理成为可能。如果没有它,处理几万甚至几十万长度的序列是完全不可想象的。
PagedAttention: 优化 KV Cache 的管理和利用率
PagedAttention 的主要目标是解决 KV Cache 的显存浪费问题,从而提高 GPU 的吞吐量(即同时处理更多请求)。
1. 作用于哪个组件?
KV Cache。
KV Cache 是在模型进行自回归生成(autoregressive generation)时,为了避免重复计算而存储的过去所有 token 的 Key 和 Value 向量。随着生成序列的变长,KV Cache 会越来越大,成为推理过程中最主要的显存消耗来源。
传统的 KV Cache 管理方式存在严重的浪费:
- 内部碎片(Internal Fragmentation): 系统需要为每个请求预留一个连续的、能容纳其最大可能长度的显存空间。例如,即使一个请求当前只有 10 个 token,系统也可能为它预留了 2048 个 token 的空间。这中间
2048 - 10 = 2038
个 token 的空间就被浪费了。 - 过度预留(Over-reservation): 对于一批请求,系统必须按照这批请求中“最长”的那个来为“所有”请求分配空间,导致短请求占用了大量不必要的显存。
2. 大概能优化多少显存占用?
PagedAttention 通过借鉴操作系统中“虚拟内存”和“分页”的思想,极大地提高了 KV Cache 的显存利用率。它能将 KV Cache 的内存浪费(碎片)从 60%-80% 降低到 4% 以下。
工作原理:
- 它将 KV Cache 空间分割成许多个固定大小的、不连续的物理块 (Block)。
- 每个序列的 KV Cache 在逻辑上是连续的,但它在物理上可以存储在任意的、不相邻的块中。
- 系统通过一个“块表”(类似页表)来管理逻辑 token 位置到物理块位置的映射。
- 当序列需要生成新 token 时,系统就按需分配一个新的块,而不需要预先分配一大片连续空间。
优化效果:
- 几乎没有内存浪费:按需分配块,使得显存利用率非常高。根据 vLLM 论文的数据,PagedAttention 可以达到 96% 的内存使用效率。
- 提升吞吐量:由于显存利用率的大幅提升,同样一张 GPU 可以容纳更多的并发请求(更大的总 Batch Size)。这使得推理服务的吞吐量可以提升 2-4倍。
- 更灵活的采样:对于复杂的采样算法(如 beam search),多个候选序列可以高效地共享前面相同的 token 所对应的 KV Cache 块,实现了类似“Copy-on-Write”的机制,极大地节省了显存。
结论:PagedAttention 优化的不是单次计算的峰值显存,而是整个服务运行期间的平均显存占用和管理效率。它的主要价值体现在高吞吐量的推理服务场景。
总结
- FlashAttention 作用于 Activation,通过改变计算方式,解决了计算过程中
O(N²)
的临时显存问题,让长序列处理成为可能,主要优势体现在训练和处理长文本的推理。 - PagedAttention 作用于 KV Cache,通过改变存储管理方式,解决了显存碎片和浪费问题,极大地提高了显存利用率,从而将推理服务的吞吐量提升数倍,主要优势体现在高并发的线上服务。
两者从不同角度优化了 Transformer 的显存问题,并且可以同时使用,共同构成当前大模型高效推理的基础。