上交最新時空預(yù)測模型PredFormer,純Transformer架構(gòu),多個數(shù)據(jù)集取得SOTA效果
今天給大家介紹一篇時空預(yù)測最新模型PredFormer,由上海交大等多所高校發(fā)表,采用純Transformer模型結(jié)構(gòu),在多個數(shù)據(jù)集中取得SOTA效果。
1.背景
時空預(yù)測學(xué)習(xí)是一個擁有廣泛應(yīng)用場景的領(lǐng)域,比如天氣預(yù)測,交通流預(yù)測,降水預(yù)測,自動駕駛,人體運(yùn)動預(yù)測等。
提起時空預(yù)測,不得不提到經(jīng)典模型ConvLSTM和最經(jīng)典的benchmark moving mnist,在ConvLSTM時代,對于Moving MNIST的預(yù)測存在肉眼可見的偽影和預(yù)測誤差。而在最新模型PredFormer中,對Moving MNIST的誤差達(dá)到肉眼難以分辨的近乎完美的預(yù)測結(jié)果。
在以前的時空預(yù)測工作中,主要分為兩個流派,基于循環(huán)(自回歸)的模型,以ConvLSTM/PredRNN//E3DLSTM/SwinLSTM/VMRNN等工作為代表;更近年來,研究者提出無需循環(huán)的SimVP框架,由CNN Encoder-Decoder結(jié)構(gòu)和一個時間轉(zhuǎn)換器組成,以SimVP/TAU/OpenSTL等工作為代表。
RNN。系列模型的缺陷在于,無法并行化,自回歸速度慢,顯存占用高,效率低;CNN系列模型無需循環(huán)提高了效率,得益于歸納偏置,但往往以犧牲泛化性和可擴(kuò)展性為代價,模型上限低。
于是作者提出了問題,時空預(yù)測,真的需要RNN嗎?真的需要CNN嗎?是否能夠設(shè)計一個模型,可以自動地學(xué)習(xí)數(shù)據(jù)中的時空依賴,而不需要依賴于歸納偏置呢?
一個直覺的想法是利用Transformer,因?yàn)樗诟鞣N視覺任務(wù)中的廣泛成功,并且是RNN和CNN的有力替代者。在此前的時空預(yù)測工作中,已有研究者把Transformer嵌入到上述兩種框架中,比如SwinLSTM(ICCV23)融合了Swin Transformer和LSTM,比如OpenSTL(NeurIPS23)把各種MetaFormer結(jié)構(gòu)(比如ViT,Swin Transformer等)作為SimVP框架中的時間轉(zhuǎn)換器。但是,純Transformer結(jié)構(gòu)的網(wǎng)絡(luò)鮮有探索。
但純Transformer模型的挑戰(zhàn)在于,如何在一個框架中同時處理時間和空間信息。一個簡單的想法是合并空間序列和時間序列,計算時空全注意力,由于Transformer的計算復(fù)雜度是序列長度的二次復(fù)雜度,這樣的做法會導(dǎo)致計算復(fù)雜度較大。
在這篇文章中,作者提出了用于時空預(yù)測學(xué)習(xí)的新框架PredFormer,這是一個純ViT模型,既沒有自回歸也沒有任何卷積。作者利用精心設(shè)計的基于門控Transfomer模塊,對3D Attention進(jìn)行了全面的分析,包括時空全注意力,時空分解的注意力,和時空交錯的注意力。PredFormer 采用非循環(huán)、基于Transformer的設(shè)計,既簡單又高效,更少參數(shù)量,F(xiàn)lops,更快推理速度,性能顯著優(yōu)于以前的方法。在合成和真實(shí)數(shù)據(jù)集上進(jìn)行的大量實(shí)驗(yàn)表明,PredFormer 實(shí)現(xiàn)了最先進(jìn)的性能。在 Moving MNIST 上,PredFormer 相對于 SimVP 實(shí)現(xiàn)了 51.3% 的 MSE 降低,突破性地達(dá)到11.6。對于 TaxiBJ,該模型將 MSE 降低了 33.1%,并將 FPS 從 533 提高到 2364。此外,在 WeatherBench 上,它將 MSE 降低了 11.1%,同時將 FPS 從 196 提高到 404。這些準(zhǔn)確度和效率方面的性能提升證明了 PredFormer 在實(shí)際應(yīng)用中的潛力。
2.實(shí)現(xiàn)方法
PredFormer模型遵循標(biāo)準(zhǔn)ViT的設(shè)計,先對輸入進(jìn)行Patch Embedding,把輸入為[B, T, C, H, W]的時空序列轉(zhuǎn)換為[B, T, N, D]的張量。在位置編碼環(huán)節(jié),作者采用了不同于一般ViT設(shè)計的可學(xué)習(xí)的位置編碼,而是采用了基于sin函數(shù)的絕對位置編碼,作者在消融實(shí)驗(yàn)中進(jìn)一步闡述了絕對位置編碼在時空任務(wù)中的優(yōu)越性。
PredFormer的編碼器部分,由門控Transfomer模塊以不同的方式堆疊而成。由于編碼器部分是純Transformer結(jié)構(gòu),沒有任何卷積,也沒有分辨率的下降,每一個門控Transformer模塊都建模了全局信息,這允許模型只需使用一個簡單的解碼器就可以構(gòu)成一個性能強(qiáng)大的預(yù)測模型。作者采用了一個線性層作為解碼器來進(jìn)行Patch Recovery,這讓模型的輸出從[B, T, N, D]恢復(fù)到[B, T, C, H, W]。
不同于標(biāo)準(zhǔn)Transformer模型采用MLP作為FFN,PredFormer采用了Gated Linear Unit(GLU)作為FFN,這是受GLU在NLP任務(wù)中優(yōu)于MLP啟發(fā)的改進(jìn)。作者在消融實(shí)驗(yàn)中進(jìn)一步闡述了GLU相比于MLP在時空任務(wù)上的優(yōu)越性。
作者對3D Attention進(jìn)行了全面的分析,并提出了9種PredFormer變體。在以前用于視頻分類的Video ViT設(shè)計中,TimesFormer(ICML21), ViviT(ICCV21), TSViT(CVPR23)等工作也對時空分解進(jìn)行了分析,但是TimesFormer是在self-attention層面進(jìn)行分解,也就是spatial attention和temporal attention共用一個MLP。ViviT則是提出了在Encoder層面(先空間后時間),self-attention層面和head層面進(jìn)行時空分解。而TSViT發(fā)現(xiàn)先時間后空間的Encoder對衛(wèi)星序列圖像分類更有效。
不同于以上工作,PredFormer是在Gated Transformer Block(GTB)層面(多了基于Gated Linear Unit)進(jìn)行時空分解。對時間和空間的self-attention都加GLU是至關(guān)重要的,因?yàn)樗梢宰寣W(xué)習(xí)到的特征互相作用并且增強(qiáng)非線性。
PredFormer提出了時空全注意力Encoder,時間在前和空間在前的2種分解Encoder和6種新穎的時空交錯的Encoder,一共9種模型。PredFormer提出了PredFormer Layer的概念,即一個既能建??臻g信息,又能建模時間信息的最小單元?;谶@種想法,作者提出了三種基本范式,二元組(由一個Temporal GTB和一個Spatial GTB組成,有T-S和S-T兩種方式),三元組(T-S-T和S-T-S),四元組(兩個二元組以相反的方向重組)。
這一設(shè)計源于不同的時空預(yù)測任務(wù)往往有著不同的空間分辨率和時間分辨率(時間間隔以及變化程度),這意味著不同的數(shù)據(jù)集上對時間信息和空間信息的依賴程度不同,作者設(shè)計了這些模型以提高PredFormer模型在不同任務(wù)上的適應(yīng)性。
3.實(shí)驗(yàn)效果
在實(shí)驗(yàn)部分,作者控制了提出的每種變體使用相同的GTB數(shù)目,這可以保證模型的參數(shù)量基本一致,從而對比不同模型的性能。
實(shí)驗(yàn)發(fā)現(xiàn)了一些規(guī)律,(1)時間在前的分解Encoder模型優(yōu)于時空全注意力模型,由于空間在前的分解Encoder模型 (2)時空交錯的6種模型在大多數(shù)任務(wù)上表現(xiàn)都很好,都能達(dá)到sota,但最優(yōu)模型因?yàn)閿?shù)據(jù)集本身的不同時空依賴特性而不同,這體現(xiàn)了PredFormer這種框架和時空交錯設(shè)計的優(yōu)勢 (3)作者在討論環(huán)節(jié)提出了建議,在其他的時空預(yù)測任務(wù)上,從四元組-TSST開始嘗試,因?yàn)檫@個模型在三個數(shù)據(jù)集上都表現(xiàn)sota,先調(diào)整M個TSST(即4M個門控Transformer)的M參數(shù),然后嘗試M個TST和M個STS以確定數(shù)據(jù)集是時間依賴更強(qiáng)或空間依賴更強(qiáng)的模型。得益于Transformer構(gòu)架的可擴(kuò)展性,不同于SimVP框架的CNN Encoder-Decoder模型,對spatial和temporal的hidden dim以及block數(shù)都設(shè)置了不同的值,PredFormer對spatial和temporal GTB采用相同的固定的參數(shù),因此只需要調(diào)整M的值,在比較少次數(shù)的調(diào)整后就可以達(dá)到最優(yōu)性能。
ViT模型的訓(xùn)練通常要求較大的數(shù)據(jù)集,在時空預(yù)測任務(wù)上,大多數(shù)據(jù)集在幾千到幾萬的量級,數(shù)據(jù)集少,因此很容易過擬合。作者還探索了不同的正則化策略,包括dropout和drop path,通過廣泛的消融實(shí)驗(yàn),作者發(fā)現(xiàn)同時使用dropout和uniform的drop path(不同于一般ViT使用線性增加的drop path rate)會產(chǎn)生最優(yōu)的模型效果。
作者還進(jìn)行了可視化比較,可以看到,在PredFormer相對于TAU明顯減少了預(yù)測誤差。作者還給出了一個特殊例子來證明PredFormerr模型相比于CNN模型在泛化性上的優(yōu)越性。在交通流預(yù)測任務(wù)上,當(dāng)?shù)谒膸啾惹叭龓黠@減少流量時,TAU受限于歸納偏置仍然預(yù)測了較高的流量,而PredFormer卻能捕捉到這里的變化。PredFormerr預(yù)測劇烈變化的能力在交通流和天氣預(yù)測中可能有非常寶貴的應(yīng)用價值。
本文轉(zhuǎn)載自 ??圓圓的算法筆記??,作者:yujin
