基于多級注意力機(jī)制的并行預(yù)測模型
創(chuàng)新點(diǎn):
1.利用 Transformer 來提取序列中的長期依賴關(guān)系的時序特征,采用并行結(jié)構(gòu),加快模型的訓(xùn)練和推理速度,提高模型對關(guān)鍵信息的感知能力;
2.通過雙向門控循環(huán)單元(BiGRU)同時從前向和后向?qū)π蛄羞M(jìn)行建模,以更好地捕獲序列中的依賴關(guān)系,同時應(yīng)用全局注意力機(jī)制GlobalAttention,對BiGRU的輸出進(jìn)行加權(quán)處理,使模型能夠聚焦于序列中最重要的部分,提高預(yù)測性能;
3.利用交叉注意力進(jìn)行并行網(wǎng)絡(luò)時空特征的融合,這樣可以同時考慮時序關(guān)系和位置關(guān)系,從而更好地捕捉時空序列數(shù)據(jù)中的特征,增強(qiáng)特征的表示能力來實(shí)現(xiàn)高精度的預(yù)測。
在多個數(shù)據(jù)集上表現(xiàn)出高精度的預(yù)測性能!
模型訓(xùn)練可視化圖:
多特征貢獻(xiàn)度可視化分析圖:
此分析代碼是我們團(tuán)隊(duì)原創(chuàng),如何利用 深度學(xué)習(xí) 訓(xùn)練好的模型 在對多特征預(yù)測任務(wù) 中進(jìn)行特征重要性(貢獻(xiàn)度)可視化?。ㄒ部梢杂糜谄渌疃葘W(xué)習(xí)模型做特征重要性可視化,代碼適用性高)
前言
本文基于前期介紹的電力變壓器(文末附數(shù)據(jù)集),介紹一種基于Transformer-BiGRUGlobalAttention-CrossAttention并行預(yù)測模型,以提高時間序列數(shù)據(jù)的預(yù)測性能。電力變壓器數(shù)據(jù)集的詳細(xì)介紹可以參考下文:
??電力變壓器數(shù)據(jù)集介紹和預(yù)處理??
1 模型整體結(jié)構(gòu)
模型整體結(jié)構(gòu)如下所示,多特征變量時間序列數(shù)據(jù)先經(jīng)過基于多頭注意的Transformer編碼器層提取長期依賴特征,同時數(shù)據(jù)通過基于GlobalAttention優(yōu)化的BiGRU網(wǎng)絡(luò)提取全局時序特征,使用交叉注意力機(jī)制進(jìn)行特征融合,通過計(jì)算注意力權(quán)重,使得模型更關(guān)注重要的特征再進(jìn)行特征增強(qiáng)融合,最后經(jīng)過全連接層進(jìn)行高精度預(yù)測。
分支一:通過基于多頭注意的Transformer編碼器層模型,Transformer是一種基于自注意力機(jī)制的序列建模方法,通過注意力機(jī)制來建模序列中不同位置之間的依賴關(guān)系,能夠捕捉序列中的全局上下文信息。Transformer是一種基于自注意力機(jī)制的序列建模方法,通過注意力機(jī)制來建模序列中不同位置之間的依賴關(guān)系,能夠捕捉序列中的全局上下文信息。
分支二:多特征序列數(shù)據(jù)同時通過基于GlobalAttention優(yōu)化的BiGRU網(wǎng)絡(luò),GlobalAttention是一種用于加強(qiáng)模型對輸入序列不同部分的關(guān)注程度的機(jī)制。在 BiGRU 模型中,全局注意力機(jī)制可以幫助模型更好地聚焦于輸入序列中最相關(guān)的部分,從而提高模型的性能和泛化能力。在每個時間步,全局注意力機(jī)制計(jì)算一個權(quán)重向量,表示模型對輸入序列各個部分的關(guān)注程度,然后將這些權(quán)重應(yīng)用于 BiGRU 輸出的特征表示,通過對所有位置的特征進(jìn)行加權(quán),使模型能夠更有針對性地關(guān)注重要的時域特征, 提高了模型對多特征序列時域特征的感知能力;
并行預(yù)測:模型采用并行結(jié)構(gòu),能夠同時預(yù)測多個時間步的目標(biāo)。并行預(yù)測可以加快模型的訓(xùn)練和推理速度,并且能夠充分利用時序數(shù)據(jù)中的信息,提高預(yù)測性能。
交叉注意力機(jī)制特征融合:使用交叉注意力機(jī)制融空間和時序特征,可以通過計(jì)算注意力權(quán)重,學(xué)習(xí)時空特征中不同位置之間的相關(guān)性,可以更好地捕捉時空序列數(shù)據(jù)中的特征,提高模型性能和泛化能力。
全局注意力機(jī)制:
Global Attention Mechanism
2 多特征變量數(shù)據(jù)集制作與預(yù)處理
2.1 導(dǎo)入數(shù)據(jù)
2.2 制作數(shù)據(jù)集
制作數(shù)據(jù)集與分類標(biāo)簽
3 交叉注意力機(jī)制
3.1 Cross attention概念
- Transformer架構(gòu)中混合兩種不同嵌入序列的注意機(jī)制
- 兩個序列必須具有相同的維度
- 兩個序列可以是不同的模式形態(tài)(如:文本、聲音、圖像)
- 一個序列作為輸入的Q,定義了輸出的序列長度,另一個序列提供輸入的K&V
3.2 Cross-attention算法
- 擁有兩個序列S1、S2
- 計(jì)算S1的K、V
- 計(jì)算S2的Q
- 根據(jù)K和Q計(jì)算注意力矩陣
- 將V應(yīng)用于注意力矩陣
- 輸出的序列長度與S2一致
在融合過程中,我們將經(jīng)過Transformer的時序特征作為查詢序列,GlobalAttention優(yōu)化的BiGRU提取的全局空間特征作為鍵值對序列。通過計(jì)算查詢序列與鍵值對序列之間的注意力權(quán)重,我們可以對不同特征之間的關(guān)聯(lián)程度進(jìn)行建模。
4 基于多級注意力機(jī)制的并行高精度預(yù)測模型
4.1 定義網(wǎng)絡(luò)模型
注意:輸入數(shù)據(jù)形狀為 [64, 7, 7], batch_size=64,7代表序列長度(滑動窗口取值), 維度7維代表7個變量的維度。
4.2 設(shè)置參數(shù),訓(xùn)練模型
50個epoch,訓(xùn)練誤差極小,多變量特征序列Transformer-BiGRUGlobalAttention-CrossAttention并行融合網(wǎng)絡(luò)模型預(yù)測效果顯著,模型能夠充分提取時間序列的空間特征和時序特征,收斂速度快,性能優(yōu)越,預(yù)測精度高,能夠從序列時空特征中提取出對模型預(yù)測重要的特征,效果明顯!
注意調(diào)整參數(shù):
- 可以適當(dāng)增加Transformer編碼器層數(shù)和隱藏層的維度、多頭注意力頭數(shù),微調(diào)學(xué)習(xí)率;
- 調(diào)整BiGRU層數(shù)和每層神經(jīng)元個數(shù),增加更多的 epoch (注意防止過擬合)
- 可以改變滑動窗口長度(設(shè)置合適的窗口長度)
5 模型評估與可視化
5.1 結(jié)果可視化
5.2 模型評估
5.3 特征可視化
點(diǎn)擊下載:原文完整數(shù)據(jù)、Python代碼
??https://mbd.pub/o/bread/ZpWTl51w??
本文轉(zhuǎn)載自?? 建模先鋒??,作者:小蝸愛建模
