自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

圖注意力網(wǎng)絡(luò)論文詳解和PyTorch實(shí)現(xiàn)

開(kāi)發(fā) 前端
圖神經(jīng)網(wǎng)絡(luò)(gnn)是一類功能強(qiáng)大的神經(jīng)網(wǎng)絡(luò),它對(duì)圖結(jié)構(gòu)數(shù)據(jù)進(jìn)行操作。它們通過(guò)從節(jié)點(diǎn)的局部鄰域聚合信息來(lái)學(xué)習(xí)節(jié)點(diǎn)表示(嵌入)。這個(gè)概念在圖表示學(xué)習(xí)文獻(xiàn)中被稱為“消息傳遞”。

圖神經(jīng)網(wǎng)絡(luò)(gnn)是一類功能強(qiáng)大的神經(jīng)網(wǎng)絡(luò),它對(duì)圖結(jié)構(gòu)數(shù)據(jù)進(jìn)行操作。它們通過(guò)從節(jié)點(diǎn)的局部鄰域聚合信息來(lái)學(xué)習(xí)節(jié)點(diǎn)表示(嵌入)。這個(gè)概念在圖表示學(xué)習(xí)文獻(xiàn)中被稱為“消息傳遞”。

消息(嵌入)通過(guò)多個(gè)GNN層在圖中的節(jié)點(diǎn)之間傳遞。每個(gè)節(jié)點(diǎn)聚合來(lái)自其鄰居的消息以更新其表示。這個(gè)過(guò)程跨層重復(fù),允許節(jié)點(diǎn)獲得編碼有關(guān)圖的更豐富信息的表示。gnn的一主要變體有GraphSAGE[2]、Graph Convolution Network[3]等。

圖注意力網(wǎng)絡(luò)(GAT)[1]是一類特殊的gnn,主要的改進(jìn)是消息傳遞的方式。他們引入了一種可學(xué)習(xí)的注意力機(jī)制,通過(guò)在每個(gè)源節(jié)點(diǎn)和目標(biāo)節(jié)點(diǎn)之間分配權(quán)重,使節(jié)點(diǎn)能夠在聚合來(lái)自本地鄰居的消息時(shí)決定哪個(gè)鄰居節(jié)點(diǎn)更重要,而不是以相同的權(quán)重聚合來(lái)自所有鄰居的信息。

圖注意力網(wǎng)絡(luò)在節(jié)點(diǎn)分類、鏈接預(yù)測(cè)和圖分類等任務(wù)上優(yōu)于許多其他GNN模型。他們?cè)趲讉€(gè)基準(zhǔn)圖數(shù)據(jù)集上也展示了最先進(jìn)的性能。

在這篇文章中,我們將介紹原始“Graph Attention Networks”(by Veli?kovi? )論文的關(guān)鍵部分,并使用PyTorch實(shí)現(xiàn)論文中提出的概念,這樣以更好地掌握GAT方法。

論文引言

在第1節(jié)“引言”中對(duì)圖表示學(xué)習(xí)文獻(xiàn)中的現(xiàn)有方法進(jìn)行了廣泛的回顧之后,論文介紹了圖注意網(wǎng)絡(luò)(GAT)。

然后將論文的方法與現(xiàn)有的一些方法進(jìn)行比較,并指出它們之間的一般異同,這是論文的常用格式,就不多介紹了。

GAT的架構(gòu)

本節(jié)是本文的主要部分,對(duì)圖注意力網(wǎng)絡(luò)的體系結(jié)構(gòu)進(jìn)行了詳細(xì)的闡述。為了進(jìn)一步解釋,假設(shè)所提出的架構(gòu)在一個(gè)有N個(gè)節(jié)點(diǎn)的圖上執(zhí)行(V = {V′};i=1,…,N),每個(gè)節(jié)點(diǎn)用向量h ^ (F個(gè)元素)表示,節(jié)點(diǎn)之間存在任意邊。

作者首先描述了單個(gè)圖注意力層的特征,以及它是如何運(yùn)作的(因?yàn)樗菆D注意力網(wǎng)絡(luò)的基本構(gòu)建塊)。一般來(lái)說(shuō),單個(gè)GAT層應(yīng)該將具有給定節(jié)點(diǎn)嵌入(表示)的圖作為輸入,將信息傳播到本地鄰居節(jié)點(diǎn),并輸出更新后的節(jié)點(diǎn)表示。

