無比喻,不論文!用「畫家流水線」的方式理解Transformer中間層
盡管Transformer架構(gòu)已經(jīng)主宰了當今幾乎所有的大模型,但我們依舊對它的工作原理知之甚少。
而且,基于Transformer的預訓練LLM動輒有幾十億參數(shù),很難直接對模型進行可解釋性分析。
同時,模型中間層由N個相同的塊堆疊在一起,它們之間唯一的區(qū)別只有層次位置和權(quán)重值,這就讓理解中間層更加困難。
然而,最近發(fā)表的一篇論文卻給出了一個十分通俗易懂的比喻——「畫家流水線」。
論文地址:https://arxiv.org/pdf/2407.09298v1
有著「東京AI夢之隊」之稱的Sakana AI,聯(lián)合IBM前AI負責人Satya Nitta創(chuàng)始的Emergence AI,兩個團隊的研究人員用一種新的「打開方式」來解釋Transformer架構(gòu)的中間層。
值得一提的是,這篇論文作者之一Llion Jones同樣也是當年Transformer架構(gòu)的共同創(chuàng)建者之一。
那么,「畫家流水線」這個比喻該如何理解呢?
首先,輸入被看作是一張畫布,輸入通過N個組成中間層的塊的過程,就像是畫布在「畫家流水線」上進行傳遞的過程。
有些畫家擅長畫鳥,而有些畫家則更擅長畫魚。每個畫家從前面的畫家手中接過畫布,然后決定是在畫上添幾筆,還是直接傳給后面的畫家。
在這個類比中,非常重要的一點是,每個畫家都使用相同的「詞匯」來理解畫作,因此一個畫家可以在流水線上從前一個畫家手中接過畫作,但不會因為對畫面理解不同而造成災難。
畫家們也可以重新排序(調(diào)整圖層的前后順序),甚至可以同時添加筆觸,就像N個塊可以并行運行。
這個類比并不是一個嚴謹?shù)睦碚?,但可以提供一個幫助我們思考Transformer層的有趣視角。
在這個類比的啟發(fā)下,研究人員提出了一些假設,并通過實驗來驗證這些假設是否成立——
- 不同層是否使用相同的表征空間?
- 所有的層都是有必要的嗎?
- 中間層是否都在執(zhí)行相同的功能?
- 層的順序重要嗎?
- 我們能并行運行各層嗎?
- 順序是否對與某些特定任務而言更重要
- 循環(huán)是否有助于并行層?
- 哪些變體對性能的損害最小?
實驗
主要用于實驗包括兩種預訓練LLM,分別是decoder-only架構(gòu)的Llama2-7B,以及encoder-only架構(gòu)的BERT。Llama2-7B有70億個參數(shù)和32層(每層含2.02億個參數(shù)),BERT僅有24層和3.4億個參數(shù)。
在下述所有實驗過程中,模型都是凍結(jié)的。除了對BERT進行GLUE基準測試時進行了標準的微調(diào)步驟,參數(shù)沒有經(jīng)過任何修改。
評估過程采用了ARC(科學考試題)、HellaSwag(常識)、GSM8K(數(shù)學應用題)、LAMBADA(單詞預測)等常用基準。
其中LAMBADA任務可以衡量模型困惑度(perplexity),任務最接近預訓練時的原始token預測。
結(jié)果發(fā)現(xiàn),Transformer的中間層有一定程度的一致性,但不冗余,而且對數(shù)學、推理任務而言,各層的運行順序比在語義任務中有更重要的影響。
各層「說同一種語言」?
Transformer中的不同層是否共享相同的表示空間?
為了回答這個問題,論文采用的方法是讓模型跳過特定層或調(diào)換相鄰層的順序,觀察會不會出現(xiàn)災難性后果。
圖2中展示了Llama 2 7B在跳過或調(diào)換一些層后,模型整體在Open-LAMADA基準上的表現(xiàn)。
可以看到,除了起始和末端的幾層,模型對這兩種架構(gòu)修改都表現(xiàn)出了相當強的魯棒性。
因此可以得出初步結(jié)論:1)中間層共享同一個表示空間,2)表示空間與「外層」(第一層和最后幾層)不同。
為了進一步驗證,論文還進入模型內(nèi)部,測量了不同層中隱藏狀態(tài)內(nèi)激活函數(shù)的余弦相似度(圖3),表明這種一致性在三個模型的所有中間層都成立。
上圖還可以很清晰看到,模型各層自然形成了4~5個不同的相似組,比如Llama 2 13B模型中分別是:第0層,1-3層、中間層,以及最后的1層或2層。
據(jù)此,Transformer中的所有層可以被大致分為三類:起始層、中間層和結(jié)束層。
此外,圖3中的矩陣也能和圖2中的模型分數(shù)相對應,更能有力證明,中間層之間共享語義表達空間。
所有層都必要?
為了進一步檢驗中間層的重定向空間是否真正共享(除了具有接近的余弦相似性),研究人員嘗試跳過多個層。
也就是說,將第N層的輸出直接送入第N+M層的輸入(其中M>1),從而「跳過」M-1層。
在不進行任何微調(diào)的情況下,這個實驗是要看看N+M層能否理解來自N層的激活,盡管它在訓練中只接受了來自N+M-1層的輸入。
結(jié)果顯示,Llama2-7B和BERT-Large的許多基準性能都出現(xiàn)了一定程度的下降。
那么,所有層都有必要嗎?這一問題已經(jīng)有了答案。
No! 并非所有層都是必要的,至少有幾個中間層可以跳過,而不會發(fā)生災難性故障。
左圖:Llama2-7B跳過N層~32-N層的基準測試結(jié)果(歸一化);右圖:BERT跳過N層~24-N 層的基準測試結(jié)果(未歸一化)
中間層功能相同嗎?
如果中間層共享一個共同的表征空間,這是否意味著這些層是多余的呢?
為了驗證這一點,研究人員重新進行了上一小節(jié)的「跳過」實驗。
但不同的是,這次不是直接跳過M個中間層,而是用模型最中心的的一層代替全部M個層(Llama是第16層,BERT是第12層),相當于在這一層上循環(huán)T-2N+1次,其中T是層的總數(shù)。
結(jié)果表明,隨著被替換層數(shù)M的增加,基準測試結(jié)果迅速下降。
在研究人員所嘗試的所有測試中,這一項測試的變化是最嚴重的,比直接跳過一些層還要嚴重得多。
因此,中間層功能相同嗎?這一問題的答案是——
No! 在中間層之間共享權(quán)重是災難性的,這表明中間層在執(zhí)行不同的功能。
用中心層替換M個中間層(左側(cè)經(jīng)過歸一化,右側(cè)未經(jīng)歸一化)
順序重要嗎?
之前的實驗表明,中間層共享一個表征空間,但對這個空間執(zhí)行不同的操作。
那么另一個問題來了——這些操作的執(zhí)行順序有多重要?
論文進行了兩組實驗來檢驗這個問題。首先,以與預訓練完全相反的順序運行中間層,如下圖所示:
第二組則是以隨機順序運行中間層,最終結(jié)果是取10個隨機種子進行實驗后的均值。
圖6和圖7分別展示了中間層完全翻轉(zhuǎn)和隨機順序的結(jié)果,雖然都出現(xiàn)了一定程度的性能下降,但兩者的結(jié)果都優(yōu)于直接跳過的情況。
所以,中間層順序重要嗎?這一問題的答案是——
比較重要。改變中間層的執(zhí)行順序,無論是隨機打亂或者完全翻轉(zhuǎn),都會導致模型性能退化。
并行運行
如果層本身的存在比它們的執(zhí)行順序更重要,那么我們是否可以獨立運行各層,最后合并它們的結(jié)果呢?
比如像下圖中,將原本堆疊在一起的中間層展開,并行運行后取各層輸出的平均值,傳遞給最后的N個層。
實驗結(jié)果顯示,GSM8K(數(shù)學應用題)基準中,模型性能有劇烈的變化,直線下降,其他基準分數(shù)的下滑則平緩得多。
我們暫且可以下這樣一個結(jié)論:并行運行是可行的,但解決數(shù)學問題除外。
要理解這種性能下降,可以用我們的「畫家流水線」進行類比:某些中間層只有在看到合適輸入時,才能對結(jié)果有所貢獻,就像一個擅長畫車輪的畫家,只有在畫面上看到汽車車身時,才更有可能畫出輪子。
如果是這種情況,將中間層并行運行的過程迭代多次應該會提高性能。
如下圖所示,論文將多個并行層的平均輸出再作為輸入反饋回去,如此進行一定次數(shù)的循環(huán)。
圖9顯示了循環(huán)3次的結(jié)果,與圖8中沒有循環(huán)的方案相比,性能曲線的確相對平緩,尤其是在圖右BERT模型未經(jīng)歸一化的分數(shù)上更加明顯。
圖10更清楚直觀地展示了,并行的中間層數(shù)和循環(huán)次數(shù)如何影響性能,其中紅框圈出了每列上的最高值。
除了29層和31層(接近Llama 2 7B的總層數(shù)32)得出例外的結(jié)果,從5層到27層都呈現(xiàn)出一致的趨勢:最佳迭代次數(shù)大致與并行化層數(shù)呈線性比例。
實驗結(jié)果總結(jié)
將上述所有實驗結(jié)果放到同一張圖中(圖11),我們就能比較不同變體對模型性能的影響程度。
左圖(Llama2)取各基準的中值,右圖(BERT)取各基準的平均值
「隨機化層順序」和「循環(huán)并行」分別在Llama2和BERT-Large上造成了最少的性能下降,「中間重復」方案(用中心層運行多次代替整個中間層)則在兩個模型上都造成了最嚴重的滑坡。
討論
自從Transformer發(fā)布后,大多數(shù)工作都在關(guān)注架構(gòu)的修改和優(yōu)化,以達到性能提升或參數(shù)減少。這篇論文則提供了另一種視角,調(diào)查了層并行化和重用的影響。
基于「Transformer層即畫家」這個類比,我們開頭提出的幾個問題都通過實驗得到了答案,最后得到了3個有趣的發(fā)現(xiàn):
- 所有Transformer層可以大致分為三類:起始層、中間層和結(jié)束層,其中中間層占比最大;
- 中間層具有一定程度的一致性,但并不冗余;
- 與語義任務相比,各層的執(zhí)行順序?qū)?shù)學和推理任務更為重要。
為什么Transformer架構(gòu)面對各種架構(gòu)修改時能表現(xiàn)出如此強大的魯棒性?作者表示將在之后的工作中再深入研究。
一個可能的假設是,訓練過程中的殘差連接是各層共享相同表征的必要條件。
我們已經(jīng)知道,殘差連接有助于解決梯度消失問題,然而相比沒有殘差連接的Transformer,加上殘差會降低性能。
如果能在沒有殘差的Transformer上重新運行上述架構(gòu)的變體,看看是否會破壞完全無殘差模型所取得的微薄收益,那將會非常有趣。
對于未來的其他工作,研究人員還計劃「解凍」模型,并研究Transformer是否需要(以及需要多長時間)通過微調(diào)來適應上述的架構(gòu)變化。
雖然本文的目的是更好地理解Transformer的中間層,而非引入新模型,但根據(jù)實驗結(jié)果,中間層并行或者干脆跳過都可以用適度的準確性損失換取更低的推理延遲。
作者團隊
本文作者分別來自兩家AI初創(chuàng)公司:Sakana AI和Emergence AI。
Sakana AI在今年年初剛剛獲得3000萬美元的種子輪融資,由Lux Capital領(lǐng)投,并得到了硅谷頂級風投公司Khosla Ventures以及Jeaf Dean、Alexandr Wang等大佬的支持。
公司研發(fā)的重點是基于自然啟發(fā)的新型基礎模型,創(chuàng)始團隊也是星光熠熠,一半成員來自「AI黃埔軍?!埂雀璐竽X和DeepMind。
相比于關(guān)注基礎研究的Sakana,Emergence AI更關(guān)注應用,專門從事LLM驅(qū)動的multi-agent系統(tǒng)研發(fā)。
公司聯(lián)合創(chuàng)始Satya Nitta曾擔任IBM研究院「AI解決方案」領(lǐng)域的全球主管,其中的許多研究人員和工程師也同樣來自谷歌、Meta、微軟、亞馬遜和Allen AI等頂尖機構(gòu)。
Emergence上個月剛剛從Learn Capital獲得9720萬美元的資金,以及額外的總計超過一億美元的信貸額度,未來的發(fā)展也是前途可期。