34B參數量超越GPT-4!「數學通用大模型」MAmmoTH開源:平均準確率最高提升29%
數學推理問題是語言模型繞不過的痛點,在各種黑科技的加持下,開源模型的推理性能依然不夠看。
最近,滑鐵盧大學、俄亥俄州立大學、香港科技大學、愛丁堡大學的研究人員聯合開源了一個專為「通用數學問題」定制的大模型MAmmoTH和一個指令調優(yōu)數據集MathInstruct.
論文鏈接:https://arxiv.org/pdf/2309.05653.pdf
項目鏈接:https://tiger-ai-lab.github.io/MAmmoTH/
MathInstruct由13個具有中間原理的數學數據集編譯而成,其中6個為新數據集,混合了思想鏈(CoT)和思想程序(PoT),并確保覆蓋了廣泛的數學領域。
CoT和PoT的混合不僅可以釋放工具使用的潛力,而且還允許模型針對不同的數學問題進行不同的思維過程。
因此,MAmmoTH系列在所有尺度上的9個數學推理數據集上的表現大大優(yōu)于現有的開源模型,平均準確率提高了12%至29%。
其中MAmmoTH-7B模型在MATH(競賽級數據集)上的準確率達到了35%,超過了最好的開源7B模型(WizardMath)25%,MAmmoTH-34B模型在MATH上的準確率達到了46%,甚至超過了GPT-4的CoT結果。
數學推理領域新王:MAmmoTH
在數學推理任務上,開源和閉源的大型語言模型(LLM)之間存在巨大的性能差距,目前基準數據集上的sota仍然是GPT-4,PaLM-2和Claude等閉源模型,其他開源模型如Llama,Falcon和OPT等仍然遠遠落后。
為了彌補性能差距,主要的研究方法有兩類:
1. 如Galactica,MINERVA等模型,繼續(xù)使用數學相關的網絡數據對語言模型進行訓練,可以提高模型的通用科學推理能力,但計算成本會更高;
2. 如拒絕采樣微調(RFT)和WizardMath等,使用特定領域數據集對模型進行微調,雖然可以提高領域內性能,但無法適用于更廣泛的數學推理任務。
在解決數學問題時,現有方法通常會采用思維鏈(CoT)方法引導語言模型循序漸進地用自然語言描述來解決數學問題。
雖然在大多數數學主題下表現出很好的通用性,但在需要精確或復雜的數學計算、算法推理的問題下(如求解二次方程根,計算矩陣特征值)表現不佳。
相比之下,思維程序(PoT, Program-of-Thought)方法和PAL利用外部工具(即Python解釋器)大大簡化了數學求解過程,將計算過程卸載到外部Python解釋器,以解決復雜的數學和算法推理過程(例如,用sympy求解二次方程或用numpy計算矩陣特征值)。
然而,PoT在處理更抽象的推理場景方面有所欠缺,尤其是在沒有內置API的情況下,常識推理、形式邏輯和抽象代數的推理能力會更差。
方法概述
研究人員的目標是編制一個高質量、多樣化的數學指令調整(instruction-tuning)數據集列表。
1. 覆蓋不同數學領域和復雜度
更全面的數據集可以讓模型接觸到多樣化的數學知識,提升模型的多功能性。
研究人員將選擇范圍縮小到幾個被廣泛采用的高質量數據集,包括GSM8K、math、AQuA、Camel和TheoremQA.
還可以注意到,現有的數據集缺乏對大學水平的數學知識的覆蓋,如抽象代數和形式邏輯,所以研究人員選擇使用GPT-4來合成TheoremQA問題中的思維鏈(CoT)原理,利用網絡上找到的數個種子樣例,通過自我指導(self-instruct)創(chuàng)建問題和CoT的數據對。
2. 混合CoT和PoT
現有的研究方法大多只關注CoT,并且數據集中也只包含有限的解題思路,導致CoT和PoT的數據量十分不均衡。
為了解決該問題,研究人員利用GPT-4來補充選定數據集的PoT解題思路,通過對比合成程序的執(zhí)行結果以及人工標注的答案進行過濾,確保生成數據的高質量。
遵循上述方法,最后得到了26萬條指令、回復數據對,涵蓋了廣泛的核心數學領域,如算術、代數、概率、微積分和幾何等,混合了CoT和PoT基本原理,并提供多種語言、多個難度級別的數據,足以證明數據集的高品質和獨特性。
訓練步驟
研究人員統一了MathInstruct中的所有子集,將指令數據集的結構標準化為Alpaca模型的格式,使得模型無需考慮原始數據集的格式,在微調階段統一處理數據即可。
研究人員選擇開源模型Llama-2和Code Llama作為基礎模型,在7B、13B、34B和70B尺寸的模型上進行微調。
實驗部分
評估數據集
研究人員選擇了不同數學領域下的樣本,對模型的通用數學推理能力進行評估:
領域內數據集包括GSM8K,MATH,AQuA-RAT,NumGLUE;領域外數據集包括SVAMP,Mathematics,SimulEq,SAT-Math和SimulEq,涵蓋了小學、高中和大學水平的數學問題,部分數據集甚至包括形式邏輯和常識推理。
問題類型為開放式問題和多選題,其中開放式問題(如GSM8K、數學)采用PoT解碼,因為大多數問題都可以由程序解決;多項選擇題(如AQuA、MMLU)采用CoT解碼。
CoT解碼不需要觸發(fā)詞,PoT需要觸發(fā)短語「讓我們寫個程序來解決這個問題」(Let’s write a program to solve the problem)。
實驗結果
總的來說,MAmmoTH和MAmmoTH-Coder在不同的模型尺寸上均優(yōu)于SoTA模型,并且在領域外(OOD)數據集上的增益要顯著優(yōu)于領域內(IND)數據集,展現出了該模型作為數學通才模型的潛力,甚至在幾個數據集上,MAmmoTH-Coder-34B和MAmmoTH-70B甚至超過了閉源模型。
在領域內數據的評估,MAmmoTH模型的主要競爭對手是WizardMath和Platypus,其中WizardMath的訓練深度依賴于GSM8K和MATH數據集,Platypus在更廣泛的文本和數學推理數據集上對LLM進行微調。
相比之下,MAmmoTH實現了全面的改進,并且更擅長解決復雜數學問題,相比WizardMath(MATH數據的sota)的增益最高超過了25%
在領域外數據評估中,主要競爭模型依然是Platypus,不過MAmmoTH可以實現比領域內數據更高的性能提升,展現出對未知數學問題的通用能力。
值得注意的是,MAmmoTH-7B還將WizardMath-7B在MMLU-Math上的CoT性能大幅提高了9%,其中包含大量沒有在訓練數據集中涵蓋的主題。
不同基礎模型之間的對比
可以發(fā)現,Code-Llama作為基礎模型時的效果始終優(yōu)于Llama-2,尤其是在領域外數據集上,二者之間的性能差異甚至達到了5%,其中MAmmoTH-Coder(34B)在領域外數據集上的平均性能實際上高于MAmmoTH(70B)
研究人員認為,MAmmoTH-Coder從Code-Llama的持續(xù)代碼訓練中受益匪淺,不僅增強了PoT能力,還提高了Llama的通用推理技能。