大模型+蒙特卡洛樹搜索,一招讓LLaMa-3 8B奧數(shù)水平直逼GPT-4
這幾天,17 歲中專生姜萍在 2024 阿里巴巴全球數(shù)學(xué)競賽預(yù)選賽中取得全球第 12 名的新聞刷了屏。而同時(shí),AI 挑戰(zhàn)賽的成績顯示,在所有 563 支 AI 參賽隊(duì)伍中,最高分 34 分,平均分 18 分,趕上了人類選手平均水平。
AI 參與數(shù)學(xué)競賽的主要短板是邏輯推理能力弱,證明題很難拿到完整得分點(diǎn)。這也是 GPT-4、LLaMA 等當(dāng)前大語言模型(LLM)在需要策略和邏輯推理的任務(wù)中面臨的重大挑戰(zhàn)。
其中的一大障礙是輸出的準(zhǔn)確性和可信度,尤其是在需要保證精度的數(shù)學(xué)上下文中,LLM 在推理時(shí)往往容易產(chǎn)生幻覺。輸出結(jié)果表面上看似合理,但實(shí)際上不相關(guān)或事實(shí)不正確,最終導(dǎo)致不合理的推理過程。
雖然像 Self-Refine 這樣的重寫技術(shù)有助于緩解這種傾向,但依然可能導(dǎo)致現(xiàn)實(shí)世界復(fù)雜的數(shù)學(xué)問題產(chǎn)生誤導(dǎo)性或錯(cuò)誤的結(jié)果。
因此,為了應(yīng)對這些挑戰(zhàn),來自復(fù)旦大學(xué)、上海 AI Lab 的研究者提出了 MCT Self-Refine(MCTSr),將 LLM 與蒙特卡洛樹搜索(MCTS)算法相結(jié)合,并重點(diǎn)提高 LLM 在復(fù)雜數(shù)學(xué)推理任務(wù)(比如奧數(shù)競賽題)中的表現(xiàn)。
作為一種決策工具,MCTS 廣泛應(yīng)用于人工智能中需要戰(zhàn)略規(guī)劃的場景,通常用于游戲和復(fù)雜的問題解決環(huán)境。本文通過將 MCTS 的系統(tǒng)探索能力與 LLM 的 Self-Refine 和 Self-Evaluation 能力相結(jié)合, 旨在創(chuàng)建一個(gè)更強(qiáng)大的框架來應(yīng)對當(dāng)前 LLM 難以解決的復(fù)雜推理任務(wù)。
- 論文地址:https://arxiv.org/pdf/2406.07394
- 項(xiàng)目地址:https://github.com/trotsky1997/MathBlackBox
不過,在將 MCTS 與 LLM 集成過程中存在一些技術(shù)挑戰(zhàn)。傳統(tǒng)的 MCTS 策略可能與 LLM 輸出的隨機(jī)性和生成性不太吻合,后者通常涉及無限、連續(xù)的潛在動作空間。這種不一致需要在 MCTS 框架內(nèi)采用定制的期望計(jì)算和反向傳播方法,以更好地適應(yīng) LLM 的特有屬性。
此外,研究者還引入了一種動態(tài)剪枝策略,它結(jié)合了改進(jìn)的置信上限(UCB)公式,以優(yōu)化高風(fēng)險(xiǎn)任務(wù)中有效決策制定所需要的探索 - 利用平衡。
可以說,這項(xiàng)研究推進(jìn)了 LLM 在復(fù)雜推理挑戰(zhàn)中的應(yīng)用,為未來整合 AI 相關(guān)的技術(shù)創(chuàng)新奠定了基礎(chǔ),從而使得 LLM 驅(qū)動的應(yīng)用擁有了更強(qiáng)大的決策制定、推理準(zhǔn)確性和可靠性。
方法概覽
MCTSr 架構(gòu)圖如圖 1 所示:
MCTSr 工作流包括:
- 初始化:使用模型生成的答案和虛擬響應(yīng)建立根節(jié)點(diǎn),以最大限度地減少模型過度擬合趨勢;
- 選擇:該算法采用值函數(shù) Q 對所有未完全展開的答案進(jìn)行排序,并采用貪心策略選擇值最高的節(jié)點(diǎn)進(jìn)行進(jìn)一步的探索和優(yōu)化;
- Self-Refine :選擇好的答案 a 使用 Self-Refine 框架進(jìn)行優(yōu)化。最初,模型生成反饋 m,指導(dǎo)優(yōu)化過程以產(chǎn)生增強(qiáng)的答案 a ′;
- Self-Evaluation:精煉后的答案經(jīng)過評分從而采樣一個(gè)獎勵(lì)值,并計(jì)算其 Q 值。這涉及模型自我獎勵(lì)反饋和約束,如嚴(yán)格的評分標(biāo)準(zhǔn)和抑制滿分,以確保評分的可靠性和公平性;
- 反向傳播:將精煉答案的值反向傳播到其父節(jié)點(diǎn)和其他相關(guān)節(jié)點(diǎn),以更新樹的值信息。如果任何子節(jié)點(diǎn)的 Q 值發(fā)生變化,則更新父節(jié)點(diǎn)的 Q;
- UCT 更新:在所有節(jié)點(diǎn)的 Q 值更新完成后,確定一個(gè)候選節(jié)點(diǎn)集合 C,用于進(jìn)一步擴(kuò)展或選擇,然后使用 UCT 更新公式更新所有節(jié)點(diǎn)的 UCT 值,以備下一步的選擇階段。
迭代上述階段,直到滿足終止條件 T 為止。
Self-Refine
在 self-refine 階段, 模型通過多輪對話完善提示來優(yōu)化針對問題 P 的答案 a。首先,模型生成一個(gè)關(guān)于答案 a 的反思性或批判性評論 m。隨后,在 m 的指導(dǎo)下,模型修改答案 a,產(chǎn)生一個(gè)改進(jìn)版本 a',這種迭代的精煉方式提高了模型響應(yīng)質(zhì)量。
自評估
在數(shù)學(xué)問題 P 的答案精煉過程中,一個(gè)答案 a 的 Q 值被定義為將 a 進(jìn)一步精煉成更優(yōu)答案的預(yù)期質(zhì)量。這個(gè)定義是基于從 a 到其重寫形式的轉(zhuǎn)換具有馬爾可夫性質(zhì),即下一個(gè)狀態(tài)(即改寫后的答案)僅依賴于當(dāng)前狀態(tài)(即當(dāng)前的答案 a),而與之前的狀態(tài)無關(guān)。
此外,研究者還設(shè)計(jì)了三個(gè)約束:提示約束、滿分抑制、重復(fù)采樣。采樣后,計(jì)算 a 的 Q 值。
反向傳播
在所有葉節(jié)點(diǎn)的獎勵(lì)值經(jīng)過采樣和 Q 值更新完成后,然后將這些變化傳播至其父節(jié)點(diǎn)和祖節(jié)點(diǎn)。在這個(gè)更新過程中,如果節(jié)點(diǎn) a 的子節(jié)點(diǎn)集合 Children (a) 中任何元素的 Q 函數(shù)值發(fā)生變化,那么節(jié)點(diǎn) a 的 Q 函數(shù)值也將進(jìn)行更新。這樣的傳播確保了節(jié)點(diǎn)的 Q 值能夠反映其所有可能子節(jié)點(diǎn)的最新狀態(tài)和評估。
更新 UCT 和選擇
在更新了樹中所有節(jié)點(diǎn)的 Q 值之后,會進(jìn)入下一輪選擇階段。這個(gè)過程包括以下步驟:
- 候選節(jié)點(diǎn)選擇:在選擇節(jié)點(diǎn)時(shí),研究者無需從根節(jié)點(diǎn)開始,而是按層次順序遍歷樹中的節(jié)點(diǎn)。
- UCT 更新:借鑒 AlphaGo,該研究使用 UCT 和 UCB-1 方法來平衡節(jié)點(diǎn)的探索和利用;對于候選集 C 中的節(jié)點(diǎn) a,其 UCT_a 值為:
終止函數(shù)
提前終止:當(dāng)搜索結(jié)果的改進(jìn)開始減少或連續(xù)搜索產(chǎn)生重復(fù)結(jié)果時(shí),終止發(fā)生。
搜索約束:一旦展開次數(shù)達(dá)到預(yù)定限制或樹中的一個(gè)或多個(gè)節(jié)點(diǎn)滿足最大深度約束,搜索就會終止。
實(shí)驗(yàn)結(jié)果
為了評估 MCTSr 算法在解決數(shù)學(xué)問題中的有效性,研究者將 LLaMA3-8B 作為基礎(chǔ)模型,并使用 MCTSr 進(jìn)行增強(qiáng)。他們在 Zero-Shot CoT、Self-Refine、4-rollouts MCTSr 和 8-rollouts MCTSr 等幾種設(shè)置中,將 LLaMA3-8B 與 GPT-4、Claude 3 和 Gemini 1.5-Pro 等進(jìn)行了比較。
研究者在 GSM8K 和 GSM-hard 測試集(它們分別包含了典型和具有挑戰(zhàn)性的數(shù)學(xué)問題)上評估了上述方法,結(jié)果如下表 1 所示。
可以發(fā)現(xiàn),MCTSr 的 rollout 次數(shù)與成功率之間存在著直接相關(guān)性,并隨著迭代次數(shù)增加而顯著提升,在不太復(fù)雜的 GSM8K 中尤為明顯。不過對于更復(fù)雜的 GSM-Hard 測試集,即使 rollout 次數(shù)更高也會達(dá)到性能上限,表明當(dāng)前策略在解決復(fù)雜問題時(shí)存在局限性。
這些結(jié)果強(qiáng)調(diào)了 MCT-Self-refine 算法的穩(wěn)健性和潛在邊界,以及持續(xù)改進(jìn)的必要性,從而有效應(yīng)對更復(fù)雜的挑戰(zhàn)。
下表 2 展示了在 MATH 數(shù)據(jù)集上應(yīng)用不同復(fù)雜度級別的 MCT-Self-refine 算法的結(jié)果。數(shù)據(jù)集分為五個(gè)難度級別,從 Level 1(最簡單)到 Level 5(最具挑戰(zhàn)性)。
結(jié)果顯示,Level 1 的成功率最高,8 次 rollout 后,MCTSr 實(shí)現(xiàn)了 90.16% 的成功率,解決了 437 個(gè)問題中的 394 個(gè)。隨著 rollout 次數(shù)的增加,這一級別的成功率顯著提高。
在最具挑戰(zhàn)性的 Level 5 難度,8 次 rollout 后,MCTSr 的成功率為 34.06%,解決了 1324 個(gè)問題中的 451 個(gè)。這說明了隨著難度不斷增加,該算法在高度復(fù)雜的場景中性能受到限制。
所有級別的整體性能顯示,8 次 rollout 后,MCTSr 的累計(jì)成功率為 58.24%,解決了 5000 個(gè)問題中的 2912 個(gè)。這一成功率相較于 Zero-Shot CoT 的初始成功率 24.36% 有了顯著提高。這表明了,rollout 次數(shù)的增加與成功率的提高呈現(xiàn)出一致性,強(qiáng)調(diào)了 MCT-Self-refine 算法在提升不同數(shù)學(xué)復(fù)雜度級別的問題解決能力方面的有效性。
這些結(jié)果還驗(yàn)證了 MCT-Self-refine 算法在學(xué)術(shù)和問題解決上下文中的潛力,并強(qiáng)調(diào)了其對 MATH 數(shù)據(jù)集中不同復(fù)雜度級別問題的可擴(kuò)展性和適應(yīng)性。
下表 3 為 MCT-Self-refne 算法在奧數(shù)競賽的三個(gè)數(shù)據(jù)集上進(jìn)行了測試:AlME、GAIC Math Odyssey 和 OlympiadBench。
AIME:從 Zero-Shot CoT 的 2.36%(解決 22 個(gè)問題)到 MCTSr 的 11.79%(解決 110 個(gè)問題)。
GAIC Math Odyssey:成功率從 17.22%(解決 67 個(gè)問題)上升至 49.36%(解決 192 個(gè)問題)。
OlympiadBench:從 Zero-Shot CoT 的 1.25%(解決 16 個(gè)問題)提高到 MCTSr 的 7.76%(解決 99 個(gè)問題)。
這些結(jié)果證實(shí)了 MCT-Self-refine 算法在未見過的數(shù)學(xué)問題上的適用性,表明其在奧林匹克等競爭性學(xué)術(shù)環(huán)境中具有優(yōu)勢。
如表 4 所示。與當(dāng)前閉源大模型進(jìn)行比較時(shí),MCTSr 可以有效提升小參數(shù)開源模型(如 LLaMa-3)的數(shù)學(xué)推理能力到相當(dāng)?shù)乃健?/span>
更多技術(shù)細(xì)節(jié)和實(shí)驗(yàn)結(jié)果請參閱原論文。