自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

TabTransformer轉(zhuǎn)換器提升多層感知機性能深度解析

譯文 精選
人工智能 機器學(xué)習(xí)
本文將介紹機器學(xué)習(xí)中的一種新型的轉(zhuǎn)換器——TabTransformer,并通過具體的實例展示該轉(zhuǎn)換器的用法。

?如今,轉(zhuǎn)換器(Transformers)成為大多數(shù)先進的自然語言處理(NLP)和計算機視覺(CV)體系結(jié)構(gòu)中的關(guān)鍵模塊。然而,表格式數(shù)據(jù)領(lǐng)域仍然主要以梯度提升決策樹(GBDT)算法為主導(dǎo)。于是,有人試圖彌合這一差距。其中,第一篇基于轉(zhuǎn)換器的表格數(shù)據(jù)建模論文是由Huang等人于2020年發(fā)表的論文《TabTransformer:使用上下文嵌入的表格數(shù)據(jù)建模》。

本文旨在提供該論文內(nèi)容的基本展示,同時將深入探討TabTransformer模型的實現(xiàn)細節(jié),并向您展示如何針對我們自己的數(shù)據(jù)來具體使用TabTransformer。

一、論文概述

上述論文的主要思想是,如果使用轉(zhuǎn)換器將常規(guī)的分類嵌入轉(zhuǎn)換為上下文嵌入,那么,常規(guī)的多層感知器(MLP)的性能將會得到顯著提高。接下來,讓我們更為深入地理解這一描述。

1.分類嵌入(Categorical Embeddings)

在深度學(xué)習(xí)模型中,使用分類特征的經(jīng)典方法是訓(xùn)練其嵌入性。這意味著,每個類別值都有一個唯一的密集型向量表示,并且可以傳遞給下一層。例如,由下圖您可以看到,每個分類特征都使用一個四維數(shù)組表示。然后,這些嵌入與數(shù)字特征串聯(lián),并用作MLP的輸入。

圖片

帶有分類嵌入的MLP

2.上下文嵌入(Contextual Embeddings)

論文作者認為,分類嵌入缺乏上下文含義,即它們并沒有對分類變量之間的任何交互和關(guān)系信息進行編碼。為了將嵌入內(nèi)容更加具體化,有人建議使用NLP領(lǐng)域當(dāng)前所使用的轉(zhuǎn)換器來實現(xiàn)這一目的。

圖片

TabTransformer轉(zhuǎn)換器中的上下文嵌入

為了以可視化方式形象地展示上述想法,我們不妨考慮下面這個訓(xùn)練后得到的上下文嵌入圖像。其中,突出顯示了兩個分類特征:關(guān)系(黑色)和婚姻狀況(藍色)。這些特征是相關(guān)的;所以,“已婚(Married)”、“丈夫(Husband)”和“妻子(Wife)”的值應(yīng)該在向量空間中彼此接近,即使它們來自不同的變量。

圖片

經(jīng)訓(xùn)練后的TabTransformer轉(zhuǎn)換器嵌入結(jié)果示例

通過上圖中經(jīng)過訓(xùn)練的上下文嵌入結(jié)果,我們可以看到,“已婚(Married)”的婚姻狀況更接近“丈夫(Husband)”和“妻子(Wife)”的關(guān)系水平,而“未結(jié)婚(non-married)”的分類值則來自右側(cè)的單獨數(shù)據(jù)簇。這種類型的上下文使這樣的嵌入更加有用,而使用簡單形式的類別嵌入技術(shù)是不可能實現(xiàn)這種效果的。

3.TabTransformer架構(gòu)

為了達到上述目的,論文作者提出了以下架構(gòu):

圖片

TabTransformer轉(zhuǎn)換器架構(gòu)示意圖

(摘取自Huang等人2020年發(fā)表的論文)

我們可以將此體系結(jié)構(gòu)分解為5個步驟:

  • 標(biāo)準(zhǔn)化數(shù)字特征并向前傳遞
  • 嵌入分類特征
  • 嵌入經(jīng)過N次轉(zhuǎn)換器塊處理,以便獲得上下文嵌入
  • 把上下文分類嵌入與數(shù)字特征進行串聯(lián)
  • 通過MLP進行串聯(lián)獲得所需的預(yù)測

雖然模型架構(gòu)非常簡單,但論文作者表示,添加轉(zhuǎn)換器層可以顯著提高計算性能。當(dāng)然,所有的“魔術(shù)”發(fā)生在這些轉(zhuǎn)換器塊內(nèi)部;所以,接下來讓我們更加詳細地研究一下其中的實現(xiàn)過程。

