補齊Transformer規(guī)劃短板,田淵棟團隊的Searchformer火了
最近幾年,基于 Transformer 的架構在多種任務上都表現(xiàn)卓越,吸引了世界的矚目。使用這類架構搭配大量數(shù)據(jù),得到的大型語言模型(LLM)等模型可以很好地泛化用于真實世界用例。
盡管有如此成功,但基于 Transformer 的架構和 LLM 依然難以處理規(guī)劃和推理任務。之前已有研究證明 LLM 難以應對多步規(guī)劃任務或高階推理任務。
為了提升 Transformer 的推理和規(guī)劃性能,近些年研究社區(qū)也提出了一些方法。一種最常見且有效的方法是模擬人類的思考過程:先生成中間「思維」,然后再輸出響應。比如思維鏈(CoT)提示法就是鼓勵模型預測中間步驟,進行按步驟的「思考」。思維樹(ToT)則使用了分支策略和評判方法,讓模型生成多個不同的思維路徑,然后從中選出最佳路徑。盡管這些技術通常是有效的,但也有研究表明,在很多案例中,這些方法會讓模型的性能下降,原因包括自我強制(self-enforcing)。
另一方面,在一個數(shù)據(jù)集上有效的技術可能無法很好地處理其它數(shù)據(jù)集,原因可能包括所涉及的推理類型發(fā)生了變化,比如從空間推理變成了數(shù)學推理或常識推理。
相較之下,傳統(tǒng)的符號式規(guī)劃和搜索技術卻能表現(xiàn)出很好的推理能力。此外,這些傳統(tǒng)方法計算得到的解決方案通常有形式上的保證,因為符號規(guī)劃算法通常遵循明確定義的基于規(guī)則的搜索過程。
為了讓 Transformer 具備復雜推理能力,Meta FAIR 田淵棟團隊近日提出了 Searchformer。
- 論文標題:Beyond A?: Better Planning with Transformers via Search Dynamics Bootstrapping
- 論文地址:https://arxiv.org/pdf/2402.14083.pdf
Searchformer 是一種 Transformer 模型,但針對迷宮導航和推箱子等多步規(guī)劃任務,它卻能計算出最優(yōu)規(guī)劃并且所用搜索步驟數(shù)也能遠少于 A? 搜索等符號規(guī)劃算法。
為了做到這一點,該團隊提出了一種新方法:搜索動態(tài)引導(search dynamics bootstrapping)。該方法首先是訓練一個 Transformer 模型來模仿 A? 的搜索過程(如圖 1 所示,然后對其進行微調(diào),使其能用更少的搜索步數(shù)找到最優(yōu)規(guī)劃。
更詳細地說,第一步,訓練一個模仿 A? 搜索的 Transformer 模型。這里,該團隊的做法是針對隨機生成的規(guī)劃任務實例運行 A* 搜索。在執(zhí)行 A? 時,該團隊會記錄執(zhí)行的計算和最優(yōu)規(guī)劃并將其整理成詞序列,即 token。這樣一來,所得到的訓練數(shù)據(jù)集就包含了 A? 的執(zhí)行軌跡并編碼了有關 A? 本身的搜索動態(tài)的信息。然后,訓練一個 Transformer 模型,讓其能針對任意規(guī)劃任務沿最優(yōu)規(guī)劃生成這些 token 序列。
第二步,使用專家迭代(expert iteration)方法進一步提升使用上述經(jīng)過搜索增強的序列(包含 A? 的執(zhí)行軌跡)訓練的 Searchformer。專家迭代方法可讓 Transformer 憑借更少的搜索步驟生成最優(yōu)解。這個過程會得到一種神經(jīng)規(guī)劃算法,其隱式地編碼在該 Transformer 的網(wǎng)絡權重之中,并且它有很高的概率以少于 A? 搜索的搜索步數(shù)找到最優(yōu)規(guī)劃。比如說,在執(zhí)行推箱子任務時,新模型能解答 93.7% 的測試任務,同時搜索步數(shù)比 A? 搜索平均少 26.8%。
該團隊表示:這為 Transformer 超越傳統(tǒng)符號規(guī)劃算法鋪平了道路。
實驗
為了更好地理解訓練數(shù)據(jù)和模型參數(shù)量對所得模型性能的影響,他們進行了一些消融研究。
他們使用了兩類數(shù)據(jù)集訓練模型:一種的 token 序列中只包含解(solution-only,其中只有任務描述和最終規(guī)劃);另一種則是搜索增強型序列(search-augmented,其中包含任務描述、搜索樹動態(tài)和最終規(guī)劃)。
實驗中,該團隊使用了 A? 搜索的一種確定性和非確定性變體來生成每個序列數(shù)據(jù)集。
迷宮導航
在第一個實驗中,該團隊訓練了一組編碼器 - 解碼器 Transformer 模型來預測 30×30 迷宮中的最優(yōu)路徑。
圖 4 表明,通過預測中間計算步驟,可在數(shù)據(jù)量少時獲得更穩(wěn)健的性能表現(xiàn)。
圖 5 給出了僅使用解訓練的模型的性能。
圖 6 展示了任務難度對每個模型的性能的影響。
整體而言,盡管當使用的訓練數(shù)據(jù)集足夠大和足夠多樣化時,僅使用解訓練的模型也能預測得到最優(yōu)規(guī)劃,但當數(shù)據(jù)量少時,經(jīng)過搜索增強的模型的表現(xiàn)明顯好得多,并且也能更好地擴展用于更困難的任務。
推箱子
為了測試能否在不同且更復雜的任務(具有不同的 token 化模式)上得到類似的結果,該團隊還生成了一個推箱子的規(guī)劃數(shù)據(jù)集進行測試。
圖 7 展示了每種模型針對每個測試任務生成正確規(guī)劃的概率。
可以看到,和上一個實驗一樣,通過使用執(zhí)行軌跡進行訓練,搜索增強型模型的表現(xiàn)優(yōu)于僅使用解訓練的模型。
Searchformer:通過引導方法提升搜索動態(tài)
最后一個實驗,該團隊研究了搜索增強型模型可以如何迭代提升,從而憑借更少的搜索步數(shù)計算出最優(yōu)規(guī)劃。這里的目標是在縮短搜索軌跡長度的同時依然得到最優(yōu)解。
圖 8 表明,新提出的搜索動態(tài)引導方法能夠迭代式地縮短 Searchformer 模型生成的序列的長度。