LLM為何頻頻翻車算術(shù)題?研究追蹤單個神經(jīng)元,「大腦短路」才是根源
由于缺少對運行邏輯的解釋,大模型一向被人稱為「黑箱」,但近來的不少研究已能夠在單個神經(jīng)元層面上解釋大模型的運行機制。
例如Claude在2023年發(fā)表的一項研究,將大模型中大約500個神經(jīng)元分解成約4000個可解釋特征。
而10月28日的一項研究,以算術(shù)推理作為典型任務,借鑒類似的研究方法,確定了大模型中的一個模型子集,能解釋模型大部分的基本算術(shù)邏輯行為。
圖片
論文地址:https://arxiv.org/abs/2410.21272
該研究首先定位了Llama3-8B/70B, Pythia-6.9B及GPT-J四個模型中負責算術(shù)計算的模型子集。
如圖1所示,少數(shù)注意力頭對大模型面對算術(shù)問題的正確率有顯著影響。第一個 MLP(多層感知機) 明顯影響操作數(shù)和操作符位置,而中間層和后期層的 MLP 將token信息投影到最后位置,提升正確答案的出現(xiàn)概率。
圖1:Llama3-8B中發(fā)現(xiàn)算術(shù)相關(guān)的模型子集
該研究聚焦于單個神經(jīng)元層面,發(fā)現(xiàn)了一組重要的神經(jīng)元,它們實現(xiàn)了簡單的啟發(fā)式算法。只需要關(guān)注特定的極少量神經(jīng)元,就能正確預測大模型進行算術(shù)運算的結(jié)果(圖2)。
圖2:大模型中每層神經(jīng)元中只需要關(guān)注1.5%的少數(shù)子集,就能預測大模型進行四則運算的結(jié)果
舉個例子:當輸入的提示詞為“ 226?68= ”時,神經(jīng)元24|12439在結(jié)果介于150和180的減法提示下顯示出高激活值,可被視為一個啟發(fā)式算法。而每個啟發(fā)式算法識別一個數(shù)值輸入模式,并輸出相應的答案。
具體可分為兩種不同的激活模式:第一種直接啟發(fā)式指的是在某些神經(jīng)元中,激活模式取決于兩個操作數(shù),值向量編碼了算術(shù)計算的預期結(jié)果(圖 3b,c)。
第二種間接激活模式取決于單個操作數(shù)對應的神經(jīng)元中,值向量通常編碼下游處理的特征,而不是直接的計算結(jié)果(圖3a)。
圖3:啟發(fā)式方法的可視化
神經(jīng)元和運算的因果聯(lián)系
該如何確認特定神經(jīng)元和相關(guān)數(shù)學運算之間存在因果關(guān)系?一種常見的方法是消融分析,即將大模型大模型中特定的神經(jīng)元敲除,看看模型的效果會有何改變,結(jié)果如圖4所示。
圖4:四則運算中敲除對應的算術(shù)神經(jīng)元后模型的性能對比
去掉了對應神經(jīng)元后,模型的運算準確性無論加減乘除都顯著下降。
不僅如此,相比去除特定算術(shù)神經(jīng)元時造成的性能下降,可以發(fā)現(xiàn),去除隨機神經(jīng)元的影響相對較小,而且這種效應在模型8B和70B不同參數(shù)量中普遍存在。
圖5:敲除與算術(shù)相關(guān)的啟發(fā)式算法的神經(jīng)元(實線)相比與算術(shù)無關(guān)的相同數(shù)量的隨機神經(jīng)元(虛線)
上述結(jié)果表明,可僅根據(jù)其相關(guān)啟發(fā)式算法來識別對特定對大模型進行算術(shù)重要的神經(jīng)元,也證明了屬于幾個啟發(fā)式算法的神經(jīng)元與提示正確完成之間的因果關(guān)系。
此外,該結(jié)果還支持了啟發(fā)式算法集合的主張:即每個啟發(fā)式算法僅略微提高正確答案的幾率,但它們結(jié)合在一起,使得大模型以高概率產(chǎn)生算術(shù)題的正確答案。
大模型為何做不對算術(shù)題
Llama3-8B模型無法可靠地對每道算術(shù)題時給出正確的回答?;趩l(fā)式規(guī)則,該研究闡述了模型為何會做錯,可能的機制共有兩種:
第一,由于參數(shù)量的限制,大模型缺乏足夠的算術(shù)神經(jīng)元,無法針對每一種情況都給出應對。
第二種原因是,可能存在回憶不完整的情況,比如某個啟發(fā)式規(guī)則對應的神經(jīng)元沒有在運算時被觸發(fā)。
圖片
圖6:隨機抽取了50個正確完成和50個錯誤完成的算術(shù)題目,考察大模型中被正確和錯誤激活的算術(shù)神經(jīng)元個數(shù)
如圖6所示,在大模型回答正確及錯誤時,激活的算術(shù)神經(jīng)元個數(shù)不存在差異,這不支持前述的第一種算術(shù)神經(jīng)元個數(shù)不足的假設(shè)。
然而,在大模型回答正確的情況下,更多比例的正確神經(jīng)元被激活了,而回答錯誤的案例中,應當被激活的神經(jīng)元激活概率反而較小。
這意味著大模型在特定算術(shù)題上失敗的主要原因是對能得出正確答案的神經(jīng)元缺少泛化能力,而不是算術(shù)神經(jīng)元的數(shù)量不足。
「算術(shù)神經(jīng)元」何時誕生
由于其訓練檢查點可供公眾獲取,該研究采用Pythia-6.9B來考察大模型過程中算術(shù)神經(jīng)元的出現(xiàn)階段。
結(jié)果顯示,大模型在訓練過程中逐漸發(fā)展其最終的算術(shù)啟發(fā)式機制,且算術(shù)神經(jīng)元在模型訓練早期就已出現(xiàn)。
圖7 :啟發(fā)式的算術(shù)神經(jīng)元的百分比隨著訓練增加
在模型訓練的不同階段,移除特定的啟發(fā)式神經(jīng)元會大幅降低模型在所有訓練檢查點的準確性,這表明算術(shù)準確性主要來自啟發(fā)式,即使在早期階段也是如此。算術(shù)啟發(fā)式神經(jīng)元與大模型算術(shù)能力的因果關(guān)系在整個訓練過程中都存在。
圖8:不同階段敲除算術(shù)神經(jīng)元對大模型進行算術(shù)運算準確性的影響
結(jié)論
理解大模型如何進行數(shù)學運算,不僅可以打開大模型內(nèi)部運行的黑箱,解釋它們?yōu)楹卧诤唵蔚臄?shù)學題上翻車,例如最著名的「9.11和9.8哪個大」。
這項研究告訴我們,并不是因為大模型缺少相關(guān)訓練,而是激活了錯誤的啟發(fā)式神經(jīng)元,例如將這個問題當成了詢問哪個版本更大。
理解了大模型的算術(shù)運算,是依賴于啟發(fā)式方法集,而非單純的依靠記憶(背題目)或?qū)W會規(guī)則,這表明提高大模型的數(shù)學能力可能需要訓練和架構(gòu)的根本性改變,而不是像激活引導這樣的小修小補。
對訓練過程的分析結(jié)果指出,大模型在訓練早期就學會了這些啟發(fā)式方法,并隨時間推移逐漸強化。這可能會導致模型過度擬合到早期的簡單策略,因此可作為之后優(yōu)化方向的參考。