DeepSeek關(guān)鍵RL算法GRPO,有人從頭跑通了,貢獻(xiàn)完整代碼
GRPO(Group Relative Policy Optimization)是 DeepSeek-R1 成功的基礎(chǔ)技術(shù)之一,我們之前也多次報(bào)道過(guò)該技術(shù),比如《DeepSeek 用的 GRPO 占用大量?jī)?nèi)存?有人給出了些破解方法》。
簡(jiǎn)單來(lái)說(shuō),GRPO 算法丟棄了 critic model,放棄了價(jià)值函數(shù)近似,轉(zhuǎn)而通過(guò)組內(nèi)樣本的相對(duì)比較來(lái)計(jì)算策略梯度,從而有效降低了訓(xùn)練的不穩(wěn)定性,同時(shí)提高了學(xué)習(xí)效率。
既然 GRPO 如此有效,那么,你知道如何從頭開(kāi)始實(shí)現(xiàn) GRPO 嗎?
近日,AI 工程師和技術(shù)作家 Andriy Burkov 發(fā)布了一份「從頭開(kāi)始寫(xiě) GRPO 代碼」的教程,其中介紹了如何基于 Qwen2.5-1.5B-Instruct 模型構(gòu)建一個(gè)使用 GRPO 的分布式強(qiáng)化學(xué)習(xí)流程。
不過(guò),在我們深入這份教程之前,先簡(jiǎn)單介紹一下它的作者。Andriy Burkov 算得上是 AI 領(lǐng)域的一位著名科普作家,在加拿大拉瓦爾大學(xué)取得了計(jì)算機(jī)科學(xué)博士學(xué)位,還曾發(fā)表過(guò)兩本頗受歡迎的 AI 主題著作:《100 頁(yè)語(yǔ)言模型書(shū)》和《100 頁(yè)機(jī)器學(xué)習(xí)書(shū)》;書(shū)中一步步詳實(shí)地介紹了相關(guān)概念,并附帶了簡(jiǎn)明的實(shí)現(xiàn)代碼。
接下來(lái)我們就來(lái)看看這份 GRPO 從頭實(shí)現(xiàn)教程吧。
教程地址:https://github.com/aburkov/theLMbook/blob/main/GRPO_From_Scratch_Multi_GPU_DataParallel_Qwen_2_5_1_5B_Instruct.ipynb
從頭編寫(xiě) GRPO 代碼
使用 Qwen2.5-1.5B-Instruct 的分布式實(shí)現(xiàn)
本教程將展示如何使用 GRPO 方法構(gòu)建分布式強(qiáng)化學(xué)習(xí)(RL)流程,從而可以針對(duì)數(shù)學(xué)、邏輯和編程任務(wù)對(duì)語(yǔ)言模型進(jìn)行微調(diào)。
首先需要明確,這些任務(wù)都存在一個(gè)唯一且正確的 ground truth 答案,可通過(guò)簡(jiǎn)單的字符串比較輕松加以驗(yàn)證。
GRPO 的發(fā)明者是 DeepSeek,最早是被用于微調(diào) DeepSeek 的 R1 和 R1-Zero 模型 —— 它們可通過(guò)學(xué)習(xí)生成思維鏈(CoT)來(lái)更好地解決數(shù)學(xué)和邏輯問(wèn)題。
本教程的目標(biāo)是將通用語(yǔ)言模型 Qwen2.5-1.5B-Instruct 轉(zhuǎn)換為數(shù)學(xué)問(wèn)題求解器。我們將從頭開(kāi)始編寫(xiě) GRPO 代碼,然后將其與幾個(gè)流行的庫(kù)和工具集成起來(lái),以實(shí)現(xiàn)分布式訓(xùn)練管道流程,包括:
- PyTorch:用于張量運(yùn)算和分布式訓(xùn)練。
- Hugging Face Transformers:用于加載預(yù)訓(xùn)練的語(yǔ)言模型和 tokenizer。
- FlashAttention2:優(yōu)化的注意力機(jī)制,有助于減少內(nèi)存使用量并提高訓(xùn)練速度。
- Weights & Biases (wandb):用于實(shí)驗(yàn)跟蹤、可視化和模型版本控制。
本教程分為幾個(gè)部分。首先是基本設(shè)置和導(dǎo)入,然后是數(shù)據(jù)格式化和答案提取、數(shù)據(jù)集準(zhǔn)備、評(píng)估函數(shù)、獎(jiǎng)勵(lì)函數(shù)、訓(xùn)練設(shè)置和執(zhí)行,最后加載和測(cè)試模型。此過(guò)程中,我們將從頭實(shí)現(xiàn) GRPO 算法。
Part 1:基礎(chǔ)設(shè)置和導(dǎo)入
首先是安裝并導(dǎo)入所有必要的模塊。下面是導(dǎo)入庫(kù)的一段代碼截圖。
部分代碼截圖。完整代碼塊參見(jiàn) GitHub。
運(yùn)行上述代碼(參考項(xiàng)目完整代碼),可以執(zhí)行以下任務(wù):
- 設(shè)置隨機(jī)種子:set_random_seed 函數(shù)通過(guò)為 Python 的隨機(jī)模塊、NumPy 和 PyTorch 設(shè)置種子,確??蓮?fù)現(xiàn)性;
- 環(huán)境變量配置:設(shè)置 WANDB_API_KEY 和 WANDB_PROJECT 環(huán)境變量,以啟用與 Weights & Biases 的實(shí)驗(yàn)跟蹤;
- 導(dǎo)入必要的庫(kù),包括 random、copy、re、torch 等等。
Part 2:數(shù)據(jù)格式以及答案提取
接下來(lái),項(xiàng)目定義了數(shù)據(jù)格式,以及模型如何從輸出和數(shù)據(jù)集中提取答案段落。
為了確保模型輸出格式一致,項(xiàng)目還定義了一個(gè)系統(tǒng)提示。該提示指示模型生成包含 < reasoning > 和 < answer > 標(biāo)簽的輸出。這一步通過(guò)兩個(gè)函數(shù)完成:
- extract_answer_from_model_output:此函數(shù)獲取模型的輸出文本,并提取 < answer > 標(biāo)簽內(nèi)的內(nèi)容;
- extract_answer_from_dataset:此函數(shù)從 GSM8K 數(shù)據(jù)集中提取預(yù)期答案,該數(shù)據(jù)集使用 “####” 分隔符來(lái)分隔答案:
部分代碼截圖。完整代碼塊參見(jiàn) GitHub。
Part 3:數(shù)據(jù)準(zhǔn)備
該項(xiàng)目使用 GSM8K 數(shù)據(jù)集進(jìn)行訓(xùn)練。項(xiàng)目使用了該數(shù)據(jù)集中的示例來(lái)訓(xùn)練模型,基于強(qiáng)化學(xué)習(xí)(RL)訓(xùn)練范式,讓模型生成多個(gè)問(wèn)題解答樣本,之后作者將這些解答與 GSM8K 示例中的標(biāo)準(zhǔn)答案進(jìn)行對(duì)比,如果匹配,就為 RL 算法(GRPO)提供高獎(jiǎng)勵(lì),然后更新模型權(quán)重,以增加模型下次獲得高獎(jiǎng)勵(lì)的可能性。
實(shí)驗(yàn)過(guò)程是這樣的。首先從 Hugging Face 加載數(shù)據(jù)集,然后格式化每個(gè)示例,包括系統(tǒng)提示和用戶提示。這段實(shí)現(xiàn)代碼中還定義了兩個(gè)輔助函數(shù):prepare_dataset 以及 build_prompt。
部分代碼截圖。完整代碼塊參見(jiàn) GitHub。
Part 4:評(píng)估函數(shù)
評(píng)估對(duì)于跟蹤模型的進(jìn)展至關(guān)重要。因此作者定義了一些函數(shù),從而可以在一組示例上對(duì)模型進(jìn)行評(píng)估。該項(xiàng)目的評(píng)估函數(shù)執(zhí)行以下任務(wù):
- token 化提示并生成響應(yīng):模型的輸出是在 token 化提示的基礎(chǔ)上生成的。
- 提取預(yù)測(cè)答案:從生成的響應(yīng)中提取答案。
- 將預(yù)測(cè)答案與預(yù)期答案進(jìn)行比較:這種比較是通過(guò)精確匹配以及數(shù)值等價(jià)檢查來(lái)完成的。
在這段代碼中,兩個(gè)輔助函數(shù) _extract_last_number 和 _extract_single_number 被用來(lái)從文本中提取數(shù)字。評(píng)估函數(shù) evaluate_model 使用這些輔助函數(shù)來(lái)確定預(yù)測(cè)答案是否正確:
部分代碼截圖。完整代碼塊參見(jiàn) GitHub。
Part 5:獎(jiǎng)勵(lì)函數(shù)
在強(qiáng)化學(xué)習(xí)中,獎(jiǎng)勵(lì)函數(shù)是必不可缺的,作者定義了兩個(gè)獎(jiǎng)勵(lì)函數(shù):
correctness_reward:這個(gè)函數(shù)根據(jù)生成的答案是否正確來(lái)分配獎(jiǎng)勵(lì)。采用兩種方式:精確的字符串匹配和數(shù)值等價(jià)檢查,將模型輸出的答案與預(yù)期答案進(jìn)行比較。完全匹配會(huì)獲得更高的獎(jiǎng)勵(lì)(2.0),而基于數(shù)值等價(jià)的匹配會(huì)獲得較小的獎(jiǎng)勵(lì)(1.5)。
format_reward:這個(gè)函數(shù)鼓勵(lì)模型遵循所需的類(lèi)似 XML 的輸出格式。它為生成文本中存在 < reasoning>、</reasoning>、<answer > 和 </answer > 標(biāo)簽提供小額獎(jiǎng)勵(lì)。
部分代碼截圖。完整代碼塊參見(jiàn) GitHub。
Part 6:從頭開(kāi)始實(shí)現(xiàn) DataParallel GRPO
這一節(jié),我們將從頭實(shí)現(xiàn) GRPO 算法的所有構(gòu)建模塊。首先,這里假設(shè)運(yùn)行代碼的機(jī)器至少有 2 臺(tái) GPU。為此,這里要使用 PyTorch 的 DataParallel API 來(lái)將策略模型放在多個(gè) GPU 核心上,每個(gè) GPU 核心都有該模型的一個(gè)副本。然后將批量數(shù)據(jù)分散在這些 GPU 核心上完成處理。
部分代碼截圖。完整代碼塊參見(jiàn) GitHub。
Part 7:訓(xùn)練設(shè)置和執(zhí)行
這一節(jié),我們將所有組件組合在一起,完成設(shè)置并開(kāi)始訓(xùn)練。
首先,加載預(yù)訓(xùn)練的模型和 tokenizer,準(zhǔn)備評(píng)估數(shù)據(jù),然后使用上面從頭實(shí)現(xiàn)的 train_with_grpo 進(jìn)行強(qiáng)化學(xué)習(xí)微調(diào)。
關(guān)鍵步驟包括:
- 模型和 tokenizer 初始化:使用優(yōu)化設(shè)置(使用 torch.bfloat16 和 FlashAttention2)加載模型 Qwen/Qwen2.5-1.5B-Instruct。tokenizer 也要加載,其填充 token 設(shè)置為序列末尾 token。使用 torch.bfloat16 加載模型會(huì)將其參數(shù)轉(zhuǎn)換為每個(gè)數(shù)值使用 16 位而不是 32 位的形式,這可將模型的內(nèi)存使用量減少一半,并且可加快在現(xiàn)代 GPU 上的訓(xùn)練速度。
- 初步評(píng)估:在微調(diào)之前,根據(jù)幾個(gè)示例對(duì)模型進(jìn)行評(píng)估,以確定基準(zhǔn)性能。
- 強(qiáng)化學(xué)習(xí)微調(diào):為從頭開(kāi)始實(shí)現(xiàn) GRPO 的訓(xùn)練函數(shù) train_with_grpo 配置適當(dāng)?shù)挠?xùn)練參數(shù)和獎(jiǎng)勵(lì)函數(shù)。然后,在剩余的訓(xùn)練數(shù)據(jù)上執(zhí)行強(qiáng)化學(xué)習(xí)訓(xùn)練。
- 最終評(píng)估和模型保存:強(qiáng)化學(xué)習(xí)微調(diào)后,再次評(píng)估模型,并保存最終模型。
下面的代碼會(huì)執(zhí)行以下功能:
- 確定設(shè)備(如果有 GPU 就用 GPU,否則就用 CPU)。
- 加載預(yù)訓(xùn)練版 Qwen2.5-1.5B-Instruct 模型和 tokenizer。tokenizer 的 pad token 設(shè)置為 eos_token。
- 保留一小部分?jǐn)?shù)據(jù)集用于評(píng)估,以提供基線。
- 通過(guò)啟用梯度檢查點(diǎn)和禁用 KV 緩存,優(yōu)化模型的內(nèi)存效率。
- 步驟 1:在微調(diào)之前評(píng)估模型,以建立基線準(zhǔn)確性。
- 步驟 2:使用 train_with_grpo 函數(shù)和我們定義的獎(jiǎng)勵(lì)函數(shù)(format_reward 和 correctness_reward,合并為 combined_reward)執(zhí)行強(qiáng)化學(xué)習(xí)微調(diào)。這里使用了多臺(tái) GPU 訓(xùn)練模型。
- 步驟 3:將最終的微調(diào)模型和 tokenizer 保存到磁盤(pán)。
GRPO 訓(xùn)練流程使用的超參數(shù)如下。
訓(xùn)練配置
以下參數(shù)設(shè)定了使用上面的 GRPO 算法實(shí)現(xiàn)強(qiáng)化學(xué)習(xí)微調(diào)運(yùn)行的配置:
- num_iteratinotallow=1:從當(dāng)前策略模型創(chuàng)建新參考模型的外部迭代次數(shù)。一次迭代是指在整個(gè)數(shù)據(jù)集上執(zhí)行一次通過(guò)。
- num_steps=500:訓(xùn)練循環(huán)將執(zhí)行最多 500 個(gè)步驟,每個(gè)步驟處理一批樣本。
- batch_size=7:在 8 臺(tái) GPU 的情況下,每個(gè)步驟每批處理 7 個(gè)樣本,每臺(tái) GPU 上放置 1 個(gè)樣本。使用一個(gè) GPU (0) 被 DataParallel 用作主節(jié)點(diǎn)來(lái)聚合梯度并收集輸出。
- num_generatinotallow=14:對(duì)于訓(xùn)練數(shù)據(jù)中的每個(gè)提示詞,訓(xùn)練器將生成 14 個(gè)不同的完成結(jié)果。這些生成結(jié)果將被用于計(jì)算指導(dǎo)強(qiáng)化學(xué)習(xí)更新的相對(duì)優(yōu)勢(shì)(或獎(jiǎng)勵(lì)信號(hào))。如果你的 GPU 的 VRAM 較少,請(qǐng)減少此數(shù)字。
- max_completion_length=400:在生成完成結(jié)果(序列的 response 部分)時(shí),生成上限為 400 個(gè) token。這限制了模型在 RL 階段生成的輸出的長(zhǎng)度。如果你的 GPU 的 VRAM 較少,請(qǐng)減少此數(shù)字。
- beta=0.04:GRPO 損失函數(shù)中 KL 散度懲罰的系數(shù)。這控制的是模型與參考模型的偏差程度。
- learning_rate=5e-6:RL 微調(diào)的學(xué)習(xí)率。為了實(shí)現(xiàn)穩(wěn)定的策略更新,這里使用了相對(duì)較低的學(xué)習(xí)率。
- mu=1:對(duì)每批 rollout 數(shù)據(jù)執(zhí)行的策略更新次數(shù)。在這里,我們每批只執(zhí)行一次更新。
- epsilnotallow=0.1:GRPO 的 PPO 組件的 clipping 參數(shù)。這可以防止策略在單次更新中發(fā)生太大的變化。
在微調(diào)之前和之后都會(huì)對(duì)模型進(jìn)行評(píng)估,以衡量準(zhǔn)確率的提高情況。最后,將微調(diào)后的模型保存到 grpo_finetuned_model 目錄中。
部分代碼截圖。完整代碼塊參見(jiàn) GitHub。
教程中還給出了詳細(xì)的執(zhí)行情況,可作參考。
下面我們也簡(jiǎn)單看看其訓(xùn)練過(guò)程。
首先,初始配置后,我們可以看到運(yùn)行 GRPO 之前的準(zhǔn)確度為 23.33%。
然后經(jīng)過(guò) 500 步的 1 輪 GRPO 迭代,下圖展示了相關(guān)的訓(xùn)練動(dòng)態(tài):
訓(xùn)練完成后,自然還需要對(duì)模型進(jìn)行新一輪的評(píng)估。這里采用了 30 個(gè)評(píng)估樣本來(lái)進(jìn)行評(píng)估,以下展示了其中一個(gè)模型回答正確的示例:
整體表現(xiàn)如何呢?可以看到,經(jīng)過(guò)一輪 GRPO 之后,Qwen-2.5-1.5B-Instruct 模型答對(duì)了 30 問(wèn)題中的 27 題,實(shí)現(xiàn)了 90% 的準(zhǔn)確度。相較于 GRPO 之前的 23.33%,可說(shuō)是實(shí)現(xiàn)了性能飛躍。
上面兩張圖展示了模型的學(xué)習(xí)過(guò)程動(dòng)態(tài),可以看到:平均獎(jiǎng)勵(lì)在 2.25 左右就趨于穩(wěn)定了(理論最大值為 0.8 + 2.0 = 2.8)。相比于另一處微調(diào)的 Qwen-2.5-0.5B-Instruct(獲得的平均獎(jiǎng)勵(lì)為 1.4),這個(gè)數(shù)字相當(dāng)高了,參閱:https://github.com/aburkov/theLMbook/blob/main/GRPO_Qwen_0_5_Instruct.ipynb
如果使用更大的模型并允許更長(zhǎng)的生成時(shí)間,模型正確解答問(wèn)題的能力還將進(jìn)一步提升。但是,如果要訓(xùn)練更大的模型,不僅需要將數(shù)據(jù)分布在多臺(tái) GPU 上,還需要將模型分開(kāi)放在多臺(tái) GPU 上,這需要用到 DeepSpeed 或 FSDP(完全分片數(shù)據(jù)并行)等模型并行工具。
下面加載和測(cè)試已經(jīng)微調(diào)的模型:
完整代碼見(jiàn)原筆記本
加載完成后測(cè)試一下,首先問(wèn)問(wèn) 1+1 等于幾:
可以看到,模型反復(fù)思考了很多次,終于認(rèn)定確實(shí)等于 2。
多次測(cè)試后還可以發(fā)現(xiàn),該模型沒(méi)有學(xué)會(huì)生成序列結(jié)束(EOS)token,因此即使在 </answer> token 之后,輸出序列仍會(huì)繼續(xù)。這是預(yù)期的行為,因?yàn)槲覀兪褂玫莫?jiǎng)勵(lì)函數(shù)中沒(méi)有包含一個(gè)用于停止生成的獎(jiǎng)勵(lì)。我們也沒(méi)有執(zhí)行監(jiān)督微調(diào)步驟 —— 該步驟可以讓模型學(xué)會(huì)在 </answer> 之后立即生成 EOS。
你對(duì)這篇代碼密集的教程怎么看?有沒(méi)有讓你產(chǎn)生在自己的電腦上實(shí)現(xiàn) GRPO 的想法?