自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

DeepSeek級AI?訓(xùn)練自己的推理模型僅需七個(gè)步驟

譯文 精選
人工智能
DeepSeek的R1模型在不需要人類反饋的情況下就能進(jìn)行更深思熟慮的推理,已顛覆了大語言模型(LLM)領(lǐng)域。這一重大突破背后的關(guān)鍵是群體相對策略優(yōu)化(GRPO),這是一種幫助模型自主開發(fā)推理能力的強(qiáng)化學(xué)習(xí)技術(shù)。

譯者 | 布加迪

審校 | 重樓

誰需要超級計(jì)算機(jī)?僅用15GB VRAM就可以訓(xùn)練你自己的功能強(qiáng)大的AI推理模型!

DeepSeek的R1模型在不需要人類反饋的情況下就能進(jìn)行更深思熟慮的推理,已顛覆了大語言模型(LLM)領(lǐng)域。這一重大突破背后的關(guān)鍵是群體相對策略優(yōu)化(GRPO),這是一種幫助模型自主開發(fā)推理能力的強(qiáng)化學(xué)習(xí)技術(shù)。與依賴值函數(shù)的近端策略優(yōu)化(PPO)不同,GRPO在不需要值函數(shù)的情況下就可以優(yōu)化響應(yīng),從而提高了效率。

開發(fā)更好的推理模型的競賽正如火如荼地進(jìn)行。但是對于我們這些GPU資源有限的人來說又該如何是好?

多虧了Unsloth,我們現(xiàn)在在消費(fèi)級GPU上僅用15GB的VRAM就可以訓(xùn)練15B參數(shù)模型。本文將介紹如何通過幾個(gè)步驟使用GRPO訓(xùn)練自己的推理模型。

GRPO簡介

 GRPO幫助AI模型通過比較答案來學(xué)習(xí)更好地思考。下面是它的工作原理:

  • 模型為一個(gè)問題編寫多個(gè)答案。
  • 每個(gè)答案都有一個(gè)分?jǐn)?shù)(比如答案正確、清晰、結(jié)構(gòu)合理的相應(yīng)得分)。
  • 得分求平均值,每個(gè)答案都與這個(gè)平均值進(jìn)行比較。
  • 超過平均分的答案會得到獎(jiǎng)勵(lì)。
  • 隨著時(shí)間的推移,模型逐漸學(xué)會生成得分更高的答案。

比如以數(shù)學(xué)為例:

  • 問:“2+2等于多少?”
  • 模型可能會輸出:“2+2=5”(錯(cuò)誤)或“2+2=4”(正確)。

GRPO獎(jiǎng)勵(lì)正確的答案,所以模型學(xué)會避免錯(cuò)誤。這項(xiàng)技術(shù)允許模型在不需要大量標(biāo)記數(shù)據(jù)集的情況下開發(fā)結(jié)構(gòu)化推理。

訓(xùn)練自己的推理模型的逐步指南

本指南介紹了如何使用GRPO訓(xùn)練一個(gè)針對推理進(jìn)行優(yōu)化的LLM,并將其部署在Hugging Face上。我們將為本文使用meta-llama/meta-Llama-3.1-8B-Instruct以及Unsloth提供的參考筆記本:https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb。

第1步:環(huán)境設(shè)置

使用以下代碼安裝依賴項(xiàng):

%%capture
# Install base packages
!pip install unsloth vllm
!pip install --upgrade pillow

# Install specific TRL version for GRPO support
!pip install git+https://github.com/huggingface/trl.git@e95f9fb74a3c3647b86f251b7e230ec51c64b72b

關(guān)鍵組件:

  • unsloth:經(jīng)過優(yōu)化的訓(xùn)練框架
  • vllm:高吞吐量推理引擎
  • trl:Transformer強(qiáng)化學(xué)習(xí)庫

第2步:模型初始化

在所有函數(shù)之前使用PatchFastRL來修補(bǔ)GRPO及其他強(qiáng)化學(xué)習(xí)算法。這一步通過將特定的算法改進(jìn)集成到FastLanguageModel中,確保模型針對強(qiáng)化學(xué)習(xí)任務(wù)進(jìn)行了優(yōu)化。然后加載使用以下參數(shù)的Llama 3.1 8B Instruct,并運(yùn)用lora適配。

from unsloth import FastLanguageModel, PatchFastRL, is_bfloat16_supported
import torch

# Enable GRPO patches
PatchFastRL("GRPO", FastLanguageModel)

