解密GCN,手把手教你用PyTorch實(shí)現(xiàn)圖卷積網(wǎng)絡(luò)
圖神經(jīng)網(wǎng)絡(luò)(GNNs,Graph Neural Networks)是一類專為圖結(jié)構(gòu)數(shù)據(jù)設(shè)計(jì)的強(qiáng)大神經(jīng)網(wǎng)絡(luò),擅長捕捉數(shù)據(jù)之間的復(fù)雜聯(lián)系和關(guān)系。
相較于傳統(tǒng)神經(jīng)網(wǎng)絡(luò),GNN在處理相互關(guān)聯(lián)的數(shù)據(jù)點(diǎn)時更具優(yōu)勢,比如在社交網(wǎng)絡(luò)分析、分子結(jié)構(gòu)建?;蚪煌ㄏ到y(tǒng)優(yōu)化等領(lǐng)域,GNN能夠發(fā)揮出卓越的性能。
1 GNN概述
圖神經(jīng)網(wǎng)絡(luò)是近年來新興的一類深度學(xué)習(xí)模型,擅長處理圖形數(shù)據(jù)。
傳統(tǒng)神經(jīng)網(wǎng)絡(luò)處理的是像數(shù)字列表這樣的簡單數(shù)據(jù),而圖神經(jīng)網(wǎng)絡(luò)能處理更復(fù)雜的圖形數(shù)據(jù),比如由很多點(diǎn)(稱為節(jié)點(diǎn))和連接這些點(diǎn)的線(稱為邊)組成的圖形,并且能從這些圖形中找出重要的信息。
其核心機(jī)制是讓圖中的每個節(jié)點(diǎn)通過與鄰近節(jié)點(diǎn)的信息交換,來學(xué)習(xí)自己在整體圖形中的位置和特性。這種基于信息傳遞的方法,讓圖神經(jīng)網(wǎng)絡(luò)能夠快速捕捉到圖形里的結(jié)構(gòu)和關(guān)系。
這種技術(shù)在很多領(lǐng)域都大放異彩,比如社交網(wǎng)絡(luò)分析、分子結(jié)構(gòu)預(yù)測、知識圖譜構(gòu)建等等。
隨著科學(xué)家們不斷地研究和創(chuàng)新,圖神經(jīng)網(wǎng)絡(luò)也在蓬勃發(fā)展,衍生出多種新模型,為機(jī)器學(xué)習(xí)在圖形數(shù)據(jù)領(lǐng)域的應(yīng)用開辟了新的可能性。
2 圖卷積網(wǎng)絡(luò)(Graph Convolutional Networks)
簡單來說,圖卷積網(wǎng)絡(luò)(GCN)跟傳統(tǒng)神經(jīng)網(wǎng)絡(luò)一樣,是由多層結(jié)構(gòu)堆疊而成的。
在深度學(xué)習(xí)中,圖卷積網(wǎng)絡(luò)(GCN)的核心是圖卷積層,其工作機(jī)制與卷積神經(jīng)網(wǎng)絡(luò)(CNN)的卷積層頗為相似。
在CNN中,卷積層負(fù)責(zé)捕捉圖像中局部區(qū)域的像素信息,這個過程稱之為“感受野”(Receptive Field),通過它,我們可以提取出圖像的簡化和低維特征。
GCN層的工作原理與之類似,不過不是處理像素,而是處理圖中的節(jié)點(diǎn)信息。它通過收集每個節(jié)點(diǎn)及其相鄰節(jié)點(diǎn)的信息,來構(gòu)建節(jié)點(diǎn)的表示,從而捕捉圖中的結(jié)構(gòu)特征。
3 推導(dǎo)GCN方程式
來聊聊圖卷積網(wǎng)絡(luò)(GNN)的數(shù)學(xué)原理。
首先,GNN的輸入是一個圖,這個圖可以用節(jié)點(diǎn)特征的矩陣和鄰接矩陣來表示。鄰接矩陣?yán)锏?代表兩個節(jié)點(diǎn)之間有連接,0則表示沒有連接。
這個例子的鄰接矩陣是這樣的:
節(jié)點(diǎn) 1 -- 節(jié)點(diǎn) 2
|
節(jié)點(diǎn) 3
當(dāng)我們用A乘以節(jié)點(diǎn)特征矩陣X,得到的結(jié)果是每個節(jié)點(diǎn)的鄰居對每個特征的貢獻(xiàn)總和。簡單來說,就是把每個節(jié)點(diǎn)i的鄰居j的特征加起來:
然而,我們不應(yīng)忽視節(jié)點(diǎn)自身的特征。為了將節(jié)點(diǎn)自身的特征也考慮進(jìn)來,可以在鄰接矩陣A的對角線上增加1,這在數(shù)學(xué)上相當(dāng)于引入了單位矩陣I。
這樣:
但是,還有一個問題:節(jié)點(diǎn)的鄰居數(shù)量可能不一樣。有的節(jié)點(diǎn)有幾百個鄰居,有的可能只有一兩個。為了公平起見,我們需要對總和進(jìn)行歸一化。
一種方法是用每個節(jié)點(diǎn)的鄰居數(shù)(也就是節(jié)點(diǎn)的度)來除以這個總和。可以創(chuàng)建一個對角線上是節(jié)點(diǎn)度的對角度矩陣D,然后歸一化方程:
這樣:
直觀地說,行歸一化就是取鄰居特征的平均值,而列歸一化則考慮了鄰居的鄰居數(shù)。
為了兩者兼顧,采用對稱歸一化:
這考慮了當(dāng)前節(jié)點(diǎn)的鄰居數(shù)和鄰居的鄰居數(shù)。
這樣一來,我們的方程式就越來越完整了!
最后,我們需要一些參數(shù)來訓(xùn)練機(jī)器學(xué)習(xí)模型,就像在線性回歸中那樣,可以簡單地插入一個權(quán)重矩陣。
而且,我們知道添加非線性可以提供更好的特征表示,所以還可以在上面加一個ReLU激活函數(shù)。
最后:
4 PyTorch 實(shí)現(xiàn)
接下來,看看如何在 PyTorch 中實(shí)現(xiàn)圖卷積網(wǎng)絡(luò)。
首先,在類的初始化方法__init__中,我們會設(shè)置好鄰接矩陣A、度矩陣D和權(quán)重矩陣W。
然后,在模型的前向傳播過程中,利用這些組件來構(gòu)建節(jié)點(diǎn)的新特征矩陣H。
import torch
import torch.nn as nn
import torch.nn.functional as F
class GCNLayer(nn.Module):
"""
GCN 層
參數(shù):
input_dim (int): 輸入的維度
output_dim (int): 輸出的維度(softmax 分布)
A (torch.Tensor): 2D 鄰接矩陣
"""
def __init__(self, input_dim: int, output_dim: int, A: torch.Tensor):
super(GCNLayer, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.A = A
# A_hat = A + I
self.A_hat = self.A + torch.eye(self.A.size(0))
# 創(chuàng)建對角度矩陣 D
self.ones = torch.ones(input_dim, input_dim)
self.D = torch.matmul(self.A.float(), self.ones.float())
# 提取對角元素
self.D = torch.diag(self.D)
# 創(chuàng)建一個新張量,對角線上是元素,其他地方是零
self.D = torch.diag_embed(self.D)
# 創(chuàng)建 D^{-1/2}
self.D_neg_sqrt = torch.diag_embed(torch.diag(torch.pow(self.D, -0.5)))
# 初始化權(quán)重矩陣作為參數(shù)
self.W = nn.Parameter(torch.rand(input_dim, output_dim))
def forward(self, X: torch.Tensor):
# D^-1/2 * (A_hat * D^-1/2)
support_1 = torch.matmul(self.D_neg_sqrt, torch.matmul(self.A_hat, self.D_neg_sqrt))
# (D^-1/2 * A_hat * D^-1/2) * (X * W)
support_2 = torch.matmul(support_1, torch.matmul(X, self.W))
# ReLU(D^-1/2 * A_hat * D^-1/2 * X * W)
H = F.relu(support_2)
return H
if __name__ == "__main__":
# 示例用法
input_dim = 3 # 假設(shè)輸入維度是 3
output_dim = 2 # 假設(shè)輸出維度是 2
# 示例鄰接矩陣
A = torch.tensor([[1., 0., 0.],
[0., 1., 1.],
[0., 1., 1.]])
# 創(chuàng)建 GCN 層
gcn_layer = GCNLayer(input_dim, output_dim, A)
# 示例輸入特征矩陣
X = torch.tensor([[1., 2., 3.],
[4., 5., 6.],
[7., 8., 9.]])
# 前向傳遞
output = gcn_layer(X)
print(output)
# tensor([[ 6.3438, 5.8004],
# [13.3558, 13.7459],
# [15.5052, 16.0948]], grad_fn=<ReluBackward0>)
本文轉(zhuǎn)載自 ??AI科技論談??,作者: AI科技論談
