
DeepSeek如何用MTP逆天改命?
DeepSeek-V3 的 Multi-Token Prediction 到底在做什么?这个问题在大模型面试中经常被问到,属于 DeepSeek 的高频面试题。
所以这篇文章我们就来看看,如果你在面试现场被问到这个问题,应该如何作答?
1.面试官心理分析
首先老规矩,我们还是来分析一下面试官的心理,面试官问这个问题,它其实主要是想考察你 3 个方面:
- 第一,为什么要做 MTP?你是否知道这个算法背后的动机?
- 第二,之前的工作 MTP 是怎么做的?DeepSeek 肯定不是这个方法的首创,那之前的研究,前因后果你是否清楚呢?
- 第三,DeepSeek 的 MTP 是怎么做的,它的设计相比之前的,有什么不同之处?
好,了解了面试官的心理之后,接下来我们就沿着面试官的心理预期,来回答一下这道题目!
2.面试题解析
首先第一个问题:为什么要做 MTP?
我们都知道,当前主流的大模型都是 decoder-only 的架构,每生成一个 token,都要频繁的跟访存交互,加载 KV-Cache,再完成前向计算。
那对于这样的访存密集型任务,通常会因为访存效率而形成推理的瓶颈,针对这种 token-by-token 生成效率的瓶颈,业界有很多方法来优化,比如减少存储空间,减少访存次数等等。
那 MTP 也是优化训练和推理效率的方法之一,它的核心动机是:通过解码阶段的优化,将 next 1-token 的生成,转变成 multi-token 的生成,以提升训练和推理的性能。
对于训练阶段,一次生成多个后续 token,可以一次学习多个位置的 label,这样可以增加样本的利用效率,提高训练速度;而在推理阶段,通过一次生成多个 token,可以实现成倍的解码加速,来提升推理性能。
好,到这里我们就回答了第一个问题:为什么要用 MTP?接着我们再来看看,DeepSeek 之前的 MTP 都是如何做的?业界经过了哪些探索?
其实最早做 MTP 方法的是 Google 在 18 年发表的这篇论文《Blockwise Parallel Decoding for Deep Autoregressive Models》。
其思想很简单,我们看这张图:
可以看到,logits 上接了多个输出头,这样训练的时候可以同时预测出多个未来的 token,也就是分别预测下个 token,再下个 token,再再下个 token,以此类推。
好,理解了网络细节,我们再看并行解码过程就很好理解了,整个推理过程看这张图:
可以看到,解码过程主要分成三步:
阶段 1:predict,利用 k 个 Head 一次生成 k 个 token,每个 Head 生成一个 token。
阶段 2:verify,将原始的序列和生成的 k 个 token 拼接,组成 sequence_input 和 label 的 Pair 对。
Pair<sequence_input, label>
大家看图中的 verify 阶段,黑框里是 sequence_input,箭头指向的是要验证的 label。
我们将组装的 k 个 Pair 对组成一个 batch,一次性发给 Head1 做校验,检查 Head1 生成的 token 是否跟 label 一致。
然后是阶段 3:accept,选择 Head1 预估结果与 label 一致的最长的 k 个 token,作为可接受的结果。
最优情况下,所有辅助 Head 预测结果跟 Head1 完全一样,也就是相当于一个 step 正确解码出了多个 token,这可以极大的提升解码效率。
实际上在 24 年,meta 也发表过一篇大模型 MTP 的工作,这是当时的论文,其结构跟 Google 那篇差别不大,这里我们就不再单独赘述。
感兴趣的同学可以去看看这篇论文《Better & Faster Large Language Models via Multi-token Prediction》。
好,了解了 MTP 在业界的发展,我们再来看看,DeepSeek 是怎么做 MTP 的?
这里直接说改进,DeepSeek 的 MTP 设计,看这张图:
实际上它在论文实现上保留了序列推理的 causal chain,也就是存在从一个 head 连接到后继 head 的箭头。其他的思路跟 google 那篇论文差不多。
另外在训练的时候,同样采用的是 teacher forcing 的思想,也就是 input 会输入真实的 token,而在实际预测解码的阶段,采用的是 free running 的思想,也就是直接用上一个 step 解码的输出,来作为下一个 step 的输入。
本文转载自丁师兄大模型,作者: 丁师兄