如上所述,ga層的所有輸入節(jié)點(diǎn)特征向量(h′)都是線性變換的(即乘以一個(gè)權(quán)重矩陣W),在PyTorch中,通常是這樣做的:

import torch from torch import nn # in_features -> F and out_feature -> 
F' in_features = ... out_feature = ... # instanciate the learnable weight matrix 
W (FxF') W = nn.Parameter(torch.empty(size=(in_features, out_feature))) # 
Initialize the weight matrix W nn.init.xavier_normal_(W) # multiply W and h (h 
is input features of all the nodes -> NxF matrix) h_transformed = torch.mm(h, 
W)

獲得了輸入節(jié)點(diǎn)特征(嵌入)的轉(zhuǎn)換版本后我們先跳到最后查看和理解GAT層的最終目標(biāo)是什么。

如論文所述,在圖注意層的最后,對(duì)于每個(gè)節(jié)點(diǎn)i,我們需要從其鄰域獲得一個(gè)新的特征向量,該特征向量更具有結(jié)構(gòu)和上下文感知性。

這是通過(guò)計(jì)算相鄰節(jié)點(diǎn)特征的加權(quán)和,然后是非線性激活函數(shù)σ來(lái)完成的。根據(jù)Graph ML文獻(xiàn),這個(gè)加權(quán)和在一般GNN層操作中也被稱為“聚合”步驟。

論文的這些權(quán)重α′?∈[0,1]是通過(guò)一種關(guān)注機(jī)制來(lái)學(xué)習(xí)和計(jì)算的,該機(jī)制表示在消息傳遞和聚合過(guò)程中節(jié)點(diǎn)i的鄰居j特征的重要性。

每一對(duì)節(jié)點(diǎn)i和它的鄰居j計(jì)算這些注意權(quán)值α′?的計(jì)算方法如下:

其中e ^?是注意力得分,在應(yīng)用Softmax函數(shù)后,有權(quán)重都會(huì)在[0,1]區(qū)間內(nèi),并且和為1。現(xiàn)在通過(guò)注意函數(shù)a(…)計(jì)算每個(gè)節(jié)點(diǎn)i和它的鄰居j∈N′之間的注意分?jǐn)?shù)e′?,如下所示:

上圖中的||表示兩個(gè)轉(zhuǎn)換后的節(jié)點(diǎn)嵌入的連接,a是大小為2 * F '(轉(zhuǎn)換后嵌入大小的兩倍)的可學(xué)習(xí)參數(shù)(即注意力參數(shù))向量。而(a1)是向量a的轉(zhuǎn)置,導(dǎo)致整個(gè)表達(dá)式a1[Wh′|| Wh?]是“a”與轉(zhuǎn)換后的嵌入的連接之間的點(diǎn)(內(nèi))積。

整個(gè)操作說(shuō)明如下:

在PyTorch中,我們采用了一種稍微不同的方法。因?yàn)橛?jì)算所有節(jié)點(diǎn)對(duì)之間的e′?然后只選擇代表節(jié)點(diǎn)之間現(xiàn)有邊的那些是更有效的。來(lái)計(jì)算所有的e′?

# instanciate the learnable attention parameter vector `a` a = 
nn.Parameter(torch.empty(size=(2 * out_feature, 1))) # Initialize the parameter 
vector `a` nn.init.xavier_normal_(a) # we obtained `h_transformed` in the 
previous code snippet # calculating the dot product of all node embeddings # and 
first half the attention vector parameters (corresponding to neighbor messages) 
source_scores = torch.matmul(h_transformed, self.a[:out_feature, :]) # 
calculating the dot product of all node embeddings # and second half the 
attention vector parameters (corresponding to target node) target_scores = 
torch.matmul(h_transformed, self.a[out_feature:, :]) # broadcast add e = 
source_scores + target_scores.T e = self.leakyrelu(e)

代碼片段的最后一部分(# broadcast add)將所有一對(duì)一的源和目標(biāo)分?jǐn)?shù)相加,得到一個(gè)包含所有e′?分?jǐn)?shù)的NxN矩陣。(下圖所示)

到目前為止,我們假設(shè)圖是完全連接的,我們計(jì)算的是所有可能的節(jié)點(diǎn)對(duì)之間的注意力得分。但是其實(shí)大部分情況下圖不可能是完全連接的,所以為了解決這個(gè)問(wèn)題,在將LeakyReLU激活應(yīng)用于注意力分?jǐn)?shù)之后,注意力分?jǐn)?shù)基于圖中現(xiàn)有的邊被屏蔽,這意味著我們只保留與現(xiàn)有邊對(duì)應(yīng)的分?jǐn)?shù)。