4.轉(zhuǎn)換器

圖片

轉(zhuǎn)換器(Transformer)架構(gòu)示意

(選自Vaswani等人于2017年發(fā)表的論文)

您可能以前見過轉(zhuǎn)換器架構(gòu),但為了快速介紹起見,請記住該轉(zhuǎn)換器是由編碼器和解碼器兩部分組成(見上圖)。對于TabTransformer,我們只關(guān)心將輸入的嵌入內(nèi)容上下文化的編碼器部分(解碼器部分將這些嵌入內(nèi)容轉(zhuǎn)換為最終輸出結(jié)果)。但它到底是如何做到的呢?答案是——多頭注意力機制。

5.多頭注意力機制(Multi-head-attention)

引用我最喜歡的關(guān)于注意力機制的文章的描述,是這樣的:

“自我關(guān)注(self attention)背后的關(guān)鍵概念是,這種機制允許神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)如何在輸入序列的各個片段之間以最好的路由方案進行信息調(diào)度?!?/p>

換句話說,自我關(guān)注(self-attention)有助于模型找出在表示某個單詞/類別時,輸入的哪些部分更重要,哪些部分相對不重要。為此,我強烈建議您閱讀一下上面引用的這篇文章,以便對自我關(guān)注為什么如此有效有一個更為直觀的理解。

圖片

多頭注意力機制

(選自Vaswani等人于2017年發(fā)表的論文)

注意力是通過3個學(xué)習(xí)過的矩陣來計算的——Q、K和V,它們代表查詢(Query)、鍵(Key)和值(Value)。首先,我們將矩陣Q和K相乘得到注意力矩陣。該矩陣被縮放并通過softmax層傳遞。然后,我們將其乘以V矩陣,得出最終值。為了更直觀地理解起見,請考慮下面的示意圖,它顯示了我們?nèi)绾问褂镁仃嘠、K和V實現(xiàn)從輸入嵌入轉(zhuǎn)換到上下文嵌入。

圖片

自我關(guān)注流程可視化

通過重復(fù)該過程h次(使用不同的Q、K、V矩陣),我們就能夠得到多個上下文嵌入,它們形成我們最終的多頭注意力。

6.簡短回顧

讓我們總結(jié)一下上面所介紹的內(nèi)容:

  • 簡單的分類嵌入不包含上下文信息
  • 通過轉(zhuǎn)換器編碼器傳遞分類嵌入,我們就能夠?qū)⑶度肷舷挛幕?/li>
  • 轉(zhuǎn)換器部分能夠?qū)⑶度肷舷挛幕?,因為它使用了多頭注意力機制
  • 多頭注意力機制在編碼變量時使用矩陣Q、K和V來尋找有用的相互作用和相關(guān)性信息
  • 在TabTransformer中,被上下文化的嵌入與數(shù)字輸入相串聯(lián),并通過一個簡單的MLP輸出預(yù)測

雖然TabTransformer背后的想法很簡單,但您可能需要一些時間才能掌握注意力機制。因此,我強烈建議您重新閱讀以上解釋。如果您感到有些迷茫,請認真閱讀本文中所有建議的鏈接相關(guān)內(nèi)容。我保證,做到這些后,您就不難搞明白注意力機制的原理了。

7.試驗結(jié)果展示

圖片

結(jié)果數(shù)據(jù)(選自Huang等人2020年發(fā)表的論文)

根據(jù)報告的結(jié)果,TabTransformer轉(zhuǎn)換器優(yōu)于所有其他深度學(xué)習(xí)表格模型,此外,它接近GBDT的性能水平,這非常令人鼓舞。該模型對缺失數(shù)據(jù)和噪聲數(shù)據(jù)也相對穩(wěn)健,并且在半監(jiān)督環(huán)境下優(yōu)于其他模型。然而,這些數(shù)據(jù)集顯然不是詳盡無遺的,正如以后發(fā)表的一些相關(guān)論文所證實的那樣,仍有很大的改進空間。

二、構(gòu)建我們自己的示例程序

現(xiàn)在,讓我們最終來確定一下如何將模型應(yīng)用于我們自己的數(shù)據(jù)。接下來的示例數(shù)據(jù)取自著名的Tabular Playground Kaggle比賽。為了方便使用TabTransformer轉(zhuǎn)換器,我創(chuàng)建了一個tabtransformertf包。它可以使用如下pip命令進行安裝:

pip install tabtransformertf

并允許我們使用該模型,而無需進行大量預(yù)處理。

1.數(shù)據(jù)預(yù)處理