# Configuration
max_seq_length = 512  # Increase for complex reasoning chains
lora_rank = 32        # Balance between capacity and speed

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.6, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

關(guān)鍵參數(shù):

  • load_in_4bit:將內(nèi)存使用減少4倍(量化)
  • fast_inference:啟用vLLM的注意力優(yōu)化
  • gpu_memory_utilization:控制VRAM分配緩沖區(qū)(本例中為60%)
  • r = lora_rank:控制允許多少LoRA適配。我們將其設(shè)置為32(秩越大=智能化越高,但速度越慢)

第3步:數(shù)據(jù)集準(zhǔn)備

在這一步,我們準(zhǔn)備了一個(gè)數(shù)據(jù)集,在生成答案之前逐步訓(xùn)練我們的模型進(jìn)行推理。數(shù)據(jù)集的格式很重要,因?yàn)樗绊懩P腿绾螛?gòu)建響應(yīng)的結(jié)構(gòu)?;A(chǔ)筆記本最初使用GSM8K,這個(gè)數(shù)據(jù)集包含需要多步驟推理的85000個(gè)小學(xué)數(shù)學(xué)單詞問題。然而,我們將使用一個(gè)不同的數(shù)據(jù)集,它提供了跨多個(gè)領(lǐng)域的更廣泛的推理覆蓋,可以在這里找到:https://huggingface.co/datasets/KingNish/reasoning-base-20k。

數(shù)據(jù)字段:

  • 用戶:用戶的查詢或問題語句。
  • 助理:問題的正確答案。
  • 推理:詳細(xì)的逐步推理過程解釋了如何得出正確的答案。
  • 模板:預(yù)運(yùn)用的RChatML聊天模板。

我們使用結(jié)構(gòu)化的響應(yīng)模板為數(shù)據(jù)集格式化,以確保我們的模型學(xué)會將推理與最終答案分開。

import re
from datasets import load_dataset, Dataset
from difflib import SequenceMatcher

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

現(xiàn)在,加載Reasoning Base 20K數(shù)據(jù)集。

def get_reasoning_questions(split="train") -> Dataset:
    data = load_dataset("KingNish/reasoning-base-20k", split=split)

    data = data.map(lambda x: {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": x["user"]}
        ],
        "reasoning": x["reasoning"],
        "answer": x["assistant"]
    })

    return data

# Load dataset
dataset = get_reasoning_questions()

第4步:獎(jiǎng)勵(lì)函數(shù)設(shè)計(jì)(最重要的步驟)

獎(jiǎng)勵(lì)函數(shù)在訓(xùn)練針對推理加以優(yōu)化的模型中至關(guān)重要,因?yàn)樗鼈冎笇?dǎo)模型“良好”表現(xiàn)的含義是什么。正確的獎(jiǎng)勵(lì)設(shè)計(jì)確保了模型生成邏輯合理、格式良好且高質(zhì)量的響應(yīng)。我們的數(shù)據(jù)集需要一種與GSM8K不同的方法,因?yàn)槲覀兊捻憫?yīng)包含詳細(xì)的推理步驟,而不僅僅是數(shù)字答案。因此,我們的獎(jiǎng)勵(lì)函數(shù)評估以下多個(gè)方面:

  • 內(nèi)容質(zhì)量→與參考答案在語義上的一致性
  • 結(jié)構(gòu)合規(guī)→XML樣式的格式化
  • 過程質(zhì)量→推理步驟的復(fù)雜性

在下面的示例代碼中,你將發(fā)現(xiàn)幾個(gè)獎(jiǎng)勵(lì)函數(shù),每個(gè)函數(shù)專注于響應(yīng)的不同方面。下面更詳細(xì)地介紹這些函數(shù):

(1)答案相關(guān)性獎(jiǎng)勵(lì)

這個(gè)函數(shù)測量模型的響應(yīng)在問題提示和參考答案(如果有)兩個(gè)方面涵蓋關(guān)鍵術(shù)語有多到位。這確保了模型至少提到或解決問題中的關(guān)鍵主題。

  • 從問題、響應(yīng)和參考答案中提取關(guān)鍵術(shù)語。
  • 如果超過30%的問題術(shù)語出現(xiàn)在響應(yīng)中,則加0.5分。
  • 如果超過30%的參考答案出現(xiàn)在響應(yīng)中,則加0.5分。
  • 確保模型正確且合乎邏輯地回答問題。