它可以通過(guò)給不存在邊的節(jié)點(diǎn)之間的分?jǐn)?shù)矩陣中的元素分配一個(gè)大的負(fù)分?jǐn)?shù)(近似于-∞)來(lái)完成,這樣它們對(duì)應(yīng)的注意力權(quán)重在softmax之后變?yōu)榱?還記得我們以前發(fā)的注意力掩碼么,就是一樣的道理)。

這里的注意力掩碼是通過(guò)使用圖的鄰接矩陣來(lái)實(shí)現(xiàn)的。鄰接矩陣是一個(gè)NxN矩陣,如果節(jié)點(diǎn)i和j之間存在一條邊,則在第i行和第j列處為1,在其他地方為0。因此,我們通過(guò)將鄰接矩陣的零元素賦值為-∞并在其他地方賦值為0來(lái)創(chuàng)建掩碼。然后將掩碼添加到分?jǐn)?shù)矩陣中。然后在它的行上應(yīng)用softmax函數(shù)。

connectivity_mask = -9e16 * torch.ones_like(e) # adj_mat is the N by N 
adjacency matrix e = torch.where(adj_mat > 0, e, connectivity_mask) # masked 
attention scores # attention coefficients are computed as a softmax over the 
rows # for each column j in the attention score matrix e attention = 
F.softmax(e, dim=-1)

最后,根據(jù)論文描述,在獲得注意力分?jǐn)?shù)并將其與現(xiàn)有的邊進(jìn)行掩碼遮蔽后,通過(guò)對(duì)分?jǐn)?shù)矩陣的行執(zhí)行softmax,得到注意力權(quán)重α1?。

我們通過(guò)一個(gè)完整的可視化圖過(guò)程如下:

最后就是計(jì)算節(jié)點(diǎn)嵌入的加權(quán)和:

# final node embeddings are computed as a weighted average of the features of 
its neighbors h_prime = torch.matmul(attention, h_transformed)

以上一個(gè)一個(gè)注意力頭的工作流程和原理,論文還引入了多頭的概念,其中所有操作都是通過(guò)多個(gè)并行的操作流來(lái)完成的。

Ebrahim Pichka多頭注意力和聚合過(guò)程如下圖所示:

節(jié)點(diǎn)1在其鄰域中的多頭注意力(K = 3個(gè)頭),不同的箭頭樣式和顏色表示獨(dú)立的注意力計(jì)算。將來(lái)自每個(gè)頭部的聚合特征連接或平均以獲得h '。

為了以更簡(jiǎn)潔的模塊化形式(作為PyTorch模塊)封裝實(shí)現(xiàn)并合并多頭注意力的功能,整個(gè)Graph關(guān)注層的實(shí)現(xiàn)如下:

