200多行代碼,超低成本復(fù)現(xiàn)DeepSeek R1「Aha Moment」!復(fù)旦大學(xué)開源
本文是復(fù)旦大學(xué)知識工場實驗室肖仰華教授、梁家卿青年副研究員科研團(tuán)隊的最新研究成果,他們用簡潔的代碼高效復(fù)現(xiàn)了 R1-zero 的自發(fā)反思能力。
在關(guān)于 DeepSeek 的文章中,我們會多次聽到「Aha Moment」這個詞。它指的是模型在訓(xùn)練過程中經(jīng)歷的一種頓悟時刻,表現(xiàn)為模型突然展現(xiàn)出類似人類的自我反思和策略調(diào)整能力。
DeepSeek 論文中提到的 Aha Moment。
DeepSeek-R1-zero 經(jīng)過強(qiáng)化學(xué)習(xí)實現(xiàn)了大模型頓悟時刻的自發(fā)涌現(xiàn),引發(fā)了大量對其方案的解讀與復(fù)現(xiàn)工作。
其中,基于 GRPO( Group Relative Policy Optimization)強(qiáng)化學(xué)習(xí)方案尤其受到關(guān)注。業(yè)界先后開源了多個基于 GRPO 算法的 R1-zero 復(fù)現(xiàn)項目。然而,這些復(fù)現(xiàn)項目嚴(yán)重依賴一些復(fù)雜代碼框架,有著較高的代碼實現(xiàn)復(fù)雜度,對部署環(huán)境存在較高依賴,資源利用率不高,代碼可讀性與可維護(hù)性仍存在改進(jìn)空間。
對此,復(fù)旦大學(xué)知識工場實驗室肖仰華教授、梁家卿青年副研究員科研團(tuán)隊基于 GRPO 算法思想高效復(fù)現(xiàn)了 R1-zero 自發(fā)反思能力。目前,該項目(Simple-GRPO)的第一版代碼實現(xiàn)已經(jīng)開源并提交 Github。
代碼地址:https://github.com/lsdefine/simple_GRPO。
該項目相對于現(xiàn)有開源的 R1-zero 復(fù)現(xiàn)具有以下優(yōu)勢:
代碼簡潔,依賴簡單,只需要 200 多行;
資源消耗低,通過模型解耦與分離進(jìn)一步降低算力需求,該項目支持在一張 A800 (80G) 加一張 3090 (24G) 完成 7B 模型的訓(xùn)練。根據(jù) AutoDL 平臺計費(fèi)標(biāo)準(zhǔn),一張 A800 (80G) 5.98 元 / 時,一張 3090 (24G) 1.32 元 / 時。以項目作者經(jīng)驗,模型在這樣的算力平臺下,訓(xùn)練 1h 模型就能出現(xiàn) aha moment,折合人民幣 7.3 元,單次實驗成本壓縮至奶茶價格區(qū)間。
項目介紹
本項目代碼簡單,GRPO 算法實現(xiàn)僅有 200 多行代碼,且僅依賴基礎(chǔ)的深度學(xué)習(xí)代碼庫,如 deepspeed 和 torch,而無需 ray 等復(fù)雜框架。具體實現(xiàn)細(xì)節(jié)如下:
① 參考模型分離:
在實現(xiàn)過程中,參考模型(reference model)被解耦,允許其在不同的 GPU 上運(yùn)行(甚至可以運(yùn)行在 3090 顯卡上)。這避免了將參考模型和訓(xùn)練模型放在同一個 GPU 上,防止 torch 的多進(jìn)程機(jī)制創(chuàng)建多個副本,避免顯存浪費(fèi),并使得在 A800(80G)上訓(xùn)練 7B 模型成為可能。
② 核心損失計算:
損失計算公式基于 Hugging Face 的 trl 實現(xiàn)。
③ 訓(xùn)練環(huán)境與過程
項目在 1 張 A800(80G)顯卡上用 Zero-Stage 2 做優(yōu)化,使用另一張 A800(80G)顯卡進(jìn)行參考模型的推理,參考模型分離使得 GRPO 的訓(xùn)練更加高效。
在以上訓(xùn)練環(huán)境中,Qwen2.5-3B 訓(xùn)練 60 步需要 12 分 34 秒,Qwen2.5-7B 訓(xùn)練 60 步需要 16 分 40 秒。其中在前 30 步優(yōu)化中,Qwen2.5-7B 和 Qwen2.5-3B 的輸出內(nèi)容中,均出現(xiàn)了「頓悟時刻」現(xiàn)象。示例如下:
- Qwen2.5-3B(步驟 20)
其中文翻譯如下:
「<think> ... 要找出卡姆登畫的雞蛋比阿諾德多多少,我們從卡姆登畫的雞蛋數(shù)量中減去阿諾德畫的雞蛋數(shù)量。所以,21-28 = -7。然而,這個結(jié)果在問題的背景下沒有意義,因為卡姆登畫的雞蛋不可能比阿諾德少。讓我們重新考慮最初的解決方案步驟:我們應(yīng)該驗證關(guān)于卡姆登和莎拉的雞蛋之間關(guān)系的初始假設(shè)是否正確 。... </think> <answer>-7</answer>」
- Qwen2.5-7B(步驟 20)
其中文翻譯如下:
「<think> ... 因此,Joanne 從普通樹上摘了 350 個蘋果。但是,這似乎不對,因為總數(shù)應(yīng)該是 500 個,而我們已經(jīng)計算了 150 個,只剩下 350 個普通樹上的蘋果,這與總數(shù)相矛盾。讓我們重新評估一下 。... </think> <answer>350</answer>」
實驗結(jié)果
使用 Qwen2.5-3B 和 Qwen2.5-7B 作為基礎(chǔ)模型,測試了模型訓(xùn)練過程中正確率(左圖)和格式遵循能力(右圖)的變化情況,比較符合預(yù)期。
- Qwen2.5-3B:
在 GSM8K 和 Math 混合數(shù)據(jù)集進(jìn)行訓(xùn)練,從上圖可以看出,Qwen2.5-3B 的準(zhǔn)確率在經(jīng)歷 5 步的優(yōu)化后能穩(wěn)定在 60% 以上,最高能達(dá)到 70% 左右;格式遵循能力在 30 步以后接近 100%.
- Qwen2.5-7B
在 GSM8K 數(shù)據(jù)集上進(jìn)行訓(xùn)練,從上圖可以看出,Qwen2.5-7B 的無論是準(zhǔn)確率還是格式遵循能力都能在三十步以內(nèi)快速收斂,準(zhǔn)確率(左圖)始終保持在 90% 以上,格式遵循能力(右圖)到達(dá) 100%.
改進(jìn)方向
近期本項目將進(jìn)一步推出以下方向的優(yōu)化版本,敬請關(guān)注。
組內(nèi)答案同質(zhì)性問題
根據(jù) GRPO 算法中的分組策略,當(dāng)組內(nèi)答案全部正確或全為錯誤時,獎勵函數(shù)無法有效分配差異化獎勵,強(qiáng)化學(xué)習(xí)將缺乏對比性的訓(xùn)練信號,導(dǎo)致模型難以收斂。后續(xù)將在訓(xùn)練過程中實時監(jiān)控答案分布,對同質(zhì)化的答案進(jìn)行重新采樣和分組,以提供有效的對比信號。
長思維鏈(CoT)顯存占用問題
當(dāng)模型生成較長的思維鏈(CoT)時,由于文本序列長度較長,顯存占用會顯著增加。對此,后續(xù)考慮拆分組別,減小批次大小,或?qū)﹂L序列分階段處理,以減小訓(xùn)練過程中的 GPU 內(nèi)存開銷,提升訓(xùn)練效率。