def answer_relevance_reward(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]["content"] for completion in completions]
    questions = [prompt[-1]["content"] for prompt in prompts]

    def check_relevance(response, question, reference):
        score = 0.0
        # Extract key terms from question
        question_terms = set(question.lower().split())
        response_terms = set(response.lower().split())
        reference_terms = set(reference.lower().split())

        # 1) Check if response addresses key terms from question
        if len(question_terms) > 0:
            common_qr = question_terms.intersection(response_terms)
            if len(common_qr) / len(question_terms) > 0.3:
                score += 0.5

        # 2) Check if response uses similar key terms as reference
        if len(reference_terms) > 0:
            common_rr = response_terms.intersection(reference_terms)
            if len(common_rr) / len(reference_terms) > 0.3:
                score += 0.5

        return score

    return [check_relevance(r, q, a) for r, q, a in zip(responses, questions, answer)]

(2)嚴(yán)格格式合規(guī)獎(jiǎng)勵(lì)

這個(gè)函數(shù)確保輸出嚴(yán)格遵循所需的XML樣式結(jié)構(gòu),以保持結(jié)構(gòu)化推理的一致輸出格式。如果格式正確,則獎(jiǎng)勵(lì)0.5,否則獎(jiǎng)勵(lì)0.0。

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    pattern = r"^\n.*?\n\n\n.*?\n\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

(3)軟格式合規(guī)獎(jiǎng)勵(lì)

這種更靈活的獎(jiǎng)勵(lì)函數(shù)允許較小的偏差,但仍然需要適當(dāng)?shù)腦ML樣式格式。如果匹配,獎(jiǎng)勵(lì)0.5分,否則獎(jiǎng)勵(lì)0.0分。如果嚴(yán)格格式過于僵硬,并且可能懲罰不影響可用性的小差異,這可能會有所幫助。

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    pattern = r".*?\s*.*?"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.search(pattern, r, re.DOTALL) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

(4)XML標(biāo)記計(jì)數(shù)獎(jiǎng)勵(lì)(啟發(fā)式示例)

這個(gè)函數(shù)通過計(jì)算所需的標(biāo)記來評估響應(yīng)遵守預(yù)期XML結(jié)構(gòu)的程度。如果出現(xiàn)額外的內(nèi)容,它會懲罰,并提供部分給分,而不是二元獎(jiǎng)勵(lì)。

def count_xml(text) -> float:
    count = 0.0
    if text.count("\n") == 1:
        count += 0.125
    if text.count("\n\n") == 1:
        count += 0.125
    if text.count("\n\n") == 1:
        count += 0.125
        count -= len(text.split("\n\n")[-1]) * 0.001
    if text.count("\n") == 1:
        count += 0.125
        count -= (len(text.split("\n")[-1]) - 1) * 0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

實(shí)際上,你常常希望將一些或所有這些不同的信號結(jié)合起來計(jì)算最終的獎(jiǎng)勵(lì)分?jǐn)?shù)。最初的筆記本使用int和正確性獎(jiǎng)勵(lì)函數(shù),因?yàn)閿?shù)據(jù)集包含單個(gè)數(shù)字答案。然而,鑒于我們的一般推理模型,更廣泛的評估方法必不可少。因此,我們使用了以下獎(jiǎng)勵(lì)函數(shù):

reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        answer_relevance_reward
    ]

第5步:GRPO訓(xùn)練配置與執(zhí)行

現(xiàn)在,設(shè)置GRPO訓(xùn)練器和所有配置。我將max_steps從250減少到150以節(jié)省時(shí)間,并將num_generations從6減少到4以節(jié)省內(nèi)存。然而,Unsloth建議至少跑300個(gè)步驟才能觀察到明顯的改善。所有其他配置保持不變,如下所示:

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    use_vllm = True, # use vLLM for fast inference!
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    bf16 = is_bfloat16_supported(),
    fp16 = not is_bfloat16_supported(),
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 4, # Decrease if out of memory
    max_prompt_length = 256,
    max_completion_length = 200,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 150,
    save_steps = 150,
    max_grad_norm = 0.1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)

現(xiàn)在,不妨初始化并運(yùn)行GRPO訓(xùn)練器:

trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = reward_funcs,
    args = training_args,
    train_dataset = dataset,
)
trainer.train()

訓(xùn)練日志讓我們得以深入了解獎(jiǎng)勵(lì)趨勢、損失值和響應(yīng)質(zhì)量改進(jìn)。最初,獎(jiǎng)勵(lì)會因隨機(jī)探索而波動(dòng),但隨著時(shí)間的推移會逐漸改善。我在Colab T4 GPU上運(yùn)行這個(gè)筆記本大約花了2小時(shí)7分鐘,150個(gè)步驟后的最終訓(xùn)練損失為0.0003475。