第一步是設(shè)置適當(dāng)?shù)臄?shù)據(jù)類型,并將我們的訓(xùn)練和驗證數(shù)據(jù)轉(zhuǎn)換為TF數(shù)據(jù)集。其中,前面安裝的軟件包中就提供了一個很好的實用程序可以做到這一點。

from tabtransformertf.utils.preprocessing import df_to_dataset, build_categorical_prep

# 設(shè)置數(shù)據(jù)類型
train_data[CATEGORICAL_FEATURES] = train_data[CATEGORICAL_FEATURES].astype(str)
val_data[CATEGORICAL_FEATURES] = val_data[CATEGORICAL_FEATURES].astype(str)

train_data[NUMERIC_FEATURES] = train_data[NUMERIC_FEATURES].astype(float)
val_data[NUMERIC_FEATURES] = val_data[NUMERIC_FEATURES].astype(float)

# 轉(zhuǎn)換成TF數(shù)據(jù)集
train_dataset = df_to_dataset(train_data[FEATURES + [LABEL]], LABEL, batch_size=1024)
val_dataset = df_to_dataset(val_data[FEATURES + [LABEL]], LABEL, shuffle=False, batch_size=1024)

下一步是為分類數(shù)據(jù)準(zhǔn)備預(yù)處理層。該分類數(shù)據(jù)稍后將被傳遞給我們的主模型。

from tabtransformertf.utils.preprocessing import build_categorical_prep

category_prep_layers = build_categorical_prep(train_data, CATEGORICAL_FEATURES)

# 輸出結(jié)果是一個字典結(jié)構(gòu),其中鍵部分是特征名稱,值部分是StringLookup層
# category_prep_layers ->
# {'product_code': <keras.layers.preprocessing.string_lookup.StringLookup at 0x7f05d28ee4e0>,
# 'attribute_0': <keras.layers.preprocessing.string_lookup.StringLookup at 0x7f05ca4fb908>,
# 'attribute_1': <keras.layers.preprocessing.string_lookup.StringLookup at 0x7f05ca4da5f8>}

這就是預(yù)處理!現(xiàn)在,我們可以開始構(gòu)建模型了。

2.構(gòu)建TabTransformer模型

初始化模型很容易。其中,有幾個參數(shù)需要指定,但最重要的幾個參數(shù)是:embeding_dim、depth和heads。所有參數(shù)都是在超參數(shù)調(diào)整后選擇的。

from tabtransformertf.models.tabtransformer import TabTransformer

tabtransformer = TabTransformer(
numerical_features = NUMERIC_FEATURES, # 帶有數(shù)字特征名稱的列表
categorical_features = CATEGORICAL_FEATURES, # 帶有分類特征名稱的列表
categorical_lookup=category_prep_layers, # 帶StringLookup層的Dict
numerical_discretisers=None, # None代表我們只是簡單地傳遞數(shù)字特征
embedding_dim=32, # 嵌入維數(shù)
out_dim=1, # Dimensionality of output (binary task)
out_activatinotallow='sigmoid', # 輸出層激活
depth=4, # 轉(zhuǎn)換器塊層的個數(shù)
heads=8, # 轉(zhuǎn)換器塊中注意力頭的個數(shù)
attn_dropout=0.1, # 在轉(zhuǎn)換器塊中的丟棄率
ff_dropout=0.1, # 在最后MLP中的丟棄率
mlp_hidden_factors=[2, 4], # 我們?yōu)槊恳粚觿澐肿罱K嵌入的因子
use_column_embedding=True, #如果我們想使用列嵌入,設(shè)置此項為真
)

# 模型運行中摘要輸出:
# 總參數(shù)個數(shù): 1,778,884
# 可訓(xùn)練的參數(shù)個數(shù): 1,774,064
# 不可訓(xùn)練的參數(shù)個數(shù): 4,820

模型初始化后,我們可以像任何其他Keras模型一樣安裝它。訓(xùn)練參數(shù)也可以調(diào)整,所以可以隨意調(diào)整學(xué)習(xí)速度和提前停止。

LEARNING_RATE = 0.0001
WEIGHT_DECAY = 0.0001
NUM_EPOCHS = 1000

optimizer = tfa.optimizers.AdamW(
learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)

tabtransformer.compile(
optimizer = optimizer,
loss = tf.keras.losses.BinaryCrossentropy(),
metrics= [tf.keras.metrics.AUC(name="PR AUC", curve='PR')],
)

out_file = './tabTransformerBasic'
checkpoint = ModelCheckpoint(
out_file, mnotallow="val_loss", verbose=1, save_best_notallow=True, mode="min"
)
early = EarlyStopping(mnotallow="val_loss", mode="min", patience=10, restore_best_weights=True)
callback_list = [checkpoint, early]

