譯者 | 布加迪
審校 | 重樓
誰需要超級計(jì)算機(jī)?僅用15GB VRAM就可以訓(xùn)練你自己的功能強(qiáng)大的AI推理模型!
DeepSeek的R1模型在不需要人類反饋的情況下就能進(jìn)行更深思熟慮的推理,已顛覆了大語言模型(LLM)領(lǐng)域。這一重大突破背后的關(guān)鍵是群體相對策略優(yōu)化(GRPO),這是一種幫助模型自主開發(fā)推理能力的強(qiáng)化學(xué)習(xí)技術(shù)。與依賴值函數(shù)的近端策略優(yōu)化(PPO)不同,GRPO在不需要值函數(shù)的情況下就可以優(yōu)化響應(yīng),從而提高了效率。
開發(fā)更好的推理模型的競賽正如火如荼地進(jìn)行。但是對于我們這些GPU資源有限的人來說又該如何是好?
多虧了Unsloth,我們現(xiàn)在在消費(fèi)級GPU上僅用15GB的VRAM就可以訓(xùn)練15B參數(shù)模型。本文將介紹如何通過幾個(gè)步驟使用GRPO訓(xùn)練自己的推理模型。
GRPO簡介
GRPO幫助AI模型通過比較答案來學(xué)習(xí)更好地思考。下面是它的工作原理:
- 模型為一個(gè)問題編寫多個(gè)答案。
- 每個(gè)答案都有一個(gè)分?jǐn)?shù)(比如答案正確、清晰、結(jié)構(gòu)合理的相應(yīng)得分)。
- 得分求平均值,每個(gè)答案都與這個(gè)平均值進(jìn)行比較。
- 超過平均分的答案會得到獎(jiǎng)勵(lì)。
- 隨著時(shí)間的推移,模型逐漸學(xué)會生成得分更高的答案。
比如以數(shù)學(xué)為例:
- 問:“2+2等于多少?”
- 模型可能會輸出:“2+2=5”(錯(cuò)誤)或“2+2=4”(正確)。
GRPO獎(jiǎng)勵(lì)正確的答案,所以模型學(xué)會避免錯(cuò)誤。這項(xiàng)技術(shù)允許模型在不需要大量標(biāo)記數(shù)據(jù)集的情況下開發(fā)結(jié)構(gòu)化推理。
訓(xùn)練自己的推理模型的逐步指南
本指南介紹了如何使用GRPO訓(xùn)練一個(gè)針對推理進(jìn)行優(yōu)化的LLM,并將其部署在Hugging Face上。我們將為本文使用meta-llama/meta-Llama-3.1-8B-Instruct以及Unsloth提供的參考筆記本:https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb。
第1步:環(huán)境設(shè)置
使用以下代碼安裝依賴項(xiàng):
%%capture
# Install base packages
!pip install unsloth vllm
!pip install --upgrade pillow
# Install specific TRL version for GRPO support
!pip install git+https://github.com/huggingface/trl.git@e95f9fb74a3c3647b86f251b7e230ec51c64b72b
關(guān)鍵組件:
- unsloth:經(jīng)過優(yōu)化的訓(xùn)練框架
- vllm:高吞吐量推理引擎
- trl:Transformer強(qiáng)化學(xué)習(xí)庫
第2步:模型初始化
在所有函數(shù)之前使用PatchFastRL來修補(bǔ)GRPO及其他強(qiáng)化學(xué)習(xí)算法。這一步通過將特定的算法改進(jìn)集成到FastLanguageModel中,確保模型針對強(qiáng)化學(xué)習(xí)任務(wù)進(jìn)行了優(yōu)化。然后加載使用以下參數(shù)的Llama 3.1 8B Instruct,并運(yùn)用lora適配。
from unsloth import FastLanguageModel, PatchFastRL, is_bfloat16_supported
import torch
# Enable GRPO patches
PatchFastRL("GRPO", FastLanguageModel)
# Configuration
max_seq_length = 512 # Increase for complex reasoning chains
lora_rank = 32 # Balance between capacity and speed
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
max_seq_length = max_seq_length,
load_in_4bit = True, # False for LoRA 16bit
fast_inference = True, # Enable vLLM fast inference
max_lora_rank = lora_rank,
gpu_memory_utilization = 0.6, # Reduce if out of memory
)
model = FastLanguageModel.get_peft_model(
model,
r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
], # Remove QKVO if out of memory
lora_alpha = lora_rank,
use_gradient_checkpointing = "unsloth", # Enable long context finetuning
random_state = 3407,
)
關(guān)鍵參數(shù):
- load_in_4bit:將內(nèi)存使用減少4倍(量化)
- fast_inference:啟用vLLM的注意力優(yōu)化
- gpu_memory_utilization:控制VRAM分配緩沖區(qū)(本例中為60%)
- r = lora_rank:控制允許多少LoRA適配。我們將其設(shè)置為32(秩越大=智能化越高,但速度越慢)
第3步:數(shù)據(jù)集準(zhǔn)備
在這一步,我們準(zhǔn)備了一個(gè)數(shù)據(jù)集,在生成答案之前逐步訓(xùn)練我們的模型進(jìn)行推理。數(shù)據(jù)集的格式很重要,因?yàn)樗绊懩P腿绾螛?gòu)建響應(yīng)的結(jié)構(gòu)?;A(chǔ)筆記本最初使用GSM8K,這個(gè)數(shù)據(jù)集包含需要多步驟推理的85000個(gè)小學(xué)數(shù)學(xué)單詞問題。然而,我們將使用一個(gè)不同的數(shù)據(jù)集,它提供了跨多個(gè)領(lǐng)域的更廣泛的推理覆蓋,可以在這里找到:https://huggingface.co/datasets/KingNish/reasoning-base-20k。
數(shù)據(jù)字段:
- 用戶:用戶的查詢或問題語句。
- 助理:問題的正確答案。
- 推理:詳細(xì)的逐步推理過程解釋了如何得出正確的答案。
- 模板:預(yù)運(yùn)用的RChatML聊天模板。
我們使用結(jié)構(gòu)化的響應(yīng)模板為數(shù)據(jù)集格式化,以確保我們的模型學(xué)會將推理與最終答案分開。
import re
from datasets import load_dataset, Dataset
from difflib import SequenceMatcher
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
現(xiàn)在,加載Reasoning Base 20K數(shù)據(jù)集。
def get_reasoning_questions(split="train") -> Dataset:
data = load_dataset("KingNish/reasoning-base-20k", split=split)
data = data.map(lambda x: {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": x["user"]}
],
"reasoning": x["reasoning"],
"answer": x["assistant"]
})
return data
# Load dataset
dataset = get_reasoning_questions()
第4步:獎(jiǎng)勵(lì)函數(shù)設(shè)計(jì)(最重要的步驟)
獎(jiǎng)勵(lì)函數(shù)在訓(xùn)練針對推理加以優(yōu)化的模型中至關(guān)重要,因?yàn)樗鼈冎笇?dǎo)模型“良好”表現(xiàn)的含義是什么。正確的獎(jiǎng)勵(lì)設(shè)計(jì)確保了模型生成邏輯合理、格式良好且高質(zhì)量的響應(yīng)。我們的數(shù)據(jù)集需要一種與GSM8K不同的方法,因?yàn)槲覀兊捻憫?yīng)包含詳細(xì)的推理步驟,而不僅僅是數(shù)字答案。因此,我們的獎(jiǎng)勵(lì)函數(shù)評估以下多個(gè)方面:
- 內(nèi)容質(zhì)量→與參考答案在語義上的一致性
- 結(jié)構(gòu)合規(guī)→XML樣式的格式化
- 過程質(zhì)量→推理步驟的復(fù)雜性
在下面的示例代碼中,你將發(fā)現(xiàn)幾個(gè)獎(jiǎng)勵(lì)函數(shù),每個(gè)函數(shù)專注于響應(yīng)的不同方面。下面更詳細(xì)地介紹這些函數(shù):
(1)答案相關(guān)性獎(jiǎng)勵(lì)
這個(gè)函數(shù)測量模型的響應(yīng)在問題提示和參考答案(如果有)兩個(gè)方面涵蓋關(guān)鍵術(shù)語有多到位。這確保了模型至少提到或解決問題中的關(guān)鍵主題。
- 從問題、響應(yīng)和參考答案中提取關(guān)鍵術(shù)語。
- 如果超過30%的問題術(shù)語出現(xiàn)在響應(yīng)中,則加0.5分。
- 如果超過30%的參考答案出現(xiàn)在響應(yīng)中,則加0.5分。
- 確保模型正確且合乎邏輯地回答問題。
def answer_relevance_reward(prompts, completions, answer, **kwargs) -> list[float]:
responses = [completion[0]["content"] for completion in completions]
questions = [prompt[-1]["content"] for prompt in prompts]
def check_relevance(response, question, reference):
score = 0.0
# Extract key terms from question
question_terms = set(question.lower().split())
response_terms = set(response.lower().split())
reference_terms = set(reference.lower().split())
# 1) Check if response addresses key terms from question
if len(question_terms) > 0:
common_qr = question_terms.intersection(response_terms)
if len(common_qr) / len(question_terms) > 0.3:
score += 0.5
# 2) Check if response uses similar key terms as reference
if len(reference_terms) > 0:
common_rr = response_terms.intersection(reference_terms)
if len(common_rr) / len(reference_terms) > 0.3:
score += 0.5
return score
return [check_relevance(r, q, a) for r, q, a in zip(responses, questions, answer)]
(2)嚴(yán)格格式合規(guī)獎(jiǎng)勵(lì)
這個(gè)函數(shù)確保輸出嚴(yán)格遵循所需的XML樣式結(jié)構(gòu),以保持結(jié)構(gòu)化推理的一致輸出格式。如果格式正確,則獎(jiǎng)勵(lì)0.5,否則獎(jiǎng)勵(lì)0.0。
def strict_format_reward_func(completions, **kwargs) -> list[float]:
pattern = r"^\n.*?\n\n\n.*?\n\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses]
return [0.5 if match else 0.0 for match in matches]
(3)軟格式合規(guī)獎(jiǎng)勵(lì)
這種更靈活的獎(jiǎng)勵(lì)函數(shù)允許較小的偏差,但仍然需要適當(dāng)?shù)腦ML樣式格式。如果匹配,獎(jiǎng)勵(lì)0.5分,否則獎(jiǎng)勵(lì)0.0分。如果嚴(yán)格格式過于僵硬,并且可能懲罰不影響可用性的小差異,這可能會有所幫助。
def soft_format_reward_func(completions, **kwargs) -> list[float]:
pattern = r".*?\s*.*?"
responses = [completion[0]["content"] for completion in completions]
matches = [re.search(pattern, r, re.DOTALL) for r in responses]
return [0.5 if match else 0.0 for match in matches]
(4)XML標(biāo)記計(jì)數(shù)獎(jiǎng)勵(lì)(啟發(fā)式示例)
這個(gè)函數(shù)通過計(jì)算所需的標(biāo)記來評估響應(yīng)遵守預(yù)期XML結(jié)構(gòu)的程度。如果出現(xiàn)額外的內(nèi)容,它會懲罰,并提供部分給分,而不是二元獎(jiǎng)勵(lì)。
def count_xml(text) -> float:
count = 0.0
if text.count("\n") == 1:
count += 0.125
if text.count("\n\n") == 1:
count += 0.125
if text.count("\n\n") == 1:
count += 0.125
count -= len(text.split("\n\n")[-1]) * 0.001
if text.count("\n") == 1:
count += 0.125
count -= (len(text.split("\n")[-1]) - 1) * 0.001
return count
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
實(shí)際上,你常常希望將一些或所有這些不同的信號結(jié)合起來計(jì)算最終的獎(jiǎng)勵(lì)分?jǐn)?shù)。最初的筆記本使用int和正確性獎(jiǎng)勵(lì)函數(shù),因?yàn)閿?shù)據(jù)集包含單個(gè)數(shù)字答案。然而,鑒于我們的一般推理模型,更廣泛的評估方法必不可少。因此,我們使用了以下獎(jiǎng)勵(lì)函數(shù):
reward_funcs = [
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
answer_relevance_reward
]
第5步:GRPO訓(xùn)練配置與執(zhí)行
現(xiàn)在,設(shè)置GRPO訓(xùn)練器和所有配置。我將max_steps從250減少到150以節(jié)省時(shí)間,并將num_generations從6減少到4以節(jié)省內(nèi)存。然而,Unsloth建議至少跑300個(gè)步驟才能觀察到明顯的改善。所有其他配置保持不變,如下所示:
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
use_vllm = True, # use vLLM for fast inference!
learning_rate = 5e-6,
adam_beta1 = 0.9,
adam_beta2 = 0.99,
weight_decay = 0.1,
warmup_ratio = 0.1,
lr_scheduler_type = "cosine",
optim = "paged_adamw_8bit",
logging_steps = 1,
bf16 = is_bfloat16_supported(),
fp16 = not is_bfloat16_supported(),
per_device_train_batch_size = 1,
gradient_accumulation_steps = 1, # Increase to 4 for smoother training
num_generations = 4, # Decrease if out of memory
max_prompt_length = 256,
max_completion_length = 200,
# num_train_epochs = 1, # Set to 1 for a full training run
max_steps = 150,
save_steps = 150,
max_grad_norm = 0.1,
report_to = "none", # Can use Weights & Biases
output_dir = "outputs",
)
現(xiàn)在,不妨初始化并運(yùn)行GRPO訓(xùn)練器:
trainer = GRPOTrainer(
model = model,
processing_class = tokenizer,
reward_funcs = reward_funcs,
args = training_args,
train_dataset = dataset,
)
trainer.train()
訓(xùn)練日志讓我們得以深入了解獎(jiǎng)勵(lì)趨勢、損失值和響應(yīng)質(zhì)量改進(jìn)。最初,獎(jiǎng)勵(lì)會因隨機(jī)探索而波動(dòng),但隨著時(shí)間的推移會逐漸改善。我在Colab T4 GPU上運(yùn)行這個(gè)筆記本大約花了2小時(shí)7分鐘,150個(gè)步驟后的最終訓(xùn)練損失為0.0003475。
第6步:模型評估
鑒于我們已經(jīng)訓(xùn)練了模型,不妨比較基準(zhǔn)LLaMA 3.1 8B Instruct與使用GRPO訓(xùn)練的模型各自的性能。
GRPO訓(xùn)練前
text = tokenizer.apply_chat_template([
{"role" : "user", "content" : "How many r's are in strawberry?"},
], tokenize = False, add_generation_prompt = True)
from vllm import SamplingParams
sampling_params = SamplingParams(
temperature = 0.8,
top_p = 0.95,
max_tokens = 1024,
)
output = model.fast_generate(
[text],
sampling_params = sampling_params,
lora_request = None,
)[0].outputs[0].text
output
輸出:
There are 2 'r's in the word "strawberry".
基準(zhǔn)模型錯(cuò)誤地識別了“strawberry”中“r”的數(shù)量,暴露了事實(shí)推理方面的不足。
GRPO訓(xùn)練后
現(xiàn)在我們加載LoRA并進(jìn)行測試:
model.save_lora("grpo_saved_lora")
text = tokenizer.apply_chat_template([
{"role" : "system", "content" : SYSTEM_PROMPT},
{"role" : "user", "content" : "How many r's are in strawberry?"},
], tokenize = False, add_generation_prompt = True)
from vllm import SamplingParams
sampling_params = SamplingParams(
temperature = 0.8,
top_p = 0.95,
max_tokens = 1024,
)
output = model.fast_generate(
text,
sampling_params = sampling_params,
lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text
output
輸出:
<reasoning>
To determine the number of 'r's in the word "strawberry," we need to spell it out and count the occurrences of 'r'. The word "strawberry" is spelled as S-T-R-A-W-B-E-R-R-Y. The letter 'r' appears in the 3rd, 8th, and 9th positions.
</reasoning>
<answer>
There are 3 'r's in the word "strawberry."
</answer>
GRPO訓(xùn)練后,模型的準(zhǔn)確率和推理能力有所提高,但仍然不夠完美。由于它在T4 GPU上訓(xùn)練僅用了2小時(shí),因此延長序列長度和訓(xùn)練時(shí)間將進(jìn)一步提升其表現(xiàn)。
第7步:部署和擴(kuò)展
一旦對模型進(jìn)行了微調(diào)和評估,下一步就是將其部署到實(shí)際使用場景中,確保它可以有效地?cái)U(kuò)展。部署需要將模型轉(zhuǎn)換成經(jīng)過優(yōu)化的格式,將其整合到推理服務(wù)器中,并通過API或應(yīng)用程序讓其可以訪問。為了確保有效的推理,我們保存了訓(xùn)練好的LoRA適配器,并將它們推送到Hugging Face Hub以便訪問。這允許其他人加載經(jīng)過微調(diào)的模型,不需要大量的計(jì)算資源。
# Just LoRA adapters
if True: model.save_pretrained_merged("model", tokenizer, save_method = "lora",)
if True: model.push_to_hub_merged("kanwal-mehreen18/Llama3.1-8B-GRPO", tokenizer, save_method = "lora", token = "YOUR_HF_KEY")
將lora模型保存到https://huggingface.co/kanwal-mehreen18/Llama3.1-8B-GRPO。
Unsloth給出的最佳實(shí)踐
- 使用參數(shù)數(shù)量>1.5B的模型進(jìn)行可靠推理。
- 接受進(jìn)行12小時(shí)的訓(xùn)練以處理復(fù)雜任務(wù)。
- 結(jié)合多種獎(jiǎng)勵(lì)信號(3-5種函數(shù)最好)。
原文標(biāo)題:DeepSeek-Level AI? Train Your Own Reasoning Model in Just 7 Easy Steps!,作者:Kanwal Mehreen