6700萬參數(shù)比肩萬億巨獸GPT-4!微軟MIT等聯(lián)手破解Transformer推理密碼
「因果推理」絕對是當(dāng)前GenAI熱潮下的小眾領(lǐng)域,但是它有一個大佬級的堅定支持者——Yann LeCun。
他在推特上的日常操作之一,就是炮轟Sora等生成模型,并為自己堅信的因果推理領(lǐng)域搖旗吶喊。
甚至,早在2019年VentureBeat的采訪中,他就表達(dá)過這一觀點:我們需要在深度學(xué)習(xí)模型中引入事件的因果關(guān)系,才能增強(qiáng)泛化能力,減少訓(xùn)練數(shù)據(jù)使用。
對于當(dāng)前最流行的模型架構(gòu)Transformer,我們能教它因果推理嗎?
最近,來自微軟MIT等機(jī)構(gòu)的研究人員提出了一種訓(xùn)練大模型新范式——公理框架(Axiomatic Framework)。
論文中,作者從頭開始訓(xùn)練了6700萬參數(shù)的模型,僅使用了簡單的因果鏈作為訓(xùn)練數(shù)據(jù)。
令人驚訝的是,在推斷復(fù)雜圖表中的因果關(guān)系時,67M模型的表現(xiàn)超越了十億級參數(shù)LLM,甚至可以與GPT-4相媲美。
論文地址:https://arxiv.org/abs/2407.07612v1
微軟MIT等團(tuán)隊最新方法的提出,是受到了圖靈獎得主Judea Pearl啟發(fā)。
Pearl曾提出了結(jié)構(gòu)化因果規(guī)則中的因果無關(guān)性公理,即直接通過符號化公理示例來教Transformer模型學(xué)習(xí)被動數(shù)據(jù)(passive data)。
這種方法不同于傳統(tǒng)機(jī)器學(xué)習(xí)模型,使用由公理推導(dǎo)出的數(shù)據(jù)。
正如結(jié)果所示,通過公理訓(xùn)練,研究證明了Transformer模型可以學(xué)習(xí)因果,從而推斷因果關(guān)系,并從相關(guān)性中識別因果性。
這暗示了,像GPT-4等大模型的訓(xùn)練,可以通過網(wǎng)絡(luò)數(shù)據(jù)中的帶噪聲的公理化示例學(xué)習(xí)因果知識,而無需進(jìn)行干預(yù)實驗。
網(wǎng)友稱贊道,「研究者的觀點非常耐人尋味,因果推理一直是LLM的致命弱點,進(jìn)一步發(fā)展這一領(lǐng)域,勢在必行」。
「這類研究可能是通向半AGI的一條途徑」。
研究背景
因果推理(causal reasoning)是一種推理過程,遵守有特定因果性的預(yù)定義公理或規(guī)則。
圖靈獎得主Judea Pearl曾通過如下的「因果關(guān)系階梯」(ladder of causation)定義了可能的因果推理類型。
通常因果推理所用的公理或規(guī)則并不會被直接引入,模型學(xué)習(xí)的只是數(shù)據(jù)。公理或規(guī)則作為歸納偏差被納入模型,比如通過正則化、模型架構(gòu)或變量選擇等方式。
而這篇論文想要探討的,就是模型能否從被動的符號演示中直接學(xué)習(xí)公理或規(guī)則。作者將這種方法稱為「公理化訓(xùn)練」(axiomatic training)。
假設(shè)因果公理都可以以如下形式表示:<前提,假設(shè),結(jié)果>,其中結(jié)果只有「是」和「否」兩種形式。
這基本類似于亞里士多德提出的「三段論」格式,比如Judeal Pearl書中提出的「碰撞公理」(collider axiom)就可以表示為:
前提:?????, ???????, ???????
假設(shè):A是否導(dǎo)致C?
結(jié)論:是
這只是單個公理的表示,那么如何表達(dá)一個復(fù)雜系統(tǒng)中多個公理的組合呢?甚至,我們能用有限數(shù)量的公理表達(dá)任意因果模型嗎?
此處,論文引用了Judea Pearl和David Galles在1997年發(fā)表的一項研究,他們證明了,對于給定的穩(wěn)定概率因果模型,都存在一組有限公理,可以充分表征對應(yīng)的有向因果圖。
因果模型M=(X,U,F)被定義為內(nèi)部變量X、外部變量U和一組結(jié)構(gòu)方程F的集合,結(jié)構(gòu)方程描述了變量X和U之間的因果關(guān)系。
模型M的另一種等效表示方式就是有向圖G,用有向邊Vi?Vj表示兩個節(jié)點Vi和Vj之間的因果關(guān)系。
所謂的「穩(wěn)定概率」(stable probabilistic)因果模型,是指他們對模型作出的穩(wěn)定性假設(shè),指M中所有的不相關(guān)性(X ? Y|Z)都是穩(wěn)定的,寫作:
在穩(wěn)定性假設(shè)下,Galles和Pearl共描述了6個公理,而這篇論文主要關(guān)注傳遞性公理。對于穩(wěn)定概率的因果模型,給定系統(tǒng)中的變量X、Y、Z,傳遞性公理可以寫作:
將上述表達(dá)式通過取反進(jìn)一步簡化,可以寫出其含有因果相關(guān)性的版本:
其中表達(dá)式左側(cè)即為前提,右側(cè)即為假設(shè)。
這樣的公理可以派生出數(shù)千個合成的符號表達(dá)式,從而用于向Transformer模型「教授」特定公理。
公理化訓(xùn)練
訓(xùn)練數(shù)據(jù)
上述含有前提和假設(shè)的公理能映射到「是」或「否」的標(biāo)簽,一條訓(xùn)練數(shù)據(jù)就可以表示為{(P,H,L)}的元組形式。
給定一個真實的因果圖,就可以通過應(yīng)用傳遞性公理(一次或多次),枚舉出所有可能的N個元組{(P,H,L)},從而構(gòu)建出數(shù)據(jù)集D。
比如,因果圖中包含X1?X2?X3?…?Xn這樣的鏈拓?fù)鋾r,一個可能的前提是X1?X2∧X2?X3,相應(yīng)的假設(shè)X1?X3的標(biāo)簽為「是」,而另一個假設(shè)X3?X1標(biāo)簽就為「否」。
值得注意的是,論文中為了表達(dá)的清晰性,使用了數(shù)學(xué)語言進(jìn)行描述,但實際上用于訓(xùn)練的數(shù)據(jù)集只包含自然語言。
比如,上面例子中的前提應(yīng)該表達(dá)為「X1導(dǎo)致X2,且X2導(dǎo)致X3」。
數(shù)據(jù)擾動:泛化的關(guān)鍵
之前有研究表明,以「擾動」(perturbation)形式增加訓(xùn)練數(shù)據(jù)的可變性與多樣性,有助于提升模型的泛化能力。
因此,作者在不同層次上對訓(xùn)練數(shù)據(jù)引入結(jié)構(gòu)化擾動,以最大化數(shù)據(jù)集分布的多樣性。
1)節(jié)點名稱:傳遞鏈上每個節(jié)點的名稱都由1~3個字母/數(shù)字組成,長度和使用的特定字符是隨機(jī)生成的。
2)因果圖拓?fù)浣Y(jié)構(gòu):主要包含兩種類型
- 順序結(jié)構(gòu)(sequential):所有的因果邊方向都是從后向前,共同形成一個典型的「傳遞鏈」,比如X?Y?Z這種形式
- 隨機(jī)翻轉(zhuǎn)(random flipping):給定一個順序結(jié)構(gòu)的傳遞鏈,對其中一些邊進(jìn)行隨機(jī)翻轉(zhuǎn),從而引入復(fù)雜性。比如X?Y?Z可以被修改為X?Y?Z。
隨機(jī)翻轉(zhuǎn)可以在單一方向的鏈中添加分叉結(jié)構(gòu)(X?Y?Z,fork)和碰撞結(jié)構(gòu)(X?Y?Z,collider),它們是任何有向因果圖的基本構(gòu)建塊,有助于提升模型進(jìn)行跨結(jié)構(gòu)泛化的能力。
3)鏈長度:訓(xùn)練集中加入了長度不等的鏈,包含3~6節(jié)點。
損失函數(shù)
論文沒有采用訓(xùn)練Transformer模型常用的next token預(yù)測損失,而是根據(jù)給定數(shù)據(jù)集中每個元組的真實標(biāo)簽進(jìn)行定義,表示為:
位置編碼
除了訓(xùn)練數(shù)據(jù)和損失函數(shù)之外,另一個重要因素是位置編碼的選擇。
之前有研究表明,位置編碼機(jī)制對Transformer的序列長度泛化能力有明顯影響,但不同的研究似乎得出了互相矛盾的結(jié)果。
因此,作者在研究中分別嘗試了不同的方法,包括可學(xué)習(xí)位置編碼(LPE)、正弦位置編碼(SPE)和無位置編碼(NoPE)。
訓(xùn)練和評估的整體流程如圖1所示,Transformer模型在順序鏈和帶有隨機(jī)翻轉(zhuǎn)的鏈上訓(xùn)練,長度為3~6個節(jié)點。
之后,訓(xùn)練過的模型在具有>6個節(jié)點的更復(fù)雜結(jié)構(gòu)上進(jìn)行評估,其中節(jié)點平均的出度(out-degree)和入度(in-degree)都更大,序列更長,且引入了分支、反轉(zhuǎn)(reversal)等復(fù)雜變化。
實現(xiàn)細(xì)節(jié):架構(gòu)、分詞器和訓(xùn)練過程
具體來說,研究人員基于GPT-2的架構(gòu),訓(xùn)練了一個擁有6700萬參數(shù)的解碼器模型。
該模型有12個注意力層、8個注意力頭,以及512個嵌入維度。
值得一提的是,67M模型是在各種訓(xùn)練數(shù)據(jù)集上,從頭開始訓(xùn)練的。為了理解位置編碼(PE)的影響,他們考慮了正弦位置編碼(SPE)、可學(xué)習(xí)位置編碼(LPE)以及不使用位置編碼(NoPE)三種情況。
所有模型都使用AdamW優(yōu)化器進(jìn)行訓(xùn)練,學(xué)習(xí)率為1e-4,訓(xùn)練100個epoch。
由于訓(xùn)練數(shù)據(jù)集遵循特定結(jié)構(gòu),研究人員還開發(fā)了一個自定義分詞器(custom tokenizer)。
字母數(shù)字節(jié)點名稱在字符級別進(jìn)行分詞,而像「causes」、「cause」、「Does」、「Yes」「No」這樣的特殊術(shù)語則在詞級別進(jìn)行分詞。
簡言之,字符級分詞用于字母數(shù)字節(jié)點名稱,詞級分詞用于特殊術(shù)語。
這種方法可以避免在測試時,出現(xiàn)詞匯表外(OOV)token,因為測試集中的字母數(shù)字節(jié)點名稱可能與訓(xùn)練集中的不同。
采用這種方法后,6700萬參數(shù)Transformer模型的詞匯表大小為69。
實驗結(jié)果
復(fù)雜因果場景的泛化
研究人員首先展示了,通過公理化訓(xùn)練的Transformer模型在泛化到更大、更復(fù)雜的因果圖方面的表現(xiàn),并將其與預(yù)訓(xùn)練的大模型進(jìn)行了比較。
序列長度泛化
表1展示了不同模型在評估訓(xùn)練過程中,未見過的更長因果鏈時的準(zhǔn)確率。
在基線預(yù)訓(xùn)練語言模型中,GPT-4在標(biāo)準(zhǔn)和隨機(jī)翻轉(zhuǎn)的因果鏈上都取得了最高的準(zhǔn)確率。
令人驚訝的是,盡管TS2(NoPE)模型在訓(xùn)練過程中從未見過更長的序列,但它的表現(xiàn)能夠與萬億參數(shù)規(guī)模的GPT-4模型相媲美。
雖然訓(xùn)練時只用到了長度為3~6個節(jié)點的因果鏈,但序列長度為7~13時,TS2(NoPE)在標(biāo)準(zhǔn)和隨機(jī)翻轉(zhuǎn)的鏈上,獲得了比GPT-4更高或相當(dāng)?shù)臏?zhǔn)確率。
對于序列長度為14-15的情況下,其準(zhǔn)確率有所下降(標(biāo)準(zhǔn)鏈為0.85,隨機(jī)翻轉(zhuǎn)鏈為0.78),但仍然顯著高于Gemini-Pro 、Phi-3模型。
需要注意的是,隨機(jī)預(yù)測會得到50%的準(zhǔn)確率,這表明通過公理化訓(xùn)練的TS2(NoPE)模型,能夠?qū)⑵渫评砟芰Ψ夯礁L的序列上。
節(jié)點名稱轉(zhuǎn)變
對于在TS2數(shù)據(jù)集上訓(xùn)練的模型,研究人員還評估了其對變量名稱變化的泛化能力(圖 3)。
結(jié)果發(fā)現(xiàn),TS2(NoPE)對節(jié)點名稱的變化很穩(wěn)健,在引入新的、更長的名稱時仍能保持較高的準(zhǔn)確率。它還保持了對新節(jié)點名稱較長序列的通用性,其表現(xiàn)與GPT-4相似。
因果序列順序
與長度和節(jié)點名稱的變化不同,反轉(zhuǎn)(reversal)以及分支(branching)操作改變了因果結(jié)構(gòu),因此能更好地評估模型是否學(xué)習(xí)到了對結(jié)構(gòu)的準(zhǔn)確表示。
在表2b中,TS2(NoPE)在長度不超過8的因果鏈上,獲得的準(zhǔn)確率高于Gemini Pro、Phi-3。長度為9時,TS2(NoPE)的準(zhǔn)確率為0.73,與Gemini Pro(0.74)相當(dāng)。
在表2a中,研究者還觀察到對完全反轉(zhuǎn)序列進(jìn)行評估的類似模式。
在這項任務(wù)中,公理訓(xùn)練模型TS2(NoPE)在限制鏈長度為3-6時,表現(xiàn)優(yōu)于GPT-4。特別是,其準(zhǔn)確率(長度為 6 的鏈為0.94)大大高于Gemini Pro和Phi-3(分別為0.62和0.69)。
分支(Branching)
分支可能是最有挑戰(zhàn)性的任務(wù),因為它引入了在訓(xùn)練期間未見的新結(jié)構(gòu)。
雖然GPT-4在圖大小不斷增大的情況下獲得了最佳準(zhǔn)確率,但TS2(NoPE)模型在除一個節(jié)點外的所有圖大小上,都比Gemini Pro獲得了更高的準(zhǔn)確率。
即使在有12個節(jié)點和1.4個分支因子的圖形上進(jìn)行評估,TS2(NoPE)模型也能獲得70%的準(zhǔn)確率,明顯優(yōu)于隨機(jī)模型(50%)。
總結(jié)
在所有評估設(shè)置中,公理化訓(xùn)練模型TS2(NoPE)的性能明顯優(yōu)于隨機(jī)基線,即使因果鏈的長度超過其訓(xùn)練數(shù)據(jù)。
特別是,模型沒有在完全反轉(zhuǎn)的鏈上進(jìn)行訓(xùn)練,它的表現(xiàn)也與規(guī)模更大的GPT-4模型相當(dāng)(圖 2)。
在其他任務(wù)中,它的準(zhǔn)確性往往優(yōu)于或與Gemini Pro、Phi-3等十億參數(shù)規(guī)模的模型相當(dāng)。
這些結(jié)果表明,經(jīng)過公理訓(xùn)練的模型可以從簡單因果序列的演示中,學(xué)會推理更復(fù)雜的因果結(jié)構(gòu)。這表明公理訓(xùn)練在因果圖推理方面的潛力。
其他結(jié)果:數(shù)據(jù)多樣性和位置編碼的作用
位置編碼的作用
比較不同位置編碼選擇的模型性能,研究人員發(fā)現(xiàn)沒有位置編碼的模型在更長的序列(最長到15個節(jié)點的鏈)和復(fù)雜的、未見過的圖結(jié)構(gòu)上都能很好地泛化,盡管它們僅在3-6個節(jié)點的鏈上進(jìn)行訓(xùn)練。
使用正弦位置編碼(SPE)和可學(xué)習(xí)位置編碼(LPE)的模型在更長的鏈上表現(xiàn)也不錯,但當(dāng)節(jié)點名稱長度增加時表現(xiàn)較差,即使是在節(jié)點數(shù)較少的鏈上也是如此(圖3)。
這種使用SPE和LPE的泛化失敗,突出了模型無法處理訓(xùn)練集中序列的微小擾動。
此外,SPE在不同的結(jié)構(gòu)維度上表現(xiàn)不佳(如分支)以及基于順序的設(shè)置(shuffling和反轉(zhuǎn))。
可學(xué)習(xí)的位置編碼在長度達(dá)9的線性鏈上表現(xiàn)良好,但之后急劇下降。
總的來說,研究結(jié)果擴(kuò)展了早期關(guān)于不使用位置編碼(NoPE)有效性的研究,將其應(yīng)用于理解因果序列的任務(wù),并在測試時泛化到更長的長度和復(fù)雜的結(jié)構(gòu)。
數(shù)據(jù)擾動的重要性
除了位置編碼外,訓(xùn)練數(shù)據(jù)中序列的多樣性也起著重要作用。
僅在因果鏈上,訓(xùn)練的模型可以泛化到較長的鏈(表 1),但不能泛化到其他DAG結(jié)構(gòu)(見圖4中的翻轉(zhuǎn),圖2中的反轉(zhuǎn),表3中的分支)。
在TS1或TS1上訓(xùn)練的模型在所有情況下都具有通用性,包括隨機(jī)翻轉(zhuǎn)、順序排列和分支;因此突出了通過隨機(jī)翻轉(zhuǎn)在邊水平上納入可變性的影響。
不過,在不同任務(wù)中,研究發(fā)現(xiàn)TS2的準(zhǔn)確率高于TS1,即使TS1因隨機(jī)翻轉(zhuǎn)而產(chǎn)生了更多變化。
這表明,雖然擾動有助于結(jié)構(gòu)泛化,但過度的擾動可能會阻礙結(jié)構(gòu)泛化。
使用公理訓(xùn)練從相關(guān)性推斷因果關(guān)系
接下來,作者研究這種能力是否可以轉(zhuǎn)移到其他因果任務(wù)上。
為此,研究人員將公理化訓(xùn)練應(yīng)用于一個任務(wù),該任務(wù)是從觀察數(shù)據(jù)中的相關(guān)性陳述推斷因果關(guān)系。
如圖5所示,每個數(shù)據(jù)實例包括用自然語言描述的3到6個節(jié)點圖的相關(guān)關(guān)系;目標(biāo)是推斷假設(shè)的真值,判斷任何給定節(jié)點之間是否存在直接或間接關(guān)系,以及可能存在的碰撞節(jié)點和混雜因素。
這個任務(wù)比應(yīng)用傳遞性公理要困難得多。
由于任務(wù)的復(fù)雜性,結(jié)果發(fā)現(xiàn)像Gemini Pro、Phi-3這樣的預(yù)訓(xùn)練模型的表現(xiàn)與隨機(jī)猜測相似(準(zhǔn)確率為52%)。
雖然GPT-4的表現(xiàn)稍好一些,但其性能仍然較低(準(zhǔn)確率為58%)。
值得注意的是,研究者的小型Transformer模型表現(xiàn)優(yōu)于所有基線模型,準(zhǔn)確率達(dá)到64%,比GPT-4高出6%。
通過進(jìn)一步探索不同的訓(xùn)練設(shè)置,公理化訓(xùn)練的Transformer模型可能會在這類因果推理任務(wù)上得到進(jìn)一步的優(yōu)化。
總的來說,研究人員認(rèn)為公理化訓(xùn)練是教Transformer模型學(xué)習(xí)因果關(guān)系的一種很有前景的方法。
受Judea Pearl愿景的啟發(fā),這項工作代表著一個潛在的新科學(xué)前沿——因果關(guān)系研究和語言模型的交叉點上。