機(jī)器學(xué)習(xí) | 從0開(kāi)發(fā)大模型之DeepSeek的GRPO
DeepSeek-R1的發(fā)布為國(guó)產(chǎn)大模型爭(zhēng)光了(太強(qiáng)了),不過(guò) GRPO
算法源自 DeepSeekMath 7B
模型,該模型在 MATH
基準(zhǔn)測(cè)試中取得了優(yōu)異成績(jī),論文發(fā)表于2024年2月份:https://huggingface.co/papers/2402.03300,以下是該論文的摘要原文:
Mathematical reasoning poses a significant challenge for language models due to its complex and structured nature. In this paper, we introduce DeepSeekMath 7B, which continues pre-training DeepSeek-Coder-Base-v1.5 7B with 120B math-related tokens sourced from Common Crawl, together with natural language and code data. DeepSeekMath 7B has achieved an impressive score of 51.7% on the competition-level MATH benchmark without relying on external toolkits and voting techniques, approaching the performance level of Gemini-Ultra and GPT-4. Self-consistency over 64 samples from DeepSeekMath 7B achieves 60.9% on MATH. The mathematical reasoning capability of DeepSeekMath is attributed to two key factors: First, we harness the significant potential of publicly available web data through a meticulously engineered data selection pipeline. Second, we introduce Group Relative Policy Optimization (GRPO), a variant of Proximal Policy Optimization (PPO), that enhances mathematical reasoning abilities while concurrently optimizing the memory usage of PPO.
翻譯如下:
數(shù)學(xué)推理對(duì)語(yǔ)言模型構(gòu)成了重大挑戰(zhàn),因?yàn)槠鋸?fù)雜且結(jié)構(gòu)化的特性。在本文中,我們介紹了DeepSeekMath 7B,它在DeepSeek-Coder-Base-v1.5 7B的基礎(chǔ)上進(jìn)行了繼續(xù)預(yù)訓(xùn)練,使用了來(lái)自Common Crawl的120B與數(shù)學(xué)相關(guān)的標(biāo)記,以及自然語(yǔ)言和代碼數(shù)據(jù)。DeepSeekMath 7B在競(jìng)爭(zhēng)級(jí)MATH基準(zhǔn)測(cè)試中取得了51.7%的優(yōu)異成績(jī),且未依賴外部工具包和投票技術(shù),接近Gemini-Ultra和GPT-4的性能水平。DeepSeekMath 7B在64個(gè)樣本上的自一致性達(dá)到了60.9%的MATH成績(jī)。DeepSeekMath的數(shù)學(xué)推理能力歸因于兩個(gè)關(guān)鍵因素:首先,我們通過(guò)精心設(shè)計(jì)的數(shù)據(jù)選擇流程,充分利用了公開(kāi)可用的網(wǎng)絡(luò)數(shù)據(jù)的巨大潛力。其次,我們引入了群體相對(duì)策略優(yōu)化(GRPO),這是一種近端策略優(yōu)化(PPO)的變體,旨在增強(qiáng)數(shù)學(xué)推理能力,同時(shí)優(yōu)化PPO的內(nèi)存使用。
對(duì)比數(shù)據(jù)
1、什么是GRPO
GRPO
是一種在線學(xué)習(xí)算法,核心思想是通過(guò)組內(nèi)相對(duì)獎(jiǎng)勵(lì)來(lái)估計(jì)基線,從而避免使用額外的價(jià)值函數(shù)模型。通過(guò)在訓(xùn)練期間使用受訓(xùn)模型自身生成的數(shù)據(jù)來(lái)迭代改進(jìn),GRPO
旨在最大化生成補(bǔ)全的優(yōu)勢(shì),同時(shí)確保模型保持接近參考策略,下圖是論文中的算法流程圖:
GRPO
GRPO
是 PPO
(Proximal Policy Optimization,近端策略優(yōu)化,是一種強(qiáng)化學(xué)習(xí)算法,由OpenAI于2017年提出,旨在解決策略梯度方法中的訓(xùn)練不穩(wěn)定問(wèn)題) 的變體,主要區(qū)別是:
GRPO
省略 value function modelGRPO
獎(jiǎng)勵(lì)計(jì)算,改成了一個(gè) q 生成多個(gè) r,然后 reward 打分
GRPO算法流程:
- 采樣一組輸出并計(jì)算每個(gè)輸出的獎(jiǎng)勵(lì)
- 對(duì)組內(nèi)獎(jiǎng)勵(lì)進(jìn)行歸一化處理
- 使用歸一化后的獎(jiǎng)勵(lì)計(jì)算優(yōu)勢(shì)函數(shù)
- 通過(guò)最大化目標(biāo)函數(shù)更新策略模型
- 迭代訓(xùn)練,逐步優(yōu)化策略模型
論文中的偽代碼
2、獎(jiǎng)勵(lì)設(shè)計(jì)
huggingface
庫(kù)提供 GRPOTrainer
可以直接使用 GRPO
訓(xùn)練,參數(shù)包括定義獎(jiǎng)勵(lì)模型和函數(shù)。
2.1 獎(jiǎng)勵(lì)模型
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-3B-Instruct",
reward_funcs="weqweasdas/RM-Gemma-2B",
args=training_args,
train_dataset=dataset,
peft_cnotallow=LoraConfig(task_type="CAUSAL_LM"),
)
這里的 reward_funcs
參數(shù)可以傳入獎(jiǎng)勵(lì)模型。
2.2 獎(jiǎng)勵(lì)函數(shù)
GRPOTrainer
允許用戶自定義獎(jiǎng)勵(lì)函數(shù),通過(guò)定義一個(gè)返回浮點(diǎn)數(shù)列表的函數(shù)來(lái)實(shí)現(xiàn)。
- 獎(jiǎng)勵(lì)函數(shù)的輸入:completions(生成的補(bǔ)全)和 prompts(提示)
- 獎(jiǎng)勵(lì)函數(shù)的輸出:返回一個(gè)浮點(diǎn)數(shù)列表,每個(gè)浮點(diǎn)數(shù)代表對(duì)應(yīng)于單個(gè)補(bǔ)全的獎(jiǎng)勵(lì)
(1)較長(zhǎng)補(bǔ)全獎(jiǎng)勵(lì)函數(shù)
def completion_reward(completions, **kwargs):
'''獎(jiǎng)勵(lì)函數(shù),對(duì)較長(zhǎng)的補(bǔ)全給予更高的分?jǐn)?shù)'''
return [float(len(completion))/100 for completion in completions]
prompts = ["The sky is", "The sun is"]
completions = [" blue.", " in the sky."]
print("completion_reward: ", completion_reward(prompts=prompts, completinotallow=completions))
(2)格式正確獎(jiǎng)勵(lì)函數(shù)
def format_reward(completions, **kwargs):
'''格式獎(jiǎng)勵(lì)'''
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, response) for response in responses]
return [0.5if match else0.0for match in matches]
prompts = [
[{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}],
[{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}],
]
completions = [
[{"role": "assistant", "content": "<think>The sum of 1 and 2 is 3, which we multiply by 4 to get 12.</think><answer>(1 + 2) * 4 = 12</answer>"}],
[{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}],
]
print("format_reward: ", format_reward(prompts=prompts, completinotallow=completions))
根據(jù)以上的獎(jiǎng)勵(lì)樣例,可以設(shè)計(jì)對(duì)于不同數(shù)據(jù)集的獎(jiǎng)勵(lì)函數(shù),如:
- 判斷內(nèi)容中是否包含數(shù)字
- 判斷內(nèi)容回答是否參考網(wǎng)頁(yè)的知識(shí)庫(kù)內(nèi)容
- ...
然后將這些函數(shù)傳入 GRPOTrainer
即可,代碼如下:
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
...
format_reward,
completion_reward,
],
args=training_args,
train_dataset=data,
...
)
3、使用 GRPO
訓(xùn)練模型
github上已經(jīng)有很多復(fù)刻 DeepSeek-R1-Zero
的方案,有興趣可以看一下這幾個(gè)開(kāi)源項(xiàng)目(成本基本都控制在500以內(nèi)):
3.1 訓(xùn)練代碼
這里為了演示如何使用 GRPO
訓(xùn)練模型,本文也給出了完整的訓(xùn)練代碼,其中流程如下:
- 使用
Qwen/Qwen2.5-3B-Instruct
作為基礎(chǔ)模型 - 使用
swulling/gsm8k_chinese
作為訓(xùn)練數(shù)據(jù)集
import re
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer
SYSTEM_PROMPT = """
按照如下格式生成:
<think>
...
</think>
<answer>
...
</answer>
"""
def process_data(data):
return data.map(
lambda x: {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": x["question_zh-cn"]},
],
"answer": x["answer_only"],
}
)
def extract_answer(text):
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
def correctness_reward(completions, answer, **kwargs):
responses = [completion[0]["content"] for completion in completions]
extracted_responses = [extract_answer(r) for r in responses]
return [2.0if response == str(ans) else0.0for response, ans in zip(extracted_responses, answer)]
def completion_reward(completions, **kwargs):
'''獎(jiǎng)勵(lì)函數(shù),對(duì)較長(zhǎng)的補(bǔ)全給予更高的分?jǐn)?shù)'''
return [float(len(completion)) / 100for completion in completions]
prompts = ["The sky is", "The sun is"]
completions = [" blue.", " in the sky."]
print("completion_reward: ", completion_reward(prompts=prompts, completinotallow=completions))
def digit_reward(completions, **kwargs):
responses = [completion[0]["content"] for completion in completions]
extracted_responses = [extract_answer(r) for r in responses]
return [0.5if response.isdigit() else0.0for response in extracted_responses]
def format_reward(completions, **kwargs):
'''格式獎(jiǎng)勵(lì)'''
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, response) for response in responses]
return [0.5if match else0.0for match in matches]
prompts = [
[{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}],
[{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}],
]
completions = [
[{"role": "assistant", "content": "<think>The sum of 1 and 2 is 3, which we multiply by 4 to get 12.</think><answer>(1 + 2) * 4 = 12</answer>"}],
[{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}],
]
print("format_reward: ", format_reward(prompts=prompts, completinotallow=completions))
def mark_reward(completions, **kwargs):
'''標(biāo)記獎(jiǎng)勵(lì)(改善格式獎(jiǎng)勵(lì)稀疏問(wèn)題)'''
def mark_num(text):
reward = 0
if text.count("<think>\n") == 1:
reward += 0.125
if text.count("</think>\n") == 1:
reward += 0.125
if text.count("<answer>\n") == 1:
reward += 0.125
if text.count("</answer>\n") == 1:
reward += 0.125 * 2
return reward
responses = [completion[0]["content"] for completion in completions]
return [mark_num(response) for response in responses]
if __name__ == "__main__":
model_name = "Qwen/Qwen2.5-3B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="./model")
model.cuda()
tokenizer = AutoTokenizer.from_pretrained(model_name)
ds = load_dataset("swulling/gsm8k_chinese", cache_dir="./dataset")
data = process_data(ds["train"])
output_dir = "output"
training_args = GRPOConfig(
output_dir=output_dir,
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
logging_steps=1,
bf16=True,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
num_generatinotallow=16,
max_prompt_length=256,
max_completion_length=200,
num_train_epochs=1,
save_steps=100,
max_grad_norm=0.1,
log_on_each_node=False,
use_vllm=False,
report_to="tensorboard",
)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
mark_reward,
format_reward,
digit_reward,
completion_reward,
correctness_reward,
],
args=training_args,
train_dataset=data,
peft_cnotallow=LoraConfig(task_type="CAUSAL_LM"),
)
trainer.train()
trainer.save_model(output_dir)
3.2 輸出結(jié)果
(1)Qwen/Qwen2.5-3B-Instruct
輸出:
用戶:
使用給定的數(shù)字 [44, 19, 35],創(chuàng)建一個(gè)等于 98 的方程。你可以使用基本算術(shù)運(yùn)算(+、-、*、/)一次或多次,但每個(gè)數(shù)字只能使用一次。在 <think> </think> 標(biāo)簽中展示你的思考過(guò)程,并在 <answer> </answer> 標(biāo)簽中返回最終方程,例如 <answer> (1 + 2) / 3 </answer>。在 <think> 標(biāo)簽中逐步思考。
機(jī)器人:
<span class="think">首先我們需要找到如何將給定的數(shù)字組合起來(lái)得到 98。我們可以嘗試不同的組合,比如加法和減法。注意到 98 可以分解為 90 和 8 的和,因此我們可以考慮用 44 和 35 來(lái)構(gòu)造這兩個(gè)數(shù)。</span>
<span class="think">觀察到 44 和 35 的差是 9,如果我們將 35 減去這個(gè)差,就可以得到 8。即 35 - (44 - 35) = 8。</span>
<span class="think">現(xiàn)在我們有了 8 和 90,可以將它們相加得到 98。所以最終的方程是 35 + (44 - 35)。</span>
<answer> 35 + (44 - 35) </answer>
(2)訓(xùn)練的模型輸出:
用戶:
使用給定的數(shù)字 [44, 19, 35],創(chuàng)建一個(gè)等于 98 的方程。你可以使用基本算術(shù)運(yùn)算(+、-、*、/)一次或多次,但每個(gè)數(shù)字只能使用一次。在 <think> </think> 標(biāo)簽中展示你的思考過(guò)程,并在 <answer> </answer> 標(biāo)簽中返回最終方程,例如 <answer> (1 + 2) / 3 </answer>。在 <think> 標(biāo)簽中逐步思考。
機(jī)器人:
<think>
首先,我需要使用數(shù)字44、19和35,通過(guò)基本算術(shù)運(yùn)算(+、-、*、/)來(lái)構(gòu)造一個(gè)等于98的方程,且每個(gè)數(shù)字只能使用一次。
我先考慮將較大的數(shù)字進(jìn)行加法或乘法運(yùn)算,以達(dá)到接近98的目標(biāo)。例如,44加19等于63,再加上35可以得到98。具體計(jì)算如下:
44 + 19 = 63
63 + 35 = 98
因此,最終的方程可以表示為44 + 19 + 35 = 98。
</think>
<answer> 44 + 19 + 35 = 98 </answer>
參考
(1)https://github.com/Jiayi-Pan/TinyZero
(2)https://github.com/huggingface/open-r1
(3)DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning(https://arxiv.org/pdf/2501.12948)
(4)https://zhuanlan.zhihu.com/p/20021693569
(5)https://zhuanlan.zhihu.com/p/19949917958
(6)https://blog.csdn.net/qq_38961840/article/details/145387854