基于 DeepSeek GRPO 的 1.5B Rust 代碼生成模型訓(xùn)練實(shí)戰(zhàn)
群組相對策略優(yōu)化(Group Relative Policy Optimization,GRPO)已被證明是一種有效的算法,可用于訓(xùn)練大語言模型(LLMs),使其具備推理能力并在基準(zhǔn)測試中持續(xù)提升性能表現(xiàn)。DeepSeek-R1 展示了如何通過監(jiān)督式微調(diào)(Supervised Fine-Tuning)與 GRPO 技術(shù)的結(jié)合,引導(dǎo)模型達(dá)到與 OpenAI 的 o1 等頂尖模型相競爭的水平。
為了進(jìn)一步探索其實(shí)踐應(yīng)用,我們嘗試將這些技術(shù)應(yīng)用于現(xiàn)實(shí)場景中。本文將概述如何利用 GRPO、自定義數(shù)據(jù)集和獎勵函數(shù)來訓(xùn)練我們自己的定制化小型 LLM。下圖是后續(xù)將展示的部分模型訓(xùn)練過程曲線預(yù)覽。觀察模型逐步學(xué)習(xí)生成代碼塊、優(yōu)化可編譯的有效代碼,最終生成能通過單元測試的代碼,這一過程頗具趣味性。
image.png
若希望直接上手實(shí)踐,可訪問該 GitHub 倉庫[1]。
1.為什么選擇 Rust?
Rust 似乎是強(qiáng)化學(xué)習(xí)(RL)的絕佳舞臺,因?yàn)榭梢允褂?Rust 編譯器和 Cargo 工具鏈。Rust 編譯器不僅能夠提供清晰的錯誤信息,還具備嚴(yán)格的類型檢查。
在本項(xiàng)目中,我們首先想要驗(yàn)證的假設(shè)是:能否將 Cargo 作為反饋機(jī)制來教導(dǎo)模型成為更優(yōu)秀的“程序員”。第二個實(shí)驗(yàn)旨在探索用 GRPO 訓(xùn)練,語言模型的最小可行規(guī)模是多少。這些實(shí)驗(yàn)特意限制在單節(jié)點(diǎn) H100 上運(yùn)行,既控制了成本又展示了實(shí)驗(yàn)的可實(shí)現(xiàn)性。
而且作為 Oxen.ai 的 Rust 開發(fā)團(tuán)隊(duì),我們也正在探索一些有趣的跨界應(yīng)用 ?? x ??。
2.為何參數(shù)量選擇 1.5B?
最近,有大量研究探索了小型語言模型在特定任務(wù)上的性能邊界。當(dāng)存在明確的反饋機(jī)制(如數(shù)學(xué)題的正確答案或程序的輸出結(jié)果)時,模型規(guī)??稍诒3指吒偁幜Φ耐瑫r大大縮小。
微軟的 rStar-Math[2] 論文表明,在可驗(yàn)證的數(shù)學(xué)問題領(lǐng)域,模型可以進(jìn)行推理。其 1.5B 參數(shù)量的模型性能超越了 GPT-4o 和 o1-preview。
image.png
我的假設(shè)是:在編碼任務(wù)中我們同樣能實(shí)現(xiàn)類似水平的性能,因?yàn)樵擃I(lǐng)域同樣存在類似的可驗(yàn)證獎勵機(jī)制 —— 代碼能否通過編譯?能否通過單元測試?
3.小型語言模型的優(yōu)勢
小型代碼模型具備諸多優(yōu)勢,包括成本效益、處理效率、數(shù)據(jù)隱私性,以及針對自有代碼庫/編碼規(guī)范進(jìn)行定制的能力。此外,這本身也是項(xiàng)有趣的挑戰(zhàn)。
我們的終極愿景是,讓這類小型模型能完整執(zhí)行所有 cursor 能夠完成的任務(wù):預(yù)測下一段代碼、補(bǔ)全中間的代碼內(nèi)容,并在智能體的優(yōu)化迭代循環(huán)中持續(xù)優(yōu)化代碼。不過,還是讓我們先從最基礎(chǔ)開始吧。
4.明確定義待解決問題
要讓生成的代碼能夠通過單元測試,有多種不同的方法,我們最終嘗試了幾種方案。其中一種看似簡單的方法是,提供一組可驗(yàn)證的單元測試,要求生成的代碼必須通過這些測試。這將形成一套黃金標(biāo)準(zhǔn)的可驗(yàn)證答案集。
image.png
在嘗試了該流程后,我們發(fā)現(xiàn)了兩個主要問題。首先,若禁止模型在編寫代碼時查看單元測試,它就無法感知需要遵循的接口規(guī)范。在使用預(yù)構(gòu)建的、已驗(yàn)證的單元測試進(jìn)行評估時,許多錯誤最終表現(xiàn)為代碼與單元測試之間的類型不匹配或命名不一致。
image.png
其次,若允許模型在編碼時查看單元測試,就會犧牲開發(fā)者體驗(yàn)。除非您是一名嚴(yán)格的"測試驅(qū)動開發(fā)"實(shí)踐者(Test Driven Developer),否則您可能更希望直接輸入提示詞,而非預(yù)先設(shè)計函數(shù)定義或單元測試。
我們沒有試圖想出更聰明的辦法,而是最終選擇了優(yōu)化方案,使其更加簡潔,將這個問題重構(gòu)為要求模型在單次響應(yīng)中同時生成代碼和單元測試。
image.png
不過,單次生成方案存在模型"作弊"的風(fēng)險:例如通過 println! 輸出而非 assert 斷言,生成可通過編譯但無實(shí)際功能的代碼。我們將在后續(xù)環(huán)節(jié)為此添加防護(hù)機(jī)制。
最終,我們通過一個詳細(xì)的系統(tǒng)提示詞(system prompt),為模型提供任務(wù)指引。
image.png
系統(tǒng)提示詞(system prompt)以特定格式和風(fēng)格為模型提供上下文,明確模型響應(yīng)用戶查詢的預(yù)期形式。
5.數(shù)據(jù)集
在開始訓(xùn)練之前,我們需要一個數(shù)據(jù)集。最開始時,我們發(fā)現(xiàn)專門針對 Rust 語言的數(shù)據(jù)集較少,現(xiàn)有的大語言模型(LLM)基準(zhǔn)測試大多聚焦于 Python。因此,我們首先將包含 Python 編程問題的提示詞數(shù)據(jù)集轉(zhuǎn)換為 Rust 語言版本。
我們從 Ace-Code-87k 數(shù)據(jù)集中隨機(jī)選取了 20,000 個提示詞樣本,使用 Qwen 2.5 Coder 32B Instruct 模型生成對應(yīng)的 Rust 代碼及單元測試。隨后通過編譯器和測試框架運(yùn)行這些代碼,篩選出所有能通過單元測試的提示詞、代碼、單元測試三元組(prompt,code,unit_test triples),最終獲得 16,500 組有效數(shù)據(jù)。該數(shù)據(jù)集被劃分為 15,000 組訓(xùn)練數(shù)據(jù)、1,000 組測試數(shù)據(jù)和 500 組評估數(shù)據(jù)。
最終數(shù)據(jù)集[3]如下所示:
您可以通過以下模型運(yùn)行記錄跟蹤整個流程:
1)將代碼翻譯為 Rust 語言:https://www.oxen.ai/ox/mbrp-playground/evaluations/ce45630c-d9e8-4fac-9b41-2d41692076b3
2)編寫 Rust 代碼:https://www.oxen.ai/ox/mbrp-playground/evaluations/febc562a-9bd4-4e91-88d7-a95ee676a5ed
3)編寫 Rust 單元測試:https://www.oxen.ai/ox/mbrp-playground/evaluations/b886ddd6-b501-4db8-8ed6-0b719d0ac595
有趣的是,在最終的 GRPO 訓(xùn)練方案中,我們移除了原始數(shù)據(jù)中的黃金標(biāo)準(zhǔn) Rust 代碼和單元測試列。在強(qiáng)化學(xué)習(xí)的不斷迭代學(xué)習(xí)過程中,我們僅需將提示詞作為輸入即可完成訓(xùn)練,這為未來數(shù)據(jù)的收集提供了便利。盡管我們舍棄了用于訓(xùn)練的代碼和單元測試數(shù)據(jù),但知道僅需這些提示詞就可以解決這個問題仍具有重要參考價值 —— 這一特性將在后續(xù)章節(jié)詳細(xì)解析。
6.設(shè)定 baseline
完成待解決問題的明確定義并構(gòu)建完數(shù)據(jù)集后,我們需設(shè)定 baseline 以評估初始模型性能。我們將使用 Qwen/Qwen2.5-Coder-1.5B-Instruct 模型進(jìn)行訓(xùn)練引導(dǎo)。
并統(tǒng)計以下指標(biāo):構(gòu)建通過率(譯者注:代碼編譯成功)、Clippy 檢查通過率(譯者注:通過 Rust 官方靜態(tài)分析工具 Clippy 檢測)、單元測試通過率。
image.png
由于 Clippy 檢測依賴于代碼的成功編譯,這兩項(xiàng)指標(biāo)通常呈現(xiàn)強(qiáng)相關(guān)性。您可在 Oxen.ai 平臺查閱原始數(shù)據(jù)[4],進(jìn)一步分析具體案例。
7.與 SOTA 模型進(jìn)行對比,孰強(qiáng)孰弱?
我們同時希望了解主流大型基礎(chǔ)模型在該任務(wù)上的表現(xiàn),以此為我們設(shè)定理論上的優(yōu)化目標(biāo)。經(jīng)測試發(fā)現(xiàn)表現(xiàn)最佳的模型是 GPT4.5。
GPT4.5 的構(gòu)建通過率高達(dá) 98%,他自己編寫的單元測試的通過率達(dá) 87%,表現(xiàn)令人矚目[5]。
image.png
Claude 3.7 Sonnet 以微弱差距位居次席。
image.png
至此,所有準(zhǔn)備工作就緒 —— 已完成明確定義待解決問題、數(shù)據(jù)集構(gòu)建、baseline 設(shè)定及目標(biāo)確立,終于可以進(jìn)入激動人心的模型訓(xùn)練環(huán)節(jié)!
8.獎勵函數(shù)設(shè)計
GRPO(群組相對策略優(yōu)化)的精妙之處在于,其獎勵機(jī)制可通過簡單的 Python 函數(shù)實(shí)現(xiàn)。工程師只需定義獎勵規(guī)則,模型便能自主優(yōu)化策略。摩根士丹利的 Will Brown 將這種方法稱為 Rubric Engineering,它為工程師提供了通過強(qiáng)化學(xué)習(xí)引導(dǎo)模型的便捷方法。
GRPO 獎勵函數(shù)的輸入?yún)?shù)包括 prompts(提示詞)、responses(模型響應(yīng))及 target answers(目標(biāo)答案)。實(shí)際上,部分評分標(biāo)準(zhǔn)甚至無需 target answers(目標(biāo)答案),僅需基于生成內(nèi)容本身的屬性(如格式規(guī)范)進(jìn)行評分。典型的評分維度包括:
- 正確性(直接字符串匹配)
- 響應(yīng)長度(token 數(shù)量)
- 響應(yīng)格式(XML、JSON、代碼等)
- 外部工具調(diào)用(本例中為 Cargo 工具鏈)
- 使用 LLM 進(jìn)行評判(是否真實(shí)、是否有幫助、是否有危害等)
在本項(xiàng)目中,我們將結(jié)合格式驗(yàn)證與 Cargo 構(gòu)建工具的執(zhí)行結(jié)果構(gòu)建獎勵函數(shù)。初始代碼基于 @willccbb 的優(yōu)質(zhì)代碼實(shí)現(xiàn):
https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb
9.我們的“評分標(biāo)準(zhǔn)”
以下是我們針對該問題設(shè)計的“評分標(biāo)準(zhǔn)”。流程大致為:接收用戶提示詞 → LLM 生成代碼及單元測試 → 通過多個評分函數(shù)評估模型響應(yīng)。
獎勵函數(shù)可以是一個簡單的正則表達(dá)式,驗(yàn)證代碼中是否包含有效的測試模塊。
image.png
您會注意到,輸入?yún)?shù) prompts 和 completions 都是列表,輸出的也是浮點(diǎn)數(shù)列表(list[float])。這是因?yàn)?GRPO 有一個名為 num_generations 的參數(shù),用于控制每個 prompts 的生成次數(shù)。此外,在生成過程中可能會有多組 prompts + completions。GRPO 算法會為每個 prompts 生成 N 種不同響應(yīng),每個模型響應(yīng)都會通過“評分標(biāo)準(zhǔn)”進(jìn)行驗(yàn)證。下圖展示了模型在嘗試不同迭代方案時各評分項(xiàng)的通過情況:
image.png
隨著時間的推移,模型將逐步掌握如何通過更多你定義的“評分標(biāo)準(zhǔn)”。借助 trl 庫,你可以定義多種獎勵函數(shù)并傳入 GRPOTrainer 進(jìn)行模型訓(xùn)練。
image.png
設(shè)計獎勵函數(shù)堪稱是搭建 GRPO 訓(xùn)練流程中最具成就感的環(huán)節(jié) —— 雙關(guān)語警告??。
10.Cargo 獎勵函數(shù)
正如文章開頭所述,我們將利用 Cargo 工具鏈構(gòu)建獎勵機(jī)制。由于獎勵函數(shù)可定義為純 Python 函數(shù),因此我們可直接通過 subprocess 模塊調(diào)用 Cargo 工具。
完整代碼實(shí)現(xiàn)詳見 GitHub:https://github.com/Oxen-AI/GRPO-With-Cargo-Feedback/blob/main/train.py
我們定義了一個 RustTool 類,支持執(zhí)行 cargo build、cargo clippy 或 cargo test 命令。其 run() 方法將返回包含工具執(zhí)行結(jié)果(通過/失敗狀態(tài)信息、錯誤信息)的字典。
image.png
該工具應(yīng)用于為每個測試用例臨時創(chuàng)建的 Rust 項(xiàng)目目錄。項(xiàng)目的初始化與清理流程包括:創(chuàng)建目錄→寫入 main.rs 和 Cargo.toml 文件→在其中寫入代碼與單元測試→執(zhí)行后清理。欲知更多詳情,可查閱完整代碼中的 setup_and_test_rust_project 函數(shù)。
在具備初始化和清理 Rust 小型項(xiàng)目的能力后,需將其與 GRPO 獎勵函數(shù)對接。獎勵函數(shù)接收一批提示詞(prompts)與生成內(nèi)容(completions),要求對每組 prompt+completion 進(jìn)行評分。
image.png
持久化記錄 prompts 和 completions 有助于幫助我們理解 GRPO 算法的內(nèi)部機(jī)制。每個 prompt 將生成 N 個 completions。假如我們設(shè)置 N=4,這樣模型就有 4 次生成內(nèi)容的機(jī)會。
以下方日志中的 task_8347 為例,模型嘗試了 4 次函數(shù)代碼的實(shí)現(xiàn),最終成功一次。GRPO 會對正確的解決方案進(jìn)行獎勵,從而促使模型性能持續(xù)提升。使用紅色標(biāo)記圈出來的部分是三個單元測試失敗的案例,綠色標(biāo)記的部分則為測試通過的案例:
image.png
使用相同的工具與邏輯,我們?yōu)?cargo build、test 和 clippy 分別構(gòu)建獎勵函數(shù)。同時設(shè)置驗(yàn)證機(jī)制確保代碼與單元測試不是空的,且單元測試必須包含 assert! 斷言,防止模型通過 println! 語句作弊。
所有訓(xùn)練結(jié)果均實(shí)時記錄至 Oxen.ai,以便繪制模型訓(xùn)練曲線(如文章開篇所示)。在模型訓(xùn)練過程中,我們會通過滾動平均(譯者注:Rolling Average,是一種統(tǒng)計方法,用于對時間序列數(shù)據(jù)中的短期波動進(jìn)行平滑處理。) 的方式對數(shù)據(jù)進(jìn)行平滑處理、計算,以此觀察模型在特定獎勵指標(biāo)上的改進(jìn)趨勢。
image.png
例如編譯通過率剛開始在 30-40% 之間波動,隨著訓(xùn)練逐步攀升至 70%:
image.png
單元測試通過率的提升相對滯后,且通過率的波動范圍更大:
image.png
需要觀察模型在每個獎勵維度(如代碼編譯通過率等)上的改進(jìn)情況,定期抽查模型接收的提示詞(輸入)及其生成的代碼(輸出)。為此我們編寫了一個 @experiment.log 裝飾器,在獎勵函數(shù)執(zhí)行時自動將結(jié)果寫入 jsonl 文件并提交至 Oxen.ai:
image.png
該裝飾器每次調(diào)用函數(shù)時將結(jié)果寫入指定文件,訓(xùn)練循環(huán)(training loop)中的回調(diào)函數(shù)每 N 步將數(shù)據(jù)提交至 Oxen.ai。由于數(shù)據(jù)按時間順序存儲,需翻頁至末尾查看訓(xùn)練后期的生成樣本。部分輸出示例:https://www.oxen.ai/ox/Rust/file/GRPO_82_2025-03-02_22-49-17_Qwen2.5-Coder-1.5B-Instruct/outputs/GRPO_82_2025-03-02_22-49-17_Qwen2.5-Coder-1.5B-Instruct/cargo_test_rewards.jsonl?page=497&ref=ghost.oxen.ai
11.訓(xùn)練效果評估
此前設(shè)定的 baseline —— 未經(jīng)訓(xùn)練的 Qwen/Qwen2.5-Coder-1.5B-Instruct 模型在代碼編譯通過率和單元測試通過率上分別僅為 61% 和 22%。
image.png
經(jīng)過一輪 GRPO 訓(xùn)練后,編譯通過率提升至 80%,單元測試通過率達(dá) 37% ??。通過一輪訓(xùn)練即實(shí)現(xiàn) 20% 與 15% 的絕對準(zhǔn)確率提升。
image.png
僅僅定義了幾個獎勵函數(shù)和使用較小規(guī)模的數(shù)據(jù)集,這樣的訓(xùn)練效果已經(jīng)相當(dāng)不錯了。您可通過此處[6]探索 1.5B 模型的原始輸出結(jié)果。
此次實(shí)驗(yàn)成果令人鼓舞。整個訓(xùn)練耗時約 24 小時,成本小于 100 美元(下圖是 H100 實(shí)例的每小時成本,僅供參考):
image.png
該實(shí)驗(yàn)表明:GRPO 算法對普通開發(fā)者來說也比較實(shí)用 —— 開發(fā)者可靈活定義任意獎勵函數(shù),即使在小模型上也能實(shí)現(xiàn)性能的明顯提升。
Thanks for reading!
Hope you have enjoyed and learned new things from this blog!