告別偏科,能玩轉(zhuǎn)多模態(tài)、多任務(wù)、多領(lǐng)域的強(qiáng)化智能體終于來(lái)了
隨著 Llama 3 發(fā)布,未來(lái)大模型的參數(shù)量已飆升至驚人的 4000 億。盡管每周幾乎都有一個(gè)聲稱(chēng)性能超強(qiáng)的大模型出來(lái)炸場(chǎng),但 AI 應(yīng)用還在等待屬于它們的「ChatGPT 時(shí)刻」。其中,AI 智能體無(wú)疑是最被看好的賽道。
就連吳恩達(dá)都說(shuō),GPT-4 加上 AI 智能體,可能提前達(dá)到 GPT-5 的效果。
不過(guò),我們熟知的智能體往往有點(diǎn)「偏科」。例如,第一個(gè) AI 軟件工程師 Devin,專(zhuān)精于代碼。會(huì)打游戲的智能體往往也只能在某一個(gè)游戲里秀操作。尋找一個(gè)能夠同時(shí)擅長(zhǎng)多個(gè)領(lǐng)域,并能在其中無(wú)縫切換的通用模型仍是機(jī)器學(xué)習(xí)研究中的一個(gè)關(guān)鍵目標(biāo)。
為了解決這個(gè)問(wèn)題,研究者們對(duì)于智能體如何結(jié)合計(jì)算機(jī)視覺(jué)(CV)和自然語(yǔ)言處理(NLP)任務(wù)進(jìn)行了廣泛探索,但將強(qiáng)化學(xué)習(xí)(RL)任務(wù)整合進(jìn)來(lái)的研究相對(duì)較少。這是由于 RL 任務(wù)本質(zhì)上是異質(zhì)的,這使得將 RL 任務(wù)與對(duì)話(huà)和圖像識(shí)別等其他任務(wù)結(jié)合起來(lái)更加困難。這要求智能體能融會(huì)貫通不同領(lǐng)域任務(wù)中的不同模態(tài)、任務(wù)復(fù)雜性和數(shù)據(jù)類(lèi)型。要達(dá)到全能型智能體,主要需要解決以下問(wèn)題:(1)如何設(shè)計(jì)一個(gè)能夠處理多種數(shù)據(jù)類(lèi)型和模態(tài)的統(tǒng)一模型結(jié)構(gòu)?(2)如何有效地平衡不同任務(wù)的學(xué)習(xí)進(jìn)度和優(yōu)先級(jí)?(3)如何確保智能體制定合適的學(xué)習(xí)目標(biāo),以避免不同任務(wù)之間的干擾和負(fù)向遷移?
來(lái)自 Hugging Face、法國(guó)國(guó)家信息與自動(dòng)化研究所(INRIA)和波爾多大學(xué)的四位研究者提出了智能體中的「六邊形戰(zhàn)士」——Jack of All Trades (JAT)。JAT 是一個(gè)基于 Transformer 的多模態(tài)通用強(qiáng)化學(xué)習(xí)智能體框架。在此框架下,智能體能夠通過(guò)同一套參數(shù)應(yīng)對(duì)不同復(fù)雜度的多種任務(wù),化身既會(huì)打游戲,又能控制機(jī)器人的全能高手。論文同時(shí)發(fā)布了大量 RL 智能體與 JAT 數(shù)據(jù)集。這是首個(gè)用于通用智能體訓(xùn)練的數(shù)據(jù)集 JAT 數(shù)據(jù)集,包含了由專(zhuān)家智能體收集的數(shù)十萬(wàn)條軌跡。
- 論文名稱(chēng):《Jack of All Trades, Master of Some, a Multi-Purpose Transformer Agent》
- 論文鏈接:https://huggingface.co/papers/2402.09844
- 代碼鏈接:https://github.com/huggingface/jat
- 項(xiàng)目鏈接:https://huggingface.co/jat-project/jat
- 數(shù)據(jù)集:https://huggingface.co/datasets/jat-project/jat-dataset
模型架構(gòu)
JAT 的核心結(jié)構(gòu)基于 Transformer,使用了 EleutherAI 的 GPT-Neo 實(shí)現(xiàn)。JAT 最大的創(chuàng)新點(diǎn)在于其嵌入機(jī)制,從本質(zhì)上解決了數(shù)據(jù)類(lèi)型不同的問(wèn)題。JAT 模型將觀察嵌入與其對(duì)應(yīng)的獎(jiǎng)勵(lì)值和動(dòng)作嵌入交錯(cuò)排列,形成一個(gè)序列。
圖 1.JAT 網(wǎng)絡(luò)架構(gòu)。對(duì)于序列中的決策任務(wù),一方面輸入觀察嵌入與獎(jiǎng)勵(lì)值,另一方面行動(dòng)嵌入被編碼并被交錯(cuò)放置。模型使用因果掩碼自回歸地生成下一個(gè)嵌入,并根據(jù)預(yù)期的模態(tài)進(jìn)行解碼。
因此,每個(gè)嵌入要么對(duì)應(yīng)一個(gè)與獎(jiǎng)勵(lì)相關(guān)聯(lián)的觀察嵌入,要么對(duì)應(yīng)一個(gè)動(dòng)作嵌入。JAT 如何進(jìn)一步對(duì)這些信息進(jìn)行編碼呢?這要取決于數(shù)據(jù)的類(lèi)型。如果觀察嵌入或動(dòng)作嵌入的數(shù)據(jù)類(lèi)型是圖像,那么 JAT 將使用 CNN。如果是連續(xù)向量,則使用線性層。如果是離散值,則使用線性投影層。模型的輸出也遵循相同的邏輯,具體取決于預(yù)測(cè)目標(biāo)的數(shù)據(jù)類(lèi)型。預(yù)測(cè)基于因果推理進(jìn)行,將觀察嵌入向后移動(dòng)一個(gè)時(shí)間步,確保智能體可以根據(jù)所有先前的觀察和動(dòng)作嵌入來(lái)預(yù)測(cè)下一個(gè)動(dòng)作嵌入。
這種嵌入設(shè)計(jì)讓研究團(tuán)隊(duì)在訓(xùn)練智能體執(zhí)行 NLP 和 CV 任務(wù)時(shí)興致盎然。對(duì)于和文本相關(guān)的任務(wù),作者讓 JAT 模型采用 GPT-2 的分詞策略,將文本轉(zhuǎn)換為一個(gè)整數(shù)序列,然后通過(guò)一個(gè)查找表映射到一個(gè)嵌入向量序列。對(duì)于和圖像有關(guān)的任務(wù),JAT 模型將選擇 ViT 方法,將圖像切割成小塊后,通過(guò)線性層轉(zhuǎn)換為嵌入向量序列。JAT 模型再將圖像和文本的向量序列拼接在一起,形成一個(gè)統(tǒng)一的序列,輸入到 Transformer 中。
考慮到數(shù)據(jù)的模態(tài)變來(lái)變?nèi)?,JAT 如何計(jì)算損失函數(shù)呢?它將針對(duì)每種模態(tài)分別計(jì)算 loss。對(duì)于圖像和連續(xù)值,它使用均方誤差(MSE)損失。對(duì)于離散值,它使用交叉熵?fù)p失。最終的損失是序列中每種元素?fù)p失的平均值。那么,這是否意味著 JAT 在預(yù)測(cè)動(dòng)作嵌入和觀察嵌入時(shí)的權(quán)重是相同的呢?實(shí)際上不是,在此后的章節(jié)中將一步探討這個(gè)問(wèn)題。
實(shí)驗(yàn)結(jié)果
研究團(tuán)隊(duì)共采用了 157 個(gè)訓(xùn)練任務(wù)來(lái) JAT 評(píng)估。他們將這些任務(wù)分為 10 類(lèi),并記錄了 JAT 的總獎(jiǎng)勵(lì)值。
JAT 模型在最終的檢查點(diǎn)上達(dá)到了 65.8% 的專(zhuān)家得分,說(shuō)明 JAT 能夠在非常廣泛的任務(wù)上達(dá)到專(zhuān)家水平。以下具體列出了 JAT 在四個(gè)常見(jiàn)的智能體訓(xùn)練環(huán)境中的得分:
- 對(duì)于 Atari 57,應(yīng)用 JAT 模型的智能體實(shí)現(xiàn)了專(zhuān)家分?jǐn)?shù)的 14.1%,這相當(dāng)于人類(lèi)表現(xiàn)的 37.6%。Atari 視頻游戲廣泛被用作評(píng)估和開(kāi)發(fā)強(qiáng)化學(xué)習(xí)算法的基準(zhǔn)環(huán)境,其中《吃豆人》是一款標(biāo)志性游戲。在這一系列的 21 款游戲中,JAT 智能體的表現(xiàn)已經(jīng)超越了人類(lèi)玩家。值得注意的是, JAT 只用了單一網(wǎng)絡(luò)就在所有 Atari 視頻游戲中達(dá)到了這種水平;
- 對(duì)于 BabyAI,應(yīng)用 JAT 模型的智能體達(dá)到了專(zhuān)家分?jǐn)?shù)的 99.0%,只有一個(gè)任務(wù)的表現(xiàn)未能超過(guò)專(zhuān)家水平的 50%;
- 對(duì)于 Meta-World,應(yīng)用 JAT 模型的智能體達(dá)到了專(zhuān)家分?jǐn)?shù)的 65.5%;
- 對(duì)于 MuJoCo,應(yīng)用 JAT 模型的智能體達(dá)到了專(zhuān)家分?jǐn)?shù)的 84.8%。
JAT 智能體在 Atari 57 基線上和人類(lèi)表現(xiàn)的對(duì)比
這些 JAT 智能體都可以通過(guò)項(xiàng)目主頁(yè)下載,進(jìn)一步測(cè)試和體驗(yàn)。更多細(xì)節(jié)請(qǐng)參閱論文原文。
專(zhuān)家智能體和 JAT 數(shù)據(jù)集
專(zhuān)家策略
傳統(tǒng)的強(qiáng)化學(xué)習(xí)往往在單一環(huán)境中尋找專(zhuān)家策略,即在一個(gè)特定任務(wù)中尋找讓模型表現(xiàn)最優(yōu)的方法。構(gòu)建跨領(lǐng)域的多功能智能體,也離不開(kāi)這種方法。論文作者選擇了 Atari、BabyAI、Meta-World 和 MuJoCo 一系列性質(zhì)不同,難度各異的訓(xùn)練環(huán)境,直到訓(xùn)練出表現(xiàn)最好的智能體。這一系列采用 JAT 框架的專(zhuān)家智能體已經(jīng)在項(xiàng)目主頁(yè)上發(fā)布。
JAT 數(shù)據(jù)集
論文作者隨論文同步發(fā)布了 JAT 數(shù)據(jù)集,這是首個(gè)針對(duì)通用智能體訓(xùn)練的專(zhuān)項(xiàng)數(shù)據(jù)集。其中包含了數(shù)十萬(wàn)條由上述專(zhuān)家智能體收集的軌跡數(shù)據(jù)。使用起來(lái)也很方便,可以像加載 Hugging Face 平臺(tái)上的其他數(shù)據(jù)集一樣簡(jiǎn)單。以下是調(diào)用代碼示例:
JAT 數(shù)據(jù)集不僅包含強(qiáng)化學(xué)習(xí)的數(shù)據(jù),還整合了來(lái)自維基百科等文本數(shù)據(jù)集,以及 Oscar、OK-VQA、Conceptual Captions 等針對(duì)視覺(jué)任務(wù)的數(shù)據(jù)集,提供了更豐富的數(shù)據(jù)類(lèi)型選擇。
增加模型預(yù)測(cè)觀察嵌入的能力
智能體學(xué)得更好更快了
在訓(xùn)練強(qiáng)化學(xué)習(xí)智能體時(shí),主要目標(biāo)是使其在未曾遇到的任務(wù)中實(shí)現(xiàn)獎(jiǎng)勵(lì)最大化。然而,如果要求智能體預(yù)測(cè)未來(lái)可能遇到的情境,這一額外任務(wù)會(huì)促進(jìn)還是阻礙其學(xué)習(xí)過(guò)程呢?
關(guān)于這個(gè)問(wèn)題存在兩種相反的觀點(diǎn)。一方面,學(xué)會(huì)預(yù)判可能會(huì)讓智能體對(duì)環(huán)境有更深入的理解,從而學(xué)得更好更快。另一方面,這可能會(huì)分散智能體對(duì)其主要目標(biāo)的注意力,導(dǎo)致在預(yù)測(cè)觀察嵌入和行動(dòng)嵌入時(shí)都表現(xiàn)平庸。
為了得到問(wèn)題的答案,論文作者進(jìn)行了一個(gè)實(shí)驗(yàn),使用了一個(gè)結(jié)合了觀察損失和行動(dòng)損失的損失函數(shù),并通過(guò)權(quán)重參數(shù) k 來(lái)平衡這兩種損失。
研究團(tuán)隊(duì)在 95% 的置信區(qū)間內(nèi),針對(duì)選定任務(wù),測(cè)量了預(yù)判將如何影響模型學(xué)習(xí)。每項(xiàng)任務(wù)進(jìn)行了 100 次評(píng)估,基于這些評(píng)估得到了 k 值的范圍。結(jié)果表明,適當(dāng)選擇 k 值可以顯著提升智能體的表現(xiàn)。
當(dāng) k 值過(guò)高(高于 0.5)時(shí),預(yù)測(cè)觀察嵌入的額外任務(wù)阻礙了學(xué)習(xí)過(guò)程。但當(dāng) k 值較低時(shí),對(duì)學(xué)習(xí)的影響可以忽略不計(jì),且智能體的表現(xiàn)與沒(méi)有額外預(yù)判任務(wù)時(shí)的表現(xiàn)相似。
研究團(tuán)隊(duì)發(fā)現(xiàn),當(dāng) k=0.005 時(shí),存在一個(gè)最佳臨界點(diǎn)。這意味著,只要平衡得當(dāng),為智能體增加預(yù)測(cè)觀察嵌入的任務(wù),實(shí)際上可以提高智能體的學(xué)習(xí)效率。這一發(fā)現(xiàn)對(duì)于設(shè)計(jì)類(lèi)似的智能體具有重要意義,突顯了輔助目標(biāo)在提升智能體學(xué)習(xí)效率方面的潛在價(jià)值。
未來(lái)展望
JAT 項(xiàng)目為通用智能體研究領(lǐng)域開(kāi)辟了全新的方向。研究團(tuán)隊(duì)表示目前只是初步探索,以下幾點(diǎn)思路可供未來(lái)研究者深入挖掘:
改進(jìn)數(shù)據(jù)的質(zhì)量:盡管填補(bǔ)了之前少有通用智能體訓(xùn)練數(shù)據(jù)集的空缺,JAT 數(shù)據(jù)集仍處于初級(jí)階段。其中的專(zhuān)家軌跡僅來(lái)自每個(gè)環(huán)境中的一名專(zhuān)家智能體,這可能導(dǎo)致一些誤差。雖然研究團(tuán)隊(duì)已盡力讓智能體達(dá)到最優(yōu)表現(xiàn),但某些環(huán)境仍具挑戰(zhàn)性。在這些環(huán)境中,智能體仍有很大進(jìn)步空間。收集到更多數(shù)據(jù),訓(xùn)練更多的專(zhuān)家智能體,將在很大程度上解決這些問(wèn)題。
使用離線強(qiáng)化學(xué)習(xí):JAT 智能體是仿照基線一比一地訓(xùn)練出來(lái)的。這意味著,其一,智能體無(wú)法利用次優(yōu)的軌跡;其二,JAT 智能體無(wú)法超越專(zhuān)家。論文選擇了這種方法是因?yàn)樗容^簡(jiǎn)單,但研究團(tuán)隊(duì)相信,使用離線強(qiáng)化學(xué)習(xí)可以提高智能體的性能,同時(shí),實(shí)現(xiàn)起來(lái)也不會(huì)過(guò)于復(fù)雜。
發(fā)揮更智能的多任務(wù)采樣策略的全部潛力:目前,JAT 智能體均勻地從所有任務(wù)中采樣數(shù)據(jù),但這種方法可能限制了它的全部潛力。通過(guò)動(dòng)態(tài)調(diào)整采樣率,專(zhuān)注于最具挑戰(zhàn)性的任務(wù),或許也可以加速智能體的學(xué)習(xí)過(guò)程,并解鎖顯著的性能提升。
本文轉(zhuǎn)自 機(jī)器之心 ,作者:機(jī)器之心
