next-token被淘汰!Meta實(shí)測「多token」訓(xùn)練方法,推理提速3倍,性能大漲10%+
當(dāng)前,大型語言模型,例如GPT和Llama,主要是根據(jù)「前文的單詞序列」對「下一個(gè)token」進(jìn)行預(yù)測的方式來訓(xùn)練。
但你有沒有想過一個(gè)問題,為什么不對后文的tokens同時(shí)進(jìn)行預(yù)測呢?
最近,Meta、巴黎高科路橋大學(xué)、巴黎薩克雷大學(xué)的研究人員就聯(lián)合提出了一種新的訓(xùn)練方法,即一次性預(yù)測多個(gè)未來tokens,可以提高模型的樣本效率。
論文鏈接:??https://arxiv.org/pdf/2404.19737??
具體來說,在訓(xùn)練語料庫的每一個(gè)位置,要求模型使用n個(gè)獨(dú)立的輸出頭網(wǎng)絡(luò)來預(yù)測緊隨其后的n個(gè)token,其中所有輸出頭都基于同一個(gè)模型主干。
研究人員將多token預(yù)測視作是一種輔助訓(xùn)練任務(wù),實(shí)驗(yàn)發(fā)現(xiàn)該方法不僅能夠提升模型在各種下游任務(wù)上的表現(xiàn),而且不會(huì)增加訓(xùn)練時(shí)間,對代碼生成和自然語言生成任務(wù)都是有益的。
隨著模型尺寸的增大,該方法的優(yōu)勢變得更加明顯,尤其是在進(jìn)行多epochs訓(xùn)練時(shí)。
在編程等生成性任務(wù)的基準(zhǔn)測試中,使用多token預(yù)測訓(xùn)練的模型的性能提升尤為顯著,能夠穩(wěn)定地超過傳統(tǒng)單token預(yù)測模型。
例如,13B參數(shù)的模型在HumanEval基準(zhǔn)測試中解決問題的能力比同等規(guī)模的單token模型高出12%,在MBPP基準(zhǔn)測試中高出17%
此外,通過在小型算法任務(wù)上的實(shí)驗(yàn),研究人員發(fā)現(xiàn)多token預(yù)測對于提升模型的歸納頭(induction heads)和算法推理能力是有益的。
而且,使用多token預(yù)測訓(xùn)練的模型在推理時(shí)速度更快,最高可達(dá)三倍,即便是在處理大規(guī)模數(shù)據(jù)批次時(shí)也是如此。
多token預(yù)測
標(biāo)準(zhǔn)語言模型通過執(zhí)行一個(gè)「下一個(gè)token預(yù)測」任務(wù)來對大型文本語料庫進(jìn)行學(xué)習(xí),任務(wù)目標(biāo)是最小化交叉熵?fù)p失,其中模型需要最大化「在給定之前token序列歷史的條件下,預(yù)測下一個(gè)token」的概率。
研究人員將「單token預(yù)測」任務(wù)泛化為「多token預(yù)測」,在訓(xùn)練預(yù)料上的每個(gè)位置,模型需要一次性預(yù)測未來n個(gè)tokens,交叉熵?fù)p失改寫為:
為了使問題可解,假設(shè)大型語言模型使用一個(gè)共享的主干網(wǎng)絡(luò)來生成觀察到的上下文的潛表征z,然后再把該表征送入到n個(gè)獨(dú)立的頭網(wǎng)絡(luò),以并行的方式預(yù)測每一個(gè)未來token
多token預(yù)測的交叉熵?fù)p失可以分解為兩部分:在給定token序列下的潛表征,以及在該潛表征條件下,預(yù)測n個(gè)未來token
在實(shí)踐中,該架構(gòu)包括一個(gè)共享Transformer主干模型,根據(jù)上下文詞序列來生成潛表征,n個(gè)獨(dú)立的、基于Transformer層的輸出頭,以及一個(gè)共享的unembedding矩陣。
節(jié)省內(nèi)存
在訓(xùn)練多token預(yù)測器時(shí),一個(gè)關(guān)鍵問題是GPU顯存占用過多。
在當(dāng)前的大型語言模型(LLMs)中,詞匯表的大小V通常遠(yuǎn)遠(yuǎn)大于潛在表示的維度d,因此logit vectors就成了GPU內(nèi)存使用的瓶頸。
如果簡單地實(shí)現(xiàn)多token預(yù)測器,將所有的logit vectors及其梯度都存儲(chǔ)在內(nèi)存中,會(huì)導(dǎo)致內(nèi)存使用量迅速增加,因?yàn)槊總€(gè)向量的形狀都是 (n, V),這種方式會(huì)極大地限制模型可同時(shí)處理的批次大小,并增加GPU顯存的平均使用量。
研究人員提出了一種內(nèi)存高效的實(shí)現(xiàn)方法,通過調(diào)整前向傳播和反向傳播操作的順序來減少內(nèi)存使用。
具體來說,在通過共享主干網(wǎng)絡(luò)fs 完成前向傳播后,模型會(huì)按順序?qū)γ總€(gè)獨(dú)立的輸出頭部fi 執(zhí)行前向和反向傳播,并在主干網(wǎng)絡(luò)處累積梯度,每個(gè)輸出頭部fi的logit向量和梯度在計(jì)算后就會(huì)被釋放,無需一直占用內(nèi)存,直到所有頭部的計(jì)算完成。
這意味著,除了主干網(wǎng)絡(luò)的梯度外,不需要長期存儲(chǔ)其他任何梯度,從而顯著降低了GPU內(nèi)存的使用。
通過這種方式,模型的內(nèi)存復(fù)雜度從O(nV+d)降低到了O(V+d),在不犧牲運(yùn)行時(shí)間的情況下,顯著減少了GPU的峰值內(nèi)存使用。
推理階段Inference
在推理時(shí),該模型的最基礎(chǔ)用法是使用「下一個(gè)token預(yù)測頭」(next-token prediction head)進(jìn)行「基本next-token自回歸預(yù)測」,同時(shí)丟棄所有其他頭網(wǎng)絡(luò)。
也可以利用額外的輸出頭網(wǎng)絡(luò)進(jìn)行自推理解碼,對從下一個(gè)token預(yù)測頭網(wǎng)絡(luò)的解碼進(jìn)行加速:
1. 區(qū)塊并行解碼(blockwise parallel decoding),一種推理解碼的變體方法,可以并行地預(yù)測多個(gè)token,而不需要額外的草稿模型;
2. 使用類似美杜莎(Medusa)樹注意力機(jī)制的推測解碼,可以提高解碼速度和效率。
實(shí)驗(yàn)結(jié)果
研究人員總共進(jìn)行了七個(gè)大規(guī)模實(shí)驗(yàn)來證明多token預(yù)測損失的有效性。
為了公平對比next-token預(yù)測器和n-token預(yù)測器,實(shí)驗(yàn)中的模型參數(shù)量均相同,也就是說,在預(yù)測未來頭網(wǎng)絡(luò)中添加n-1層時(shí),同時(shí)也會(huì)從共享模型主干中移除n-1層。
1. 性能隨模型尺寸增大而提升
為了研究模型尺寸的影響,研究人員從零開始訓(xùn)練了「六個(gè)」模型,尺寸范圍覆蓋了從300M到13B參數(shù)量,至少使用了91B tokens的代碼。
從評估結(jié)果中可以看到,在MBPP和HumanEval上的實(shí)驗(yàn)表明,在相同的計(jì)算量下,使用多token預(yù)測,可以在固定數(shù)據(jù)集上獲得更好的性能。
研究人員認(rèn)為,該特性只有在大規(guī)模數(shù)據(jù)、大尺寸模型上才能體現(xiàn)出來,這也可能是多token預(yù)測一直沒有在大型語言模型訓(xùn)練上廣泛應(yīng)用的原因。
2. 更快的推理速度
研究人員使用異構(gòu)批量大小的xFormers實(shí)現(xiàn)貪婪自我推測解碼(self-speculative decoding),并測量了最佳的4-tokens預(yù)測模型(7B參數(shù))在補(bǔ)全代碼和自然語言數(shù)據(jù)時(shí)的解碼速度。
可以看到,該方法在代碼生成任務(wù)上速度提升了3.0倍,文本生成的速度提升了2.7倍,在8字節(jié)預(yù)測模型上,推理速度提升了6.4倍。
使用多token預(yù)測進(jìn)行預(yù)訓(xùn)練時(shí),額外的頭網(wǎng)絡(luò)可以比單個(gè)next-token預(yù)測模型的微調(diào)更準(zhǔn)確,從而讓模型充分發(fā)揮自推測解碼的全部潛力。
3. 用多字節(jié)預(yù)測來學(xué)習(xí)全局pattern
為了展示next-token預(yù)測任務(wù)能夠捕捉到局部模式,研究人員采取了極端情況,即字節(jié)級分詞(byte-level tokenization),通過訓(xùn)練一個(gè)7B參數(shù)的字節(jié)級Transformer模型來處理314B個(gè)byte,大約相當(dāng)于116B個(gè)tokens
8-byte預(yù)測模型與next-byte預(yù)測相比取得了顯著的性能提升,在MBPP pass@1上解決了超過67%的問題,在HumanEval pass@1上解決了20%的問題。
因此,多字節(jié)預(yù)測是一個(gè)非常有前景的方法,可以讓字節(jié)級模型的訓(xùn)練更高效。
自推測解碼可以實(shí)現(xiàn)8字節(jié)預(yù)測模型的6倍速度提升,完全彌補(bǔ)了在推理時(shí)「更長字節(jié)序列」的成本,甚至比next-token預(yù)測模型快近兩倍。
盡管訓(xùn)練所用的數(shù)據(jù)量少了1.7倍,但8字節(jié)預(yù)測模型的性能仍然能接近基于token的模型。
4. 尋找最優(yōu)的n值
為了更好地理解預(yù)測token數(shù)量的影響,研究人員在7B尺寸的模型(訓(xùn)練數(shù)據(jù)包含了200B個(gè)代碼token)上進(jìn)行了全面的消融實(shí)驗(yàn),在不同實(shí)驗(yàn)設(shè)置中嘗試了 n = 1, 2, 4, 6和8
實(shí)驗(yàn)結(jié)果顯示,使用4個(gè)未來token進(jìn)行訓(xùn)練時(shí),在HumanEval和MBPP的所有pass at 1, 10和100指標(biāo)上均超越了其他對比模型:MBPP的改進(jìn)分別為+3.8%, +2.1%和+3.2%,HumanEval的改進(jìn)分別為+1.2%, +3.7%和+4.1%
有趣的是,在APPS/Intro上,n = 6時(shí)的性能提升分別為+0.7%, +3.0%和+5.3%
最佳的窗口尺寸很可能取決于輸入數(shù)據(jù)的分布。至于字節(jié)級模型,最佳窗口大小在基準(zhǔn)測試中更為一致(8字節(jié))。
5. 多epochs訓(xùn)練
在進(jìn)行機(jī)器學(xué)習(xí)模型訓(xùn)練時(shí),多tokens訓(xùn)練方法在處理相同數(shù)據(jù)集的多個(gè)訓(xùn)練周期時(shí),對于預(yù)測下一個(gè)token的任務(wù)仍然顯示出了優(yōu)勢。
雖然隨著訓(xùn)練周期的增加,優(yōu)勢略有下降,但在MBPP數(shù)據(jù)集上的pass@1指標(biāo)上,仍然觀察到了2.4%的提升;在HumanEval數(shù)據(jù)集上的pass@100指標(biāo)上,提升更是達(dá)到了3.2%
結(jié)果表明,即使在多次訓(xùn)練后,多tokens訓(xùn)練方法仍然能夠帶來一定的性能提升。
但對于APPS/Intro數(shù)據(jù)集來說,當(dāng)訓(xùn)練token數(shù)量達(dá)到200B時(shí),使用窗口大小為4的訓(xùn)練方法已經(jīng)不再是最優(yōu)的選擇,可能需要調(diào)整窗口大小或采用其他策略來進(jìn)一步提高模型性能。
6. 微調(diào)多token預(yù)測器
在機(jī)器學(xué)習(xí)領(lǐng)域,預(yù)訓(xùn)練模型通過多token預(yù)測損失函數(shù)進(jìn)行訓(xùn)練,相較于傳統(tǒng)的單token預(yù)測模型,該方法在后續(xù)的微調(diào)階段展現(xiàn)出了更好的性能。
研究人員在CodeContests數(shù)據(jù)集上對具有7B參數(shù)的模型進(jìn)行了微調(diào)測試,將一個(gè)能夠預(yù)測接下來4個(gè)token的模型與基礎(chǔ)的單token預(yù)測模型進(jìn)行了比較,并嘗試了一種將4 tokens預(yù)測模型去除額外預(yù)測頭后,按照傳統(tǒng)的單token預(yù)測目標(biāo)進(jìn)行微調(diào)的設(shè)置。
實(shí)驗(yàn)結(jié)果顯示,在pass@k指標(biāo)上,無論采用哪種微調(diào)方式,4-tokens預(yù)測模型的表現(xiàn)都超過了單token預(yù)測模型,也表明4-tokens預(yù)測模型在理解任務(wù)、解決問題以及產(chǎn)生多樣化答案的能力上更為出色。
實(shí)驗(yàn)結(jié)果還表明,在4-tokens預(yù)測預(yù)訓(xùn)練的基礎(chǔ)上進(jìn)行單token預(yù)測微調(diào),可能是一個(gè)綜合性能最佳的策略,與先使用輔助任務(wù)進(jìn)行預(yù)訓(xùn)練,然后進(jìn)行特定任務(wù)微調(diào)的經(jīng)典機(jī)器學(xué)習(xí)范式相吻合。
7. 在自然語言上的多token預(yù)測
研究人員訓(xùn)練了參數(shù)量為7B的模型,并使用了三種不同的預(yù)測損失方法:預(yù)測4token、2-token以及單個(gè)token,并在6個(gè)標(biāo)準(zhǔn)的自然語言處理(NLP)基準(zhǔn)測試中進(jìn)行了性能評估。
在摘要任務(wù)中,使用了8個(gè)不同的基準(zhǔn)測試,并通過ROUGE指標(biāo)來自動(dòng)評估生成文本的質(zhì)量,結(jié)果顯示,2-token和4-token的性能都比單token預(yù)測基線的表現(xiàn)更好。
本文轉(zhuǎn)自 新智元 ,作者:新智元
