聊聊 KAN、KAN 卷積結(jié)合注意力機(jī)制!
第一類 基礎(chǔ)線性層替換
KAN 層替換線性層 Linear:
更新關(guān)于LSTM、TCN、Transformer模型中用 KAN 層替換線性層的故障分類模型。
KAN 的準(zhǔn)確率要優(yōu)于 MLP,我們可以進(jìn)一步嘗試在常規(guī)模型的最后一層線性層都替換為 KAN 層來進(jìn)行對比;KAN 卷積比常規(guī)卷積準(zhǔn)確率有略微的提升!
第二類 并行融合模型
KAN卷積、GRU并行:
故障信號(hào)同時(shí)送入并行模型,分支一經(jīng)過 KAN卷積進(jìn)行學(xué)習(xí),分支二利用 GRU 提取故障時(shí)域特征,然后并行特征進(jìn)行堆疊融合,來增強(qiáng)故障信號(hào)特征提取能力。
2.1 定義 KANConv-GRU 分類網(wǎng)絡(luò)模型
2.2 設(shè)置參數(shù),訓(xùn)練模型
50個(gè)epoch,訓(xùn)練集、驗(yàn)證集準(zhǔn)確率97%,用改進(jìn) KANConv-GRU 并行網(wǎng)絡(luò)分類效果顯著,模型能夠充分提取軸承故障信號(hào)中的故障特征,收斂速度快,性能優(yōu)越,精度高,效果明顯!
2.3 模型評(píng)估
準(zhǔn)確率、精確率、召回率、F1 Score
故障十分類混淆矩陣:
第三類 結(jié)合注意力機(jī)制
3.1 KAN 結(jié)合自注意力機(jī)制:
我們創(chuàng)造性的提出在利用 KAN 層提取的特征作為自注意力機(jī)制的輸入,來進(jìn)一步增加非線性能力,具體步驟如下:
1.輸入嵌入:
首先使用 unsqueeze 將輸入從 ([batch_size, input_dim]) 擴(kuò)展為 ([batch_size, 1, input_dim]),以便兼容后續(xù)的操作。
使用 input_proj 線性層將輸入從 ([batch_size, 1, input_dim]) 映射到 ([batch_size, 1, embed_dim])。
2.查詢-鍵-值投影:
- 使用 qkv_proj 線性層將輸入映射到查詢、鍵和值的嵌入空間,結(jié)果形狀為 ([batch_size, 1, embed_dim * 3])。
3. 重塑和轉(zhuǎn)置:
- 將 qkv 重塑為 ([batch_size, 1, 3, num_heads, head_dim])。
- 然后將維度重新排列為 ([3, batch_size, num_heads, 1, head_dim])。
4.計(jì)算注意力權(quán)重和輸出:
- 通過縮放的點(diǎn)積計(jì)算注意力權(quán)重,并對其進(jìn)行 softmax 歸一化。
- 使用注意力權(quán)重與值進(jìn)行加權(quán)求和,得到注意力輸出。
5.輸出重塑和映射:
- 將注意力輸出重新排列并重塑為 ([batch_size, 1, embed_dim])。
- 使用 o_proj 線性層將自注意力機(jī)制的輸出從 ([batch_size, 1, embed_dim]) 映射回 ([batch_size, 1, input_dim])。
- 使用 squeeze 移除序列長度的維度,得到最終輸出 ([batch_size, input_dim])。
?
?
通過這種方式,輸入和輸出的維度保持一致。自注意力機(jī)制通過計(jì)算每個(gè)輸入元素與其他所有輸入元素之間的相關(guān)性(注意力分?jǐn)?shù)),并利用這些相關(guān)性來加權(quán)求和,更新每個(gè)輸入元素的表示,從而捕捉到輸入序列中元素之間的依賴關(guān)系。進(jìn)一步加強(qiáng)了 KAN 輸出信息對復(fù)雜特征的建模能力。
3.2 KAN 卷積結(jié)合通道注意力機(jī)制SENet:
KAN 卷積與卷積非常相似,但不是在內(nèi)核和圖像中相應(yīng)像素之間應(yīng)用點(diǎn)積,而是對每個(gè)元素應(yīng)用可學(xué)習(xí)的非線性激活函數(shù),然后將它們相加。我們在KAN卷積的基礎(chǔ)上融合通道注意力機(jī)制,進(jìn)一步加強(qiáng)了對特征的提取能力!
從對比實(shí)驗(yàn)可以看出, 在軸承故障診斷任務(wù)中:
KAN卷積融合注意力機(jī)制后,效果提升明顯,后續(xù)還可以進(jìn)一步嘗試與其他類型的注意力機(jī)制做融合!
本文轉(zhuǎn)載自 ??建模先鋒??,作者: 小蝸愛建模
