在GSM8K上比GRPO快8倍!廈大提出CPPO,讓強(qiáng)化學(xué)習(xí)快如閃電
DeepSeek-R1 的成功離不開一種強(qiáng)化學(xué)習(xí)算法:GRPO(組相對(duì)策略優(yōu)化)。
不同于 PPO(近端策略優(yōu)化),GRPO 是直接根據(jù)組分?jǐn)?shù)估計(jì)基線,因此消除了對(duì) critic 模型的需求。但是,這又需要為每個(gè)問題都采樣一組完成結(jié)果,進(jìn)而讓訓(xùn)練過程的計(jì)算成本較高。
之后,GRPO 會(huì)使用一個(gè)基于規(guī)則的獎(jiǎng)勵(lì)函數(shù)來計(jì)算每個(gè)完成結(jié)果的獎(jiǎng)勵(lì),并計(jì)算每個(gè)完成結(jié)果的相對(duì)優(yōu)勢(shì)。
為了保證訓(xùn)練的穩(wěn)定性,GRPO 還會(huì)計(jì)算一組完成結(jié)果的策略模型、參考模型和舊策略模型的預(yù)測(cè)概率之比作為策略目標(biāo)函數(shù)的一部分,這又會(huì)進(jìn)一步提升強(qiáng)化學(xué)習(xí)的訓(xùn)練開銷。GRPO 巨大的訓(xùn)練開銷限制了其訓(xùn)練效率和可擴(kuò)展性。而在實(shí)踐中,提高訓(xùn)練效率是非常重要的。
總結(jié)起來,GRPO 訓(xùn)練的計(jì)算成本主要源自其核心設(shè)計(jì):為了進(jìn)行組內(nèi)比較,會(huì)為每個(gè)提示詞生成一大組完成結(jié)果。此外,GRPO 的前向計(jì)算會(huì)以完成數(shù)量的 3 倍的尺度擴(kuò)展。
那么,問題來了:在這個(gè)強(qiáng)化學(xué)習(xí)過程中,每個(gè)完成結(jié)果的貢獻(xiàn)都一樣嗎?
近日,廈門大學(xué)紀(jì)榮嶸團(tuán)隊(duì)研究發(fā)現(xiàn),每個(gè)完成結(jié)果的貢獻(xiàn)與其相對(duì)優(yōu)勢(shì)有關(guān)。也就是說,每個(gè)完成結(jié)果對(duì)策略模型訓(xùn)練的貢獻(xiàn)并不相等。如圖 1 所示,完成結(jié)果的數(shù)量增大時(shí),準(zhǔn)確度提升并不非常顯著,但訓(xùn)練時(shí)間卻會(huì)迅速增長(zhǎng)。
基于這一見解,他們發(fā)現(xiàn)可以通過對(duì)完成結(jié)果進(jìn)行剪枝來加速 GRPO。然后,他們提出了一種加速版的 GRPO:CPPO(Completion Pruning Policy Optimization / 完成剪枝策略優(yōu)化)。并且他們也已經(jīng)開源發(fā)布了該算法的代碼。
- 論文標(biāo)題:CPPO: Accelerating the Training of Group Relative Policy Optimization-Based Reasoning Models
- 論文地址:https://arxiv.org/pdf/2503.22342
- 項(xiàng)目地址:https://github.com/lzhxmu/CPPO
顧名思義,CPPO 會(huì)根據(jù)優(yōu)勢(shì)對(duì)完成結(jié)果進(jìn)行剪枝,這樣一來就可以提升強(qiáng)化學(xué)習(xí)過程的速度。
具體來說,一開始,策略模型會(huì)針對(duì)每個(gè)問題采樣一組完成結(jié)果。隨后,通過獎(jiǎng)勵(lì)函數(shù)計(jì)算每個(gè)完成結(jié)果的相對(duì)優(yōu)勢(shì)。然后,CPPO 會(huì)修剪掉絕對(duì)優(yōu)勢(shì)值較低的完成結(jié)果,僅保留絕對(duì)優(yōu)勢(shì)較高的完成結(jié)果來計(jì)算損失。此過程可大大減少訓(xùn)練所需的完成結(jié)果數(shù)量,從而加快訓(xùn)練過程。
此外,他們還觀察到,由于完成剪枝會(huì)導(dǎo)致 GPU 資源利用率不足,從而導(dǎo)致資源浪費(fèi)。為了解決這個(gè)問題,他們引入了一種動(dòng)態(tài)完成結(jié)果分配策略。該策略會(huì)用新問題的完成結(jié)果填充每個(gè)設(shè)備,從而充分利用 GPU 資源并進(jìn)一步提高訓(xùn)練效率。
實(shí)驗(yàn)證明,他們的方法是有效的。當(dāng)使用 Qwen-2.5 系列模型時(shí)(包括 Qwen-2.5-1.5B-Instruct 和 Qwen-2.5-7B-Instruct),在保證了準(zhǔn)確度相當(dāng)?shù)幕A(chǔ)上,CPPO 在 GSM8K 基準(zhǔn)上的速度比 GRPO 快 8.32 倍,在 MATH 基準(zhǔn)上快 3.51 倍。
或者用網(wǎng)友的話來說,快如閃電!
CPPO:完成剪枝策略優(yōu)化
要了解 CPPO,首先必須知道 GRPO,其公式如下:
其中,q 是從數(shù)據(jù)集分布 P (Q) 中采樣的問題,{o_1, o_2, ... , o_G} 是 G 個(gè)完成結(jié)果,π_θ 是策略模型,π_θ_old 是舊策略模型,π_θ_ref 是參考模型,? 和 β 是超參數(shù),A_i 是使用一組獎(jiǎng)勵(lì) {r_1, r_2, ... , r_G} 計(jì)算的優(yōu)勢(shì)。
相比于 GRPO,CPPO 引入了一個(gè)選擇性條件,該條件僅會(huì)包括表現(xiàn)出足夠高優(yōu)勢(shì)的完成結(jié)果。CPPO 的目標(biāo)公式如下:
其中 γ 是一個(gè)預(yù)定義的閾值,用于確保在梯度更新中僅保留絕對(duì)優(yōu)勢(shì)高于 γ 的完成結(jié)果。需要注意的是,當(dāng),或者
時(shí),clip 函數(shù)會(huì)被激活。
圖 2 展示了 CPPO 的概況:
統(tǒng)一單/多 GPU 設(shè)置
在多 GPU 訓(xùn)練場(chǎng)景中,該團(tuán)隊(duì)觀察到具有顯著優(yōu)勢(shì)的完成結(jié)果的數(shù)量因設(shè)備而異。在這種情況下,整體訓(xùn)練效率會(huì)有設(shè)備處理最多完成結(jié)果數(shù)量的瓶頸 —— 這種現(xiàn)象稱為「木桶效應(yīng)(bucket effect)」。為了緩解這種情況,對(duì)于每臺(tái) GPU,該團(tuán)隊(duì)的選擇是只保留每個(gè)問題具有最大絕對(duì)優(yōu)勢(shì)的 k 個(gè)完成結(jié)果,其中
其中 P ∈ (0, 1] 表示剪枝率。在此策略下修改后的 CPPO 為:
其中僅在具有最高絕對(duì)優(yōu)勢(shì)值的 k 個(gè)完成結(jié)果對(duì)應(yīng)的索引集 I 上進(jìn)行求和,即
CPPO 算法的流程如下:
- 舊策略模型為每個(gè)問題采樣一組完成結(jié)果;
- 獎(jiǎng)勵(lì)函數(shù)計(jì)算每個(gè)完成結(jié)果的獎(jiǎng)勵(lì);
- 計(jì)算每個(gè)完成結(jié)果的相對(duì)優(yōu)勢(shì);
- CPPO 保留 k 個(gè)具有最高絕對(duì)優(yōu)勢(shì)的完成結(jié)果;
- 根據(jù)選定的完成結(jié)果更新策略模型。
CPPO 和 GRPO 之間的關(guān)鍵區(qū)別是:CPPO 不會(huì)將所有完成結(jié)果用于策略模型、參考模型和舊策略模型的前向計(jì)算。相反,通過僅保留具有高絕對(duì)優(yōu)勢(shì)的完成結(jié)果進(jìn)行梯度更新,CPPO 可顯著降低前向傳遞期間的計(jì)算開銷,從而加速了訓(xùn)練過程。
通過動(dòng)態(tài)完成結(jié)果分配進(jìn)行并行處理
該團(tuán)隊(duì)還提出了一種新的動(dòng)態(tài)完成結(jié)果分配策略,以進(jìn)一步優(yōu)化 CPPO 的訓(xùn)練效率。
由于 GPU 內(nèi)存限制,傳統(tǒng)方法(如 GRPO 采用的方法)面臨固有的局限性。具體而言,單臺(tái)設(shè)備每批最多可以處理 B 個(gè)問題,每個(gè)問題生成 G 個(gè)候選完成結(jié)果。剪枝操作之后,每臺(tái)設(shè)備保留的完成結(jié)果總數(shù)減少到 B × k,進(jìn)而導(dǎo)致 GPU 利用率不理想,并行計(jì)算能力未得到充分利用。
為了解決這種低效率問題,該團(tuán)隊(duì)的方法是將來自其他問題的剪枝后的完成結(jié)果動(dòng)態(tài)分配到設(shè)備的處理管道中,如圖 3 所示。
此策略通過不斷用來自原始問題和新引入問題的高質(zhì)量完成結(jié)果填充其內(nèi)存,確保每個(gè)設(shè)備都能以滿負(fù)荷運(yùn)行。至關(guān)重要的是,所有新合并的完成結(jié)果都經(jīng)過相同的嚴(yán)格剪枝過程,以保持一致性和相關(guān)性。
這種方法的好處有兩個(gè):
- 通過充分利用設(shè)備的并行計(jì)算潛力,它能最大化 GPU 利用率。
- 它能使每臺(tái)設(shè)備每批處理更多的問題,從而減少實(shí)現(xiàn)收斂所需的總訓(xùn)練步驟數(shù)。
有這兩大優(yōu)勢(shì),CPPO 便可在保證訓(xùn)練質(zhì)量的同時(shí)提高訓(xùn)練效率。
CPPO 的實(shí)驗(yàn)效果
使用 Qwen2.5-1.5B-Instruct 和 Qwen2.5-7B-Instruct 模型,該團(tuán)隊(duì)在 GSM8K 和 MATH 數(shù)據(jù)集上對(duì) CPPO 進(jìn)行了實(shí)驗(yàn)評(píng)估。此外,為了評(píng)估模型的分布外推理能力,他們還引入了 AMC2023 和 AIME2024 作為測(cè)試基準(zhǔn)。
在 GSM8K 上的結(jié)果如表 1 所示,CPPO 在準(zhǔn)確度和加速比上都明顯優(yōu)于 GRPO。值得注意的是,CPPO 在各種剪枝率下都達(dá)到了與 GRPO 相當(dāng)甚至更高的準(zhǔn)確度。在 87.50% 的剪枝率下,CPPO 的準(zhǔn)確度達(dá)到 80.41%,比 GRPO 的 77.05% 高出 3.36%。
在效率方面,CPPO 大大加快了訓(xùn)練速度。在 93.75% 的剪枝率下,其加速比達(dá)到 8.32 倍。這些結(jié)果表明,CPPO 不僅能保持或提高準(zhǔn)確度,還可顯著提高訓(xùn)練效率。因此,CPPO 有潛力成為大規(guī)模推理模型訓(xùn)練的實(shí)用有效解決方案。
在 MATH 上的表現(xiàn)見表 2。可以看到,CPPO 可以很好地?cái)U(kuò)展到更大的模型 —— 在不犧牲準(zhǔn)確度的情況下在 MATH 上實(shí)現(xiàn)了高達(dá) 3.51 倍的加速。例如,在 87.5% 的修剪率下,CPPO 保持了與 GRPO (75.20%) 相當(dāng)?shù)臏?zhǔn)確度,同時(shí)還將訓(xùn)練時(shí)間減少了 3.51 倍。
此外,在 AMC2023 和 AIME2024 基準(zhǔn)上的評(píng)估表明,盡管 CPPO 僅在高絕對(duì)優(yōu)勢(shì)完成結(jié)果上進(jìn)行訓(xùn)練,但它仍保留了模型在分布外任務(wù)上的泛化能力。因此,CPPO 不僅在增強(qiáng)推理能力方面匹敵甚至超越了 GRPO,而且還很好地減少了訓(xùn)練時(shí)間,使其成為一種更有效的替代方案。
該團(tuán)隊(duì)也研究了 CPPO 的穩(wěn)定性和收斂性。圖 4 展示了在 GSM8K 和 MATH 數(shù)據(jù)集上訓(xùn)練時(shí)的獎(jiǎng)勵(lì)曲線。
總體而言,獎(jiǎng)勵(lì)曲線證明 CPPO 在提高收斂速度的同時(shí)可保證 GRPO 的訓(xùn)練穩(wěn)定性:CPPO 的獎(jiǎng)勵(lì)曲線不會(huì)崩潰或出現(xiàn)劇烈波動(dòng),這對(duì)于穩(wěn)定訓(xùn)練至關(guān)重要。這些結(jié)果表明 CPPO 具有穩(wěn)健而穩(wěn)定的訓(xùn)練穩(wěn)定性。此外,CPPO 的獎(jiǎng)勵(lì)曲線顯示出了明顯的上升趨勢(shì),能比 GRPO 更快地達(dá)到更高的獎(jiǎng)勵(lì)值。獎(jiǎng)勵(lì)值的更快增長(zhǎng)表明 CPPO 的收斂速度更快。
你有興趣在自己的強(qiáng)化學(xué)習(xí)訓(xùn)練流程中嘗試這種更快的 CPPO 嗎?