對標(biāo)GPT-4代碼解釋器!港中大讓模型寫代碼解決數(shù)學(xué)難題,得分超越GPT-4
對標(biāo)GPT-4代碼解釋器,港中大最新研究放了個“大招”:
他們開發(fā)了一個叫做MathCoder的大模型,數(shù)學(xué)能力直接在競賽級“題庫”Math上超過GPT-4。
△ 形象為羊駝是因為MathCoder底層模型來自羊駝家族
做到這一點靠的就是無縫集成代碼的能力——
在遇到數(shù)學(xué)問題時,它不僅能用自然語言推理,還能自動編寫和執(zhí)行代碼來建模、推導(dǎo)公式與方程。
這樣的工作方式無疑和強大的GPT-4代碼解釋器一樣。
在實際評測中,MathCoder除了超過GPT-4,還順利在MATH和GSM8K兩大數(shù)據(jù)集上取得了開源LLM中的SOTA(打敗了8月份才誕生的WizardMath)
這個“新王”究竟是如何誕生的?
對標(biāo)GPT-4代碼解釋器
總的來看,港大這項研究為了提高大模型的數(shù)學(xué)推理能力,學(xué)習(xí)了GPT-4代碼解釋器的優(yōu)點和工作原理,提出了一種微調(diào)開源語言模型的方法。
該方法最終使大模型無縫集成代碼,利用代碼來解決數(shù)學(xué)問題。
具體而言,他們首先提出了一個可以生成高質(zhì)量數(shù)學(xué)題的數(shù)據(jù)集:MathCodeInstruct。
該數(shù)據(jù)集由兩部分組成:
種子數(shù)據(jù)(D0):主要基于GSM8K和MATH,并利用GPT-4收集答案。
插值數(shù)據(jù)(D1):讓GPT-4基于他們提出的一種叫做“問題插值提示”的方法生成。
如下圖所示:
示例1和2分別來自于GSM8K和MATH,1簡單,2難一些,GPT-4要做的“插值”就是生成比1難但比2更簡單的新問題。
基于以上兩類問題,最終MathCodeInstruct數(shù)據(jù)集一共收集了8萬道數(shù)學(xué)題。
如下表所示,這比業(yè)內(nèi)其他數(shù)據(jù)集規(guī)模稍小一些:
而與其他數(shù)據(jù)集相比,它的特點之一是同時彌補了GSM8K和MATH這兩大重要數(shù)據(jù)集中不足的部分,給出了一些難度范圍更廣的問題,增強了數(shù)據(jù)集的泛化能力。
特點之二是數(shù)據(jù)集中的每道題目同時包含基于自然語言推理的部分+基于代碼解決的部分(包括執(zhí)行代碼和代碼輸出結(jié)果)。
如下圖所示,這是對上面GPT-4生成的“插值”問題的解決思路:
在數(shù)據(jù)集準(zhǔn)備好以后,團隊便提出了一種定制的監(jiān)督微調(diào)和推理方法,最終在Llama-2和Code Llama上微調(diào)出了MathCoder。
具體而言,該方法使用特殊的token(<|text|>、<|code|>、<|execution|>)來識別訓(xùn)練數(shù)據(jù)集中哪一部分是自然語言、代碼還是結(jié)果,讓模型學(xué)習(xí)生成由這些特殊標(biāo)記劃分的自然語言和代碼。
在推理期間,該方法還會將動態(tài)執(zhí)行的結(jié)果附加到模型的先前預(yù)測中。
然后,繼續(xù)基于這個新版本的輸入自回歸預(yù)測下一個token,以及最后的執(zhí)行結(jié)果。
作者表示,通過這種方式,模型將能夠“看到”執(zhí)行結(jié)果,并不斷地繼續(xù)推理。
最終,該方法使微調(diào)模型MathCoder以類似GPT-4代碼解釋器的方式運行。
在評測中,MathCoder憑此直接在MATH和GSM8K這倆數(shù)據(jù)集上取得了45.2%和83.9%的好成績。
該成績證明:
其一,它超過了ChatGPT-3.5和PaLM-2等9個閉源模型,并在以數(shù)學(xué)競賽題為主的MATH集上超過GPT-4。
其二,它打敗了此前數(shù)學(xué)領(lǐng)域里最強的開源模型WizardMath,成為新的開源之最。
不過其三,模仿但還未超越,在這倆數(shù)據(jù)集上,MathCoder還是與GPT-4代碼解釋器(69.7%和97%高分)存在著一定的性能差距。
作者介紹
本研究一共10位作者,除了兩位來自香港城市大學(xué)以外,其余均來自香港中文大學(xué)。
共同一作一共有6位,分別是:Ke Wang、Houxing Ren、Aojun Zhou、Zimu Lu、Sichun Luo和Weikang Shi。
通訊作者為李鴻升,為港中大電子工程系副教授,同時也就職于上海人工智能研究室。
論文地址:https://arxiv.org/abs/2310.03731