如何在LLM训练过程中精妙设计SFT与RL步骤—— LLM训练框架推荐 原创

发布于 2025-5-27 06:48
浏览
0收藏

一种可以“自适应切换SFT与RL”的训练框架分享。

大家应该都还记得,DeepSeek-R1的“SFT->RL->增强SFT->增强RL”这种左脚踩右脚直接起飞的操作,这说明监督微调(SFT)与强化学习(RL)交替训练的训练范式确实可以提高模型性能。

很多大佬也有自己做小规模实验,在进行新的训练范式探索:

  • 预训练后做两次SFT接一次RL
  • 预训练后先RL再SFT
  • ....

那么如何设计训练框架能实现效果最优呢?

本篇分享一种可以“自适应切换SFT与RL”的训练框架;这是念空科技联合上海交通大学计算机学院投的新论文 《Step-wise Adaptive Integration of Supervised Fine-tuning and Reinforcement Learning for Task-Specific LLMs》。

如何在LLM训练过程中精妙设计SFT与RL步骤—— LLM训练框架推荐-AI.x社区


下面是一个快捷目录。

1. 待解决的问题

2. 论文方法

3. 实验结果

4. 其他可发散的点

一、待解决的问题

目前这种 “固定步骤的SFT和RL交替” 静态混合训练方法可能会带来一些问题,比如,一种训练范式直接切换到另一种时,可能会导致模型下降;不同阶段任务着重训练的知识不同,模型很可能灾难性遗忘或者陷入局部最优等,最终影响训练的连续性和稳定性。

这篇论文主要解决的就是如何设计训练步骤的问题:如何设计一个最优的训练框架来保证LLM的训练稳定性。

二、论文方法

论文提出了一个名为SASR(Step-wise Adaptive Integration of Supervised Fine-tuning and Reinforcement Learning)的逐步自适应混合训练框架,通过理论统一监督微调(SFT)和强化学习(RL),并动态平衡两者在整个优化过程中的比例。

如何在LLM训练过程中精妙设计SFT与RL步骤—— LLM训练框架推荐-AI.x社区

主要包含两个阶段:


如何在LLM训练过程中精妙设计SFT与RL步骤—— LLM训练框架推荐-AI.x社区


第一阶段:Warm-up Phase

首先使用小规模的(问题,链式思考)数据对进行SFT,以建立模型的基本推理能力。这些数据对包括输入问题的标记序列和对应的链式思考推理路径,帮助模型学习结构化的问题解决策略。

在第一阶段中通过最小化负对数似然(NLL)损失来最大化真实序列的似然,从而更新模型参数。

loss长这样,at是思维链中的token第t个token标记,st是步骤t中的上下文状态,包括之前所有生成的标记。

如何在LLM训练过程中精妙设计SFT与RL步骤—— LLM训练框架推荐-AI.x社区

第二阶段:Hybrid Training Phase

在Warm-up之后,逐步开始自适应混合训练,把SFT和GRPO结合起来。

GRPO通过组间比较扩展策略优化,通过采样当前和旧策略的输出,并根据相对优势将它们分为高优势组和低优势组,然后结合优势最大化和KL正则化来更新策略。

另外此阶段根据当前模型的训练状态来动态调整SFT和GRPO的比例。具体来说,通过比较当前梯度范数与Warm-up阶段记录的梯度范数,动态更新两者的比例。

loss长这样, πθold 是更新前的上一个策略,πref 表示参考策略(通常是初始 SFT 模型),ε控制策略更新的裁剪范围,β调整 KL 正则化的强度。比率 πθ πθold 衡量每个step的新策略与旧策略的偏差程度。

如何在LLM训练过程中精妙设计SFT与RL步骤—— LLM训练框架推荐-AI.x社区

那么如何进行动态比例的分配呢?主要通过监测训练过程中的梯度范数和模型策略相对于原始数据分布的KL散度,当模型与原始数据分布的偏差较大时,增加SFT的权重;当模型接近原始数据分布时,增加GRPO的权重。

最终整体损失函数 L(θ)如下

如何在LLM训练过程中精妙设计SFT与RL步骤—— LLM训练框架推荐-AI.x社区


这里引入了 I(t) 作为状态函数,它根据当前模型的训练状态 t 返回训练范式决策变量 I(t)。

与传统的 Hybrid方法在一个 epoch 内使用固定的训练范式相比,SASR 采用更细粒度的训练步骤 s 作为训练单元,可实现更灵活的自适应调整。

下面这段伪代码可以辅助大家很快理解他的思路。

如何在LLM训练过程中精妙设计SFT与RL步骤—— LLM训练框架推荐-AI.x社区

另外论文还进行了理论分析与实验验证,建立了SFT损失的梯度范数与KL散度之间的关系,证明了SASR在避免SFT引起的过拟合、缓解RL导致的模型坍塌以及克服静态混合训练的局限的优势。

三、实验结果

模型设计了三个实验:

如何在LLM训练过程中精妙设计SFT与RL步骤—— LLM训练框架推荐-AI.x社区


  • GSM8K(小学水平数学算术)+ DeepSeek-R1-Distill-Qwen-1.5B模型:模型的准确率从63.8%提高到80.3%,接近GPT-4o的水平
  • KK(逻辑推理)+ Qwen2.5-1.5B-Instruct模型:平均准确率提升9%,超过了GPT-4o
  • MATH(数学竞赛、公式)+ Qwen2.5-0.5B-Instruct模型:平均准确率提升了9%,超过了GPT-4o

四、其他可发散的点

这篇论文感觉还是有很多可以继续去发散的,比如跟除了GPRO的其他强化学习算法结合,推广到多模态,改进动态调整策略等等。有想法的朋友们可以一起交流一下~

参考文献

[1] ​​​https://arxiv.org/pdf/2505.13026​



本文转载自​瓦力算法学研所​,作者:喜欢瓦力的卷卷

©著作权归作者所有,如需转载,请注明出处,否则将追究法律责任
已于2025-5-27 06:48:51修改
收藏
回复
举报
回复
相关推荐