TabR:檢索增強(qiáng)能否讓深度學(xué)習(xí)在表格數(shù)據(jù)上超過梯度增強(qiáng)模型?
這是一篇7月新發(fā)布的論文,他提出了使用自然語言處理的檢索增強(qiáng)Retrieval Augmented技術(shù),目的是讓深度學(xué)習(xí)在表格數(shù)據(jù)上超過梯度增強(qiáng)模型。
檢索增強(qiáng)一直是NLP中研究的一個(gè)方向,但是引入了檢索增強(qiáng)的表格深度學(xué)習(xí)模型在當(dāng)前實(shí)現(xiàn)與非基于檢索的模型相比幾乎沒有改進(jìn)。所以論文作者提出了一個(gè)新的TabR模型,模型通過增加一個(gè)類似注意力的檢索組件來改進(jìn)現(xiàn)有模型。據(jù)說,這種注意力機(jī)制的細(xì)節(jié)可以顯著提高表格數(shù)據(jù)任務(wù)的性能。TabR模型在表格數(shù)據(jù)上的平均性能優(yōu)于其他DL模型,在幾個(gè)數(shù)據(jù)集上設(shè)置了新的標(biāo)準(zhǔn),在某些情況下甚至超過了GBDT模型,特別是在通常被視為GBDT友好的數(shù)據(jù)集上。
TabR
表格數(shù)據(jù)集通常被表示為特征和標(biāo)簽對{(xi, yi)},其中xi和yi分別是第i個(gè)對象的特征和標(biāo)簽。一般有三種類型的主要任務(wù):二元分類、多類分類和回歸。
對于表格數(shù)據(jù)我們會將數(shù)據(jù)集分為訓(xùn)練部分、驗(yàn)證部分和測試部分,模型對“輸入”或“目標(biāo)”對象進(jìn)行預(yù)測。當(dāng)使用檢索技術(shù)時(shí),檢索是在一組“上下文候選”或“候選”中完成的,被檢索的對象稱為“上下文對象”或簡稱為“上下文”。同一組候選對象用于所有輸入對象。
論文的實(shí)驗(yàn)設(shè)置涉及調(diào)優(yōu)和評估協(xié)議,其中需要超參數(shù)調(diào)優(yōu)和基于驗(yàn)證集性能的早期停止。然后在15個(gè)隨機(jī)種子的平均測試集上測試最佳超參數(shù),并在算法比較中考慮標(biāo)準(zhǔn)偏差。
論文作者的目標(biāo)是將檢索功能集成到傳統(tǒng)的前饋網(wǎng)絡(luò)中。該過程包括通過編碼器傳遞目標(biāo)對象及其上下文候選者,然后檢索組件會對目標(biāo)對象進(jìn)行的表示,最后預(yù)測器進(jìn)行預(yù)測。
編碼器和預(yù)測器模塊很簡單簡單,因?yàn)樗鼈儾皇枪ぷ鞯闹攸c(diǎn)。檢索模塊對目標(biāo)對象的表示以及候選對象的表示和標(biāo)簽進(jìn)行操作。這個(gè)模塊可以看作是注意力機(jī)制的一般化版本。
這個(gè)過程包括幾個(gè)步驟:
- 如果編碼器包含至少一個(gè)塊,則將表示進(jìn)行規(guī)范化;
- 根據(jù)與目標(biāo)對象的相似性定義上下文對象;
- 基于softmax函數(shù)對上下文對象的相似性分配權(quán)重;
- 定義上下文對象的值;
- 使用值和權(quán)重輸出加權(quán)聚合。
上下文大小設(shè)置為一個(gè)較大的值96,softmax函數(shù)會自動(dòng)選擇有效的上下文大小。
檢索模塊是最重要的部分
作者探討了檢索模塊的不同實(shí)現(xiàn),特別是相似度模塊和值模塊。并且說明了是通過一下幾個(gè)步驟得到最終的模型。
1、作者評估了傳統(tǒng)注意力的相似性和值模塊,發(fā)現(xiàn)該配置與多層感知器(MLP)相似,因此不能證明使用檢索組件是合理的。
2、然后他們將上下文標(biāo)簽添加到值模塊中,但發(fā)現(xiàn)這并沒有改進(jìn),這表明傳統(tǒng)注意力的相似性模塊可能是瓶頸。
3、為了改進(jìn)相似度模塊,作者刪除了查詢的概念,并用L2距離替換點(diǎn)積。這種調(diào)整使得幾個(gè)數(shù)據(jù)集上性能的顯著躍升。
4、值模塊也進(jìn)行改進(jìn),靈感來自最近提出的DNNR(用于回歸問題的kNN算法的廣義版本)。新的值模塊帶來了進(jìn)一步的性能改進(jìn)。
5、最后,作者創(chuàng)建模型TabR。在相似性模塊中省略縮放項(xiàng),不包括目標(biāo)對象在其自身的上下文中(使用交叉注意),平均而言會得到更好的結(jié)果。
生成的TabR模型為基于檢索的表格深度學(xué)習(xí)問題提供了一種健壯的方法。
作者也強(qiáng)調(diào)了TabR模型的兩個(gè)主要局限性:
與所有檢索增強(qiáng)模型一樣,從應(yīng)用程序的角度來看,使用真實(shí)的訓(xùn)練對象進(jìn)行預(yù)測可能會帶來一些問題,例如隱私和道德問題。
TabR的檢索組件雖然比以前的工作更有效,但會產(chǎn)生明顯的開銷。所以它可能無法有效地?cái)U(kuò)展以處理真正的大型數(shù)據(jù)集。
實(shí)驗(yàn)結(jié)果
作者將TabR與現(xiàn)有的檢索增強(qiáng)解決方案和最先進(jìn)的參數(shù)模型進(jìn)行比較。除了完全配置的TabR,他們還使用了一個(gè)簡化版本,TabR- s,它不使用特征嵌入,只有一個(gè)線性編碼器和一個(gè)塊預(yù)測器。
與全參數(shù)深度學(xué)習(xí)模型的比較表明,TabR在幾個(gè)數(shù)據(jù)集上優(yōu)于大多數(shù)模型,除了MI數(shù)據(jù)集,在其他數(shù)據(jù)集也很有競爭力。在許多數(shù)據(jù)集上,它比多層感知器(MLP)提供了顯著的提升。
與GBDT模型相比,調(diào)整后的TabR在幾個(gè)數(shù)據(jù)集上也有明顯的改進(jìn),并且在其他數(shù)據(jù)集上保持競爭力(除了MI數(shù)據(jù)集),并且TabR的平均表現(xiàn)也優(yōu)于GBDT模型。
總之,TabR將自己確立為表格數(shù)據(jù)問題的強(qiáng)大深度學(xué)習(xí)解決方案,展示了強(qiáng)大的平均性能,并在幾個(gè)數(shù)據(jù)集上設(shè)置了新的基準(zhǔn)。它的基于檢索的方法具有良好的潛力,并且在某些數(shù)據(jù)集上可以明顯優(yōu)于梯度增強(qiáng)的決策樹。
一些研究
1、凍結(jié)上下文以更快地訓(xùn)練TabR
在TabR的原始實(shí)現(xiàn)中,由于需要對所有候選對象進(jìn)行編碼并計(jì)算每個(gè)訓(xùn)練批次的相似度,因此在大型數(shù)據(jù)集上的訓(xùn)練可能很慢。作者提到在完整的“Weather prediction”數(shù)據(jù)集上訓(xùn)練一個(gè)TabR需要18個(gè)多小時(shí),該數(shù)據(jù)集有300多萬個(gè)對象。
作者注意到在訓(xùn)練過程中,平均訓(xùn)練對象的上下文(即,根據(jù)相似度模塊S,前m個(gè)候選對象及其分布)趨于穩(wěn)定,這為優(yōu)化提供了機(jī)會。在一定數(shù)量的epoch之后,他們提出了一個(gè)“上下文凍結(jié)”,即最后一次計(jì)算所有訓(xùn)練對象的最新上下文,然后在其余的訓(xùn)練中重用。
這種簡單的技術(shù)可以加速TabR的訓(xùn)練,并且不會在指標(biāo)上造成重大損失。在上面提到的完整的“Weather prediction”數(shù)據(jù)集上,它使速度提高了近7倍(將訓(xùn)練時(shí)間從18小時(shí)9分鐘減少到3小時(shí)15分鐘),同時(shí)仍然保持有競爭力的均方根誤差(RMSE)值。
2、用新的訓(xùn)練數(shù)據(jù)更新TabR不需要再訓(xùn)練(初步探索)
在現(xiàn)實(shí)世界的場景中,在機(jī)器學(xué)習(xí)模型已經(jīng)訓(xùn)練完之后,通常會收到新的、看不見的訓(xùn)練數(shù)據(jù)。作者測試了TabR在不需要再訓(xùn)練的情況下合并新數(shù)據(jù)的能力,方法是將新數(shù)據(jù)添加到候選檢索集中。
他們使用完整的“Weather prediction”數(shù)據(jù)集進(jìn)行了這個(gè)測試。結(jié)果表明在線更新可以有效地將新數(shù)據(jù)整合到訓(xùn)練好的TabR模型中。這種方法可以通過在數(shù)據(jù)子集上訓(xùn)練模型并從完整數(shù)據(jù)集中檢索模型來將TabR擴(kuò)展到更大的數(shù)據(jù)集。
3、使用檢索組件增強(qiáng)XGBoost
作者試圖通過結(jié)合類似于TabR中的檢索組件來提高XGBoost的性能。這種方法涉及在原始特征空間中找到與給定輸入對象最接近的96個(gè)訓(xùn)練對象(匹配TabR的上下文大小)。然后對這些最近鄰的特征和標(biāo)簽進(jìn)行平均,將標(biāo)簽按原樣用于回歸任務(wù),并將其轉(zhuǎn)換為用于分類任務(wù)的單一編碼。
將這些平均數(shù)據(jù)與目標(biāo)對象的特征和標(biāo)簽連接起來,形成XGBoost的新輸入向量。但是該策略并沒有顯著提高XGBoost的性能。試圖改變鄰居的數(shù)量也沒有產(chǎn)生任何顯著的改善。
總結(jié)
深度學(xué)習(xí)模型在表格類數(shù)據(jù)上一直沒有超越梯度增強(qiáng)模型,TabR還在這個(gè)方向繼續(xù)努力。