7B羊駝戰(zhàn)勝540B“谷歌版GPT”,MIT用博弈論調(diào)教大模型,無需訓(xùn)練就能完成
基于博弈論,MIT提出了一種新的大模型優(yōu)化策略。
在其加持之下,7B參數(shù)的Llama在多個(gè)數(shù)據(jù)集上超越了540B的“谷歌版GPT”PaLM。
而且整個(gè)過程無需對(duì)模型進(jìn)行額外訓(xùn)練,消耗的算力資源更低。
圖片
這種基于博弈論制定的優(yōu)化策略被稱為均衡排名(Equilibrium Ranking)。
研究團(tuán)隊(duì)將大模型語言解碼過程轉(zhuǎn)化為正則化不完全信息博弈。
這個(gè)詞可以拆解成“正則化”和“不完全信息博弈”兩部分,我們將在原理詳解部分展開介紹。
在博弈過程中,模型不斷對(duì)生產(chǎn)的答案進(jìn)行優(yōu)化,讓生成結(jié)果更加符合事實(shí)。
實(shí)驗(yàn)結(jié)果表明,在多個(gè)測(cè)試數(shù)據(jù)集上,均衡排名優(yōu)化方式的效果顯著優(yōu)于其他方式,甚至其他模型。
那么,均衡排序方法具體是如何將博弈論應(yīng)用到大模型當(dāng)中的呢?
讓大模型“自我博弈”
前面提到,研究人員將大模型進(jìn)行語言解碼的過程直接變成了“正則化不完全信息博弈”過程。
不完全信息博弈是整個(gè)方法的核心,正則化則是一種避免出錯(cuò)的機(jī)制,我們先來看這種博弈。
具體而言,他們?cè)O(shè)計(jì)了生成器(G)和判別器(D)兩個(gè)模塊,它們掌握著不同的信息,扮演不同角色。
生成器根據(jù)環(huán)境(N)隨機(jī)給出的“正確性參數(shù)”生成答案;判別器則只負(fù)責(zé)判斷生成器的答案是否正確,而不看環(huán)境參數(shù)。
如果判別器的判斷與環(huán)境參數(shù)一致,兩者都得到1分獎(jiǎng)勵(lì),否則都不得分。
圖片
在執(zhí)行重復(fù)的生成和判別當(dāng)中,模型的目標(biāo)是達(dá)到納什均衡。
在納什均衡策略組合下單方面改變自己的策略,而其他玩家策略不變,都不會(huì)提高自身的收益。
舉個(gè)例子,張三和李四一起決定晚餐吃什么,選項(xiàng)有火鍋和燒烤,其他已知條件如下:
- 張三對(duì)火鍋的滿意度是2分(很喜歡),對(duì)燒烤的滿意度為1分(還可以)
- 李四對(duì)燒烤的滿意度是2分,對(duì)火鍋的滿意度為1分
- 兩個(gè)人都不想自己?jiǎn)为?dú)吃飯,因此單獨(dú)吃飯時(shí)滿意度均為0分
此時(shí),兩人的選擇共有四種方式,對(duì)應(yīng)的滿意度得分如下表:
圖片
這一情境下,兩人選擇相同時(shí)即為最佳策略,此時(shí)只要任何一個(gè)人單方面改變策略,兩人的滿意度將同時(shí)變?yōu)?。
回到均衡排名優(yōu)化法當(dāng)中,生成器和判別器會(huì)先初始化策略,二者的依據(jù)分別基于問題或答案。
這一環(huán)境下的納什均衡如下表所示:
圖片
初始化完成后,生成器和判別器會(huì)進(jìn)行多輪博弈,逐步更新策略,直到迭代終止。
每一次博弈結(jié)束后,分別計(jì)算判別器和生成器的得分和最優(yōu)策略得分的差值,稱為“后悔值”。
然后逐步進(jìn)行迭代,直到后悔值收斂,逼近納什均衡。
達(dá)到納什均衡后,生成器和判別器的策略便確定,會(huì)分別對(duì)候選答案進(jìn)行打分,然后進(jìn)行排序選出最佳答案。
在納什均衡條件下,二者的評(píng)分應(yīng)當(dāng)是一致的,如果不一致,答案便會(huì)被剔除。
不過由于給生成器和判斷器打分的標(biāo)準(zhǔn)是與環(huán)境信息的一致性,而不是客觀事實(shí),因此單純追求達(dá)到納什均衡,不一定能保證答案合理。
為了避免二者同時(shí)出錯(cuò)的情況出現(xiàn),開發(fā)者還引入了正則化糾錯(cuò)機(jī)制。
首先是向生成器和判別器基于客觀事實(shí)的先驗(yàn)策略,而不是任由其隨機(jī)初始化。
這些先驗(yàn)策略是生成器和判別器生成策略的“金科玉律”,引導(dǎo)了策略的優(yōu)化方向。
圖片
在此還有一種KL懲罰策略,當(dāng)新的策略出現(xiàn)時(shí),會(huì)計(jì)算其與初始策略的KL散度(又叫相對(duì)熵)。
KL散度描述了二者之間的相關(guān)性,數(shù)值越大,相關(guān)性越低。
假設(shè)P(x)和Q(x)分別是隨機(jī)變量X上的兩個(gè)概率分布,則在離散和連續(xù)的情形下,KL散度分別為:
圖片
這一結(jié)果會(huì)加入到生成新策略的函數(shù)當(dāng)中,避免了最終生成的結(jié)果偏離客觀事實(shí)。
如下式所示,獎(jiǎng)勵(lì)函數(shù)U中包含了KL散度項(xiàng),并設(shè)置了懲罰系數(shù)λ(>0)。
圖片
當(dāng)KL散度越大,也就是和客觀事實(shí)偏差越大時(shí),模型獲得的獎(jiǎng)勵(lì)分?jǐn)?shù)將會(huì)降低。
這樣一來,當(dāng)生成器和判別器結(jié)果一致卻不符合事實(shí)時(shí),相關(guān)結(jié)果不會(huì)獲得高評(píng)分,也就不會(huì)成為最終答案。
憑借著這樣的策略,研究團(tuán)隊(duì)用更低的消耗讓7B的Llama取得了優(yōu)異的成績(jī)。
部分能力超越“谷歌版GPT”
總的來說,均衡排序優(yōu)化后的Llama在常識(shí)推理、閱讀理解、數(shù)學(xué)和對(duì)話任務(wù)中的表現(xiàn)都十分出色。
選擇題方面,同樣是Llama,經(jīng)均衡排名方法優(yōu)化之后,模型在MMLU等多個(gè)數(shù)據(jù)集上的成績(jī)都排在比較靠前的位置。
圖片
問答題方面,均衡排名策略優(yōu)化后的13B Llama在TruthfulQA數(shù)據(jù)集中取得了最佳成績(jī),7B版本也與第一名相差無幾。
圖片
除了文本相關(guān)的理解和推理,模型在數(shù)學(xué)方面也達(dá)到了較高水平。
7B Llama模型的諸多優(yōu)化方式中,均衡排序取得了GSM8K測(cè)試的最好成績(jī)。
圖片
均衡排序方法不僅是諸多Llama優(yōu)化方式中的佼佼者,優(yōu)化后的Llama成績(jī)也超過了其他模型。
在ARC數(shù)據(jù)集的Challenge分集和RACE數(shù)據(jù)集的High分集上,Llama-7B+均衡排序的準(zhǔn)確率分別為58.3%和56.4%,顯著超越了PaLM-540B的53.0%和49.1%。
更多具體細(xì)節(jié),可以到原論文中一探究竟。