無(wú)損加速最高5x,EAGLE-2讓RTX 3060的生成速度超過(guò)A100
李堉暉:北京大學(xué)智能學(xué)院碩士,受張弘揚(yáng)老師和張超老師指導(dǎo),研究方向?yàn)榇竽P图铀俸蛯?duì)齊,正在尋找25屆工作機(jī)會(huì)
魏芳蕓:微軟亞研院研究員,研究方向?yàn)榫呱碇悄堋D像生成和AI agents
張超:北京大學(xué)智能學(xué)院研究員,研究方向?yàn)橛?jì)算機(jī)視覺和機(jī)器學(xué)習(xí)
張弘揚(yáng):滑鐵盧大學(xué)計(jì)算機(jī)學(xué)院、向量研究院助理教授,研究方向?yàn)長(zhǎng)LM加速和AI安全
自回歸解碼已經(jīng)成為了大語(yǔ)言模型(LLMs)的事實(shí)標(biāo)準(zhǔn),大語(yǔ)言模型每次前向計(jì)算需要訪問(wèn)它全部的參數(shù),但只能得到一個(gè)token,導(dǎo)致其生成昂貴且緩慢。
今日,一篇題為《EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees》的論文提出了動(dòng)態(tài)草稿樹投機(jī)采樣,依據(jù)草稿模型的置信度動(dòng)態(tài)調(diào)整草稿樹的結(jié)構(gòu),最高可以將大語(yǔ)言模型的推理速度提高5倍,同時(shí)不改變大語(yǔ)言模型的輸出分布,確保無(wú)損。
- 論文鏈接:https://arxiv.org/pdf/2406.16858
- 項(xiàng)目鏈接:https://github.com/SafeAILab/EAGLE
- Demo鏈接:https://huggingface.co/spaces/yuhuili/EAGLE-2
EAGLE-2在多輪對(duì)話數(shù)據(jù)集MT-bench上的加速效果(上圖為貪婪生成,下圖為采樣生成):
使用EAGLE-2,2張RTX 3060($300)的推理速度可以超過(guò)A100($10000)。
背景
投機(jī)采樣使用一個(gè)小的模型快速生成草稿,原始的大語(yǔ)言模型可以通過(guò)一次前向計(jì)算驗(yàn)證草稿的正確性,將正確的草稿作為輸出,從而一次生成多個(gè)token,并確保無(wú)損。EAGLE是投機(jī)采樣的一種改進(jìn)。它在更有規(guī)律的特征層面而不是token層面進(jìn)行自回歸,同時(shí)輸入采樣結(jié)果(超前一個(gè)時(shí)間步的token)消除了不確定性,明顯提升了草稿模型的準(zhǔn)確率。
到目前為止,EAGLE在第三方測(cè)試Spec-Bench(https://github.com/hemingkx/Spec-Bench/blob/main/Leaderboard.md)中排名第一。
思路
EAGLE和Medusa等方法使用靜態(tài)的草稿樹,隱式地假設(shè)草稿token的接受率和上下文無(wú)關(guān),下面是一個(gè)簡(jiǎn)單的例子
上文是“10+2”時(shí),下一個(gè)token難以預(yù)測(cè),EAGLE在這個(gè)位置添加兩個(gè)候選token以增加草稿命中率,“10+2=”和“10+2+”有一個(gè)正確即可。當(dāng)上文是“10+2=”時(shí),下一個(gè)token明顯是“1”,但是EAGLE使用靜態(tài)的草稿結(jié)構(gòu),仍然添加兩個(gè)候選“1”和“3”,“10+2=3”不可能通過(guò)大語(yǔ)言模型的檢查,存在浪費(fèi)。EAGLE-2旨在解決這一問(wèn)題,如下圖所示,當(dāng)上文是“10+2=”時(shí),EAGLE-2只增加一個(gè)候選token“1”,將節(jié)約出的token用于讓草稿樹更深,這樣“10+2=12”通過(guò)大語(yǔ)言模型的檢查,EAGLE-2可以一次生成更多的token。
EAGLE-2的作者們?cè)贏lpaca數(shù)據(jù)集上進(jìn)行了簡(jiǎn)單的測(cè)試,下圖顯示了不同位置的草稿token的接受率,左圖中的P1-P6代表位置,與右圖的橫軸坐標(biāo)對(duì)應(yīng)。實(shí)驗(yàn)結(jié)果顯示,在相同的位置上的草稿token的接受率也有較大的差異,這說(shuō)明了使用動(dòng)態(tài)草稿樹可能取得比靜態(tài)草稿樹更好的效果。
上述例子中,EAGLE-2根據(jù)預(yù)測(cè)草稿token的難易程度決定草稿樹的結(jié)構(gòu),精確計(jì)算難易程度(接受率)需要原始大語(yǔ)言模型的計(jì)算結(jié)果,這違背了投機(jī)采樣減少對(duì)原始大語(yǔ)言模型訪問(wèn)的初衷。幸運(yùn)的是,EAGLE的草稿模型的置信度與接受率(難易程度)高度正相關(guān)。下圖顯示了草稿模型不同置信度區(qū)間的草稿token的平均接受率,紅色虛線連接(0,0)和(1,1)。由此可見,草稿模型的置信度可以作為接受率的有效近似。
方法
EAGLE-2包括兩個(gè)階段,擴(kuò)展和重排,擴(kuò)展階段加深加大草稿樹,重排階段修剪草稿樹,丟棄部分節(jié)點(diǎn)(token)。
為了保證無(wú)損,一個(gè)草稿token被接受的前提是它的祖先節(jié)點(diǎn)都被接受,所以EAGLE-2將一個(gè)節(jié)點(diǎn)的價(jià)值定義為它和它祖先的接受率的乘積,用置信度的乘積來(lái)近似。
在擴(kuò)展階段,EAGLE-2選擇草稿樹最后一層價(jià)值最高的m個(gè)節(jié)點(diǎn)(token)進(jìn)行擴(kuò)展。這些token被送入草稿模型,然后將草稿模型的輸出作為子節(jié)點(diǎn)連接到輸入節(jié)點(diǎn),加深加大草稿樹。在重排階段,EAGLE-2按照價(jià)值對(duì)整棵草稿樹進(jìn)行重排序,保留前n個(gè)節(jié)點(diǎn)(token)。草稿token的置信度在0-1之間,兩個(gè)節(jié)點(diǎn)價(jià)值相同時(shí)優(yōu)先保留淺層節(jié)點(diǎn),因此重排后保留的草稿樹一定是連通的,保證了語(yǔ)義上的連貫性。重排后草稿樹變小,降低了原始大語(yǔ)言模型驗(yàn)證的計(jì)算量。為了保證計(jì)算結(jié)果的正確性,還需要調(diào)整attention mask,確保每一個(gè)token只能看到它的祖先節(jié)點(diǎn),不受其他分支的影響。下面是一個(gè)簡(jiǎn)單的例子。
擴(kuò)展(Expand)階段的黃色框表示被選中進(jìn)行擴(kuò)展的節(jié)點(diǎn),綠色框?yàn)橐赃@些節(jié)點(diǎn)為輸入時(shí)草稿模型的預(yù)測(cè)。重排(Rerank)階段的藍(lán)色框表示被保留的節(jié)點(diǎn),之后它們被展平成一維作為原始大語(yǔ)言模型的輸入。EAGLE-2根據(jù)樹的結(jié)構(gòu)調(diào)整attention mask,比如,”a”只能看到它的祖先“It”和“is”,看不到另一個(gè)分支的“has”。EAGLE-2也同時(shí)調(diào)整位置編碼,確保和標(biāo)準(zhǔn)自回歸解碼的一致性。
實(shí)驗(yàn)
EAGLE-2在多輪對(duì)話、代碼、數(shù)學(xué)推理、指令遵循、問(wèn)答、總結(jié)六項(xiàng)任務(wù)上分別使用MT-bench、Humaneval、GSM8K、Alpaca、CNN/DM、Natural Questions數(shù)據(jù)集進(jìn)行了實(shí)驗(yàn),與6種先進(jìn)的投機(jī)采樣方法(SpS、PLD、Medusa、Lookahead、Hydra、EAGLE)進(jìn)行了比較。
表格中的Speedup為加速比,τ 為平均接受長(zhǎng)度,也就是原始大語(yǔ)言模型每次前向計(jì)算能生成的token數(shù)。EAGLE-2每次前向計(jì)算能生成大約4-5個(gè)token,而自回歸解碼每次生成1個(gè)token,因此EAGLE-2明顯加速了大語(yǔ)言模型的生成,加速比為2.5x-5x。加速比和接受長(zhǎng)度在代碼生成任務(wù)(Humaneval數(shù)據(jù)集)上最高,這是因?yàn)榇a中存在大量確定性的模板,草稿更容易命中。在所有任務(wù)和大語(yǔ)言模型上,EAGLE-2的加速比和平均接受長(zhǎng)度都是最高的,明顯優(yōu)于其他方法。
應(yīng)用
EAGLE-2也在工業(yè)界得到應(yīng)用,集成至Intel/intel-extension-for-transformers等。