被OpenAI帶火的強(qiáng)化微調(diào)RFT技術(shù)解析
OpenAI年終大戲第二場推出了強(qiáng)化微調(diào)RFT (Reinforcement Fine-Tuning),它可以讓你用幾十到幾千個(gè)的訓(xùn)練數(shù)據(jù),為特定的復(fù)雜任務(wù)構(gòu)建專家模型,加強(qiáng)了模型如何處理類似問題的推理,微調(diào)后的o1-mini得分提高80%,反超o1正式版!
強(qiáng)化微調(diào)技術(shù)的一種實(shí)現(xiàn)方式:首先通過監(jiān)督式微調(diào)(Supervised Fine-Tuning)對模型進(jìn)行預(yù)熱,然后利用在線強(qiáng)化學(xué)習(xí),特別是PPO算法,進(jìn)一步微調(diào)模型。這種方法能夠自動(dòng)采樣多種推理路徑,并從真實(shí)答案中自然派生出獎(jiǎng)勵(lì)信號。
SFT和ReFT在CoT替代方案存在時(shí)的比較
強(qiáng)化微調(diào)(RFT)的兩個(gè)主要階段:預(yù)熱階段和強(qiáng)化學(xué)習(xí)階段。
- 預(yù)熱階段(Warm-up):
- 在這個(gè)階段,模型使用包含“問題(question)”和“思維鏈(Chain-of-Thought,CoT)”元組的數(shù)據(jù)集進(jìn)行微調(diào),通常持續(xù)1-2個(gè)epoch。
- 目的是使模型具備基本的問題解決能力,能夠生成適當(dāng)?shù)捻憫?yīng)。
- CoT生成過程被分解為一系列預(yù)測下一個(gè)詞(token)的動(dòng)作,直到生成結(jié)束符(<eos>)。
- 強(qiáng)化學(xué)習(xí)階段(Reinforcement Learning):
- 在這個(gè)階段,模型通過在線自我學(xué)習(xí)的方式提高性能,使用包含“問題(question)”和“答案(answer)”元組的數(shù)據(jù)集。
- 模型通過重復(fù)采樣響應(yīng)、評估響應(yīng)的答案正確性,并在線更新其參數(shù)。
- 使用PPO(Proximal Policy Optimization)算法進(jìn)行訓(xùn)練,其中價(jià)值模型(value model)V?是基于預(yù)熱階段后的政策模型πθ的最后隱藏狀態(tài)構(gòu)建的。
- 獎(jiǎng)勵(lì)函數(shù)在終端狀態(tài)時(shí)直接比較從狀態(tài)的CoT提取的答案和真實(shí)答案y,正確則返回1,否則返回0。對于數(shù)值型答案的數(shù)據(jù)集,還可以應(yīng)用部分獎(jiǎng)勵(lì)(partial reward)0.1。
- 總獎(jiǎng)勵(lì)是獎(jiǎng)勵(lì)函數(shù)得分和學(xué)習(xí)到的RL政策與初始政策之間的Kullback-Leibler(KL)散度的和。
GSM8K中的一個(gè)問題(x)、思維鏈(CoT)(e)和答案(y)的示例。SFT過程在訓(xùn)練數(shù)據(jù)上迭代多個(gè)周期。提出的ReFT從SFT預(yù)熱并在同一數(shù)據(jù)上執(zhí)行RL訓(xùn)練。
實(shí)驗(yàn)表明,RFT在GSM8K、MathQA和SVAMP等數(shù)據(jù)集上的性能顯著優(yōu)于SFT,并且可以通過多數(shù)投票和重新排名等策略進(jìn)一步提升性能
ReFT和基線模型在所有數(shù)據(jù)集上微調(diào)后的價(jià)值準(zhǔn)確度
SFT和ReFT在GSM8K數(shù)據(jù)集中第1、3和5周期的P-CoT響應(yīng)對同一個(gè)問題的反應(yīng)。綠色框架內(nèi)的反應(yīng)是正確的,而紅色框架內(nèi)的反應(yīng)是錯(cuò)誤的。
https://arxiv.org/pdf/2401.08967
Code: https://github.com/lqtrung1998/mwp_ReFT
本文轉(zhuǎn)載自??PaperAgent??
