Meta對Transformer架構下手了:新注意力機制更懂推理
大型語言模型(LLM)很強已經(jīng)是一個不爭的事實,但它們有時仍然容易犯一些簡單的錯誤,表現(xiàn)出較弱的推理能力。
舉個例子,LLM 可能會因不相關的上下文或者輸入提示中固有的偏好或意見做出錯誤的判斷。后一種情況表現(xiàn)出的問題被叫做「阿諛奉承」,即模型與輸入保持一致。
有沒有方法來緩解這類問題呢?有些學者試圖通過添加更多監(jiān)督訓練數(shù)據(jù)或通過強化學習策略來解決,但這些無法從根本上解決問題。
近日 Meta 研究者在論文《System 2 Attention (is something you might need too)》中認為,根本問題在于 Transformer 本身固有的構建方式,尤其是其注意力機制。也就是說,軟注意力既傾向于將概率分配給大部分上下文(包括不相關的部分),也傾向于過度關注重復的 token。
因此,研究者提出了一種完全不同的注意力機制方法,即通過將 LLM 用作一個自然語言推理器來執(zhí)行注意力。具體來講,他們利用 LLM 遵循指令的能力,提示它們生成應該注意的上下文,從而使它們只包含不會扭曲自身推理的相關資料。研究者將這一過程稱為 System 2 Attention(S2A),他們將底層 transformer 及其注意力機制視為類似于人類 System 1 推理的自動操作。
當人們需要特意關注一項任務并且 System 1 可能出錯時,System 2 就會分配費力的腦力活動,并接管人類的工作。因此,這一子系統(tǒng)與研究者提出的 S2A 具有類似目標,后者希望通過額外的推理引擎工作來減輕上述 transformer 軟注意力的失敗。
論文地址:https://arxiv.org/pdf/2311.11829.pdf
研究者詳細描述了 S2A 機制的類別、提出該機制的動機以及幾個具體實現(xiàn)。在實驗階段,他們證實與基于標準注意力的 LLM 相比,S2A 可以產(chǎn)生更講事實、更少固執(zhí)己見或阿諛奉承的 LLM。
特別是在問題中包含干擾性觀點的修正后 TriviQA 數(shù)據(jù)集上,與 LLaMA-2-70B-chat 相比,S2A 將事實性從 62.8% 提高到 80.3%;在包含干擾性輸入情緒的長格式參數(shù)生成任務重,S2A 的客觀性提高了 57.4%,并且基本上不受插入觀點的影響。此外對于 GSM-IC 中帶有與主題不相關語句的數(shù)學應用題,S2A 將準確率從 51.7% 提高到了 61.3%。
這項研究得到了 Yann LeCun 的推薦。
System 2 Attention
下圖 1 展示了一個偽相關示例。當上下文包含不相關的句子時,即使是最強大的 LLM 也會改變關于簡單事實問題的答案,從而因為上下文中出現(xiàn)的 token 無意間增加了錯誤答案的 token 概率。
因此我們需要探究一種依賴更深入理解的更深思熟慮的注意力機制。為了與更底層的注意力機制區(qū)分開來,研究者將提出的系統(tǒng)稱為 S2A。他們探索了利用 LLM 本身來構建這樣一種注意力機制的方法,尤其是利用指令調(diào)整 LLM 通過移除不相關的文本來重寫上下文。
通過這種方式,LLM 可以在輸出響應之前對要關注的輸入部分做出深思熟慮的推理決定。使用指令調(diào)整的 LLM 還有另一個好處,即可以控制注意力焦點,這有點類似于人類控制自己注意力的方式。
S2A 包含兩個過程:
- 給定上下文 x,S2A 首先重新生成上下文 x ',從而刪除會對輸出產(chǎn)生不利影響的上下文的不相關部分。本文將其表示為 x ′ ~ S2A (x)。
- 給定 x ′ ,然后使用重新生成的上下文而不是原始上下文生成 LLM 的最終響應:y ~ LLM (x ′ )。
替代實現(xiàn)和變體
本文考慮了 S2A 方法的幾種變體。
無上下文和問題分離。在圖 2 的實現(xiàn)中,本文選擇重新生成分解為兩部分(上下文和問題)的上下文。圖 12 給出了該提示變體。
保留原始上下文在 S2A 中,在重新生成上下文之后,應該包含所有應該注意的必要元素,然后模型僅在重新生成的上下文上進行響應,原始上下文被丟棄。圖 14 給出了該提示變體。
指令式提示。圖 2 中給出的 S2A 提示鼓勵從上下文中刪除固執(zhí)己見的文本,并使用步驟 2(圖 13)中的說明要求響應不固執(zhí)己見。
強調(diào)相關性與不相關性。以上 S2A 的實現(xiàn)都強調(diào)重新生成上下文以提高客觀性并減少阿諛奉承。然而,本文認為還有其他需要強調(diào)的點, 例如,人們可以強調(diào)相關性與不相關性。圖 15 中的提示變體給出了這種方法的一個實例:
實驗
本文在三種設置下進行了實驗:事實問答、長論點生成以及對數(shù)學應用題的解決。此外,本文還使用 LLaMA-2-70B-chat 作為基礎模型,在兩種設置下進行評估:
- 基線:數(shù)據(jù)集中提供的輸入提示被饋送到模型,并以零樣本方式回答。模型生成可能會受到輸入中提供的虛假相關性的影響。
- Oracle Prompt:沒有附加意見或不相關句子的提示被輸入到模型中,并以零樣本的方式回答。
圖 5 (左) 展示了在事實問答上的評估結果。System 2 Attention 比原來的輸入提示有了很大的改進,準確率達到 80.3%—— 接近 Oracle Prompt 性能。
圖 6(左)顯示了長論點生成的總體結果,基線、Oracle Prompt 以及 System 2 Attention 都被評估為可以提供類似的高質量評估。圖 6(右)為細分結果:
圖 7 顯示了不同方法在 GSM-IC 任務上的結果。與 Shi 等人的研究結果一致,本文發(fā)現(xiàn)基線準確率遠低于 oracle。當不相關的句子與問題屬于同一主題時,這種影響甚至更大,如圖 7(右)所示。
了解更多內(nèi)容,請參考原論文。