性能大漲20%!中科大「狀態(tài)序列頻域預(yù)測(cè)」方法:表征學(xué)習(xí)樣本效率max
強(qiáng)化學(xué)習(xí)算法(Reinforcement Learning, RL)的訓(xùn)練過(guò)程往往需要大量與環(huán)境交互的樣本數(shù)據(jù)作為支撐。然而,現(xiàn)實(shí)世界中收集大量的交互樣本通常成本高昂或者難以保證樣本采集過(guò)程的安全性,例如無(wú)人機(jī)空戰(zhàn)訓(xùn)練和自動(dòng)駕駛訓(xùn)練。
為了提升強(qiáng)化學(xué)習(xí)算法在訓(xùn)練過(guò)程中的樣本效率,一些研究者們借助于表征學(xué)習(xí)(representation learning),設(shè)計(jì)了預(yù)測(cè)未來(lái)狀態(tài)信號(hào)的輔助任務(wù),使得表征能從原始的環(huán)境狀態(tài)中編碼出與未來(lái)決策相關(guān)的特征。
基于這個(gè)思路,該工作設(shè)計(jì)了一種預(yù)測(cè)未來(lái)多步的狀態(tài)序列頻域分布的輔助任務(wù),以捕獲更長(zhǎng)遠(yuǎn)的未來(lái)決策特征,進(jìn)而提升算法的樣本效率。
該工作標(biāo)題為State Sequences Prediction via Fourier Transform for Representation Learning,發(fā)表于NeurIPS 2023,并被接收為Spotlight。
作者列表:葉鳴軒,匡宇飛,王杰*,楊睿,周文罡,李厚強(qiáng),吳楓
論文鏈接:https://openreview.net/forum?id=MvoMDD6emT
代碼鏈接:https://github.com/MIRALab-USTC/RL-SPF/
研究背景與動(dòng)機(jī)
深度強(qiáng)化學(xué)習(xí)算法在機(jī)器人控制[1]、游戲智能[2]、組合優(yōu)化[3]等領(lǐng)域取得了巨大的成功。但是,當(dāng)前的強(qiáng)化學(xué)習(xí)算法仍存在「樣本效率低下」的問(wèn)題,即機(jī)器人需要大量與環(huán)境交互的數(shù)據(jù)才能訓(xùn)得性能優(yōu)異的策略。
為了提升樣本效率,研究者們將目光投向于表征學(xué)習(xí),希望訓(xùn)得的表征能從環(huán)境的原始狀態(tài)中提取出充足且有價(jià)值的特征信息,從而提升機(jī)器人對(duì)狀態(tài)空間的探索效率。
基于表征學(xué)習(xí)的強(qiáng)化學(xué)習(xí)算法框架
在序列決策任務(wù)中,「長(zhǎng)期的序列信號(hào)」相對(duì)于單步信號(hào)包含更多有利于長(zhǎng)期決策的未來(lái)信息。啟發(fā)于這一觀點(diǎn),一些研究者提出通過(guò)預(yù)測(cè)未來(lái)多步的狀態(tài)序列信號(hào)來(lái)輔助表征學(xué)習(xí)[4,5]。然而,直接預(yù)測(cè)狀態(tài)序列來(lái)輔助表征學(xué)習(xí)是非常困難的。
現(xiàn)有的兩類方法中,一類方法通過(guò)學(xué)習(xí)單步概率轉(zhuǎn)移模型來(lái)逐步地產(chǎn)生單個(gè)時(shí)刻的未來(lái)狀態(tài),以間接預(yù)測(cè)多步的狀態(tài)序列[6,7]。但是,這類方法對(duì)所訓(xùn)得的概率轉(zhuǎn)移模型的精度要求很高,因?yàn)槊坎降念A(yù)測(cè)誤差會(huì)隨預(yù)測(cè)序列長(zhǎng)度的增加而積累。
另一類方法通過(guò)直接預(yù)測(cè)未來(lái)多步的狀態(tài)序列來(lái)輔助表征學(xué)習(xí)[8],但這類方法需要存儲(chǔ)多步的真實(shí)狀態(tài)序列作為預(yù)測(cè)任務(wù)的標(biāo)簽,所耗存儲(chǔ)量大。因此,如何有效從環(huán)境的狀態(tài)序列中提取有利于長(zhǎng)期決策的未來(lái)信息,進(jìn)而提升連續(xù)控制機(jī)器人訓(xùn)練時(shí)的樣本效率是需要解決的問(wèn)題。
為了解決上述問(wèn)題,我們提出了一種基于狀態(tài)序列頻域預(yù)測(cè)的表征學(xué)習(xí)方法(State Sequences Prediction via Fourier Transform, SPF),其思想是利用「狀態(tài)序列的頻域分布」來(lái)顯式提取狀態(tài)序列數(shù)據(jù)中的趨勢(shì)性和規(guī)律性信息,從而輔助表征高效地提取到長(zhǎng)期未來(lái)信息。
狀態(tài)序列中的結(jié)構(gòu)性信息分析
我們從理論上證明了狀態(tài)序列存在「兩種結(jié)構(gòu)性信息」,一是與策略性能相關(guān)的趨勢(shì)性信息,二是與狀態(tài)周期性相關(guān)的規(guī)律性信息。
馬爾科夫決策過(guò)程
在具體分析兩種結(jié)構(gòu)性信息之前,我們先介紹產(chǎn)生狀態(tài)序列的馬爾科夫決策過(guò)程(Markov Decision Processes,MDP)的相關(guān)定義。
我們考慮連續(xù)控制問(wèn)題中的經(jīng)典馬爾可夫決策過(guò)程,該過(guò)程可用五元組 表示。其中, 為相應(yīng)的狀態(tài)、動(dòng)作空間, 為獎(jiǎng)勵(lì)函數(shù), 為環(huán)境的狀態(tài)轉(zhuǎn)移函數(shù), 為狀態(tài)的初始分布, 為折扣因子。此外,我們用 表示策略在狀態(tài) 下的動(dòng)作分布。
我們將 時(shí)刻下智能體所處的狀態(tài)記為 ,所選擇的動(dòng)作記為 .智能體做出動(dòng)作后,環(huán)境轉(zhuǎn)移到下一時(shí)刻狀態(tài) 并反饋給智能體獎(jiǎng)勵(lì) 。我們將智能體與環(huán)境交互過(guò)程中所得到狀態(tài)、動(dòng)作對(duì)應(yīng)的軌跡記為 ,軌跡服從分布 。
強(qiáng)化學(xué)習(xí)算法的目標(biāo)是最大化未來(lái)預(yù)期的累積回報(bào),我們用 表示當(dāng)前策略 和 環(huán)境模型 下的平均累積回報(bào),并簡(jiǎn)寫為 ,定義如下:
顯示了當(dāng)前策略 的性能表現(xiàn)。
趨勢(shì)性信息
下面我們介紹狀態(tài)序列的「第一種結(jié)構(gòu)性特征」,其涉及狀態(tài)序列和對(duì)應(yīng)獎(jiǎng)勵(lì)序列之間的依賴關(guān)系,能顯示出當(dāng)前策略的性能趨勢(shì)。
在強(qiáng)化學(xué)習(xí)任務(wù)中,未來(lái)的狀態(tài)序列很大程度上決定了智能體未來(lái)采取的動(dòng)作序列,并進(jìn)一步?jīng)Q定了相應(yīng)的獎(jiǎng)勵(lì)序列。因此,未來(lái)的狀態(tài)序列不僅包含環(huán)境固有的概率轉(zhuǎn)移函數(shù)的信息,也能輔助表征捕獲反映當(dāng)前策略的走向趨勢(shì)。
啟發(fā)于上述結(jié)構(gòu),我們證明了以下定理,進(jìn)一步論證了這一結(jié)構(gòu)性依賴關(guān)系的存在:
定理一:若獎(jiǎng)勵(lì)函數(shù)只與狀態(tài)有關(guān),那么對(duì)于任意兩個(gè)策略 和 ,他們的性能差異可以被這兩個(gè)策略所產(chǎn)生的狀態(tài)序列分布差異所控制:
上述公式中, 表示在指定策略和轉(zhuǎn)移概率函數(shù)條件下狀態(tài)序列的概率分布, 表示 范數(shù)。
上述定理表明,兩個(gè)策略的性能差異越大,其對(duì)應(yīng)的兩個(gè)狀態(tài)序列的分布差異也越大。這意味著好策略和壞策略會(huì)產(chǎn)生出兩個(gè)差異較大的狀態(tài)序列,這進(jìn)一步說(shuō)明狀態(tài)序列所包含的長(zhǎng)期結(jié)構(gòu)性信息能潛在影響搜索性能優(yōu)異的策略的效率。
另一方面,在一定條件下,狀態(tài)序列的頻域分布差異也能為對(duì)應(yīng)的策略性能差異提供上界,具體如以下定理所示:
定理二:若狀態(tài)空間有限維且獎(jiǎng)勵(lì)函數(shù)是與狀態(tài)有關(guān)的n次多項(xiàng)式,那么對(duì)于任意兩個(gè)策略 和 ,他們的性能差異可以被這兩個(gè)策略所產(chǎn)生的狀態(tài)序列的頻域分布差異所控制:
上述公式中, 表示由策略 所產(chǎn)生的狀態(tài)序列的 次方序列的傅里葉函數(shù), 表示傅里葉函數(shù)的第 個(gè)分量。
這一定理表明狀態(tài)序列的頻域分布仍包含與當(dāng)前策略性能相關(guān)的特征。
規(guī)律性信息
下面我們介紹狀態(tài)序列中存在的「第二種結(jié)構(gòu)性特征」,其涉及到狀態(tài)信號(hào)之間的時(shí)間依賴性,即一段較長(zhǎng)時(shí)期內(nèi)狀態(tài)序列所表現(xiàn)出的規(guī)律性模式。
在許多的真實(shí)場(chǎng)景任務(wù)中,智能體也會(huì)表現(xiàn)出周期性行為,因?yàn)槠洵h(huán)境的狀態(tài)轉(zhuǎn)移函數(shù)本身就是具有周期性的。以工業(yè)裝配機(jī)器人為例,該機(jī)器人的訓(xùn)練目標(biāo)是將零件組裝在一起以創(chuàng)造最終產(chǎn)品,當(dāng)策略訓(xùn)練達(dá)到穩(wěn)定時(shí),它就會(huì)執(zhí)行一個(gè)周期性的動(dòng)作序列,使其能夠有效地將零件組裝在一起。
啟發(fā)于上面的例子,我們提供了一些理論分析,證明了有限狀態(tài)空間中,當(dāng)轉(zhuǎn)移概率矩陣滿足某些假設(shè),對(duì)應(yīng)的狀態(tài)序列在智能體達(dá)到穩(wěn)定策略時(shí)可能表現(xiàn)出「漸近周期性」,具體定理如下:
定理三:對(duì)于狀態(tài)轉(zhuǎn)移矩陣為 的有限維狀態(tài)空間 ,假設(shè) 有 個(gè)循環(huán)類,對(duì)應(yīng)的狀態(tài)轉(zhuǎn)移子矩陣為 。設(shè)這 個(gè)矩陣模為1的特征值個(gè)數(shù)為 ,則對(duì)于任意狀態(tài)的初始分布 ,狀態(tài)分布 呈現(xiàn)出周期為 的漸進(jìn)周期性。
在MuJoCo任務(wù)中,策略訓(xùn)練達(dá)到穩(wěn)定時(shí),智能體也會(huì)表現(xiàn)出周期性的運(yùn)動(dòng)。下圖中給出了MuJoCo任務(wù)中HalfCheetah智能體在一段時(shí)間內(nèi)的狀態(tài)序列示例,可以觀察到明顯的周期性。(更多MuJoCo任務(wù)中帶周期性的狀態(tài)序列示例可參考本論文附錄第E節(jié))
MuJoCo任務(wù)中HalfCheetah智能體在一段時(shí)間內(nèi)狀態(tài)所表現(xiàn)出的周期性
時(shí)間序列在時(shí)域中呈現(xiàn)的信息相對(duì)分散,但在頻域中,序列中的規(guī)律性信息以更加集中的形式呈現(xiàn)。通過(guò)分析頻域中的頻率分量,我們能顯式地捕獲到狀態(tài)序列中存在的周期性特征。
方法介紹
上一部分中,我們從理論上證明狀態(tài)序列的頻域分布能反映策略性能的好壞,并且通過(guò)在頻域上分析頻率分量我們能顯式捕獲到狀態(tài)序列中的周期性特征。
啟發(fā)于上述分析,我們?cè)O(shè)計(jì)了「預(yù)測(cè)無(wú)窮步未來(lái)狀態(tài)序列傅里葉變換」的輔助任務(wù)來(lái)鼓勵(lì)表征提取狀態(tài)序列中的結(jié)構(gòu)性信息。
SPF方法損失函數(shù)
下面介紹我們關(guān)于該輔助任務(wù)的建模。給定當(dāng)前狀態(tài) 和動(dòng)作 ,我們定義未來(lái)的狀態(tài)序列期望如下:
我們的輔助任務(wù)訓(xùn)練表征去預(yù)測(cè)上述狀態(tài)序列期望的離散時(shí)間傅里葉變換(discrete-time Fourier transform, DTFT),即
上述傅里葉變換公式可改寫為如下的遞歸形式:
其中,
其中, 為狀態(tài)空間的維度, 為所預(yù)測(cè)的狀態(tài)序列傅里葉函數(shù)的離散化點(diǎn)的個(gè)數(shù)。
啟發(fā)于Q-learning中優(yōu)化Q值網(wǎng)絡(luò)的TD-error損失函數(shù)[9],我們?cè)O(shè)計(jì)了如下的損失函數(shù):
其中, 和 分別為損失函數(shù)要優(yōu)化的表征編碼器(encoder)和傅里葉函數(shù)預(yù)測(cè)器(predictor)的神經(jīng)網(wǎng)絡(luò)參數(shù), 為存儲(chǔ)樣本數(shù)據(jù)的經(jīng)驗(yàn)池。
進(jìn)一步地,我們可以證明上述的遞歸公式可以表示為一個(gè)壓縮映射:
定理四:令 表示函數(shù)族 ,并定義 上的范數(shù)為:
其中 表示矩陣 的第 行向量。我們定義映射 為
則可以證明 為一個(gè)壓縮映射。
根據(jù)壓縮映射原理,我們可以迭代地使用算子 ,使得 逼近真實(shí)狀態(tài)序列的頻域分布,且在表格型情況(tabular setting)下有收斂性保證。
此外,我們所設(shè)計(jì)的損失函數(shù)只依賴于當(dāng)前時(shí)刻與下一時(shí)刻的狀態(tài),所以無(wú)需存儲(chǔ)未來(lái)多步的狀態(tài)數(shù)據(jù)作為預(yù)測(cè)標(biāo)簽,具有「實(shí)施簡(jiǎn)單且存儲(chǔ)量低」的優(yōu)點(diǎn)。
SPF方法算法框架
下面我們介紹本論文方法(SPF)的算法框架。
基于狀態(tài)序列頻域預(yù)測(cè)的表征學(xué)習(xí)方法(SPF)的算法框架圖
我們將當(dāng)前時(shí)刻和下一時(shí)刻的狀態(tài)-動(dòng)作數(shù)據(jù)分別輸入到在線(online)和目標(biāo)(target)表征編碼器(encoder)中,得到狀態(tài)-動(dòng)作表征數(shù)據(jù),然后將該表征數(shù)據(jù)輸入到傅里葉函數(shù)預(yù)測(cè)器(predictor)得到當(dāng)前時(shí)刻和下一時(shí)刻下的兩組狀態(tài)序列傅里葉函數(shù)預(yù)測(cè)值。通過(guò)代入這兩組傅里葉函數(shù)預(yù)測(cè)值,我們能計(jì)算出損失函數(shù)值。
我們通過(guò)最小化損失函數(shù)來(lái)優(yōu)化更新表征編碼器 和傅里葉函數(shù)預(yù)測(cè)器 ,使預(yù)測(cè)器的輸出能逼近真實(shí)狀態(tài)序列的傅里葉變換,從而鼓勵(lì)表征編碼器提取出包含未來(lái)長(zhǎng)期狀態(tài)序列的結(jié)構(gòu)性信息的特征。
我們將原始狀態(tài)和動(dòng)作輸入到表征編碼器中,將得到的特征作為強(qiáng)化學(xué)習(xí)算法中actor網(wǎng)絡(luò)和critic網(wǎng)絡(luò)的輸入,并用經(jīng)典強(qiáng)化學(xué)習(xí)算法優(yōu)化actor網(wǎng)絡(luò)和critic網(wǎng)絡(luò)。
實(shí)驗(yàn)結(jié)果
(注:本節(jié)僅選取部分實(shí)驗(yàn)結(jié)果,更詳細(xì)的結(jié)果請(qǐng)參考論文原文第6節(jié)及附錄。)
算法性能比較
我們將 SPF 方法在 MuJoCo 仿真機(jī)器人控制環(huán)境上測(cè)試,對(duì)如下 6 種方法進(jìn)行對(duì)比:
- SAC:基于Q值學(xué)習(xí)的soft actor-critic算法[10],一種傳統(tǒng)的RL算法;
- PPO:基于策略優(yōu)化的proximal policy optimization算法[11],一種傳統(tǒng)RL算法;
- SAC-OFE:利用預(yù)測(cè)單步未來(lái)狀態(tài)的輔助任務(wù)進(jìn)行表征學(xué)習(xí),以優(yōu)化SAC算法;
- PPO-OFE:利用預(yù)測(cè)單步未來(lái)狀態(tài)的輔助任務(wù)進(jìn)行表征學(xué)習(xí),以優(yōu)化PPO算法;
- SAC-SPF:利用預(yù)測(cè)無(wú)窮步狀態(tài)序列的頻域函數(shù)的輔助任務(wù)進(jìn)行表征學(xué)習(xí)(我們的方法),以優(yōu)化SAC算法;
- PPO-SPF:利用預(yù)測(cè)無(wú)窮步狀態(tài)序列的頻域函數(shù)的輔助任務(wù)進(jìn)行表征學(xué)習(xí)(我們的方法),以優(yōu)化PPO算法;
基于6種MuJoCo任務(wù)的對(duì)比實(shí)驗(yàn)結(jié)果
上圖顯示了在 6 種 MuJoCo 任務(wù)中,我們所提出的SPF方法(紅線及橙線)與其他對(duì)比方法的性能曲線。結(jié)果顯示,我們所提出的方法相比于其他方法能獲得19.5%的性能提升。
消融實(shí)驗(yàn)
我們對(duì) SPF 方法的各個(gè)模塊進(jìn)行了消融實(shí)驗(yàn),將本方法與不使用投影器模塊(noproj)、不使用目標(biāo)網(wǎng)絡(luò)模塊(notarg)、改變預(yù)測(cè)損失(nofreqloss)、改變特征編碼器網(wǎng)絡(luò)結(jié)構(gòu)(mlp,mlp_cat)時(shí)的性能表現(xiàn)做比較。
SPF方法應(yīng)用于SAC算法的消融實(shí)驗(yàn)結(jié)果圖,測(cè)試于HalfCheetah任務(wù)
可視化實(shí)驗(yàn)
我們使用 SPF 方法所訓(xùn)練好的預(yù)測(cè)器輸出狀態(tài)序列的傅里葉函數(shù),并通過(guò)逆傅里葉變換恢復(fù)出的200步狀態(tài)序列,與真實(shí)的200步狀態(tài)序列進(jìn)行對(duì)比。
基于傅里葉函數(shù)預(yù)測(cè)值恢復(fù)出的狀態(tài)序列示意圖,測(cè)試于Walker2d任務(wù)。其中,藍(lán)線為真實(shí)的狀態(tài)序列示意圖,5條紅線為恢復(fù)出的狀態(tài)序列示意圖,越下方的、顏色越淺的紅線表示利用越久遠(yuǎn)的歷史狀態(tài)所恢復(fù)出的狀態(tài)序列。
結(jié)果顯示,即使用更久遠(yuǎn)的狀態(tài)作為輸入,恢復(fù)出的狀態(tài)序列也和真實(shí)的狀態(tài)序列非常相似,這說(shuō)明 SPF 方法所學(xué)習(xí)出的表征能有效編碼出狀態(tài)序列中包含的結(jié)構(gòu)性信息。