HybridFlow:基于 Ray 構(gòu)建靈活且高效的 RLHF 編程框架
一、RLHF 是什么
1. LLM 訓(xùn)練流程
借用 Andrej Karpathy 在 Microsoft build 2023 上的大語(yǔ)言模型訓(xùn)練的圖。
模型訓(xùn)練總體上分為四個(gè)階段:預(yù)訓(xùn)練(Pre-training)、微調(diào)(Fine-tuning)、獎(jiǎng)勵(lì)建模(Reward Modeling)和強(qiáng)化學(xué)習(xí)訓(xùn)練(Reinforcement Learning Training)。
- 預(yù)訓(xùn)練階段Pre-training)。從互聯(lián)網(wǎng)收集大量數(shù)據(jù),如維基百科、雜志、報(bào)紙和小說(shuō)等,以擬合下一個(gè) token。預(yù)訓(xùn)練的模型一般不可以直接用,如果和他對(duì)話,大概率會(huì)重復(fù)輸入內(nèi)容。
- 微調(diào)階段(Fine-tuning)。即 SFT,收集多人對(duì)話和示范,以提升預(yù)訓(xùn)練模型的能力。微調(diào)后的模型即可部署使用。
今天重點(diǎn)介紹的是強(qiáng)化學(xué)習(xí)階段,包括獎(jiǎng)勵(lì)建模和強(qiáng)化學(xué)習(xí)訓(xùn)練。
- 獎(jiǎng)勵(lì)建模階段(Reward Modeling)。比如同一個(gè) prompt,可以收集多個(gè)輸出,并由人類(lèi)標(biāo)注出好的、壞的,組成一個(gè)個(gè) pair 對(duì)。收集到多個(gè)這樣的 pair 后,就可以對(duì)獎(jiǎng)勵(lì)模型進(jìn)行訓(xùn)練。
- 強(qiáng)化學(xué)習(xí)訓(xùn)練(Reinforcement Learning Training)。有了獎(jiǎng)勵(lì)模型后,可以使用 PPO、DPO 等算法進(jìn)行強(qiáng)化學(xué)習(xí)訓(xùn)練。這些算法的輸入是 prompt。
2. 為什么需要 RLHF?
使用 RLHF 可以通過(guò)降低模型輸出不正確或有害內(nèi)容的概率,來(lái)提升模型輸出符合人類(lèi)偏好的概率?;ヂ?lián)網(wǎng)上有海量的數(shù)據(jù),通過(guò)預(yù)訓(xùn)練或者微調(diào)可以擬合下一個(gè) token,也就是模型可以判斷輸出什么,但是沒(méi)有負(fù)反饋來(lái)表達(dá)不應(yīng)輸出什么。例如,當(dāng)詢問(wèn) GPT 如何制造炸彈時(shí),如果沒(méi)有 RLHF,模型就會(huì)輸出制造炸彈的步驟,而有了 RLHF 之后,模型就可以拒絕回答這類(lèi)有害內(nèi)容。
3. 如何將 LLM 訓(xùn)練表述為強(qiáng)化學(xué)習(xí)問(wèn)題?
強(qiáng)化學(xué)習(xí)問(wèn)題有五個(gè)要素:策略、狀態(tài)、動(dòng)作、環(huán)境和獎(jiǎng)勵(lì),可以將大語(yǔ)言模型映射到這些概念上,即:
- 策略:LLM 模型,輸入是狀態(tài),輸出是所有可選的動(dòng)作(token)的概率。
- 狀態(tài):用戶輸入的 prompt 和 LLM 已經(jīng)解碼出來(lái)的那部分內(nèi)容的組合。
- 動(dòng)作:模型輸出的 token(動(dòng)作空間即詞表大小)。
- 環(huán)境:reward model,輸入是狀態(tài),輸出是獎(jiǎng)勵(lì)分?jǐn)?shù) r。
- 獎(jiǎng)勵(lì):reward model 給出的評(píng)價(jià)分?jǐn)?shù),也可以看做是人類(lèi)對(duì)當(dāng)前模型輸出的評(píng)價(jià),是一個(gè)客觀的值。
因此,LLM 接受一個(gè) prompt,輸出一個(gè)長(zhǎng)度為 256 response 的過(guò)程,可視為智能體在環(huán)境中行動(dòng)了 256 步,最終由 reward model 給出相應(yīng)的獎(jiǎng)勵(lì)?;谝陨媳硎觯覀兙涂梢岳?PPO 算法對(duì)策略(LLM)進(jìn)行優(yōu)化。
4. LLM PPO 算法流程
如上圖,PPO 算法的工作流程分為推理階段和訓(xùn)練階段,推理階段負(fù)責(zé)生成經(jīng)驗(yàn),訓(xùn)練階段負(fù)責(zé)經(jīng)驗(yàn)回放,這兩個(gè)階段交替進(jìn)行。
左邊推理階段,給定一個(gè) prompt,首先經(jīng)過(guò) rollout 模型生成 response。然后,將 prompt 和 response 拼成一個(gè) sequence,這個(gè) sequence 同時(shí)過(guò) actor、initial、critic、 reward 4 個(gè)模型進(jìn)行 inference,生成對(duì)應(yīng) logits、value、score。其中,actor、initial、critic 這三個(gè)模型是直接從微調(diào)模型SFT初始化的,reword 模型是從 SFT 訓(xùn)練出來(lái)的。這幾個(gè)模型的作用是:
- actor 模型是我們最終要訓(xùn)練的目標(biāo)模型;
- initial 模型是一個(gè)參考模型,它是 freeze 的,不會(huì)用作訓(xùn)練;
- critic 和 reward 模型,是用來(lái)評(píng)判 actor 模型的;
- rollout 模型,它和 actor 模型是同參數(shù)的,也就是他們參數(shù)量完全一樣,可以共享一份參數(shù),也可以進(jìn)行分離部署。rollout 模型是負(fù)責(zé)生成的,可以使用現(xiàn)在很多的推理框架來(lái)單獨(dú)部署。
右邊訓(xùn)練階段,可以看到推理階段生成的這些橙色框的內(nèi)容,就是我們所說(shuō)的經(jīng)驗(yàn)。根據(jù)這經(jīng)驗(yàn)五元組,來(lái)對(duì) critic 和 actor 模型進(jìn)行更新。
RLHF 算法整體比較復(fù)雜:多階段,有推理有訓(xùn)練;多模型,有 5 個(gè)模型,甚至可能更多;框架中需要重點(diǎn)關(guān)注 actor 模型,它與 rollout 模型同參數(shù),在訓(xùn)練完成后,要做參數(shù)同步,更新給推理階段的 rollout 模型。
二、RLHF 的挑戰(zhàn)與動(dòng)機(jī)
RLHF 面臨的挑戰(zhàn)包括模型規(guī)模的增大、模型參數(shù)同步的耗時(shí)增加以及研究對(duì)靈活性的要求。大型模型的特點(diǎn)和新挑戰(zhàn)使得我們難以直接將現(xiàn)有的強(qiáng)化學(xué)習(xí)框架應(yīng)用于語(yǔ)言模型的 RLHF。
1. 現(xiàn)有的 RL 框架是否能用于 LLM RLHF?
Ray 里有一個(gè)傳統(tǒng)的 RL 框架 Rlib,它可以直接用于大語(yǔ)言模型的 RLHF 嗎?答案是可以,但是也不一定可以。
說(shuō)可以是因?yàn)閷⒋笳Z(yǔ)言模型的強(qiáng)化學(xué)習(xí)訓(xùn)練,對(duì)齊到傳統(tǒng)的強(qiáng)化學(xué)習(xí)的五要素中,都能映射上。如上圖右邊所示,gym.vector,有人或車(chē)走動(dòng),映射到 RL 就是一個(gè)自回歸生成過(guò)程(Autoregressive Generation);獎(jiǎng)勵(lì) reward,可以通過(guò)一個(gè)模型以及當(dāng)前模型和初始模型的 KL 散度來(lái)計(jì)算。
說(shuō)不可以是因?yàn)榇笳Z(yǔ)言模型的一些特點(diǎn)或者一些新的挑戰(zhàn),使得難以將 Rlib 應(yīng)用于 RLHF。具體的挑戰(zhàn)總結(jié)起來(lái)有三點(diǎn)。
2. 挑戰(zhàn) 1:模型規(guī)模的增大
模型規(guī)模變大是最直觀的一個(gè)挑戰(zhàn)。傳統(tǒng)的 RL 算法中,模型大小從 100K~100M。但是,LLM 模型大小遠(yuǎn)超傳統(tǒng)模型:7B~405B。在這么大的模型規(guī)模下,對(duì)每一個(gè)計(jì)算點(diǎn),都需要單獨(dú)針對(duì)性地做優(yōu)化。比如:
- rollout:用來(lái)做生成 generate,需要一些專門(mén)針對(duì)推理做優(yōu)化的框架,如 vLLM、TensorRT-LLM、SGLang;
- actor/cirtic:用來(lái)進(jìn)行訓(xùn)練的,需要針對(duì)訓(xùn)練的框架,如 Megatron-LM、Deepspeed、FSDP 等;
這樣在這個(gè)算法中,需要用到若干種框架的組合:比如,推理時(shí)用 vLLM;訓(xùn)練時(shí),大一點(diǎn)的模型用 TP、PP、DP 并行技術(shù),也可以用 Deepspeed、FSDP 等更 native 的方式。
3. 挑戰(zhàn) 2:模型參數(shù)同步的耗時(shí)增加
隨著模型規(guī)模增大,模型參數(shù)的同步變得更加耗時(shí)。比如,英偉達(dá) A100 顯卡的跨機(jī)器的 RDMA 帶寬可達(dá) 50GB/s,它同步一個(gè) Llama 405B 耗時(shí)仍需 60 秒左右。這在 PPO 算法中,占比較高。在傳統(tǒng)模型中,這個(gè)同步耗時(shí)可以忽略不計(jì),但在大語(yǔ)言模型中,需要針對(duì)性優(yōu)化。
4. 挑戰(zhàn) 3:對(duì)靈活性的要求
由于 LLM RLHF 領(lǐng)域還處于高速發(fā)展中,過(guò)去一年涌現(xiàn)了非常多的新算法,包括 PPO 算法的一些變種,比如 ReMax、Safe-RLHF、GRPO;還有針對(duì) PPO 的改進(jìn),比如 DPO、KTO、ReST 等。其中,DPO 將 RLHF 簡(jiǎn)化成兩個(gè)模型;KTO 更進(jìn)一步,不需要有 pair 對(duì);Google 提出的 Rest,進(jìn)行模型自學(xué)習(xí),而不需要外部獎(jiǎng)勵(lì)模型。
在這樣的背景下,對(duì)于 RL 算法研究者來(lái)說(shuō),一個(gè)僅僅可以調(diào)整模型參數(shù),喂數(shù)據(jù)進(jìn)行訓(xùn)練的算法框架并不能滿足需求。研究者需要一個(gè)編程框架,以支持實(shí)現(xiàn)各種不同的算法,也就是需要可以復(fù)用的算法模塊,并通過(guò)算法模塊靈活組合來(lái)實(shí)現(xiàn)不同的 RL 算法。
三、HybridFlow 設(shè)計(jì)與實(shí)現(xiàn)
HybridFlow 針對(duì)前文所述挑戰(zhàn)進(jìn)行了設(shè)計(jì)和實(shí)現(xiàn)。
1. HybridFlow 核心設(shè)計(jì)思想
在分布式系統(tǒng)中采用 Single-Controller 還是 Multi-Controller 是個(gè)重要命題。Google Pathways 在論文中對(duì)這兩點(diǎn)做了非常詳細(xì)的闡述。
Single-Controller 由一個(gè)中心控制器驅(qū)動(dòng)所有 worker 工作,這個(gè)中心控制器向所有 worker 發(fā)送數(shù)據(jù)和計(jì)算指令,worker 可以運(yùn)行不同的程序(MPMD)。它的優(yōu)點(diǎn)是靈活、異構(gòu),即每個(gè) worker 可以執(zhí)行不同的算法;缺點(diǎn)是因?yàn)槭菃吸c(diǎn)控制,所以會(huì)引入額外的通信開(kāi)銷(xiāo)。
Multi-Controller 無(wú)中心控制器,各個(gè) worker 自驅(qū),各自計(jì)算,通過(guò)一些通信原語(yǔ)進(jìn)行同步,worker 運(yùn)行相同的程序(SPMD)。它的優(yōu)點(diǎn)是高效;缺點(diǎn)是因?yàn)橥瑯?gòu)所以不夠靈活。
典型的 Single-Controller,比如 Spark、TensorFlow1.0 的 non-SPMD 的模式。
Multi-Controller 則更普遍,基本上主流的開(kāi)源 RLHF 分布式訓(xùn)練框架,比如 Megatron-LM、Deepspeed、FSDP,以及 vLLM 都是 Multi-Controller 模式的。這是因?yàn)橹髁?RLHF 框架基本上是從預(yù)訓(xùn)練框架演進(jìn)而來(lái),而預(yù)訓(xùn)練基本上都是 Multi-Controller 模式。
那么,Multi-Controller 模式在處理 RLHF 的多模型的異構(gòu)場(chǎng)景時(shí),有哪些缺點(diǎn)呢?
首先,Multi-Controller 是 SPMD 模式的,所有 worker 運(yùn)行相同的程序,這時(shí)推理和訓(xùn)練的框架以及工作流都不一樣,要實(shí)現(xiàn)這樣的 SPMD 算法難度就很高。
其次,靈活性不足。比如算法研究者想改一個(gè)算法,需要改每個(gè)模塊中 3D 并行的策略,設(shè)計(jì)同步點(diǎn)。
通過(guò)我們的觀察分析,可以將 RL 的數(shù)據(jù)流拆成兩部分:
第一部分是 RL 的數(shù)據(jù)流,就是在不同的組件間交換數(shù)據(jù)。如前文提到的,先在 rollout 模型做一些推理,拿到結(jié)果后,再把數(shù)據(jù)發(fā)送給其它 4 個(gè)模型做推理,然后再把經(jīng)驗(yàn)給到訓(xùn)練側(cè)的 actor 和 critic。
第二部分是組件內(nèi)部的數(shù)據(jù)流。每個(gè)組件都是分布式的,組件內(nèi)部是 SPMD 的工作模式,通過(guò) NCCL 進(jìn)行數(shù)據(jù)交換。
因?yàn)檫@樣兩層的數(shù)據(jù)流架構(gòu),所以我們將其命名為 HybridFlow,即通過(guò) Single-Controller 實(shí)現(xiàn) RL 的數(shù)據(jù)流,通過(guò) Multi-Controller 實(shí)現(xiàn)各個(gè)組件的計(jì)算和通信流。
HybridFlow 對(duì)于各個(gè)組件,可復(fù)用已有框架,不用重復(fù)造輪子。比如,訓(xùn)練使用 Megatron-LM/Deepspeed/FSDP,推理使用 vLLM/TensorRT-LLM/SGLang。同時(shí),通過(guò)組件的抽象和封裝,使得 Single-Controller 可以像單進(jìn)程一樣去調(diào)用組件,而不需要關(guān)注組件內(nèi)部的并行策略。
2. HybridFlow 實(shí)現(xiàn)
HybridFlow 的核心是 worker group 的概念,它的本質(zhì)是將一組 SPMD 進(jìn)程進(jìn)行抽象,使其可以像單進(jìn)程一樣被調(diào)用。
舉例來(lái)說(shuō),如上圖,用 Tensor 并行加數(shù)據(jù)并行來(lái)實(shí)現(xiàn)一個(gè)簡(jiǎn)單的 Transformer 的 FFN layer。假設(shè) TP、DP 各 2 個(gè):首先,在 DP 維度進(jìn)行數(shù)據(jù)拆分,拆成 batch 0 和 batch 1。然后,在 DP group 內(nèi)部,把相同數(shù)據(jù)發(fā)給所有的 rank。所有 worker 收到數(shù)據(jù)后,就進(jìn)行推理。有結(jié)果后,再把數(shù)據(jù)聚合回來(lái),也就是從每一個(gè) DP 的 rank 0 去把數(shù)據(jù)取回。
Worker group 負(fù)責(zé)了數(shù)據(jù)分發(fā)(data dispatch)和聚合(data aggregation)兩個(gè)重要功能。
如前文所介紹,第一層RL數(shù)據(jù)流是 Single-Controller 模式的,由 driver 進(jìn)程去驅(qū)動(dòng)所有的 worker group 工作。即數(shù)據(jù)的分發(fā)和聚合都在 driver 側(cè),通過(guò) ray object store 來(lái)完成,并提供了簡(jiǎn)單的 dispatch function 和 collect function。
根據(jù)整體測(cè)試的結(jié)果,在千卡集群下,ray object store 基本可以滿足要求。
HybridFlow 封裝的一些組件接口如上圖所示,包括 actor/rollout 的混合訓(xùn)練推理引擎,reward、reference 以及 critic 模型。
在 HybridFlow 框架基礎(chǔ)上,實(shí)現(xiàn) RLHF 算法就比較簡(jiǎn)單了。對(duì)每個(gè)模型,每個(gè)組件創(chuàng)建對(duì)應(yīng)的 work group,共創(chuàng)建 4 個(gè) group;然后,對(duì)于每個(gè)算法,調(diào)用對(duì)應(yīng) group 的接口,進(jìn)行訓(xùn)練。
3. HybridFlow 優(yōu)點(diǎn)
有了 HybridFlow 的框架和流程后,算法工程師再對(duì)算法做微調(diào)就比較簡(jiǎn)單了。上圖右邊展示了一個(gè)完整的 PPO 算法流程。當(dāng)進(jìn)行 PPO 算法變種時(shí),只需要改其中一些步驟:比如 ReMax、GRPO 算法,不需要 critic 計(jì)算 advantages,把這行刪掉就可以;Safe-RLHF 算法需要多個(gè) reward 模型,增加對(duì)應(yīng)的 group 就可以。
WorkerGroup 可以進(jìn)行靈活組合,使用 Ray PlacementGroup 調(diào)度,通過(guò)設(shè)置 actor 的 placement_group、GPU 數(shù)量等,靈活控制 WorkerGroup 是否共享同一個(gè) placement_group, 從而實(shí)現(xiàn)了不同 RLHF 組件邏輯隔離、物理共享,提高了資源利用率。
實(shí)驗(yàn)結(jié)果表明,相比業(yè)界其他 RLHF 框架(Deepspeed-Chat、OpenRLHF、Nemo-Aligner),HybridFlow 在 7B 到 70B,模型吞吐量提高了 1.6 倍到 20 倍不等。
我們的論文已被 USENIX ATC 2024 接收,代碼已在 GitHub 上開(kāi)源,歡迎大家關(guān)注。
- Paper: https://arxiv.org/abs/2409.19256, EuroSys’ 25
- Github: https://github.com/volcengine/verl HybridFlow 在 GitHub 上開(kāi)源項(xiàng)目名 veRL。