DeepSeek如何用MTP逆天改命?
DeepSeek-V3 的 Multi-Token Prediction 到底在做什么?這個(gè)問(wèn)題在大模型面試中經(jīng)常被問(wèn)到,屬于 DeepSeek 的高頻面試題。
所以這篇文章我們就來(lái)看看,如果你在面試現(xiàn)場(chǎng)被問(wèn)到這個(gè)問(wèn)題,應(yīng)該如何作答?
1.面試官心理分析
首先老規(guī)矩,我們還是來(lái)分析一下面試官的心理,面試官問(wèn)這個(gè)問(wèn)題,它其實(shí)主要是想考察你 3 個(gè)方面:
- 第一,為什么要做 MTP?你是否知道這個(gè)算法背后的動(dòng)機(jī)?
- 第二,之前的工作 MTP 是怎么做的?DeepSeek 肯定不是這個(gè)方法的首創(chuàng),那之前的研究,前因后果你是否清楚呢?
- 第三,DeepSeek 的 MTP 是怎么做的,它的設(shè)計(jì)相比之前的,有什么不同之處?
好,了解了面試官的心理之后,接下來(lái)我們就沿著面試官的心理預(yù)期,來(lái)回答一下這道題目!
2.面試題解析
首先第一個(gè)問(wèn)題:為什么要做 MTP?
我們都知道,當(dāng)前主流的大模型都是 decoder-only 的架構(gòu),每生成一個(gè) token,都要頻繁的跟訪存交互,加載 KV-Cache,再完成前向計(jì)算。
那對(duì)于這樣的訪存密集型任務(wù),通常會(huì)因?yàn)樵L存效率而形成推理的瓶頸,針對(duì)這種 token-by-token 生成效率的瓶頸,業(yè)界有很多方法來(lái)優(yōu)化,比如減少存儲(chǔ)空間,減少訪存次數(shù)等等。
那 MTP 也是優(yōu)化訓(xùn)練和推理效率的方法之一,它的核心動(dòng)機(jī)是:通過(guò)解碼階段的優(yōu)化,將 next 1-token 的生成,轉(zhuǎn)變成 multi-token 的生成,以提升訓(xùn)練和推理的性能。
對(duì)于訓(xùn)練階段,一次生成多個(gè)后續(xù) token,可以一次學(xué)習(xí)多個(gè)位置的 label,這樣可以增加樣本的利用效率,提高訓(xùn)練速度;而在推理階段,通過(guò)一次生成多個(gè) token,可以實(shí)現(xiàn)成倍的解碼加速,來(lái)提升推理性能。
好,到這里我們就回答了第一個(gè)問(wèn)題:為什么要用 MTP?接著我們?cè)賮?lái)看看,DeepSeek 之前的 MTP 都是如何做的?業(yè)界經(jīng)過(guò)了哪些探索?
其實(shí)最早做 MTP 方法的是 Google 在 18 年發(fā)表的這篇論文《Blockwise Parallel Decoding for Deep Autoregressive Models》。
其思想很簡(jiǎn)單,我們看這張圖:
可以看到,logits 上接了多個(gè)輸出頭,這樣訓(xùn)練的時(shí)候可以同時(shí)預(yù)測(cè)出多個(gè)未來(lái)的 token,也就是分別預(yù)測(cè)下個(gè) token,再下個(gè) token,再再下個(gè) token,以此類(lèi)推。
好,理解了網(wǎng)絡(luò)細(xì)節(jié),我們?cè)倏床⑿薪獯a過(guò)程就很好理解了,整個(gè)推理過(guò)程看這張圖:
可以看到,解碼過(guò)程主要分成三步:
階段 1:predict,利用 k 個(gè) Head 一次生成 k 個(gè) token,每個(gè) Head 生成一個(gè) token。
階段 2:verify,將原始的序列和生成的 k 個(gè) token 拼接,組成 sequence_input 和 label 的 Pair 對(duì)。
Pair<sequence_input, label>
大家看圖中的 verify 階段,黑框里是 sequence_input,箭頭指向的是要驗(yàn)證的 label。
我們將組裝的 k 個(gè) Pair 對(duì)組成一個(gè) batch,一次性發(fā)給 Head1 做校驗(yàn),檢查 Head1 生成的 token 是否跟 label 一致。
然后是階段 3:accept,選擇 Head1 預(yù)估結(jié)果與 label 一致的最長(zhǎng)的 k 個(gè) token,作為可接受的結(jié)果。
最優(yōu)情況下,所有輔助 Head 預(yù)測(cè)結(jié)果跟 Head1 完全一樣,也就是相當(dāng)于一個(gè) step 正確解碼出了多個(gè) token,這可以極大的提升解碼效率。
實(shí)際上在 24 年,meta 也發(fā)表過(guò)一篇大模型 MTP 的工作,這是當(dāng)時(shí)的論文,其結(jié)構(gòu)跟 Google 那篇差別不大,這里我們就不再單獨(dú)贅述。
感興趣的同學(xué)可以去看看這篇論文《Better & Faster Large Language Models via Multi-token Prediction》。
好,了解了 MTP 在業(yè)界的發(fā)展,我們?cè)賮?lái)看看,DeepSeek 是怎么做 MTP 的?
這里直接說(shuō)改進(jìn),DeepSeek 的 MTP 設(shè)計(jì),看這張圖:
實(shí)際上它在論文實(shí)現(xiàn)上保留了序列推理的 causal chain,也就是存在從一個(gè) head 連接到后繼 head 的箭頭。其他的思路跟 google 那篇論文差不多。
另外在訓(xùn)練的時(shí)候,同樣采用的是 teacher forcing 的思想,也就是 input 會(huì)輸入真實(shí)的 token,而在實(shí)際預(yù)測(cè)解碼的階段,采用的是 free running 的思想,也就是直接用上一個(gè) step 解碼的輸出,來(lái)作為下一個(gè) step 的輸入。
本文轉(zhuǎn)載自???丁師兄大模型??,作者: 丁師兄
