MiniCache 和 PyramidInfer 等 6 种优化 LLM KV Cache 的最新工作
一、背景
在 LLM 推理中,常常会采用 KV Cache 来缓存之前 Token 的中间结果,以显著减少重复计算,从而降低自回归生成中的延迟。然而,KV Cache 的大小与序列长度成正比,在处理长序列时会面临极大的挑战。尤其当前许多模型开始支持几百 K 甚至几 M 的序列长度,进一步凸显了 KV Cache 的问题,因此很多研究工作致力于降低 KV Cache 的占用。
本文中简单介绍几个最新的工作,包括 SnapKV、YOCO、CLA、Layer-Condensed KV Cache、MiniCache 以及 PyramidInfer,它们都试图降低缓解 KV Cache 的压力。关于 GQA、MQA、DeepSeek MLA 以及量化相关的工作我们已经在之前进行了介绍,这里不再赘述。
二、KV Cache 大小
KV Cache 的大小与模型配置(层数,hidden_size,Attention head 个数等)以及序列长度、Batch Size 成正比。其中单个 Token 对应的 KV Cache 大小与模型配置相关,并且是固定的,这里将其称为单位 KV Cache 计算公式为:
sum_token = (hidden_size / num_attention_heads * num_key_value_heads) * num_hidden_layers * 2(k, v)
而总的 KV Cache 大小为:
sum = sum_token * seq_len * batch_size
batch_size 和 seq_len 越大,KV Cache 越大,如下图所示为 LLaMA2-7B 模型的 batch_size 和 seq_len 对应的 KV Cache 大小(默认 FP16 精度):
- 当 batch_size * seq_len 为32K时,比如 batch_size 为 1,seq_len 为 32K,其 KV Cache 大小为16GB,甚至超过模型权重大小 14GB。
- 当 batch_size * seq_len 为128K时,比如 batch_size 为 1,seq_len 为 128K,其 KV Cache 大小为 64GB,加上模型权重 14GB 甚至快要超过 A100 GPU 的 80GB 显存限制。
三、SnapKV
[2404.14469] SnapKV: LLM Knows What You are Looking for Before Generation 的核心思路比较简单,如下图 Figure 1 所示,在 Prefill 阶段不是保留所有输入 Token 的 KV Cache,而是采用稀疏化的方式,针对每个 Attention Head 将 Prompt 分为 Prefix 和 Window 两部分;然后,通过 Window 中 Token 与 Prefix 中 Token 的 Attention Score 来选择稀疏化的 Token;最后,将它们的 KV Cache 和 Window 中 Token 的 KV Cache 一起作为 Prompt 的 KV Cache。需要说明的是:每个 Attention Head 中从 Prefix 里挑选的 Token 可能不同。此外,Decoding 阶段也不会再更新 Prompt 的 KV Cache。
SnapKV 在处理 16K Token 的输入时,可以获得 3.6x 的加速,内存效率提升 8.2x。同时在 16 个长序列数据集上保持了与基线模型相当的精度。此外,使用 Huggingface 可以在单个 A100-80GB GPU 上处理 380K 上下文 Token 的任务。
四、YOCO
在 [2405.05254] You Only Cache Once: Decoder-Decoder Architectures for Language Models 中,作者只保留一层全局的 KV Cache。这种设计可以大大降低 GPU 显存的需求,加快 Prefill 阶段。如下图所示,YOCO 模型与常规 Decoder-Only LLM 的区别有几点:
- 前 L/2 层(Self-Decoder)使用Efficient Self-Attention,实际上就是滑动窗口 Self-Attention或作者之前论文提出的Multi-Scale Retention。其只用保存窗口内的 KV Cache 即可。
- 第 L/2 层的 KV Cache 作为Global KV Cache。也就是只有一层有全局 KV Cache。
- 后 L/2 层(Cross-Decoder)使用Global Cross Attention,对应的 KV 为上一步的 Global KV Cache,也就是后续所有 L/2 层的 Cross Attention 的 KV Cache 都是相同的。
五、CLA
[2405.12981] Reducing Transformer Key-Value Cache Size with Cross-Layer Attention 中作者同样采用 Cross-Attention 机制来降低 KV Cache。不同的是作者并非采用固定层作为 Cross-Attention 的输入,而是采用相邻层,如下图左图所示。最简单的方式就是隔层共享,称作 CLA2,实际也可以每 3 层共享,称作 CLA3,如下图右图所示。此外,这种方法与 MQA 和 GQA 等修改 Attention Head 的方案是兼容的。CLA2 显存减小 2x,CLA3 显存减小 3x。
作者训练 1B 和 3B 参数模型模型实验表明,CLA 相比传统的 MQA 在显存占用、准确性方面可以实现帕累托改进,从而实现更长的序列长度和更大的 Batch Size。(PS:但并不意味着可以优于现在广泛采用的 GQA?)
六、Layer-Condensed KV Cache
在 [2405.10637] Layer-Condensed KV Cache for Efficient Inference of Large Language Models 中,作者同样采用了仅计算和缓存少量层 KV Cache 的方案,从而显著节约显存消耗并提升吞吐量。如下图 Figure 1 所示,仅保留最后一个 Transfomer Block 层的 KV Cache,当生成后续 Token 时其对应的 KV Cache 都从最后一层取。
七、MiniCache
在 [2405.14366] MiniCache: KV Cache Compression in Depth Dimension for Large Language Models 中,作者观察到 KV Cache 在 LLM 中的深层部分的相邻层之间表现出了高度相似性,可以基于这些相似性对 KV Cache 进行压缩。此外,作者还引入了 Token 保留策略,对高度不同的 KV Cache 不进行合并。并且这种方法可以与其他的 KV Cache 量化方案正交使用。
作者在 LLaMA-2、LLaMA-3、Phi-3、Mistral 和 Mixtral 等模型上进行实验,在 ShareGPT 数据集上,采用 4 Bit MiniCache LLaMA–7B 与 FP16 全量 KV Cache 相比实现了 5.02x 的压缩比,推理吞吐提高约 5 倍,显存占用减少 41%,同时性能几乎无损。
如下图 Figure 3 所示为其压缩策略和保留策略:
如下图 Figure A 所示为其详细的执行流程:
- 1. 获取 KV Cache:在 Prefill 阶段,逐层生成 KV Cache。
- 2. 跨层合并:当到达合并开始层 S 时,将当前层 L 的 KV Cache 与前一层 L-1 的 KV Cache 进行合并,以减少冗余。
- 3. 缓存:将合并后的 KV Cache 存储起来,以便将来使用。
- 4. 删除:在 Decoding 阶段,删除不必要的或冗余的 KV Cache,以优化内存使用。
- 5. 加载和生成:获取所需的 KV Cache,用于生成输出。
- 6. 恢复:对获取的 KV Cache 应用误差抑制机制,包括 rescaling 和 retention recovery,以最小化合并和压缩过程中引入的误差。
- 7. 更新:在恢复阶段后,使用最终的 KV Cache 更新共享的 KV Cache。
八、PyramidInfer
在 [2405.12532] PyramidInfer: Pyramid KV Cache Compression for High-throughput LLM Inference 中,作者发现影响未来生成的关键 KV 的数量逐层减少,并且可以通过注意力权重的一致性来提取这些关键 KV。基于这些发现,作者提出了 PyramidInfer,通过逐层保留关键上下文来压缩 KV Cache。PyramidInfer 在不牺牲性能的情况下计算更少的 KV,并节约大量显存。实验结果表明,与 Accelerate 相比,PyramidInfer 的吞吐提高了 2.2 倍,KV Cache 的显存占用减少了 54% 以上。
如下图 Figure 2 所示为 PyramidInfer 与 StreamingLLM 和 H2O 的区别,PyramidInfer 中 KV Cache 会逐层递减,越往后越稀疏(PS:如果是这样,那么 Layer-Condensed KV Cache 中只保留最后一层的方案是不是不太合理):
PyramidInfer 的执行过程如下图 Figure 6 所示:
- 在 Prefill 阶段,PyramidInfer 只保留每层的关键上下文(Pivotal Context, PvC)来压缩 KV Cache。
- 在 Decoding 阶段,PyramidInfer 根据新的最近的 Token 来更新 PvC。
如下图 Table 1 所示,PyramidInfer 在使用更少 KV Cache 的情况下获得更快的推理速度:
如下图 Figure 11 所示,作者进一步测试了 PyramidInfer 在更多 Batch Size 下的表现,其在比较小 Batch Size 时几乎没有加速,主要是因为减少 KV Cache 还需要一些额外的计算;而在比较大的 Batch Size 能获得更大的加速比。而 Full Cache 当 Batch Size 大于 32 吞吐反而降低:(PS:这个降低不太符合预期,通常来说随着 Batch Size 的增加,计算密度会更高,相应的吞吐也应该更高,而且在 32 左右还远没有到 Compute Bound)。
九、参考链接
- https://arxiv.org/abs/2404.14469
- https://arxiv.org/abs/2405.05254
- https://arxiv.org/abs/2405.12981
- https://arxiv.org/abs/2405.10637
- https://arxiv.org/abs/2405.14366
- https://arxiv.org/abs/2405.12532
本文转载自 AI闲谈,作者: AI闲谈