DeepSeek R1本地訓(xùn)練全流程實操指南,手把手教你打通其“任督二脈”
作者 | asher
許多關(guān)于 DeepSeek R1 的復(fù)現(xiàn)文章,主要聚焦在“rewards的設(shè)計、訓(xùn)練指標(biāo)的變化、benchmark測評”這些內(nèi)容,但是對于“本地訓(xùn)練”這個開啟深度探索的關(guān)鍵前置步驟,卻很少有人深挖。
可能有人覺得,照著readme操作就能輕松訓(xùn)練了吧?太天真啦!實際動手就會發(fā)現(xiàn),和自家的環(huán)境各種水土不服,大模型不是訓(xùn)不起來就是訓(xùn)的太慢,問題多到讓人頭大。
為了解決本地訓(xùn)練的適配性問題,今天挑選HuggingFace的開源項目open-r1,為大家?guī)硪粓鋈鞒虒嵅傺菔?,從怎么?卡A100(40G)上跑通基于Qwen-14B的DeepSeek R1復(fù)現(xiàn),到分享超實用的環(huán)境鏡像,還有滿滿踩坑經(jīng)驗,再到手把手教你改造代碼適配自己的任務(wù)數(shù)據(jù),助大家光速開啟DeepSeek R1在自定義數(shù)據(jù)上的訓(xùn)練探索之旅。
一、 環(huán)境搭建不求人
1. 顯卡驅(qū)動與CUDA適配要點
open-r1明確要求cuda12.4,得先瞅瞅自己機器的顯卡驅(qū)動版本(如下圖),要是版本太老,那可就得升級才能適配適配cuda12.4,我親測,顯卡驅(qū)動版本為470以上就能正常運行,我的版本是535。
# 查看自己的顯卡版本與cuda是否適配
import torch
print(torch.cuda.is_available()) # True就可以
2. 快速搞定環(huán)境安裝
與readme里的uv相比,我還是習(xí)慣使用conda管理虛擬環(huán)境:
1. conda create -n openr1 python=3.11
2. pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
3. pip install vllm==0.7.2
4. pip install flash-attn
5. 切換到open-r1目錄執(zhí)行pip install -e ".[dev]"
二、訓(xùn)練踩坑大避雷
1. 導(dǎo)致OOM的原因有這么多
以grpo訓(xùn)練為例,使用Qwen-14B在A100上訓(xùn)練很容易報錯OOM,原因有多種,讓我來為大家一一分析:grpo任務(wù)可以分為兩部分:一部分是模型訓(xùn)練(7卡),一部分是模型推理(1卡),OOM報錯的原因就來自這兩部分。
- 訓(xùn)練報錯oom:7張A100卡無法實現(xiàn)14B模型的訓(xùn)練。解決方法:修改recipes/accelerate_configs/zero3.yaml,開啟offload
- 推理報錯oom:如果vllm版本在0.7.3以下,很容易發(fā)生oom,需要修改recipes/Qwen2.5-14B-Instruct/grpo/config_simple_rl.yaml,調(diào)低vllm_gpu_memory_utilization參數(shù)值,14B模型可以改為0.2,7B模型可以改為0.5。
- 推理報錯oom:指定vllm推理的max_model_len太長,導(dǎo)致kv caceh需要占用的顯存太多。解決方法:修改recipes/Qwen2.5-14B-Instruct/grpo/config_simple_rl.yaml,調(diào)低vllm_max_model_len,注意這個參數(shù)是指prompt+模型輸出長度,不宜過短,可以調(diào)整為4k-8k。默認(rèn)值是讀取基座模型config,比如Qwen-14B默認(rèn)是32768。
那么如何識別自己的OOM報錯是出自訓(xùn)練還是推理呢?直接看報錯的GPU卡號,因為默認(rèn)是最后一張卡用于推理,如下圖既然是GPU 7 內(nèi)存不足,那就推理出了問題,只需要調(diào)整上述提到的兩個參數(shù)即可。
針對Qwen-14B在8卡A100(40G)訓(xùn)練對應(yīng)的配置文件,我已經(jīng)調(diào)教好了放在本文最后,供大家參考。
2. reward函數(shù)的形參命名有講究
在設(shè)計reward函數(shù),有個注意:reward函數(shù)聲明的形參很重要,不是隨便起的,要求與dataset的列名是一致的。比如下面這個reawrd函數(shù)的兩個形參,completions表示模型生成的內(nèi)容,ground_truth表示dataset中”ground_truth“列的值,這里的形參ground_truth就是要求與dataset列名字對齊。
import re
def reward_func(completions, ground_truth, **kwargs):
# Regular expression to capture content inside \boxed{}
matches = [re.search(r"\\boxed\{(.*?)\}", completion) for completion in completions]
contents = [match.group(1) if match else "" for match in matches]
# Reward 1 if the content is the same as the ground truth, 0 otherwise
return [1.0 if c == gt else 0.0 for c, gt in zip(contents, ground_truth)]
三、DeepSeek R1訓(xùn)練快速開啟不迷路
1. 數(shù)據(jù)先行!準(zhǔn)備業(yè)務(wù)數(shù)據(jù)要點
離線構(gòu)造業(yè)務(wù)數(shù)據(jù)集data.json,注意字段名為problem與solution,與官方給的示例數(shù)據(jù)字段名一致,這樣可以少去很多改代碼的麻煩:
{"problem": "Classify the text into neutral, negative, or positive\nText: I think the food was okay.\nSentiment:\n", "solution": "positive"}
{"problem": "Classify the text into neutral, negative, or positive\nText: I think the food was shit.\nSentiment:\n", "solution": "negative"}
2. 巧妙變身!輕松更改數(shù)據(jù)讀取方式
修改grpo.py中數(shù)據(jù)讀取方式,由讀取hub數(shù)據(jù)改為讀取離線數(shù)據(jù):
dataset = load_dataset("json", data_files=XXX/data.json)
dataset = dataset["train"].train_test_split(test_size=0.02)
個性定制!手把手自定義reward函數(shù) 注意這里函數(shù)聲明中solution形參要與dataset的字段保持一致:
def accuracy_reward_ours(completions, solution, **kwargs):
"""Reward function that checks if the completion is the same as the ground truth."""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol in zip(contents, solution):
gold_parsed = sol # 從數(shù)據(jù)集中讀取ground-truth
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = re.findall("<answer>(.*?)</answer>",content) # 從模型輸出文本中提取預(yù)測答案
if len(answer_parsed)>0:
answer_parsed = answer_parsed[0]
reward = float(1 if answer_parsed==gold_parsed else 0) # 判斷預(yù)測結(jié)果與真實結(jié)果是否一致
else:
reward = float(0)
else:
# If the gold solution is not parseable, we reward 1 to skip this example
reward = 1.0
print("Failed to parse gold solution: ", sol)
rewards.append(reward)
return rewards
3. 一鍵啟動!暢爽開啟DeepSeek R1訓(xùn)練
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \
--num_processes=7 src/open_r1/grpo.py \
--config recipes/Qwen2.5-14B-Instruct/grpo/config_simple_rl.yaml \
&> /workspace/user_code/Qwen2.5-14B-Instruct.log
四、能讓14B模型在A100上絲滑跑通R1的配置參數(shù)大公開
recipes/accelerate_configs/zero3.yaml
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: "cpu"
offload_param_device: "cpu"
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
recipes/Qwen2.5-14B-Instruct/grpo/config_simple_rl.yaml
# Model arguments
model_name_or_path: XXX/models/Qwen2.5-14B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
# Data training arguments
dataset_name: XXX/dataset/data.json
# Num processes is less by 1 as vLLM is using 1 GPU
num_processes: 7
# GRPO trainer config
reward_funcs:
- accuracy_ours
- format
bf16: true
use_vllm: true
vllm_device: cuda:7
vllm_gpu_memory_utilization: 0.2 # vllm版本在0.7.3以下
vllm_max_model_len: 8000
do_eval: true
eval_strategy: steps
eval_steps: 100
gradient_accumulation_steps: 8
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: Qwen-2.5-7B-Simple-RL
hub_strategy: every_save
learning_rate: 3.0e-06
log_level: info
logging_steps: 5
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 512
max_completion_length: 1024
max_steps: -1
num_generations: 7
num_train_epochs: 1
output_dir: XXX/Qwen-2.5-7B-Instruct-RL
overwrite_output_dir: true
per_device_eval_batch_size: 8
per_device_train_batch_size: 8
push_to_hub: false
report_to: "none"
save_strategy: "steps"
save_steps: 100
save_total_limit: 2
seed: 42
warmup_ratio: 0.1