用短输入模拟长样本,高效拓展LLM上下文窗口,北大联合MSRA提出PoSE
论文题目:PoSE: Efficient Context Window Extension of LLMs via Positional Skip-wise Training
论文链接:https://arxiv.org/abs/2309.10400
代码链接:https://github.com/dwzhu-pku/PoSE
一、研究简介
大型语言模型(LLMs)通常有一个预定义的上下文窗口大小,这限制了它们在长输入的场景中的使用。为了使 LLMs 适应更长的输入,通常需要用目标长度的样本对其进行微调(全长微调),由此导致训练成本十分昂贵。
举例来说,在 Positional Interpolation [1] 这份工作中,将 LLaMA 的上下文窗口从 2048 拓展到 8192 使用了 32 张 A100,对于更大的上下文窗口则使用了 128 张 A100。
为了将训练长度与目标长度解耦合,以实现高效的上下文窗口扩展,我们提出了一种称为位置跳跃式训练(Positional Skip-wisE training, PoSE)的方法,在原始的上下文窗口中模拟更长的训练样本。
如下图所示,我们将原始的上下文窗口分成几块,然后引入不同的 bias 项来调整每个块的位置编码。对于每一条训练样本,这些 bias 项和块的长度都会发生变化,因此通过大量的训练,模型能适应目标长度内的所有位置。
实验结果表明,PoSE 有以下三方面的优势:
- 训练的时空效率:由于只需要按照原始的上下文长度进行训练,PoSE避免了由目标上下文长度增加带来的平方级别的计算复杂度,使得训练对于内存和时间的开销都大大减小。
- 能支持极长的上下文:通过解耦合训练长度和目标长度,我们仅使用2k的训练窗口就成功将 LLaMA 拓展到 128k。
- 兼容所有基于 RoPE 的模型和位置插值策略:PoSE 的有效性在 LLaMA、GPT-J、Baichuan 等多种基础模型,和 Linear、NTK、YaRN 等多种插值策略上得到了验证。
二、技术背景
旋转位置编码 RoPE:RoPE 是当下主流的位置编码方式,被 LLaMA、GPT-J 等大语言模型所采用。给定一个 维的隐向量和位置 ,RoPE 通过如下方式编码位置信息:
其中 。此前的绝对位置编码多是直接作用在输入向量上,与之不同的是,RoPE 是在作用在每一层的 query 和 key 向量上。RoPE 可以看作是一种相对位置编码,给定位置处的 query 向量和位置处的 key 向量,注意力分数可以写成如下函数:
上下文窗口扩展:给定一个以为原始上下文窗口长度的大语言模型,我们的目标是其支持的上下文长度拓展到 ,使得在个输入内能较好地保持原有的性能。
位置插值(PI):为了将 LLM 的上下文窗口从 拓展到 ,一种直接的做法是使用长的输入文本 ,设定其位置编码为 , 对 LLM 进行微调。
然而,实践表明 [1] [2],这部分的位置在前向传播时会产生灾难性的离群值,从而导致训练无法达到预期的效果。这主要是因为模型在预训练时只见过
这些位置,无法很好的泛化到外推出去的这部分位置。
为了解决这个问题,Position Interpolation [1] 这份工作首先提出用“内插”代替“外推”,设定缩放因子,并将上述注意力公式修改为 (也就是将位置编码线性修改为 )。
这种方式可以减少离群值的出现,将上下文窗口拓展到了 32k。在此基础上,NTK 提出通过修改 来进行位置插值,取得了更好的效果。YaRN 则根据不同的维度,对上述线性插值和 NTK 进行了整合。
三、方法描述
尽管上述 Linear / NTK / YaRN 等插值方式能一定程度上解决位置外推的问题,他们仍然需要用目标长度的训练样本来训练模型(即全长微调)。
随着目标长度的增加,平方级别的计算复杂度带来的开销依旧是难以承受的。因此,在插值技术的基础上,我们提出调整原始的上下文窗口中的位置编码,来模拟更长的训练样本,从而实现高效的上下文扩展。
位置编码的调整主要有两个考量:
- 为了避免推理时遇到 out-of-distribution 的相对位置,调整后的位置编码应覆盖 所有这些相对位置;
- 用调整后的位置编码来微调 LLM 不应该损害其原有性能,因此调整后的位置编码的结构应该和预训练时尽可能接近。
第一步:我们将原上下文窗口 分成 N 个块 ,每个块的长度为,满足 。记 的起始位置编码为 ,则这个块的位置编码如下:
第二步:我们从离散均匀分布 中采样出跳跃偏置项 ,并施加到 的位置编码上:
为了避免块之间位置编码的重合,我们施加了 这一限制。值得注意的是,对于每条数据,我们会重新采样每个块的大小和跳跃偏置项。直观上来说,通过这种方式,我们扩大了原上下文窗口能覆盖的相对位置范围,并且位置编码的不连续只发生在块之间,因此尽可能地保持了预训练阶段的位置编码结构。
第三步:选定每个块内的内容。给定输入文本,我们用类似的方法来抽取每个块内的填充的内容:
我们也尝试了其它 的赋值方式,如,此时块间的内容也是连续的;或如,此时调整后的位置编码恰好对应训练数据在原始文本中的位置。实验结果表明,这几种赋值方式并没有明显的差别。
第四步:位置插值及超参初始化。我们使用位置插值来使训练更稳定。 和 设置成 0,N 设置为 2 。
四、实验分析
1. 实验设置
训练过程:我们主要使用 LLaMA-7B 作为基模型,对于所有设定都只训练 1000 步,训练时长度为 2k,batch size 为 64。我们使用 8 张 V100 进行训练,1 张 A100 进行推理。对于我们的方法和各个 baseline,我们都默认采用线性插值来使训练更稳定。
Baseline:
- 全长微调(Full-length fine-tuning)
- 随机位置(RandPos):给定目标长度 和原始长度 ,从 中随机采样 个位置,按升序排列,作为位置编码。
2. 主要结果
语言模型:
我们使用滑动窗口的方式来计算困惑度 PPL。在 GovReport 和 Proof-Pile 两个数据集上,PoSE 的性能和 Full-length 十分接近,远超未做窗口扩展的版本(Original)和随机位置的版本(RandPos)。且随着窗口长度从 2k 增加到 32k,PPL 呈下降趋势,说明拓展后的模型能充分利用更长的上下文信息。
密码检索:
在密码检索任务上,利用 PoSE 拓展到 16k 和 32k 的模型能分别在 16k 和 32k 的上下文内取得接近 100% 的密码检索准确率,说明模型能关注到目标长度内的每个位置。
时空效率:
在时空效率方面,全长微调的训练时长和内存消耗随目标长度的增加而迅速增长,相比之下,PoSE 需要的训练时间和内存较为稳定。并且在每个时间步上,性能和全长微调都很接近。
兼容性:
兼容性方面,PoSE 可以适配 LLaMA、LLaMA2、GPT-J、Baichuan2 等各种基于 RoPE 的基础模型,以及 Linear、NTK、YaRN 等各种插值策略,展现出较好的普适性。其中 NTK 在最后阶段会有一个 PPL 的突增,这主要是因为给定缩放因子 ,NTK 实际实现的缩放倍数会略小于 [3]。YaRN 解决了这个缺陷,取得了三者中最好的效果。
超长上下文拓展的潜力:
只使用 2k 的训练长度和 1000 步的训练步数,我们尝试了将 LLaMA 模型拓展到 128k。实验表明,在使用 YaRN 的情况下,模型在 128k 的窗口下仍然能保持较低的 PPL。
原窗口内的语言能力:
最后,我们分析了经由 PoSE 训练过后的模型在原窗口内的语言能力。可以看出,和全长微调以及原始模型相比,PoSE 模型能力的损失非常微小,这说明 PoSE 在拓展上下文窗口的同时较好地保持了模型的基础能力。
五、总结与讨论
本文提出了一种位置跳跃式训练(PoSE)来高效的拓展大语言模型的上下文窗口。通过调整位置编码,PoSE 在原始的上下文窗口中模拟更长的训练样本,以达到解耦合训练长度和目标长度的目的。
实验结果表明 PoSE 在和全长微调保持同等性能的情况下,大大缩小了训练所需的时空开销,并表现出良好的普适性和超长上下文扩展的潜力。我们相信 PoSE 将大大降低上下文窗口拓展的成本,使更多人可以参与到相关的研究中来,从而推动长上下文建模领域的快速发展。
PoSE 完成于 2023 年 9 月,我们相信这种位置跳跃的思路是 Long Context 的有效解决方案。结合近几个月来 Long Context 相关研究的进展,我们认为 PoSE 可能有以下一些方面值得进一步探究:
- 应用范围的拓展:融合到预训练或者 SFT 阶段中。PoSE 的实验主要是对基础模型进行轻量级的 post-pretrain,如果能直接融入预训练中,可能可以更好的解决推理时长文本位置 out-of-distribution 的问题;如果能用于 SFT 中,则可以将模型更好地适配到具体的下游任务或者 alignment 要求上。
- 效果的提升:探索更优的 skip 策略和合适的训练数据配比。一方面,实验中只将窗口分成了两块,且跳步由随机采样决定,如何在更多场景设计更科学合理的 skip 结构指的关注。
另一方面,PoSE 简单从 The Pile 数据集采样了部分超过 2k 长的样本作为训练数据,并没有特别关注数据的来源、长度等配比。根据 Fu et al. (2024) [4] 的结论,优化训练数据的分布对于长上下文建模能力的获取以及模型原始能力的保持都能有较大的帮助。
本文转载自PaperWeekly