拯救Transformer推理能力!DeepMind新研究TransNAR:給模型嵌入「算法推理大腦」
如今的NLP領(lǐng)域,已然是Transformer架構(gòu)的天下。
從Bert到GPT,再到Llama、Claude,LLM模型使用Transformer已經(jīng)是再正常不過的事情。
Transformer的「大一統(tǒng)」局面正是由于其簡單、高效的架構(gòu),以及在理解自然語言方面無與倫比的泛化能力。
然而,隨著研究的逐漸深入,Transformer的一個致命缺陷也逐漸暴露出來——無法勝任算法推理任務(wù),尤其是不能進行精確、穩(wěn)健的推理。
這嚴重限制了模型在數(shù)學(xué)、代碼等領(lǐng)域下游任務(wù)的應(yīng)用,近年來對Transformer的各種調(diào)優(yōu)、修改似乎也收效甚微。
于是DeepMind的研究人員想到了混合架構(gòu)——將Transformers的語言理解能力與基于圖神經(jīng)網(wǎng)絡(luò)(GNN)的神經(jīng)算法推理器(NAR)的穩(wěn)健性結(jié)合起來,提升其算法推理能力。
他們最近在arxiv上的一篇論文就提出了這個名為TransNAR的架構(gòu),但遺憾的是,目前還沒有公布源代碼。
論文地址:https://arxiv.org/abs/2406.09308
神經(jīng)算法推理(NAR)由本文作者之一Petar Veleckovic在2021年與人合著的一篇論文中提出,并被接收為Patterns期刊的opinion paper。
論文地址:https://arxiv.org/abs/2105.02761
NAR被稱為「構(gòu)建能執(zhí)行算法的神經(jīng)網(wǎng)絡(luò)的藝術(shù)」。作者提出,算法與深度學(xué)習(xí)的本質(zhì)不同,但如果神經(jīng)網(wǎng)絡(luò)能夠更好地模仿算法,它甚至可能具備算法的強泛化性。
更進一步,神經(jīng)網(wǎng)絡(luò)若能表示出算法中連續(xù)空間內(nèi)的元素,就會使已知算法更接近現(xiàn)實世界的問題,提出的解決方案可能超過人類科學(xué)家。
如上圖所示,NAR的整體想法是訓(xùn)練出一個高維隱空間中的處理器網(wǎng)絡(luò)P(processor network),旨在不斷逼近算法的運行結(jié)果A(x)。
但由于算法的輸入和輸出一般是圖、樹、矩陣等抽象、結(jié)構(gòu)化的形式,這與深度學(xué)習(xí)模型高維、嘈雜且多變的輸入很不兼容,因此還需要訓(xùn)練編碼器f和解碼器g,將抽象形式轉(zhuǎn)換為自然形式。
NAR發(fā)布后,有多項研究證實了它有同時執(zhí)行多種算法的能力,也能部署在各種下游任務(wù)中。更重要的是,它的泛化能力似乎遠遠優(yōu)于Transformer架構(gòu)。
原則上,NAR可以擴展到比訓(xùn)練數(shù)據(jù)的分布大幾個數(shù)量級的系統(tǒng)上,有時這個數(shù)量級能達到1.8萬倍。
在使用適當(dāng)?shù)臍w納偏差(inductive biases)時,即使輸入比訓(xùn)練集大6倍,NAR也能在高度復(fù)雜的算法任務(wù)中保持完美的泛化能力。
找到了Transformer和NAR這兩種十分強大且各有所長的架構(gòu),下面最關(guān)鍵的問題就是如何進行相應(yīng)的調(diào)整和修改,使這兩個似乎完全不相容的模型真正實現(xiàn)溝通和Embedding交換。
TransNAR:用預(yù)訓(xùn)練NAR增強Transformer
如何實現(xiàn)NAR+Transformer的有效溝通?作者從多模態(tài)LLM中找到了靈感。
多模態(tài)LLM可以同時接收文本和圖像兩種模態(tài)的輸入,TransNAR也是如此。一邊是算法運行需要的圖結(jié)構(gòu),一邊是描述問題的自然語言。
作者的設(shè)想是,將預(yù)訓(xùn)練的NAR作為Transformer中編碼的調(diào)制器(modulator),二者通過embedding溝通,同時借鑒VLM和Flamingo模型中所用的交叉注意算子,融合不同模態(tài)的信息。
TransNAR接受雙重輸入,包括文本形式的算法問題規(guī)范(T個token)及其對應(yīng)的圖表征(N個節(jié)點),并輸出問題的文本答案。其中輸入的圖表征遵循算法推理基準CLRS-30的格式。
我們可以假設(shè),編碼完成后,文本輸入存儲在T ∈ R^(T×k)中,圖輸入存儲在G ∈ R^(N×l)中。
TransNAR的前向傳播過程如下:
首先,我們通過設(shè)置T^(0) = T和G^(0) = G來正確初始化輸入。
接下來,為了計算第(t+1)步的表征,文本(token)表征被輸入到Transformer的當(dāng)前層:
其中,Qt,Kt ∈ Rk×d_k,Vt ∈ Rk×k分別是鍵、查詢和值矩陣的變換,F(xiàn)FN是一個前饋神經(jīng)網(wǎng)絡(luò)。
以類似的方式,圖表征被輸入到NAR層,例如實現(xiàn)一個標(biāo)準的max-MPNN:
其中,ψ,? : Rk × Rk → Rk分別是可學(xué)習(xí)的消息函數(shù)和更新函數(shù),max是逐元素最大值聚合。
需要注意的是,方程2僅簡要提供了節(jié)點之間的成對交互——實際上,這里的NAR是一個Triplet-GMPNN,它還包含三元組交互和一個門控機制。
此外,還需注意,NAR的可學(xué)習(xí)部分沒有時間步索引——每一步都應(yīng)用相同的共享函數(shù)。這很好地契合了圖算法計算的迭代和重復(fù)性質(zhì)。
一旦兩個流都準備好它們的表征Θt+1和Gt+1,圖中的節(jié)點嵌入將對Transformer的token嵌入進行條件設(shè)置,從而產(chǎn)生Transformer流中TransNAR塊的最終結(jié)果:
其中,Qt×,Kt× ∈ Rk×d_k, Vtx ∈ Rk×k分別是交叉注意力的鍵、查詢和值變換。在結(jié)束這一層之前,對Gt+1不進行額外的變換。
這個過程會一直重復(fù),直到最后的第Nl層,在這一層中,從TN_l讀取最終的文本輸出。
最終輸出通過最后一層生成的預(yù)測頭轉(zhuǎn)換為token logits,并通過標(biāo)準的下一個token預(yù)測來監(jiān)督訓(xùn)練。
在開始TransNAR微調(diào)之前,首先預(yù)訓(xùn)練NAR,使其能夠穩(wěn)健地執(zhí)行CLRS-30覆蓋的三十個算法。這種方法已知可以在圖空間中實現(xiàn)高達4倍輸入規(guī)模的分布外泛化。
在微調(diào)過程中,NAR的參數(shù)通常保持凍結(jié)狀態(tài),因為額外的梯度會削弱模型的原有穩(wěn)健性特性。同樣的原因,圖嵌入不會執(zhí)行交叉注意力。
LLM本身可以在大規(guī)模數(shù)據(jù)集上進行預(yù)訓(xùn)練,以建立其一般語言先驗,即使在開始時隨機初始化LM,也能獲得相同的實驗結(jié)果。
實驗設(shè)置
在實驗中,作者展示了TransNAR為大語言模型架構(gòu)中的分布外推理帶來的顯著優(yōu)勢。
Transformer架構(gòu)和初始化
論文使用Chinchilla家族的一個decoder-only架構(gòu)、6層的Transformer模型,首先在MassiveText上進行了預(yù)訓(xùn)練,參數(shù)量有70M,上下文大小為2048。
為了探究初始化設(shè)置的影響,作者設(shè)計了兩個變體進行消融實驗。
第一個變體中,Transformer權(quán)重用預(yù)訓(xùn)練的結(jié)果初始化,模擬微調(diào)場景;第二個變體則是完全隨機的初始化。這兩個模型分別被標(biāo)記為「預(yù)訓(xùn)練」和「未訓(xùn)練」。
隨機位置編碼
之前DeepMind的一篇論文論證過,隨機位置編碼可以增強Transformer的長度泛化與推理穩(wěn)健性。
論文地址:https://arxiv.org/abs/2305.16843
作者也提到,隨機位置嵌入確實在基線模型和TransNAR上都帶來了顯著增益,因此本文中的所有實驗也都使用隨機位置嵌入。
預(yù)訓(xùn)練NAR
論文使用CLRS-30基準中的問題預(yù)訓(xùn)練了一個多任務(wù)、基于MPNN的NAR,輸入問題規(guī)模最多達16個。
由于CLRS-30的標(biāo)準圖結(jié)構(gòu)表達,這樣訓(xùn)練出來的NAR有很強的分布外(OOD)泛化能力,有時在4倍大小的圖上仍保持競爭力,這種豐富的知識表達正是文本模型可資利用的。
結(jié)合節(jié)點和邊緣的跨注意力貢獻
在上述的算法描述中,我們將NAR模型的圖輸入限于N個節(jié)點,但作者注意到了之前的研究曾嘗試過,同時對圖的節(jié)點和邊生成隱變量表達,也許可以添加有用的互補信息。
于是實驗中引入圖中邊的特征E(t) ∈ RN×N×k,并再次應(yīng)用公式3讓Θ(t)對E(t)進行交叉注意力。
作者也嘗試其他方法,希望將E(t)和G(t)結(jié)合起來,比如拼接后加線性層組合、向量求和、2層MLP,或者用Gram-Schmidt過程使二者的貢獻正交化,但這些都沒有給原始方法帶來提升。
數(shù)據(jù)集
訓(xùn)練數(shù)據(jù)使用CLRS-Text基準,即CLRS-30基準的文本版本,以確定性的方式直接從基于圖的CLRS-30中派生,因此這兩個數(shù)據(jù)集傳達的是完全相同的信息。
表1展示了該數(shù)據(jù)集的幾個樣本,以及它們的輸入大小和token數(shù)量。
由于語言模型上下文長度的限制,實驗選擇用規(guī)模為4、8、12的問題訓(xùn)練,并在規(guī)模為110、12、14的問題上評估。
值得注意的是,與當(dāng)前的評估環(huán)境相比,CLRS-Text是對LM最具挑戰(zhàn)性的長程推理任務(wù)之一——相比小學(xué)數(shù)學(xué),復(fù)雜度顯著提高。
CLRS-Text的挑戰(zhàn)性主要源于它允許顯式控制分布外泛化。然而,每個問題都有清晰的多項式時間解法,這意味當(dāng)今典型LLM的參數(shù)量應(yīng)該足以解決這些問題。
該數(shù)據(jù)集每種算法的每種輸入規(guī)模包含一萬個樣本,總共240萬個數(shù)據(jù)點,其中70%用于訓(xùn)練、30%用于驗證。
訓(xùn)練細節(jié)
實驗將batch大小設(shè)置為256訓(xùn)練了7個epoch,并使用Adam優(yōu)化器,學(xué)習(xí)率為10-4。
如前所述,在所有Chinchilla Transformer的旋轉(zhuǎn)位置編碼(RoPE)之上應(yīng)用隨機位置編碼,最大長度為8192,且訓(xùn)練期間保持NAR凍結(jié)。
評估指標(biāo)
作者提出,合適的評估指標(biāo)應(yīng)該反映模型在特定樣本上失敗的原因,且需要度量型輸出與正確答案的接近程度。因此,使用精確字符串匹配來計算模型準確性是絕對不可行的。
論文選擇的性能指標(biāo)包括以下三個:
1. 形狀分數(shù):一個二元指標(biāo),用于判斷輸出是否具有正確的形狀。例如,在排序任務(wù)中,輸出應(yīng)與輸入有完全相同的元素數(shù)量。或者,如果輸出是一個矩陣,我們需要確保其形狀與輸入和任務(wù)一致。
2. 解析分數(shù):一個二元指標(biāo),用于判斷輸出是否不含任何非法字符。例如,在對數(shù)字列表進行排序的任務(wù)中,輸出不應(yīng)包含任何字母。
3. CLRS分數(shù):輸出中與真實答案匹配的元素百分比,也常用于CLRS-30測試。形狀分數(shù)為0時,CLRS分數(shù)也會自動置零。
這種多方面的指標(biāo)設(shè)計能夠捕捉到LLM在文本上進行推理任務(wù)的各種失敗模式。
比如在某個問題規(guī)模上過度專門化訓(xùn)練(導(dǎo)致輸出的形狀不正確)、無法處理看不見的數(shù)字組合(導(dǎo)致解析錯誤),由于推理錯誤造成的答案不一致則由CLRS分數(shù)反映。
結(jié)果
實驗結(jié)果顯示,TransNAR整體上顯著優(yōu)于Transformer模型,在動態(tài)規(guī)劃、幾何、圖、貪心算法、排序、字符串等任務(wù)上的OOD推理能力都有大幅提升。
并且在大多數(shù)單個算法上,無論是在分布內(nèi)還是分布外都表現(xiàn)更佳。
特別值得注意的是,這種方法不僅增強了Transformer原有的OOD泛化能力,還激發(fā)了一些模型先前完全不具備的能力。
比如Graham掃描(graham_scan)、最長公子串長度(lcs_length)、強連通分量(scc)這些經(jīng)典問題中,基線模型得分為零或接近零,但TransNAR卻實現(xiàn)了突破。
分析形狀分數(shù)可以進一步解釋,為什么TransNAR表現(xiàn)如此出色。
首先,回顧一下,如果形狀不匹配,CLRS得分必然為零。
從形狀得分來看,將Transformer的輸出建立在NAR嵌入基礎(chǔ)上顯著提高了答案中形狀正確的比例——這表明TransNAR緩解了一種特定的LLM故障模式。
此外,通過對比「預(yù)訓(xùn)練」和「未訓(xùn)練」兩種初始化方式的分數(shù),可以看到模型較好的穩(wěn)定性和可用性。在隨機初始化時,也能訓(xùn)練到與微調(diào)相當(dāng)?shù)乃疁省?/span>
然而,在一些算法中,TransNAR仍未能超越基線,且在分布內(nèi)和分布外都是如此。
這些算法包括二分搜索、尋找最大子數(shù)組、最小值和快速選擇等,都涉及在輸入列表中按照索引搜索特定元素。
這暗示了TransNAR的一種故障模式:模型無法泛化到訓(xùn)練數(shù)據(jù)中未見過的新索引邊界。因此,使用索引提示或許是一條有前景的改進途徑。
另一種可能的解釋是,NAR最終計算出的隱藏狀態(tài)難以在交叉注意力層以可泛化的方式被解碼。如果原因在此,解決途徑可以是增加交叉注意力的容量,或者采用漸進式解碼。
此外,TransNAR在架構(gòu)上有一個本質(zhì)的局限性,就是必需一個能得出ground truth的模擬器或者數(shù)據(jù)標(biāo)簽,用于將輸入的文本轉(zhuǎn)換為圖結(jié)構(gòu),再作為模型輸入。
但是作者強調(diào),TransNAR的概念對于未來研究是有借鑒意義的??梢钥紤]將這種混合架構(gòu)的想法移植到單模態(tài)LLM,或者將TransNAR訓(xùn)練后獲得的知識提煉出來注入到普通的Transformer中。