訓(xùn)練提速17%,第四范式開源強(qiáng)化學(xué)習(xí)研究框架,支持單、多智能體訓(xùn)練
OpenRL 是由第四范式強(qiáng)化學(xué)習(xí)團(tuán)隊(duì)開發(fā)的基于 PyTorch 的強(qiáng)化學(xué)習(xí)研究框架,支持單智能體、多智能體、自然語(yǔ)言等多種任務(wù)的訓(xùn)練。OpenRL 基于 PyTorch 進(jìn)行開發(fā),目標(biāo)是為強(qiáng)化學(xué)習(xí)研究社區(qū)提供一個(gè)簡(jiǎn)單易用、靈活高效、可持續(xù)擴(kuò)展的平臺(tái)。目前,OpenRL 支持的特性包括:
- 簡(jiǎn)單易用且支持單智能體、多智能體訓(xùn)練的通用接口
- 支持自然語(yǔ)言任務(wù)(如對(duì)話任務(wù))的強(qiáng)化學(xué)習(xí)訓(xùn)練
- 支持從 Hugging Face 上導(dǎo)入模型和數(shù)據(jù)
- 支持 LSTM,GRU,Transformer 等模型
- 支持多種訓(xùn)練加速,例如:自動(dòng)混合精度訓(xùn)練,半精度策略網(wǎng)絡(luò)收集數(shù)據(jù)等
- 支持用戶自定義訓(xùn)練模型、獎(jiǎng)勵(lì)模型、訓(xùn)練數(shù)據(jù)以及環(huán)境
- 支持 gymnasium 環(huán)境
- 支持字典觀測(cè)空間
- 支持 wandb,tensorboardX 等主流訓(xùn)練可視化工具
- 支持環(huán)境的串行和并行訓(xùn)練,同時(shí)保證兩種模式下的訓(xùn)練效果一致
- 中英文文檔
- 提供單元測(cè)試和代碼覆蓋測(cè)試
- 符合 Black Code Style 和類型檢查
目前,OpenRL 已經(jīng)在 GitHub 開源:
項(xiàng)目地址:https://github.com/OpenRL-Lab/openrl
OpenRL 初體驗(yàn)
OpenRL 目前可以通過(guò) pip 進(jìn)行安裝:
pip install openrl
也可以通過(guò) conda 安裝:
conda install -c openrl openrl
OpenRL 為強(qiáng)化學(xué)習(xí)入門用戶提供了簡(jiǎn)單易用的接口, 下面是一個(gè)使用 PPO 算法訓(xùn)練 CartPole 環(huán)境的例子:
# train_ppo.py
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent
env = make ("CartPole-v1", env_num=9) # 創(chuàng)建環(huán)境,并設(shè)置環(huán)境并行數(shù)為 9
net = Net (env) # 創(chuàng)建神經(jīng)網(wǎng)絡(luò)
agent = Agent (net) # 初始化智能體
agent.train (total_time_steps=20000) # 開始訓(xùn)練,并設(shè)置環(huán)境運(yùn)行總步數(shù)為 20000
使用 OpenRL 訓(xùn)練智能體只需要簡(jiǎn)單的四步:創(chuàng)建環(huán)境 => 初始化模型 => 初始化智能體 => 開始訓(xùn)練!
在普通筆記本電腦上執(zhí)行以上代碼,只需要幾秒鐘,便可以完成該智能體的訓(xùn)練:
此外,對(duì)于多智能體、自然語(yǔ)言等任務(wù)的訓(xùn)練,OpenRL 也提供了同樣簡(jiǎn)單易用的接口。例如,對(duì)于多智能體任務(wù)中的 MPE 環(huán)境,OpenRL 也只需要調(diào)用幾行代碼便可以完成訓(xùn)練:
# train_ppo.py
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent
def train ():
# 創(chuàng)建 MPE 環(huán)境,使用異步環(huán)境,即每個(gè)智能體獨(dú)立運(yùn)行
env = make (
"simple_spread",
env_num=100,
asynchrnotallow=True,
)
# 創(chuàng)建 神經(jīng)網(wǎng)絡(luò),使用 GPU 進(jìn)行訓(xùn)練
net = Net (env, device="cuda")
agent = Agent (net) # 初始化訓(xùn)練器
# 開始訓(xùn)練
agent.train (total_time_steps=5000000)
# 保存訓(xùn)練完成的智能體
agent.save ("./ppo_agent/")
if __name__ == "__main__":
train ()
下圖展示了通過(guò) OpenRL 訓(xùn)練前后智能體的表現(xiàn):
加載配置文件
此外,OpenRL 還同時(shí)支持從命令行和配置文件對(duì)訓(xùn)練參數(shù)進(jìn)行修改。比如,用戶可以通過(guò)執(zhí)行 python train_ppo.py --lr 5e-4 來(lái)快速修改訓(xùn)練時(shí)候的學(xué)習(xí)率。
當(dāng)配置參數(shù)非常多的時(shí)候,OpenRL 還支持用戶編寫自己的配置文件來(lái)修改訓(xùn)練參數(shù)。例如,用戶可以自行創(chuàng)建以下配置文件 (mpe_ppo.yaml),并修改其中的參數(shù):
# mpe_ppo.yaml
seed: 0 # 設(shè)置 seed,保證每次實(shí)驗(yàn)結(jié)果一致
lr: 7e-4 # 設(shè)置學(xué)習(xí)率
episode_length: 25 # 設(shè)置每個(gè) episode 的長(zhǎng)度
use_recurrent_policy: true # 設(shè)置是否使用 RNN
use_joint_action_loss: true # 設(shè)置是否使用 JRPO 算法
use_valuenorm: true # 設(shè)置是否使用 value normalization
最后,用戶只需要在執(zhí)行程序的時(shí)候指定該配置文件即可:
python train_ppo.py --config mpe_ppo.yaml
訓(xùn)練與測(cè)試可視化
此外,通過(guò) OpenRL,用戶還可以方便地使用 wandb 來(lái)可視化訓(xùn)練過(guò)程:
OpenRL 還提供了各種環(huán)境可視化的接口,方便用戶對(duì)并行環(huán)境進(jìn)行可視化。用戶可以在創(chuàng)建并行環(huán)境的時(shí)候設(shè)置環(huán)境的渲染模式為 "group_human",便可以同時(shí)對(duì)多個(gè)并行環(huán)境進(jìn)行可視化:
env = make ("simple_spread", env_num=9, render_mode="group_human")
此外,用戶還可以通過(guò)引入 GIFWrapper 來(lái)把環(huán)境運(yùn)行過(guò)程保存為 gif 動(dòng)畫:
from openrl.envs.wrappers import GIFWrapper
env = GIFWrapper (env, "test_simple_spread.gif")
智能體的保存和加載
OpenRL 提供 agent.save () 和 agent.load () 接口來(lái)保存和加載訓(xùn)練好的智能體,并通過(guò) agent.act () 接口來(lái)獲取測(cè)試時(shí)的智能體動(dòng)作:
# test_ppo.py
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent
from openrl.envs.wrappers import GIFWrapper # 用于生成 gif
def test ():
# 創(chuàng)建 MPE 環(huán)境
env = make ( "simple_spread", env_num=4)
# 使用 GIFWrapper,用于生成 gif
env = GIFWrapper (env, "test_simple_spread.gif")
agent = Agent (Net (env)) # 創(chuàng)建 智能體
# 保存智能體
agent.save ("./ppo_agent/")
# 加載智能體
agent.load ('./ppo_agent/')
# 開始測(cè)試
obs, _ = env.reset ()
while True:
# 智能體根據(jù) observation 預(yù)測(cè)下一個(gè)動(dòng)作
action, _ = agent.act (obs)
obs, r, done, info = env.step (action)
if done.any ():
break
env.close ()
if __name__ == "__main__":
test ()
執(zhí)行該測(cè)試代碼,便可以在同級(jí)目錄下找到保存好的環(huán)境運(yùn)行動(dòng)畫文件 (test_simple_spread.gif):
訓(xùn)練自然語(yǔ)言對(duì)話任務(wù)
最近的研究表明,強(qiáng)化學(xué)習(xí)也可以用于訓(xùn)練語(yǔ)言模型, 并且能顯著提升模型的性能。目前,OpenRL 已經(jīng)支持自然語(yǔ)言對(duì)話任務(wù)的強(qiáng)化學(xué)習(xí)訓(xùn)練。OpenRL 通過(guò)模塊化設(shè)計(jì),支持用戶加載自己的數(shù)據(jù)集 ,自定義訓(xùn)練模型,自定義獎(jiǎng)勵(lì)模型,自定義 wandb 信息輸出以及一鍵開啟混合精度訓(xùn)練等。
對(duì)于對(duì)話任務(wù)訓(xùn)練,OpenRL 提供了同樣簡(jiǎn)單易用的訓(xùn)練接口:
# train_ppo.py
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent
from openrl.configs.config import create_config_parser
def train ():
# 添加讀取配置文件的代碼
cfg_parser = create_config_parser ()
cfg = cfg_parser.parse_args ()
# 創(chuàng)建 NLP 環(huán)境
env = make ("daily_dialog",env_num=2,asynchrnotallow=True,cfg=cfg,)
net = Net (env, cfg=cfg, device="cuda")
agent = Agent (net)
agent.train (total_time_steps=5000000)
if __name__ == "__main__":
train ()
可以看出,OpenRL 訓(xùn)練對(duì)話任務(wù)和其他強(qiáng)化學(xué)習(xí)任務(wù)一樣,都是通過(guò)創(chuàng)建交互環(huán)境的方式進(jìn)行訓(xùn)練。
加載自定義數(shù)據(jù)集
訓(xùn)練對(duì)話任務(wù),需要對(duì)話數(shù)據(jù)集。這里我們可以使用 Hugging Face 上的公開數(shù)據(jù)集(用戶可以替換成自己的數(shù)據(jù)集)。加載數(shù)據(jù)集,只需要在配置文件中傳入數(shù)據(jù)集的名稱或者路徑即可:
# nlp_ppo.yaml
data_path: daily_dialog # 數(shù)據(jù)集路徑
env: # 環(huán)境所用到的參數(shù)
args: {'tokenizer_path': 'gpt2'} # 讀取 tokenizer 的路徑
seed: 0 # 設(shè)置 seed,保證每次實(shí)驗(yàn)結(jié)果一致
lr: 1e-6 # 設(shè)置 policy 模型的學(xué)習(xí)率
critic_lr: 1e-6 # 設(shè)置 critic 模型的學(xué)習(xí)率
episode_length: 20 # 設(shè)置每個(gè) episode 的長(zhǎng)度
use_recurrent_policy: true
上述配置文件中的 data_path 可以設(shè)置為 Hugging Face 數(shù)據(jù)集名稱或者本地?cái)?shù)據(jù)集路徑。此外,環(huán)境參數(shù)中的 tokenizer_path 用于指定加載文字編碼器的 Hugging Face 名稱或者本地路徑。
自定義訓(xùn)練模型
在 OpenRL 中,我們可以使用 Hugging Face 上的模型來(lái)進(jìn)行訓(xùn)練。為了加載 Hugging Face 上的模型,我們首先需要在配置文件 nlp_ppo.yaml 中添加以下內(nèi)容:
# nlp_ppo.yaml
# 預(yù)訓(xùn)練模型路徑
model_path: rajkumarrrk/gpt2-fine-tuned-on-daily-dialog
use_share_model: true # 策略網(wǎng)絡(luò)和價(jià)值網(wǎng)絡(luò)是否共享模型
ppo_epoch: 5 # ppo 訓(xùn)練迭代次數(shù)
data_path: daily_dialog # 數(shù)據(jù)集名稱或者路徑
env: # 環(huán)境所用到的參數(shù)
args: {'tokenizer_path': 'gpt2'} # 讀取 tokenizer 的路徑
lr: 1e-6 # 設(shè)置 policy 模型的學(xué)習(xí)率
critic_lr: 1e-6 # 設(shè)置 critic 模型的學(xué)習(xí)率
episode_length: 128 # 設(shè)置每個(gè) episode 的長(zhǎng)度
num_mini_batch: 20
然后在 train_ppo.py 中添加以下代碼:
# train_ppo.py
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent
from openrl.configs.config import create_config_parser
from openrl.modules.networks.policy_value_network_gpt import (
PolicyValueNetworkGPT as PolicyValueNetwork,
)
def train ():
# 添加讀取配置文件的代碼
cfg_parser = create_config_parser ()
cfg = cfg_parser.parse_args ()
# 創(chuàng)建 NLP 環(huán)境
env = make ("daily_dialog",env_num=2,asynchrnotallow=True,cfg=cfg,)
# 創(chuàng)建自定義神經(jīng)網(wǎng)絡(luò)
model_dict = {"model": PolicyValueNetwork}
net = Net (env, cfg=cfg, model_dict=model_dict)
# 創(chuàng)建訓(xùn)練智能體
agent = Agent (net)
agent.train (total_time_steps=5000000)
if __name__ == "__main__":
train ()
通過(guò)以上簡(jiǎn)單幾行的修改,用戶便可以使用 Hugging Face 上的預(yù)訓(xùn)練模型進(jìn)行訓(xùn)練。如果用戶希望分別自定義策略網(wǎng)絡(luò)和價(jià)值網(wǎng)絡(luò),可以寫好 CustomPolicyNetwork 以及 CustomValueNetwork 后通過(guò)以下方式從外部傳入訓(xùn)練網(wǎng)絡(luò):
model_dict = {
"policy": CustomPolicyNetwork,
"critic": CustomValueNetwork,
}
net = Net (env, model_dict=model_dict)
自定義獎(jiǎng)勵(lì)模型
通常,自然語(yǔ)言任務(wù)的數(shù)據(jù)集中并不包含獎(jiǎng)勵(lì)信息。因此,如果需要使用強(qiáng)化學(xué)習(xí)來(lái)訓(xùn)練自然語(yǔ)言任務(wù),就需要使用額外的獎(jiǎng)勵(lì)模型來(lái)生成獎(jiǎng)勵(lì)。在該對(duì)話任務(wù)中,我們可以使用一個(gè)復(fù)合的獎(jiǎng)勵(lì)模型,它包含以下三個(gè)部分:
●意圖獎(jiǎng)勵(lì):即當(dāng)智能體生成的語(yǔ)句和期望的意圖接近時(shí),智能體便可以獲得更高的獎(jiǎng)勵(lì)。
●METEOR 指標(biāo)獎(jiǎng)勵(lì):METEOR 是一個(gè)用于評(píng)估文本生成質(zhì)量的指標(biāo),它可以用來(lái)衡量生成的語(yǔ)句和期望的語(yǔ)句的相似程度。我們把這個(gè)指標(biāo)作為獎(jiǎng)勵(lì)反饋給智能體,以達(dá)到優(yōu)化生成的語(yǔ)句的效果。
●KL 散度獎(jiǎng)勵(lì):該獎(jiǎng)勵(lì)用來(lái)限制智能體生成的文本偏離預(yù)訓(xùn)練模型的程度,防止出現(xiàn) reward hacking 的問(wèn)題。
我們最終的獎(jiǎng)勵(lì)為以上三個(gè)獎(jiǎng)勵(lì)的加權(quán)和,其中 KL 散度獎(jiǎng)勵(lì)的系數(shù)是隨著 KL 散度的大小動(dòng)態(tài)變化的。想在 OpenRL 中使用該獎(jiǎng)勵(lì)模型,用戶無(wú)需修改訓(xùn)練代碼,只需要在 nlp_ppo.yaml 文件中添加 reward_class 參數(shù)即可:
# nlp_ppo.yaml
reward_class:
id: NLPReward # 獎(jiǎng)勵(lì)模型名稱
args: {
# 用于意圖判斷的模型的名稱或路徑
"intent_model": rajkumarrrk/roberta-daily-dialog-intent-classifier,
# 用于計(jì)算 KL 散度的預(yù)訓(xùn)練模型的名稱或路徑
"ref_model": roberta-base, # 用于意圖判斷的 tokenizer 的名稱或路徑
}
OpenRL 支持用戶使用自定義的獎(jiǎng)勵(lì)模型。首先,用戶需要編寫自定義獎(jiǎng)勵(lì)模型 (需要繼承 BaseReward 類)。接著,用戶需要注冊(cè)自定義的獎(jiǎng)勵(lì)模型,即在 train_ppo.py 添加以下代碼:
# train_ppo.py
from openrl.rewards.nlp_reward import CustomReward
from openrl.rewards import RewardFactory
RewardFactory.register ("CustomReward", CustomReward)
最后,用戶只需要在配置文件中填寫自定義的獎(jiǎng)勵(lì)模型即可:
reward_class:
id: "CustomReward" # 自定義獎(jiǎng)勵(lì)模型名稱
args: {} # 用戶自定義獎(jiǎng)勵(lì)函數(shù)可能用到的參數(shù)
自定義訓(xùn)練過(guò)程信息輸出
OpenRL 還支持用戶自定義 wandb 和 tensorboard 的輸出內(nèi)容。例如,在該任務(wù)的訓(xùn)練過(guò)程中,我們還需要輸出各種類型獎(jiǎng)勵(lì)的信息和 KL 散度系數(shù)的信息, 用戶可以在 nlp_ppo.yaml 文件中加入 vec_info_class 參數(shù)來(lái)實(shí)現(xiàn):
# nlp_ppo.yaml
vec_info_class:
id: "NLPVecInfo" # 調(diào)用 NLPVecInfo 類以打印 NLP 任務(wù)中獎(jiǎng)勵(lì)函數(shù)的信息
# 設(shè)置 wandb 信息
wandb_entity: openrl # 這里用于指定 wandb 團(tuán)隊(duì)名稱,請(qǐng)把 openrl 替換為你自己的團(tuán)隊(duì)名稱
experiment_name: train_nlp # 這里用于指定實(shí)驗(yàn)名稱
run_dir: ./run_results/ # 這里用于指定實(shí)驗(yàn)數(shù)據(jù)保存的路徑
log_interval: 1 # 這里用于指定每隔多少個(gè) episode 上傳一次 wandb 數(shù)據(jù)
# 自行填寫其他參數(shù)...
修改完配置文件后,在 train_ppo.py 文件中啟用 wandb:
# train_ppo.py
agent.train (total_time_steps=100000, use_wandb=True)
然后執(zhí)行 python train_ppo.py –config nlp_ppo.yaml,稍后,便可以在 wandb 中看到如下的輸出:
從上圖可以看到,wandb 輸出了各種類型獎(jiǎng)勵(lì)的信息和 KL 散度系數(shù)的信息。
如果用戶還需要輸出其他信息,還可以參考 NLPVecInfo 類 和 VecInfo 類來(lái)實(shí)現(xiàn)自己的 CustomVecInfo 類。然后,需要在 train_ppo.py 中注冊(cè)自定義的 CustomVecInfo 類:
# train_ppo.py # 注冊(cè)自定義輸出信息類
VecInfoFactory.register ("CustomVecInfo", CustomVecInfo)
最后,只需要在 nlp_ppo.yaml 中填寫 CustomVecInfo 類即可啟用:
# nlp_ppo.yaml
vec_info_class:
id: "CustomVecInfo" # 調(diào)用自定義 CustomVecInfo 類以輸出自定義信息
使用混合精度訓(xùn)練加速
OpenRL 還提供了一鍵開啟混合精度訓(xùn)練的功能。用戶只需要在配置文件中加入以下參數(shù)即可:
# nlp_ppo.yaml
use_amp: true # 開啟混合精度訓(xùn)練
對(duì)比評(píng)測(cè)
下表格展示了使用 OpenRL 訓(xùn)練該對(duì)話任務(wù)的結(jié)果。結(jié)果顯示使用強(qiáng)化學(xué)習(xí)訓(xùn)練后,模型各項(xiàng)指標(biāo)皆有所提升。另外,從下表可以看出,相較于 RL4LMs , OpenRL 的訓(xùn)練速度更快(在同樣 3090 顯卡的機(jī)器上,速度提升 17% ),最終的性能指標(biāo)也更好:
最后,對(duì)于訓(xùn)練好的智能體,用戶可以方便地通過(guò) agent.chat () 接口進(jìn)行對(duì)話:
# chat.py
from openrl.runners.common import ChatAgent as Agent
def chat ():
agent = Agent.load ("./ppo_agent", tokenizer="gpt2",)
history = []
print ("Welcome to OpenRL!")
while True:
input_text = input ("> User:")
if input_text == "quit":
break
elif input_text == "reset":
history = []
print ("Welcome to OpenRL!")
continue
response = agent.chat (input_text, history)
print (f"> OpenRL Agent: {response}")
history.append (input_text)
history.append (response)
if __name__ == "__main__":
chat ()
執(zhí)行 python chat.py ,便可以和訓(xùn)練好的智能體進(jìn)行對(duì)話了:
總結(jié)
OpenRL 框架經(jīng)過(guò)了 OpenRL-Lab 的多次迭代并應(yīng)用于學(xué)術(shù)研究和 AI 競(jìng)賽,目前已經(jīng)成為了一個(gè)較為成熟的強(qiáng)化學(xué)習(xí)框架。OpenRL-Lab 團(tuán)隊(duì)將持續(xù)維護(hù)和更新 OpenRL,歡迎大家加入我們的開源社區(qū),一起為強(qiáng)化學(xué)習(xí)的發(fā)展做出貢獻(xiàn)。更多關(guān)于 OpenRL 的信息,可以參考:
- OpenRL 官方倉(cāng)庫(kù):https://github.com/OpenRL-Lab/openrl/
- OpenRL 中文文檔:https://openrl-docs.readthedocs.io/zh/latest/
致謝
OpenRL 框架的開發(fā)吸取了其他強(qiáng)化學(xué)習(xí)框架的優(yōu)點(diǎn):
- Stable-baselines3: https://github.com/DLR-RM/stable-baselines3
- pytorch-a2c-ppo-acktr-gail:https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail
- MAPPO: https://github.com/marlbenchmark/on-policy
- Gymnasium: https://github.com/Farama-Foundation/Gymnasium
- DI-engine:https://github.com/opendilab/DI-engine/
- Tianshou: https://github.com/thu-ml/tianshou
- RL4LMs: https://github.com/allenai/RL4LMs
未來(lái)工作
目前,OpenRL 還處于持續(xù)開發(fā)和建設(shè)階段,未來(lái) OpenRL 將會(huì)開源更多功能:
- 支持智能體自博弈訓(xùn)練
- 加入離線強(qiáng)化學(xué)習(xí)、模范學(xué)習(xí)、逆強(qiáng)化學(xué)習(xí)算法
- 加入更多強(qiáng)化學(xué)習(xí)環(huán)境和算法
- 集成 Deepspeed 等加速框架
- 支持多機(jī)分布式訓(xùn)練
OpenRL Lab 團(tuán)隊(duì)
OpenRL框架是由OpenRL Lab團(tuán)隊(duì)開發(fā),該團(tuán)隊(duì)是第四范式公司旗下的強(qiáng)化學(xué)習(xí)研究團(tuán)隊(duì)。第四范式長(zhǎng)期致力于強(qiáng)化學(xué)習(xí)的研發(fā)和工業(yè)應(yīng)用。為了促進(jìn)強(qiáng)化學(xué)習(xí)的產(chǎn)學(xué)研一體化,第四范式成立了OpenRL Lab研究團(tuán)隊(duì),目標(biāo)是先進(jìn)技術(shù)開源和人工智能前沿探索。成立不到一年,OpenRL Lab團(tuán)隊(duì)已經(jīng)在AAMAS發(fā)表過(guò)三篇論文,參加谷歌足球游戲 11 vs 11比賽并獲得第三的成績(jī)。團(tuán)隊(duì)提出的TiZero智能體,實(shí)現(xiàn)了首個(gè)從零開始,通過(guò)課程學(xué)習(xí)、分布式強(qiáng)化學(xué)習(xí)、自博弈等技術(shù)完成谷歌足球全場(chǎng)游戲智能體的訓(xùn)練:
截止 2022 年 10 月 28 日,Tizero 在及第評(píng)測(cè)平臺(tái)上排名第一: