終于把圖神經(jīng)網(wǎng)絡(luò)算法搞懂了!?。?/h1>
今天給大家分享一個(gè)強(qiáng)大的算法模型,GNN。
圖神經(jīng)網(wǎng)絡(luò)(GNN)是一類專門處理圖結(jié)構(gòu)數(shù)據(jù)的深度學(xué)習(xí)模型。
在傳統(tǒng)的深度學(xué)習(xí)中,輸入數(shù)據(jù)通常是結(jié)構(gòu)化的(如圖像、文本、時(shí)間序列等),這些數(shù)據(jù)都可以表示為一個(gè)規(guī)則的網(wǎng)格或序列。然而,圖數(shù)據(jù)具有更加復(fù)雜的非歐幾里得結(jié)構(gòu),節(jié)點(diǎn)和邊之間可能沒有固定的順序,也可能存在不同的連接模式。
GNN 通過設(shè)計(jì)一種特定的機(jī)制來學(xué)習(xí)和表示圖結(jié)構(gòu)數(shù)據(jù)中的節(jié)點(diǎn)、邊和全圖的信息。
圖片
圖的基本組成
在討論 GNN 之前,先了解一下圖的基本構(gòu)成。
- 節(jié)點(diǎn)(Node),圖中的基本元素,通常表示圖中實(shí)體或?qū)ο蟆?/li>
- 邊(Edge),連接節(jié)點(diǎn)之間的關(guān)系,可能是有向的或無向的。
- 鄰接關(guān)系(Adjacency),描述哪些節(jié)點(diǎn)之間通過邊相連。鄰接矩陣通常用于表示這種關(guān)系。
- 節(jié)點(diǎn)特征(Node Feature),每個(gè)節(jié)點(diǎn)可能有附加的屬性或特征,如社交網(wǎng)絡(luò)中用戶的年齡、性別等。
- 邊特征(Edge Feature),邊也可以有特征,例如在交通網(wǎng)絡(luò)中,邊可能表示道路的長度或交通流量。
圖片
圖神經(jīng)網(wǎng)絡(luò)的核心思想
GNN 的核心思想是利用圖的拓?fù)浣Y(jié)構(gòu),通過節(jié)點(diǎn)間的鄰接關(guān)系來傳播信息和進(jìn)行學(xué)習(xí)。在 GNN 中,節(jié)點(diǎn)的表示不僅依賴于其自身的特征,還依賴于其鄰居節(jié)點(diǎn)的特征。
圖神經(jīng)網(wǎng)絡(luò)的計(jì)算通常包括以下幾個(gè)步驟。
- 信息傳遞節(jié)點(diǎn)通過與其鄰居節(jié)點(diǎn)交換信息來更新自身的表示。這一過程通常通過消息傳遞機(jī)制實(shí)現(xiàn),節(jié)點(diǎn)會將自己的特征向量傳遞給鄰居節(jié)點(diǎn),鄰居節(jié)點(diǎn)再根據(jù)自己的特征和接收到的信息來更新自身的特征。
- 聚合每個(gè)節(jié)點(diǎn)會根據(jù)鄰居節(jié)點(diǎn)的特征進(jìn)行聚合操作,常見的聚合操作包括求和、均值、最大值等。這個(gè)步驟使得每個(gè)節(jié)點(diǎn)不僅包含自身的信息,還融合了鄰居的信息。
- 更新聚合后的信息會與當(dāng)前節(jié)點(diǎn)的原始特征一起傳入一個(gè)非線性函數(shù)(通常是一個(gè)神經(jīng)網(wǎng)絡(luò)層),來更新節(jié)點(diǎn)的表示。
- 迭代GNN 是一個(gè)迭代過程,通常會執(zhí)行多次消息傳遞和特征更新,每次迭代都會使得節(jié)點(diǎn)的表示更加豐富,能夠捕捉到更廣泛的上下文信息。
- 輸出層根據(jù)任務(wù)需求,最終會從節(jié)點(diǎn)特征或者圖特征中提取出有用的信息進(jìn)行分類、回歸等任務(wù)。
GNN 任務(wù)類型
節(jié)點(diǎn)級任務(wù)
節(jié)點(diǎn)級任務(wù)主要關(guān)注圖中單個(gè)節(jié)點(diǎn)的預(yù)測或嵌入。它通常依賴于節(jié)點(diǎn)的特征及其鄰居節(jié)點(diǎn)的信息。
節(jié)點(diǎn)級任務(wù)常見的應(yīng)用包括節(jié)點(diǎn)分類、節(jié)點(diǎn)嵌入等。
- 節(jié)點(diǎn)分類:預(yù)測每個(gè)節(jié)點(diǎn)的類別。
- 節(jié)點(diǎn)嵌入:學(xué)習(xí)每個(gè)節(jié)點(diǎn)的低維表示,通常用于下游任務(wù)(如聚類或分類)。
- 節(jié)點(diǎn)回歸:預(yù)測節(jié)點(diǎn)的連續(xù)值。
示例代碼:節(jié)點(diǎn)分類任務(wù)
假設(shè)我們有一個(gè)社交網(wǎng)絡(luò)圖,任務(wù)是預(yù)測每個(gè)用戶的興趣類別(例如,體育、音樂、科技等)。
我們使用 PyTorch Geometric 框架實(shí)現(xiàn)一個(gè)簡單的圖卷積網(wǎng)絡(luò)(GCN)。
import torch
import torch.nn as nn
import torch_geometric
from torch_geometric.nn import GCNConv
# GCN模型定義
class GCN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(GCN, self).__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
# 第一次圖卷積
x = self.conv1(x, edge_index)
x = torch.relu(x)
# 第二次圖卷積
x = self.conv2(x, edge_index)
return x
# 假設(shè)我們有一個(gè)圖,包含節(jié)點(diǎn)特征和邊的連接關(guān)系
# 節(jié)點(diǎn)特征: x, 鄰接矩陣: edge_index
x = torch.randn(100, 16) # 100個(gè)節(jié)點(diǎn),16維特征
edge_index = torch.randint(0, 100, (2, 500)) # 500條邊
# 目標(biāo)標(biāo)簽:節(jié)點(diǎn)的類別(假設(shè)有10個(gè)類別)
y = torch.randint(0, 10, (100,))
# 創(chuàng)建GCN模型
model = GCN(in_channels=16, hidden_channels=32, out_channels=10)
# 定義損失函數(shù)和優(yōu)化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# 訓(xùn)練過程
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(x, edge_index) # 獲取節(jié)點(diǎn)的分類輸出
loss = criterion(out, y) # 計(jì)算損失
loss.backward()
optimizer.step()
if epoch % 20 == 0:
print(f'Epoch {epoch}, Loss: {loss.item()}')
節(jié)點(diǎn)級任務(wù)的應(yīng)用場景
- 社交網(wǎng)絡(luò)分析:預(yù)測社交網(wǎng)絡(luò)中每個(gè)用戶的興趣標(biāo)簽。
- 生物信息學(xué):預(yù)測基因、蛋白質(zhì)的功能類別。
- 推薦系統(tǒng):預(yù)測用戶或物品的類別或偏好。
邊級任務(wù)
邊級任務(wù)關(guān)注圖中節(jié)點(diǎn)間的關(guān)系。
邊級任務(wù)常見的應(yīng)用包括鏈接預(yù)測、邊分類等。
- 鏈接預(yù)測:預(yù)測兩個(gè)節(jié)點(diǎn)之間是否存在邊,或預(yù)測未觀察到的潛在邊。
- 邊分類:對圖中的邊進(jìn)行分類任務(wù),如判斷兩個(gè)節(jié)點(diǎn)之間的關(guān)系類型。
- 邊回歸:預(yù)測邊的連續(xù)值,如邊的權(quán)重或相似度。
示例代碼:鏈接預(yù)測任務(wù)
在鏈接預(yù)測任務(wù)中,我們預(yù)測圖中節(jié)點(diǎn)對是否存在邊。
通過GNN學(xué)習(xí)到的節(jié)點(diǎn)表示,可以計(jì)算節(jié)點(diǎn)對之間的相似度,進(jìn)而預(yù)測鏈接。
import torch
import torch.nn as nn
import torch_geometric
from torch_geometric.nn import GCNConv
from torch_geometric.utils import negative_sampling
# GCN模型定義(用于鏈接預(yù)測)
class GCNLinkPrediction(nn.Module):
def __init__(self, in_channels, hidden_channels):
super(GCNLinkPrediction, self).__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = torch.relu(x)
x = self.conv2(x, edge_index)
return x
# 假設(shè)我們有一個(gè)圖,包含節(jié)點(diǎn)特征和邊的連接關(guān)系
x = torch.randn(100, 16) # 100個(gè)節(jié)點(diǎn),16維特征
edge_index = torch.randint(0, 100, (2, 500)) # 500條邊
# 創(chuàng)建GCN模型
model = GCNLinkPrediction(in_channels=16, hidden_channels=32)
# 定義損失函數(shù)和優(yōu)化器
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# 訓(xùn)練過程
model.train()
for epoch in range(200):
optimizer.zero_grad()
# 前向傳播,得到節(jié)點(diǎn)嵌入表示
out = model(x, edge_index)
# 負(fù)采樣生成不存在的邊
neg_edge_index = negative_sampling(edge_index, num_nodes=100, num_neg_samples=edge_index.size(1))
# 獲取真實(shí)邊和負(fù)邊
pos_out = out[edge_index[0]] * out[edge_index[1]]
neg_out = out[neg_edge_index[0]] * out[neg_edge_index[1]]
# 計(jì)算損失
pos_loss = torch.sigmoid(pos_out).sum()
neg_loss = torch.sigmoid(neg_out).sum()
loss = -(pos_loss - neg_loss)
loss.backward()
optimizer.step()
if epoch % 20 == 0:
print(f'Epoch {epoch}, Loss: {loss.item()}')
邊級任務(wù)的應(yīng)用場景
- 社交網(wǎng)絡(luò)分析:預(yù)測用戶之間是否會建立新的聯(lián)系(如好友推薦)。
- 推薦系統(tǒng):預(yù)測用戶與物品之間的潛在關(guān)系(例如,是否購買)。
- 知識圖譜:預(yù)測實(shí)體之間的關(guān)系(如“巴黎”和“法國”的“首都”關(guān)系)。
圖級任務(wù)
圖級任務(wù)關(guān)注整個(gè)圖的預(yù)測或表示。
任務(wù)的目標(biāo)是將整個(gè)圖映射到一個(gè)類別或一個(gè)值,常見的任務(wù)包括圖分類、圖回歸等。
- 圖分類:對整個(gè)圖進(jìn)行分類,常用于生物分子分類、文檔分類等。
- 圖回歸:預(yù)測整個(gè)圖的連續(xù)值,如預(yù)測圖的某種特性(例如分子的毒性)。
GNN的類型
不同的 GNN 變種在消息傳遞、聚合和更新機(jī)制上有所不同。
以下是一些常見的GNN模型:
- GCNGCN 是最經(jīng)典的圖卷積網(wǎng)絡(luò),它借鑒了卷積神經(jīng)網(wǎng)絡(luò)的思想,通過對鄰居節(jié)點(diǎn)的特征進(jìn)行加權(quán)平均來更新節(jié)點(diǎn)表示。GCN 使用了圖的鄰接矩陣來定義節(jié)點(diǎn)間的信息傳播規(guī)則。
- GATGAT 引入了注意力機(jī)制,在信息傳遞的過程中,給不同的鄰居節(jié)點(diǎn)分配不同的權(quán)重(即鄰接節(jié)點(diǎn)的影響力不同)。這種方式使得 GAT 能夠更靈活地處理圖中節(jié)點(diǎn)的異質(zhì)性。
- GraphSAGEGraphSAGE 通過對每個(gè)節(jié)點(diǎn)的鄰居進(jìn)行采樣來減少計(jì)算開銷,而不是直接使用全部鄰居節(jié)點(diǎn)。
GNN的應(yīng)用場景
圖神經(jīng)網(wǎng)絡(luò)在很多領(lǐng)域得到了廣泛應(yīng)用
- 社交網(wǎng)絡(luò)分析在社交網(wǎng)絡(luò)中,節(jié)點(diǎn)表示人或社交媒體賬戶,邊表示他們之間的互動(dòng)關(guān)系。GNN可以用來進(jìn)行用戶推薦、社交圈分析、輿情分析等任務(wù)。
- 化學(xué)分子建模在化學(xué)中,分子結(jié)構(gòu)可以用圖表示,其中節(jié)點(diǎn)代表原子,邊代表原子之間的化學(xué)鍵。GNN可以用來預(yù)測分子的性質(zhì)、藥物設(shè)計(jì)等。
- 知識圖譜知識圖譜是包含實(shí)體和關(guān)系的大型圖結(jié)構(gòu),GNN 可以用于關(guān)系預(yù)測、實(shí)體鏈接等任務(wù)。
- 推薦系統(tǒng)在推薦系統(tǒng)中,用戶和物品可以構(gòu)成圖結(jié)構(gòu),GNN 可以用于用戶偏好預(yù)測、物品推薦等。
- 自然語言處理在文本中,詞語之間的關(guān)系可以通過圖表示,GNN 可以用來進(jìn)行句子理解、語義分析等任務(wù)。