import torch from torch import nn import torch.nn.functional as F 
################################ ### GAT LAYER DEFINITION ### 
################################ class GraphAttentionLayer(nn.Module): def 
__init__(self, in_features: int, out_features: int, n_heads: int, concat: bool = 
False, dropout: float = 0.4, leaky_relu_slope: float = 0.2): 
super(GraphAttentionLayer, self).__init__() self.n_heads = n_heads # Number of 
attention heads self.concat = concat # wether to concatenate the final attention 
heads self.dropout = dropout # Dropout rate if concat: # concatenating the 
attention heads self.out_features = out_features # Number of output features per 
node assert out_features % n_heads == 0 # Ensure that out_features is a multiple 
of n_heads self.n_hidden = out_features // n_heads else: # averaging output over 
the attention heads (Used in the main paper) self.n_hidden = out_features # A 
shared linear transformation, parametrized by a weight matrix W is applied to 
every node # Initialize the weight matrix W self.W = 
nn.Parameter(torch.empty(size=(in_features, self.n_hidden * n_heads))) # 
Initialize the attention weights a self.a = 
nn.Parameter(torch.empty(size=(n_heads, 2 * self.n_hidden, 1))) self.leakyrelu = 
nn.LeakyReLU(leaky_relu_slope) # LeakyReLU activation function self.softmax = 
nn.Softmax(dim=1) # softmax activation function to the attention coefficients 
self.reset_parameters() # Reset the parameters def reset_parameters(self): 
nn.init.xavier_normal_(self.W) nn.init.xavier_normal_(self.a) def 
_get_attention_scores(self, h_transformed: torch.Tensor): source_scores = 
torch.matmul(h_transformed, self.a[:, :self.n_hidden, :]) target_scores = 
torch.matmul(h_transformed, self.a[:, self.n_hidden:, :]) # broadcast add # 
(n_heads, n_nodes, 1) + (n_heads, 1, n_nodes) = (n_heads, n_nodes, n_nodes) e = 
source_scores + target_scores.mT return self.leakyrelu(e) def forward(self, h: 
torch.Tensor, adj_mat: torch.Tensor): n_nodes = h.shape[0] # Apply linear 
transformation to node feature -> W h # output shape (n_nodes, n_hidden * 
n_heads) h_transformed = torch.mm(h, self.W) h_transformed = 
F.dropout(h_transformed, self.dropout, training=self.training) # splitting the 
heads by reshaping the tensor and putting heads dim first # output shape 
(n_heads, n_nodes, n_hidden) h_transformed = h_transformed.view(n_nodes, 
self.n_heads, self.n_hidden).permute(1, 0, 2) # getting the attention scores # 
output shape (n_heads, n_nodes, n_nodes) e = 
self._get_attention_scores(h_transformed) # Set the attention score for 
non-existent edges to -9e15 (MASKING NON-EXISTENT EDGES) connectivity_mask = 
-9e16 * torch.ones_like(e) e = torch.where(adj_mat > 0, e, connectivity_mask) 
# masked attention scores # attention coefficients are computed as a softmax 
over the rows # for each column j in the attention score matrix e attention = 
F.softmax(e, dim=-1) attention = F.dropout(attention, self.dropout, 
training=self.training) # final node embeddings are computed as a weighted 
average of the features of its neighbors h_prime = torch.matmul(attention, 
h_transformed) # concatenating/averaging the attention heads # output shape 
(n_nodes, out_features) if self.concat: h_prime = h_prime.permute(1, 0, 
2).contiguous().view(n_nodes, self.out_features) else: h_prime = 
h_prime.mean(dim=0) return h_prime

最后將上面所有的代碼整合成一個(gè)完整的GAT模型:

class GAT(nn.Module): def __init__(self, in_features, n_hidden, n_heads, 
num_classes, concat=False, dropout=0.4, leaky_relu_slope=0.2): super(GAT, 
self).__init__() # Define the Graph Attention layers self.gat1 = 
GraphAttentionLayer( in_features=in_features, out_features=n_hidden, 
n_heads=n_heads, concat=concat, dropout=dropout, 
leaky_relu_slope=leaky_relu_slope ) self.gat2 = GraphAttentionLayer( 
in_features=n_hidden, out_features=num_classes, n_heads=1, concat=False, 
dropout=dropout, leaky_relu_slope=leaky_relu_slope ) def forward(self, 
input_tensor: torch.Tensor , adj_mat: torch.Tensor): # Apply the first Graph 
Attention layer x = self.gat1(input_tensor, adj_mat) x = F.elu(x) # Apply ELU 
activation function to the output of the first layer # Apply the second Graph 
Attention layer x = self.gat2(x, adj_mat) return F.softmax(x, dim=1) # Apply 
softmax activation function

方法對(duì)比

作者對(duì)GATs和其他一些現(xiàn)有GNN方法/架構(gòu)進(jìn)行了比較:

由于GATs能夠計(jì)算注意力權(quán)重并并行執(zhí)行局部聚合,因此它比現(xiàn)有的一些方法計(jì)算效率更高。

GATs可以在聚合消息時(shí)為節(jié)點(diǎn)的鄰居分配不同的重要性,這可以實(shí)現(xiàn)模型容量的飛躍并提高可解釋性。

GAT不考慮節(jié)點(diǎn)的完整鄰域(不需要從鄰域采樣),也不假設(shè)節(jié)點(diǎn)內(nèi)部有任何排序。

通過(guò)將偽坐標(biāo)函數(shù)設(shè)置為u(x, y) = f(x)||f(y), GAT可以重新表述為MoNet的一個(gè)特定實(shí)例(Monti等人,2016),其中f(x)表示(可能是mlp轉(zhuǎn)換的)節(jié)點(diǎn)x的特征,而||是連接;權(quán)函數(shù)為wj(u) = softmax(MLP(u))

基準(zhǔn)測(cè)試

在論文的第三部分中,作者描述了評(píng)估GAT的基準(zhǔn)、數(shù)據(jù)集和任務(wù)。然后,他們提出了他們對(duì)模型的評(píng)估結(jié)果。

