Binary Block Masking:加快稀疏 Attention 的一种新方法
一、背景
我们在之前的文章中简单介绍了 Sample Packing 相关的技术方案及涉及的问题,也在看其中 Attention 计算带来的各种挑战。机缘巧合正好看到一篇文章试图解决相应的 Attention 计算问题,这里进行简单介绍。
对应的论文为:[2409.15097] Efficiently Dispatching Flash Attention For Partially Filled Attention Masks
相关工作可以参考我们之前的文章:
- Sample Packing:长序列 LLM 训练的 Attention 问题及优化
- Sample Packing 综述:LLM 效果与效率的 Tradeoff
- 万字综述 10+ 种 LLM 投机采样推理加速方案
- LLM 投机解码 & 美杜莎(Medusa)实现
- LLM 推理的 Attention 计算和 KV Cache 优化:PagedAttention、vAttention 等
- 大规模分布式 AI 模型训练系列——序列并行
二、摘要
Transformer 已广泛应用于各种应用场景,会产生稀疏或部分填充的注意力矩阵,比如旨在降低 Attention 二次复杂度的 Attention Mask、Sample Packing 或其他技术引入的稀疏 Attention,例如用于 Medusa 投机采样快速验证的 Tree Attention。尽管这些矩阵具有固定的稀疏性,但 SOTA 的 FlashAttention 仍然将其当做稠密矩阵处理。
本文中,作者提出了 Binary Block Masking (BinBlkMsk) ,通过使 FlashAttention 具备 Mask 感知能力来增强 FlashAttention。作者也进一步提出了两种优化方案,在真实场景的 Attention Mask 实验表明,运行速度可以提升 9x。
- 一种针对具有连续非零 Mask:Dense Binary Block Masking。
- 一种针对极度稀疏 Mask:Binary Block Masking with RCM。
PS:其实上述的 Baseline 有问题,FlashAttention 可以通过 Varlen 方案支持 Sample Packing 引入的 Block Diagonal Mask,避免当做稠密矩阵计算。此外,Pytorch 官方也发布了 FlexAttention,可以支持各种 Attention Mask 变体,本文作者也并没有对比。
三、方案
3.1 Binary Block Masking
以 Sample Packing 引入的 Block Diagonal Mask 为例,如下图所示,Max Sequence Length 为 16,拼接了 3 个 Sample,长度分别为 4,5,7。因此会形成一个 16x16 的 Global Mask,其中包含 3 个 Causal Mask。
在 FlashAttention 实际的计算中会将其划分为不同的 Block 计算。本文提出的 Binary Block Masking 是通过一个额外的 Binary Block Matrix 来标记哪些 Block 全是 0,而哪些中有非 0 值。如下图,假设 Block 的大小为 2x2,则会生成一个 8x8 Binary Matrix,如下图红框和蓝框对应的位置为 1,其他位置为 0。在计算时,如果 Binary Block Matrix 中为 0 的 Mask 可以忽略不计算。
3.2 Dense Binary Block Masking
在自然语言场景中,通常上述的 Attention Mask 的 Block 中会存在很多连续全 1 的 Block,只在边界的 Block 里既有 0 又有 1。对于全为 1 的 Block,在计算时不用再读取 Attention Mask,减少访存并提升计算效率。因此作者使用两个额外的数组 total_ones 和 offset 来标识这些连续全 1 Block,如下图所示,total_ones[7] 为 2,表示最后一行有 2 个全 1 Block;offset[7] 为 5,表示最后一行连续全 1 Block 的起始索引为 5。
- total_ones 存储全 1 Block 的个数。
- offset 存储 Binary Block Matrix 中每行连续全 1 Block 的起始位置。
- 右图 Binary Block Matrix 中红色 1表示 Block 中全是 1,黑色 1 表示 Block 中部分为 1。
PS:实际计算中,对于所有 Layer,所有 Attention Head 都只需计算一次。
3.3 Binary Block Masking with RCM
在处理非常稀疏的注意力掩码时,如果使用标准的 Binary Block Masking 方法,可能会存在效率问题。这是因为即使整个块中只有一个非零值,也必须处理整个块,导致计算资源的浪费。作者使用 Reverse Cuthill-McKee (RCM) 算法来有效地重新组织 Mask 矩阵的结构,从而减少必须处理的 Block 的数量。如下图 Figure 3 所示:
如下图所示,作者展示了一个合成掩码的结果,其中 RCM 预处理减少了 90% 的 Block 数量,从而获得显著的性能提升。Binary Block Masking with RCM 的速度明显快于 Binary Block Masking。
四、实验和结果
4.1 配置
作者主要评估三种 Mask 类型,包括 Medusa([2401.10774] Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads) 里面的 Tree Attention Mask,Sample Packing 中的 Block Diagonal Mask 以及 LongFormer([2004.05150] Longformer: The Long-Document Transformer) 中的稀疏 Mask。测试时使用 Batch Size 为 4,32 个 Attention Head。FlashAttention 中的 Block Size 固定为 (128, 32)。
此外,测试的 FlashAttention 方案并不是开源的 FlashAttention 实现,而是作者基于 Triton 实现的,以方便对比。
4.2 预处理
如下图 Table 1 所示,作者使用 Triton 实现了基于 Attention Mask 生成 Binary Block Mask 的 Kernel,并测试了相应的时间。虽然 Binary Block Mask 的生成时间和 Batch Size 为 1,1 个 Head 的计算时间差不多,但是整体只用计算一次,影响比较小。
4.3 Medusa Tree Attention
如下图 Figure 3 所示为对 Medusa Tree Attention 的优化效果。左图为 Tree Attention 的 Mask,其中 K 表示 Medusa 中 Head 个数,sk 表示每个 Head 候选 Token 个数。(b) 表示固定候选 Token 为 3,不同 Head 个数下的速度;(c) 表示固定 4 个 Head,不同候选 Token 的速度。
- Naive Attention Masking:所有 Block 都计算。
- Bash FlashAttention:Causal Mask,只计算下三角的 Block,当 Head 或者候选 Token 很多时 FlashAttention 的 Causal Mask 时间大概是 Naive 的一半(计算量大约是一半)。
- Binary Block Masking:只计算部分 Block,速度远快于上述两种方案。
4.4 Sample Packing Block Diagonal Mask
作者同样评估了 Sample Packing 场景的 Block Diagonal Mask,如下图(a)表示标准 Causal Mask 的组合;(b)表示 Input 双向 Attention,而 Output Causal Mask 的情况。从(b)和(c)可以得出类似 Tree Mask 的结论,当序列比较长时,本文的 Binary Block Masking 可以带来 9x 加速。不过需要说明的是:Base FlashAttention 的时间几乎是 Naive Attention Masking 的一半,也可以推导出作者使用的 FlashAttention 方案确实计算了整个下三角 Block,而没有使用 Varlen 方案。
4.5 LongFormer Attention Mask
作者同样评估了 LongFormer 场景的各种稀疏 Mask,如下图分别为 Sliding Window Attention,Dilated Attention 以及 Global Attention 等场景,可以得出类似的结论。需要说明的是,此时 Base FlashAttention 相比 Naive Attention 的提升比较有限。
4.6 Pytorch FlexAttention
Pytorch 在 2.5.0 版本引入了 FlexAttention(FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention),可以很容易支持各种 Attention Mask 变种,比如标准 Causal Mask、Sliding Window 、Prefix Mask 以及 Block Diagonal Mask 等,相比 FlashAttention 灵活了很多,本文提到的各种 Attention 基本也都能实现。
如下表所示,我们同样做了 Sample Packing 的实验,随着 Sequence 中 Sample 分布的不同,计算的耗时甚至可能差 10x,与本文中 9x 的提升也能对上。
五、参考链接
- https://arxiv.org/abs/2409.15097
- https://arxiv.org/abs/2401.10774
- https://arxiv.org/abs/2004.05150
- https://pytorch.org/blog/flexattention/
本文转载自 AI闲谈,作者: AI闲谈