NeurIPS 2024 | 自我糾錯(cuò)如何使OpenAI o1推理能力大大加強(qiáng)?北大、MIT團(tuán)隊(duì)給出理論解釋
自我糾錯(cuò)(Self Correction)能力,傳統(tǒng)上被視為人類特有的特征,正越來越多地在人工智能領(lǐng)域,尤其是大型語言模型(LLMs)中得到廣泛應(yīng)用,最近爆火的OpenAI o1模型[1]和Reflection 70B模型[2]都采取了自我糾正的方法。
傳統(tǒng)的大語言模型,因?yàn)樵谳敵龃鸢傅臅r(shí)候是逐個(gè)Token輸出,當(dāng)輸出長(zhǎng)度較長(zhǎng)時(shí),中間某些Token出錯(cuò)是必然發(fā)生。但即使LLM后來知道前面輸出的Token錯(cuò)了,它也得用更多錯(cuò)誤來“圓謊”,因?yàn)闆]有機(jī)制讓它去修正前面的錯(cuò)誤。
而OpenAI o1在“慢思考”也就是生成Hidden COT的過程中,通過分析OpenAI官網(wǎng)給出的Hidden COT例子可以發(fā)現(xiàn),在解決字謎問題的思考過程中,o1首先發(fā)現(xiàn)了每?jī)蓚€(gè)連續(xù)的明文字母會(huì)映射到一個(gè)秘文字母,于是便嘗試使用奇數(shù)字母來構(gòu)建明文,但是經(jīng)過驗(yàn)證發(fā)現(xiàn)并不合理(Not directly);接著又重新修正答案最終成功解出字謎。
圖1 OpenAI o1 官網(wǎng)示例(部分Hidden CoT)
Reflection 70B的關(guān)鍵技術(shù)也包括錯(cuò)誤識(shí)別和錯(cuò)誤糾正。他們用到了一種名為 Reflection-Tuning(反思微調(diào)) 的技術(shù),使得模型能夠在最終確定回復(fù)之前,先檢測(cè)自身推理的錯(cuò)誤并糾正。在實(shí)際的執(zhí)行過程中,這會(huì)用到一種名為思考標(biāo)簽(thinking tag)的機(jī)制。模型會(huì)在這個(gè)標(biāo)簽內(nèi)部進(jìn)行反思,直到它得到正確答案或認(rèn)為自己得到了正確答案。
頻頻應(yīng)用于大語言模型的自我糾錯(cuò)技術(shù)為何有效?為什么糾錯(cuò)過程可以讓模型把原本答錯(cuò)的問題重新答對(duì)?
為了探究這一問題,北大王奕森團(tuán)隊(duì)與MIT合作,從理論上分析了大語言模型自我糾錯(cuò)能力背后的工作機(jī)理。
- 論文題目:A Theoretical Understanding of Self-Correction through In-context Alignment
- 論文地址:https://openreview.net/pdf?id=OtvNLTWYww
- 代碼地址:https://github.com/yifeiwang77/Self-Correction
作者團(tuán)隊(duì)將自我糾錯(cuò)的過程抽象為對(duì)齊任務(wù),從上下文學(xué)習(xí)(In-context learning)的角度對(duì)自我糾錯(cuò)進(jìn)行了理論分析。值得一提的是,他們并沒有使用線性注意力機(jī)制下的線性回歸任務(wù)進(jìn)行理論分析,而是使用真實(shí)世界LLM在用的softmax多頭注意力機(jī)制的transformer結(jié)構(gòu),并利用Bradley-Terry 模型和 Plackett-Luce 模型(LLM對(duì)齊的實(shí)際選擇,用于RLHF和DPO)設(shè)計(jì)對(duì)齊任務(wù)進(jìn)行研究。受理論啟發(fā),他們提出了一種簡(jiǎn)單的自我糾錯(cuò)策略--上下文檢查(Check as Context),并通過實(shí)驗(yàn),在消除大語言模型中存在的潛在偏見以及防御越獄攻擊中效果顯著。
理論分析:自我糾錯(cuò)實(shí)際上是一種上下文對(duì)齊?
不同于類似監(jiān)督學(xué)習(xí)的標(biāo)準(zhǔn)上下文示例(請(qǐng)求,回答),自我糾錯(cuò)示例可以形成一個(gè)三元組形式(請(qǐng)求,回答,獎(jiǎng)勵(lì)),這類似于通過獎(jiǎng)勵(lì)指示好壞樣本的 LLM 對(duì)齊。因此,作者團(tuán)隊(duì)提出將自我糾錯(cuò)形式化為一種“上下文對(duì)齊”(In-context Alignment),即通過提供一系列自我糾錯(cuò)步驟的上下文,優(yōu)化LLM的最終輸出,以獲得更高的獎(jiǎng)勵(lì)。
對(duì)齊的過程通常包括:對(duì)于問題,收集個(gè)不同的模型回答,然后由人類或評(píng)估模型(在本文中,評(píng)估模型即該 LLM 本身)對(duì)這 個(gè)回答給出排序偏好。接著,使用一般的對(duì)齊模型(如Bradley-Terry (BT,n=2) or Plackett-Luce (PL loss, general n))進(jìn)行建模:
其中為獎(jiǎng)勵(lì)模型。
針對(duì)transformer模型,作者采用了帶有softmax多頭注意力機(jī)制的transformer結(jié)構(gòu),其前向傳播更新可以分為兩部分
- 多頭注意力(MHSA)層:
- FFN層:
獎(jiǎng)勵(lì)函數(shù) 被設(shè)置為負(fù)均方誤差(MSE)損失,即:
在該設(shè)置下,參數(shù)的梯度下降可等價(jià)于對(duì)數(shù)據(jù)的更新:
作者證明了多層transformer(包含3-head softmax attention和relu激活函數(shù)的FFN)可以利用自我糾錯(cuò)樣本生成更優(yōu)獎(jiǎng)勵(lì)的回答。具體而言,作者證明了存在模型權(quán)重,使得transformer可以通過在前向傳播的過程中執(zhí)行對(duì)其內(nèi)部獎(jiǎng)勵(lì)模型參數(shù)的梯度下降,來生成更符合對(duì)齊目標(biāo)的更優(yōu)回答。
這是首次在理論上表明 LLM 可以在上下文中實(shí)現(xiàn)對(duì)齊的分析。該理論適用于多種自我糾錯(cuò)方法,因?yàn)樵u(píng)估可以來自人類、外部驗(yàn)證者或 LLM 本身。
圖2 關(guān)于上下文對(duì)齊的驗(yàn)證實(shí)驗(yàn),分別涉及TF和GD的比較(a)、不同獎(jiǎng)勵(lì)噪聲p的影響(b)、模型深度的影響(c)、以及不同注意力機(jī)制的效果(d)、(e)、(f)。
作者也通過設(shè)置驗(yàn)證實(shí)驗(yàn)來檢驗(yàn)其理論導(dǎo)出的種種結(jié)論,以及各個(gè) transformer 結(jié)構(gòu)模塊對(duì) LLM 執(zhí)行上下文對(duì)齊能力的影響,作者發(fā)現(xiàn)了很多有趣的結(jié)論:
- 通過觀察比較LLM在執(zhí)行上下文對(duì)齊時(shí)前向傳播的損失與梯度下降的損失曲線,LLM執(zhí)行上下文對(duì)齊時(shí)的前傳行為與梯度下降損失曲線幾乎相同。(圖2(a))
- 評(píng)價(jià)的質(zhì)量直接影響自我糾錯(cuò)的質(zhì)量(圖2(b))。
- 對(duì)多樣本的排序需要更深的模型層數(shù),在達(dá)到一定深度后(15層),增加更多的層數(shù)并不能帶來更高的收益。(圖2(c))
- Softmax注意力機(jī)制對(duì)從評(píng)價(jià)中分析回答優(yōu)劣排序至關(guān)重要,而linear注意力則做不到這一點(diǎn)。具體來說,softmax 注意力機(jī)制可以有效地選取最優(yōu)回答 并為各樣本生成加權(quán)平均所需的權(quán)重。(圖2(d))
- 多頭注意力機(jī)制對(duì)token角色的區(qū)分很重要。具體而言,多頭注意力機(jī)制可以將生成的回答與正樣本拉近,與負(fù)樣本拉遠(yuǎn)。實(shí)驗(yàn)表明,3個(gè)attention head是上下文對(duì)齊任務(wù)中最優(yōu)選擇。(圖2(e))
- FFN對(duì)于token角色的轉(zhuǎn)變很重要。在經(jīng)過一個(gè)MHSA層后,F(xiàn)FN可以將上一輪的正樣本屏蔽掉,從而使次優(yōu)樣本變成下一輪迭代的最優(yōu)樣本。(圖2(f))
自我糾錯(cuò)策略:上下文檢查
作者使用上下文檢查(Check as Context,CaC)作為L(zhǎng)LM完成自我糾錯(cuò)的方法,在兩個(gè)現(xiàn)實(shí)世界的對(duì)齊任務(wù)中探索了自我糾錯(cuò):緩解社會(huì)偏見和防范越獄攻擊。
圖3 BBQ數(shù)據(jù)集上使用CaC的示例。
具體而言,首先對(duì)模型請(qǐng)求問題獲得回答初始回答,然后對(duì)該回答進(jìn)行評(píng)估,得到獎(jiǎng)勵(lì)。之后將初始回答,評(píng)估送入上下文,并重新請(qǐng)求問題,得到改正后的回答。此過程可多次重復(fù)以迭代改進(jìn)回答,最終以最后一輪的模型回答作為模型的最終輸出。
消除LLM社會(huì)偏見
本文使用 BBQ(Bias Benchmark for QA)數(shù)據(jù)集,在 vicuna-7B 和 Llama2-7b-chat 模型上測(cè)試了 CaC 方法的效果。此外,還在 BBQ 上研究了模型大小、評(píng)估質(zhì)量和糾錯(cuò)輪數(shù)對(duì)糾錯(cuò)效果的影響。主要結(jié)論如下:
- 多數(shù)情況下,自我糾錯(cuò)后的正確率高于原正確率(圖4)
- 正確率提升與自我評(píng)估的準(zhǔn)確率高度相關(guān)(圖4(c): ),甚至呈線性關(guān)系(圖5(a))。
- 采用不同的評(píng)價(jià)方式效果依次提升:僅使用對(duì)/錯(cuò)評(píng)價(jià) < 自然語言評(píng)價(jià) < 包含 CoT 的對(duì)/錯(cuò)評(píng)價(jià)。這是因?yàn)?CoT 不僅能提高評(píng)價(jià)準(zhǔn)確性,還能為模型提供額外的自然語言信息。(圖5(b))
- 更大的模型有更好的糾錯(cuò)能力(圖5(c)(d))
- 當(dāng)評(píng)價(jià)的正確率足夠高時(shí),更多的糾錯(cuò)輪數(shù)可以帶來更好的糾錯(cuò)效果。(圖5(e))
圖4 CaC對(duì)于不同種類的偏見的修正
圖5 BBQ上關(guān)于模型大小、評(píng)估質(zhì)量以及糾錯(cuò)輪數(shù)的消融實(shí)驗(yàn)
同時(shí),在防御越獄攻擊的實(shí)驗(yàn)中,CaC也是所有測(cè)試的防御手段中最低的。
更多文章細(xì)節(jié),請(qǐng)參考原文:https://openreview.net/pdf?id=OtvNLTWYww