論文中用作基準(zhǔn)的數(shù)據(jù)集分為兩種類型的任務(wù),轉(zhuǎn)換和歸納。

歸納學(xué)習(xí):這是一種監(jiān)督學(xué)習(xí)任務(wù),其中模型僅在一組標(biāo)記的訓(xùn)練樣例上進(jìn)行訓(xùn)練,并且在訓(xùn)練過(guò)程中完全未觀察到的樣例上對(duì)訓(xùn)練后的模型進(jìn)行評(píng)估和測(cè)試。這是一種被稱為普通監(jiān)督學(xué)習(xí)的學(xué)習(xí)類型。

傳導(dǎo)學(xué)習(xí):在這種類型的任務(wù)中,所有的數(shù)據(jù),包括訓(xùn)練、驗(yàn)證和測(cè)試實(shí)例,都在訓(xùn)練期間使用。但是在每個(gè)階段,模型只訪問(wèn)相應(yīng)的標(biāo)簽集。這意味著在訓(xùn)練期間,模型只使用由訓(xùn)練實(shí)例和標(biāo)簽產(chǎn)生的損失進(jìn)行訓(xùn)練,但測(cè)試和驗(yàn)證特征用于消息傳遞。這主要是因?yàn)槭纠写嬖诘慕Y(jié)構(gòu)和上下文信息。

論文使用四個(gè)基準(zhǔn)數(shù)據(jù)集來(lái)評(píng)估GATs,其中三個(gè)對(duì)應(yīng)于傳導(dǎo)學(xué)習(xí),另一個(gè)用作歸納學(xué)習(xí)任務(wù)。

轉(zhuǎn)導(dǎo)學(xué)習(xí)數(shù)據(jù)集,即Cora、Citeseer和Pubmed (Sen et al., 2008)數(shù)據(jù)集都是引文圖,其中節(jié)點(diǎn)是已發(fā)布的文檔,邊(連接)是它們之間的引用,節(jié)點(diǎn)特征是文檔的詞包表示的元素。

歸納學(xué)習(xí)數(shù)據(jù)集是一個(gè)蛋白質(zhì)-蛋白質(zhì)相互作用(PPI)數(shù)據(jù)集,其中包含不同人體組織的圖形(Zitnik & Leskovec, 2017)。數(shù)據(jù)集的詳細(xì)描述如下:

作者報(bào)告了四個(gè)基準(zhǔn)測(cè)試的以下性能,顯示了GATs與現(xiàn)有GNN方法的可比結(jié)果。

總結(jié)

通過(guò)閱讀這篇文章并試用代碼,希望你能夠?qū)ATs的工作原理以及如何在實(shí)際場(chǎng)景中應(yīng)用它們有一個(gè)扎實(shí)的理解。

責(zé)任編輯:華軒 來(lái)源: DeepHub IMBA
相關(guān)推薦

2024-07-16 14:15:09

2021-02-02 14:47:58

微軟PyTorch可視化

2021-08-04 10:17:19

開(kāi)發(fā)技能代碼

2024-04-03 14:31:08

大型語(yǔ)言模型PytorchGQA

2024-08-12 08:40:00

PyTorch代碼

2025-02-24 13:00:00

YOLOv12目標(biāo)檢測(cè)Python

2024-09-19 10:07:41

2017-07-07 15:46:38

循環(huán)神經(jīng)網(wǎng)絡(luò)視覺(jué)注意力模擬

2018-08-26 22:25:36

自注意力機(jī)制神經(jīng)網(wǎng)絡(luò)算法

2025-02-19 15:30:00

模型訓(xùn)練數(shù)據(jù)

2020-09-17 12:40:54

神經(jīng)網(wǎng)絡(luò)CNN機(jī)器學(xué)習(xí)

2025-02-25 09:40:00

模型數(shù)據(jù)AI

2018-05-03 16:27:29

RNN神經(jīng)網(wǎng)絡(luò)ResNet

2024-10-09 15:30:00

2024-11-04 10:40:00

AI模型

2023-10-07 07:21:42

注意力模型算法

2011-07-07 13:12:58

移動(dòng)設(shè)備端設(shè)計(jì)注意力

2025-02-10 00:00:55

MHAValue向量

2024-12-04 15:55:21

2024-06-28 08:04:43

語(yǔ)言模型應(yīng)用
點(diǎn)贊
收藏

51CTO技術(shù)棧公眾號(hào)