小紅書(shū)提出大模型推理加速算法 HASS 刷新 SOTA
在大模型推理領(lǐng)域,投機(jī)采樣是一種被廣泛使用的無(wú)損加速算法。近期一些投機(jī)采樣的工作將大模型的上下文信息(例如 hidden states 和 KV cache)引入草稿模型,可以充分利用大模型的知識(shí)來(lái)提升加速比,但這類算法也會(huì)帶來(lái)訓(xùn)練和解碼的上下文不一致問(wèn)題。此外,我們也發(fā)現(xiàn)現(xiàn)有算法在訓(xùn)練和解碼的目標(biāo)上也存在一定的不一致現(xiàn)象。小紅書(shū)中臺(tái)算法團(tuán)隊(duì)提出的 HASS 算法在目標(biāo)和上下文上對(duì)齊了草稿模型的訓(xùn)練和解碼階段,達(dá)到了普通推理速度的 2.81~4.05 倍,相比 SOTA 方法 EAGLE-2 提升 8%~20%,相關(guān)技術(shù)已應(yīng)用在小紅書(shū)實(shí)際業(yè)務(wù)場(chǎng)景中。
論文地址
https://arxiv.org/pdf/2408.15766
01 背景
生成式大語(yǔ)言模型(LLMs)在各種任務(wù)上表現(xiàn)出令人驚嘆的能力。然而,由于其固有的自回歸解碼機(jī)制,人們難以在這些模型上高效推理,這限制了它們?cè)跁r(shí)間敏感場(chǎng)景中的應(yīng)用。投機(jī)采樣技術(shù)通過(guò)利用額外的資源來(lái)增加并發(fā)性,提供了一種大模型推理加速的解決方案。
投機(jī)采樣(Speculative Sampling)
投機(jī)采樣是一種先起草再驗(yàn)證的解碼范式。在每一步解碼時(shí),先高效地生成多個(gè)草稿 token,再使用目標(biāo) LLM 并行地驗(yàn)證這些 token 來(lái)加速推理。 表示想要加速推理的目標(biāo) LLM,
表示基于前綴
從目標(biāo) LLM 生成下一個(gè) token 的條件概率分布(簡(jiǎn)寫(xiě)為
)。
表示一個(gè)更高效的草稿模型,
表示基于前綴
從草稿模型生成下一個(gè) token 的條件概率分布(簡(jiǎn)寫(xiě)為
)。投機(jī)采樣分為如下3步:
1.使用更高效的草稿模型 來(lái)生成
個(gè)草稿 token。
2.使用目標(biāo) LLM 來(lái)并行地驗(yàn)證這些草稿 token 以及它們從
被生成的概率,接受能使得輸出分布和
一致的所有草稿 token。
3.如果某個(gè)草稿 token 被拒絕后,從修正后的分布中采樣一個(gè)額外的 token 替代它;如果所有的草稿 token都被接受,額外增加一個(gè)新的 token。
具體驗(yàn)證過(guò)程如下:從 中采樣一個(gè)草稿 token
,如果
則接受
;否則將以
的概率拒絕并從修正分布
中重新采樣一個(gè) token 接受。
經(jīng)證明,對(duì)于任意的 和
,如此得到的 token 總是與目標(biāo) LLM 分布一致。目標(biāo) LLM 的每一次前向推理至少產(chǎn)生一個(gè)新的 token,而至多產(chǎn)生
個(gè)新的 token,生成的個(gè)數(shù)取決于目標(biāo) LLM 和草稿模型的對(duì)齊程度。
投機(jī)采樣的實(shí)際性能取決于兩個(gè)因素:草稿模型的解碼成本及其與目標(biāo) LLM 的對(duì)齊程度。為了獲得與目標(biāo) LLM 高度對(duì)齊的高效草稿模型,近期的工作提出利用目標(biāo) LLM 的上下文信息。例如,EAGLE 使用目標(biāo) LLM 的 hidden states 作為草稿模型的輸入特征。然而,這些方法在訓(xùn)練和解碼階段引入了不一致的上下文,如圖 2 所示。在訓(xùn)練期間,草稿模型總是能獲取到目標(biāo) LLM 在先前時(shí)間步的 hidden states。但在解碼期間,草稿模型卻無(wú)法獲取到未被驗(yàn)證時(shí)間步的目標(biāo) LLM 的 hidden states,這導(dǎo)致了訓(xùn)練和解碼階段的上下文不一致。這一問(wèn)題可以看作是投機(jī)采樣中在特征層面的 exposure bias。
訓(xùn)練和解碼階段之間還存在目標(biāo)上的不一致。在解碼階段,草稿模型的目標(biāo)是生成目標(biāo) LLM 會(huì)賦予高概率的 token。在這種情況下,草稿模型應(yīng)更關(guān)注于召回這些高概率 token,而對(duì)它們之間的具體順序則可以稍微放松。另外,大部分 LLM 在應(yīng)用時(shí)采取核采樣或 top-k 采樣。在這些解碼策略中,高概率 token 對(duì)輸出起著更重要的作用。因此,為了獲得高效的草稿模型,它的訓(xùn)練目標(biāo)應(yīng)考慮到解碼階段的這些特性。據(jù)我們所知,現(xiàn)有的涉及訓(xùn)練草稿模型的投機(jī)采樣方法普遍忽視了這些解碼目標(biāo)。
02 方法
為解決上述的訓(xùn)練和解碼階段不一致問(wèn)題,我們提出了協(xié)調(diào)投機(jī)采樣(HASS),旨在通過(guò)訓(xùn)練階段學(xué)習(xí)協(xié)調(diào)的表征來(lái)解決上述問(wèn)題。我們的方法包含兩部分:(1)為了讓草稿模型在訓(xùn)練階段感知到解碼目標(biāo),HASS 將推薦系統(tǒng)中的排序蒸餾思想擴(kuò)展到投機(jī)采樣,即協(xié)調(diào)目標(biāo)蒸餾;(2)為了解決訓(xùn)練和解碼間的上下文不一致,我們提出了一種多步的對(duì)齊訓(xùn)練策略,即協(xié)調(diào)上下文對(duì)齊。結(jié)合這兩部分,HASS 顯著提高了 LLM 的推理速度。在無(wú)需額外推理開(kāi)銷的情況下,也保持了草稿模型訓(xùn)練的高效。
協(xié)調(diào)目標(biāo)蒸餾(Harmonized Objective Distillation)
HASS 通過(guò)引入推薦系統(tǒng)中的排序蒸餾思想,優(yōu)先考慮草稿模型解碼時(shí)更重要的一些 token。具體來(lái)說(shuō),排序蒸餾的目標(biāo)是訓(xùn)練學(xué)生模型,使其對(duì)教師模型中排名靠前的項(xiàng)賦予更高的排序。在投機(jī)采樣中,草稿模型是學(xué)生模型,而目標(biāo) LLM 是教師模型。具有類似特性的草稿模型在解碼階段將獲得更高的接收率。設(shè) K 個(gè)概率最高的 token 組成的集合為 ,其中
代表整個(gè)詞匯表。HASS 在訓(xùn)練時(shí)使用以下的 Top-K 蒸餾損失:
其中 和
分別表示目標(biāo) LLM 和草稿模型預(yù)測(cè)下一個(gè)詞的條件概率分布。在結(jié)合 EAGLE 時(shí),訓(xùn)練階段可以從目標(biāo) LLM 的 hidden states 中獲取
,這意味著結(jié)合 Top-K 損失訓(xùn)練有著和 EAGLE 一樣的訓(xùn)練效率。
協(xié)調(diào)上下文對(duì)齊(Harmonized Context Alignment)
HASS 采用了多步的對(duì)齊訓(xùn)練策略,使草稿模型在訓(xùn)練和解碼階段的上下文保持一致。具體來(lái)說(shuō),HASS 將訓(xùn)練過(guò)程分為 n 步,使草稿模型能夠利用與解碼階段一致的上下文特征。過(guò)程如下:
- 第一步與 EAGLE 的訓(xùn)練相同。在時(shí)間步 t+1,草稿模型以目標(biāo)LLM的特征
作為輸入并生成草稿模型特征
。這一步中,注意力掩碼與因果掩碼一致,不做修改。
- 第二步利用了來(lái)自第一步的特征。在時(shí)間步 t+1 的自注意力機(jī)制中,使用
來(lái)生成 query。key 和 value 由
生成,其中
表示拼接操作,
表示早于時(shí)間步 t 的特征。注意力掩碼被修改以確保
看到的前一個(gè)特征始終是
,如圖 3 中的“HASS Training Step 2“所示。
- 對(duì)于第 j 步(j ≥ 3),前一步生成的特征
用于生成時(shí)間步 t+1 的query,而 key 和 value 由
生成。
HASS 的訓(xùn)練開(kāi)銷是 EAGLE 的 n 倍,但解碼開(kāi)銷不變。后續(xù)實(shí)驗(yàn)證明,HASS 的加速效果在 n 值較小時(shí)就會(huì)收斂,因此是訓(xùn)練高效的,具體實(shí)現(xiàn)請(qǐng)參考論文的附錄部分。
03 實(shí)驗(yàn)
主要實(shí)驗(yàn)
如表 1、2 所示,HASS 在所有的數(shù)據(jù)集和目標(biāo) LLM 上都表現(xiàn)出了最高的接受長(zhǎng)度和最優(yōu)的加速比。大部分方法在 HumanEval 數(shù)據(jù)集上加速效果最好,因?yàn)榇a生成任務(wù)中的固定模版對(duì)于草稿模型更易生成從而加速。盡管 PLD 和 Lookahead 無(wú)需訓(xùn)練,但是它們的性能都顯著弱于 EAGLE、EAGLE-2 和 HASS。
協(xié)調(diào)目標(biāo)蒸餾的消融實(shí)驗(yàn)
我們改變了 Top-K 損失的 K 和權(quán)重,結(jié)果如圖 4 所示。使用 Top-K 損失訓(xùn)練(權(quán)重大于 0)時(shí),總是能提升草稿模型的接受長(zhǎng)度。當(dāng) K 值很小時(shí)(K=1)會(huì)導(dǎo)致性能下降,可能是因?yàn)椴莞迥P瓦^(guò)度關(guān)注概率最高的 token 而忽視了其他潛在 token。在 K=5 時(shí),草稿模型的接受長(zhǎng)度最大。
我們還嘗試了更多關(guān)注高概率 token 的損失函數(shù)以替換 Top-K 損失,結(jié)果如表 3 所示。BiLD 損失在 T=0 時(shí)表現(xiàn)最好,Top-K 損失在 T=1 時(shí)表現(xiàn)最好??傮w上,Top-K 損失的表現(xiàn)最好。
協(xié)調(diào)上下文對(duì)齊的消融實(shí)驗(yàn)
我們改變了協(xié)調(diào)上下文對(duì)齊的對(duì)齊步數(shù),將用 Top-K 損失訓(xùn)練后的 EAGLE-2 權(quán)重作為基準(zhǔn),結(jié)果如表 4 所示。在不使用協(xié)調(diào)上下文對(duì)齊時(shí)(EAGLE-2+Top-K),草稿模型的效果最差。用 3 或 4 步協(xié)調(diào)上下文對(duì)齊訓(xùn)練的草稿模型總體上能獲得最優(yōu)的接受長(zhǎng)度。當(dāng)對(duì)齊步數(shù)增加到 5 步時(shí),接受長(zhǎng)度反而會(huì)下降,這可能是因?yàn)椴莞迥P偷哪芰τ邢?,?dāng)過(guò)度關(guān)注后幾步的 token 生成時(shí)就會(huì)導(dǎo)致在前幾步的預(yù)測(cè)精度下降。
我們畫(huà)出了 HASS 和 EAGLE-2 在每一步生成token時(shí)的接受率曲線,如圖 5 所示。可見(jiàn)在后幾步生成 token 時(shí),HASS 的接受率顯著高于 EAGLE-2,驗(yàn)證了協(xié)調(diào)上下文對(duì)齊的有效性。
但在 LLaMA2-Chat 13B 和 LLaMA3-Instruct 70B 上,HASS 的第一步接受率相比 EAGLE-2 下降了。這可能是因?yàn)椴莞迥P完P(guān)注后幾步的 token 生成而忽視了第一步的,但第一步的接受率對(duì)于接受長(zhǎng)度非常關(guān)鍵。因此我們考慮調(diào)整訓(xùn)練時(shí)每一步對(duì)齊的損失權(quán)重,來(lái)強(qiáng)調(diào)前幾步的重要性。具體的,我們對(duì)于第 j 步的訓(xùn)練損失乘上權(quán)重 ,結(jié)果如表 5 和圖 6 所示。當(dāng)
從 1.0 降到 0.5 時(shí),草稿模型的接受長(zhǎng)度不斷提高。其在第一步的接受率也對(duì)應(yīng)增長(zhǎng),而后幾步的接受率有所下降。當(dāng)
下降到 0.3 時(shí),訓(xùn)練過(guò)程過(guò)分強(qiáng)調(diào)了第一步 token 生成,導(dǎo)致了接受長(zhǎng)度下降。我們將在多步對(duì)齊間取得平衡的探索留到后續(xù)工作中。
04 作者簡(jiǎn)介
- 樂(lè)凡
小紅書(shū)中臺(tái)算法工程師,目前主要負(fù)責(zé)大語(yǔ)言模型的相關(guān)研究和應(yīng)用。
- 曉丹
小紅書(shū)中臺(tái)算法工程師,目前主要負(fù)責(zé)大語(yǔ)言模型的相關(guān)研究和應(yīng)用。
- 特圖
小紅書(shū)中臺(tái)算法基礎(chǔ)模型方向負(fù)責(zé)人,主要研究方向:多模態(tài)大模型 x 內(nèi)容分發(fā)技術(shù)。
- 瑞格
小紅書(shū)中臺(tái)算法團(tuán)隊(duì)負(fù)責(zé)人。