history = tabtransformer.fit(
train_dataset,
epochs=NUM_EPOCHS,
validation_data=val_dataset,
callbacks=callback_list
)

3.評價

競賽中最關(guān)鍵的指標(biāo)是ROC AUC。因此,讓我們將其與PR AUC指標(biāo)一起輸出來評估一下模型的性能。

val_preds = tabtransformer.predict(val_dataset)

print(f"PR AUC: {average_precision_score(val_data['isFraud'], val_preds.ravel())}")
print(f"ROC AUC: {roc_auc_score(val_data['isFraud'], val_preds.ravel())}")

# PR AUC: 0.26
# ROC AUC: 0.58

您也可以自己給測試集評分,然后將結(jié)果值提交給Kaggle官方。我現(xiàn)在選擇的這個解決方案使我躋身前35%,這并不壞,但也不太好。那么,為什么TabTransfromer在上述方案中表現(xiàn)不佳呢?可能有以下幾個原因:

  • 數(shù)據(jù)集太小,而深度學(xué)習(xí)模型以需要大量數(shù)據(jù)著稱
  • TabTransformer很容易在表格式數(shù)據(jù)示例領(lǐng)域出現(xiàn)過擬合
  • 沒有足夠的分類特征使模型有用

三、結(jié)論

本文探討了TabTransformer背后的主要思想,并展示了如何使用Tabtransformertf包來具體應(yīng)用此轉(zhuǎn)換器。

歸納起來看,TabTransformer的確是一種有趣的體系結(jié)構(gòu),它在當(dāng)時的表現(xiàn)明顯優(yōu)于大多數(shù)深度表格模型。它的主要優(yōu)點是將分類嵌入語境化,從而增強其表達能力。它使用在分類特征上的多頭注意力機制來實現(xiàn)這一點,而這是在表格數(shù)據(jù)領(lǐng)域使用轉(zhuǎn)換器的第一個應(yīng)用實例。

TabTransformer體系結(jié)構(gòu)的一個明顯缺點是,數(shù)字特征被簡單地傳遞到最終的MLP層。因此,它們沒有語境化,它們的價值也沒有在分類嵌入中得到解釋。在下一篇文章中,我將探討如何修復(fù)此缺陷并進一步提高性能。

譯者介紹

朱先忠,51CTO社區(qū)編輯,51CTO專家博客、講師,濰坊一所高校計算機教師,自由編程界老兵一枚。

原文鏈接:https://towardsdatascience.com/transformers-for-tabular-data-tabtransformer-deep-dive-5fb2438da820?source=collection_home---------4----------------------------

責(zé)任編輯:武曉燕 來源: 51CTO技術(shù)棧
相關(guān)推薦

2010-02-04 10:05:28

Dalvik虛擬機

2020-10-14 10:25:20

深度學(xué)習(xí)機器學(xué)習(xí)神經(jīng)網(wǎng)絡(luò)

2010-06-10 14:33:03

協(xié)議轉(zhuǎn)換器

2013-06-13 15:10:27

.NET代碼轉(zhuǎn)換

2010-06-10 14:44:33

協(xié)議轉(zhuǎn)換器

2017-04-14 08:58:55

深度學(xué)習(xí)感知機深度網(wǎng)絡(luò)

2010-06-10 15:03:13

協(xié)議轉(zhuǎn)換器

2020-09-08 13:02:00

Python神經(jīng)網(wǎng)絡(luò)感知器

2009-12-28 13:38:35

WPF類型轉(zhuǎn)換器

2021-09-04 17:26:31

SpringBoot轉(zhuǎn)換器參數(shù)

2016-10-25 13:46:25

深度學(xué)習(xí)機器學(xué)習(xí)性能提升

2025-03-06 10:41:32

2009-09-11 12:41:41

C#類型轉(zhuǎn)換

2010-06-10 14:38:30

協(xié)議轉(zhuǎn)換器

2010-06-10 14:49:07

協(xié)議轉(zhuǎn)換器

2021-09-08 07:44:26

人工智能keras神經(jīng)網(wǎng)絡(luò)

2009-06-17 11:31:23

Open XMLUOF文檔

2024-05-11 08:47:36

Python工具裝飾器

2024-08-23 08:57:13

PyTorch視覺轉(zhuǎn)換器ViT

2009-12-14 17:15:15

協(xié)議轉(zhuǎn)換器
點贊
收藏

51CTO技術(shù)棧公眾號