第6步:模型評估

鑒于我們已經(jīng)訓(xùn)練了模型,不妨比較基準(zhǔn)LLaMA 3.1 8B Instruct與使用GRPO訓(xùn)練的模型各自的性能。

GRPO訓(xùn)練前

text = tokenizer.apply_chat_template([
    {"role" : "user", "content" : "How many r's are in strawberry?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,
)[0].outputs[0].text

output

輸出:

There are 2 'r's in the word "strawberry".

基準(zhǔn)模型錯(cuò)誤地識別了“strawberry”中“r”的數(shù)量,暴露了事實(shí)推理方面的不足。

GRPO訓(xùn)練后

現(xiàn)在我們加載LoRA并進(jìn)行測試:

model.save_lora("grpo_saved_lora")


text = tokenizer.apply_chat_template([
    {"role" : "system", "content" : SYSTEM_PROMPT},
    {"role" : "user", "content" : "How many r's are in strawberry?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text

output

輸出:

<reasoning>
To determine the number of 'r's in the word "strawberry," we need to spell it out and count the occurrences of 'r'. The word "strawberry" is spelled as S-T-R-A-W-B-E-R-R-Y. The letter 'r' appears in the 3rd, 8th, and 9th positions. 
</reasoning>
<answer> 
There are 3 'r's in the word "strawberry." 
</answer>

GRPO訓(xùn)練后,模型的準(zhǔn)確率和推理能力有所提高,但仍然不夠完美。由于它在T4 GPU上訓(xùn)練僅用了2小時(shí),因此延長序列長度和訓(xùn)練時(shí)間將進(jìn)一步提升其表現(xiàn)。

第7步:部署和擴(kuò)展

一旦對模型進(jìn)行了微調(diào)和評估,下一步就是將其部署到實(shí)際使用場景中,確保它可以有效地?cái)U(kuò)展。部署需要將模型轉(zhuǎn)換成經(jīng)過優(yōu)化的格式,將其整合到推理服務(wù)器中,并通過API或應(yīng)用程序讓其可以訪問。為了確保有效的推理,我們保存了訓(xùn)練好的LoRA適配器,并將它們推送到Hugging Face Hub以便訪問。這允許其他人加載經(jīng)過微調(diào)的模型,不需要大量的計(jì)算資源。

# Just LoRA adapters
if True: model.save_pretrained_merged("model", tokenizer, save_method = "lora",)
if True: model.push_to_hub_merged("kanwal-mehreen18/Llama3.1-8B-GRPO", tokenizer, save_method = "lora", token = "YOUR_HF_KEY")

將lora模型保存到https://huggingface.co/kanwal-mehreen18/Llama3.1-8B-GRPO。

Unsloth給出的最佳實(shí)踐

  • 使用參數(shù)數(shù)量>1.5B的模型進(jìn)行可靠推理。
  • 接受進(jìn)行12小時(shí)的訓(xùn)練以處理復(fù)雜任務(wù)。
  • 結(jié)合多種獎(jiǎng)勵(lì)信號(3-5種函數(shù)最好)。

原文標(biāo)題:DeepSeek-Level AI? Train Your Own Reasoning Model in Just 7 Easy Steps!,作者:Kanwal Mehreen

責(zé)任編輯:姜華 來源: 51CTO內(nèi)容精選
相關(guān)推薦

2024-05-07 08:00:00

自然語言處理機(jī)器學(xué)習(xí)

2025-03-06 09:55:49

2010-04-09 09:55:43

Oracle sqlp

2025-01-21 11:53:53

2014-03-12 15:23:20

2022-08-02 20:22:01

SaaS安全網(wǎng)絡(luò)攻擊

2025-02-24 08:40:00

開源模型訓(xùn)練

2023-12-21 18:01:58

Docker容器部署

2025-03-05 00:22:00

2023-04-25 12:45:09

2023-07-10 13:28:43

智能建筑工具

2025-04-01 09:54:09

AI算法大模型AI

2022-02-15 11:03:40

SD-WAN軟件定義WAN

2023-03-06 08:48:52

2023-06-01 13:09:09

智能建筑數(shù)字孿生

2015-12-23 09:48:32

2025-03-11 10:51:35

DifyDeepSeek大模型

2022-05-30 15:44:33

模型訓(xùn)練GAN

2023-11-01 18:01:02

改進(jìn)WakaTime編程

2024-02-19 00:21:45

開源圖片
點(diǎn)贊
收藏

51CTO技術(shù)棧公眾號