GraphSAGE圖神經(jīng)網(wǎng)絡(luò)算法詳解
GraphSAGE 是 17 年的文章了,但是一直在工業(yè)界受到重視,最主要的就是它論文名字中的兩個(gè)關(guān)鍵詞:inductive 和 large graph。今天我們就梳理一下這篇文章的核心思路,和一些容易被忽視的細(xì)節(jié)。
為什么要用 GraphSAGE
大家先想想圖為什么這么火,主要有這么幾點(diǎn)原因,圖的數(shù)據(jù)來源豐富,圖包含的信息多。所以現(xiàn)在都在考慮如何更好的使用圖的信息。
那么我們用圖需要做到什么呢?最核心的就是利用圖的結(jié)構(gòu)信息,為每個(gè) node 學(xué)到一個(gè)合適的 embedding vector。只要有了合適的 embedding 的結(jié)果,接下來無論做什么工作,我們就可以直接拿去套模型了。
在 GraphSAGE 之前,主要的方法有 DeepWalk,GCN 這些,但是不足在于需要對全圖進(jìn)行學(xué)習(xí)。而且是以 transductive learning 為主,也就是說需要在訓(xùn)練的時(shí)候,圖就已經(jīng)包含了要預(yù)測的節(jié)點(diǎn)。
考慮到實(shí)際應(yīng)用中,圖的結(jié)構(gòu)會頻繁變化,在最終的預(yù)測階段,可能會往圖中新添加一些節(jié)點(diǎn)。那么該怎么辦呢?GraphSAGE 就是為此而提出的,它的核心思路其實(shí)就是它的名字 GraphSAGE = Graph Sample Aggregate。也就是說對圖進(jìn)行 sample 和 aggregate。
GraphSAGE 的思路
我們提到了 sample 和 aggregate,具體指的是什么呢?這個(gè)步驟如何進(jìn)行?為什么它可以應(yīng)用到大規(guī)模的圖上?接下來就為大家用通俗易懂的語言描述清楚。
顧名思義,sample 就是選一些點(diǎn)出來,aggregate 就是再把它們的信息聚合起來。那么整個(gè)流程怎么走?看下面這張圖:
我們在第一幅圖上先學(xué)習(xí) sample 的過程。假如我有一張這樣的圖,需要對最中心的節(jié)點(diǎn)進(jìn)行 mebedding 的更新,先從它的鄰居中選擇 S1 個(gè)(這里的例子中是選擇 3 個(gè))節(jié)點(diǎn),假如 K=2,那么我們對第 2 層再進(jìn)行采樣,也就是對剛才選擇的 S1 個(gè)鄰居再選擇它們的鄰居。
在第二幅圖上,我們就可以看到對于聚合的操作,也就是說先拿鄰居的鄰居來更新鄰居的信息,再用更新后的鄰居的信息來更新目標(biāo)節(jié)點(diǎn)(也就是中間的紅色點(diǎn))的信息。聽起來可能稍微有點(diǎn)啰嗦,但是思路上并不繞,大家仔細(xì)梳理一下就明白了。
第三幅圖中,如果我們要預(yù)測一個(gè)未知節(jié)點(diǎn)的信息,只需要用它的鄰居們來進(jìn)行預(yù)測就可以了。
我們再梳理一下這個(gè)思路:如果我想知道小明是一個(gè)什么性格的人,我去找?guī)讉€(gè)他關(guān)系好的小伙伴觀察一下,然后我為了進(jìn)一步確認(rèn),我再去選擇他的小伙伴們的其他小伙伴,再觀察一下。也就是說,通過小明的小伙伴們的小伙伴,來判斷小明的小伙伴們是哪一類人,然后再根據(jù)他的小伙伴們,我就可以粗略的得知,小明是哪一類性格的人了。
GraphSAGE 思路補(bǔ)充
現(xiàn)在我們知道了 GraphSAGE 的基本思路,可能小伙伴們還有一些困惑:單個(gè)節(jié)點(diǎn)的思路是這樣子,那么整體的訓(xùn)練過程該怎么進(jìn)行呢?至今也沒有告訴我們 GraphSAGE 為什么可以應(yīng)用在大規(guī)模的圖上,為什么是 inductive 的呢?
接下來我們就補(bǔ)充一下 GraphSAGE 的訓(xùn)練過程,以及在這個(gè)過程它有哪些優(yōu)勢。
首先是考慮到我們要從初始特征開始,一層一層的做 embedding 的更新,我們該如何知道自己需要對哪些點(diǎn)進(jìn)行聚合呢?應(yīng)用前面提到的 sample 的思路,具體的方法來看一看算法:
首先看算法的第 2-7 行,其實(shí)就是一個(gè) sample 的過程,并且將 sample 的結(jié)果保存到 B 中。接下來的 9-15 行,就是一個(gè) aggregate 的過程,按照前面 sample 的結(jié)果,將對應(yīng)的鄰居信息 aggregate 到目標(biāo)節(jié)點(diǎn)上來。
細(xì)心的小伙伴肯定發(fā)現(xiàn)了 sample 的過程是從 K 到 1 的(看第 2 行),而 aggregate 的過程是從 1 到 K 的(第 9 行)。這個(gè)道理很明顯,采樣的時(shí)候,我們先從整張圖選擇自己要給哪些節(jié)點(diǎn) embedding,然后對這些節(jié)點(diǎn)的鄰居進(jìn)行采樣,并且逐漸采樣到遠(yuǎn)一點(diǎn)的鄰居上。
但是在聚合時(shí),肯定先從最遠(yuǎn)處的鄰居上開始進(jìn)行聚合,最后第 K 層的時(shí)候,才能聚合到目標(biāo)節(jié)點(diǎn)上來。這就是 GraphSAGE 的完整思路。
那么需要思考一下的是,這么簡單的思路其中有哪些奧妙呢?
GraphSAGE 的精妙之處
首先是為什么要提出 GraphSAGE 呢?其實(shí)最主要的是 inductive learning 這一點(diǎn)。這兩天在幾個(gè)討論群同時(shí)看到有同學(xué)對 transductive learning 和 inductive learning 有一些討論,總體來說,inductive learning 無疑是可以在測試時(shí),對新加入的內(nèi)容進(jìn)行推理的。
因此,GraphSAGE 的一大優(yōu)點(diǎn)就是,訓(xùn)練好了以后,可以對新加入圖網(wǎng)絡(luò)中的節(jié)點(diǎn)也進(jìn)行推理,這在實(shí)際場景的應(yīng)用中是非常重要的。
另一方面,在圖網(wǎng)絡(luò)的運(yùn)用中,往往是數(shù)據(jù)集都非常大,因此 mini batch 的能力就非常重要了。但是正因?yàn)?GraphSAGE 的思路,我們只需要對自己采樣的數(shù)據(jù)進(jìn)行聚合,無需考慮其它節(jié)點(diǎn)。每個(gè) batch 可以是一批 sample 結(jié)果的組合。
再考慮一下聚合函數(shù)的部分,這里訓(xùn)練的結(jié)果中,聚合函數(shù)占很大的重要性。關(guān)于聚合函數(shù)的選擇有兩個(gè)條件:
首先要可導(dǎo),因?yàn)橐聪騻鬟f來訓(xùn)練目標(biāo)的聚合函數(shù)參數(shù);
其次是對稱,這里的對稱指的是對輸入不敏感,因?yàn)槲覀冊诰酆系臅r(shí)候,圖中的節(jié)點(diǎn)關(guān)系并沒有順序上的特征。
所以在作者原文中選擇的都是諸如 Mean,max pooling 之類的聚合器,雖然作者也使用了 LSTM,但是在輸入前會將節(jié)點(diǎn)進(jìn)行 shuffle 操作,也就是說 LSTM 從序列順序中并不能學(xué)到什么知識。
此外在論文中還有一個(gè)小細(xì)節(jié),我初次看的時(shí)候沒有細(xì)讀論文,被一位朋友指出后才發(fā)現(xiàn)果然如此,先貼一下原文:
這里的 lines 4 and 5 in Algorithm 1,也就是我們前面給出的算法中的第 11 和 12 行。
也就是說,作者在文中提到的 GraphSAGE-GCN 其實(shí)就是用上面這個(gè)聚合函數(shù),替代掉其它方法中先聚合,再 concat 的操作,并且作者指出這種方法是局部譜卷積的線性近似,因此將其稱為 GCN 聚合器。
來點(diǎn)善后工作
最后我們就簡單的補(bǔ)充一些喜聞樂見,且比較簡單的東西吧。用 GraphSAGE 一般用來做什么?
首先作者提出,它既可以用來做無監(jiān)督學(xué)習(xí),也可以用來做有監(jiān)督學(xué)習(xí),有監(jiān)督學(xué)習(xí)我們就可以直接使用最終預(yù)測的損失函數(shù)為目標(biāo),反向傳播來訓(xùn)練。那么無監(jiān)督學(xué)習(xí)呢?
其實(shí)無論是哪種用途,需要注意的是圖本身,我們還是主要用它來完成 embedding 的操作。也就是得到一個(gè)節(jié)點(diǎn)的 embedding 后比較有效的 feature vector。那么做無監(jiān)督時(shí),如何知道它的 embedding 結(jié)果是對是錯(cuò)呢?
作者選擇了一個(gè)很容易理解的思路,就是鄰居的關(guān)系。默認(rèn)當(dāng)兩個(gè)節(jié)點(diǎn)距離相近時(shí),就會讓它們的 embedding 結(jié)果比較相似,如果距離遠(yuǎn),那 embedding 的結(jié)果自然應(yīng)該區(qū)別較大。這樣一來下面的損失函數(shù)就很容易理解了:
z_v 表示是目標(biāo)節(jié)點(diǎn) u 的鄰居,而 v_n 則表示不是,P_n(v) 是負(fù)樣本的分布,Q 是負(fù)樣本的數(shù)量。
那么現(xiàn)在剩下唯一的問題就是鄰居怎么定義?
作者選擇了一個(gè)很簡單的思路:直接使用 DeepWalk 進(jìn)行隨機(jī)游走,步長為 5,測試 50 次,走得到的都是鄰居。
總結(jié)
實(shí)驗(yàn)結(jié)果我們就不展示了,其實(shí)可以看到作者在很多地方都用了一些比較 baseline 的思路,大家可以在對應(yīng)的地方進(jìn)行更換和調(diào)整,以適應(yīng)自己的業(yè)務(wù)需求。
后面我們也會繼續(xù)分享 GNN 和 embedding 方面比較經(jīng)典和啟發(fā)性的一些 paper,歡迎大家持續(xù)關(guān)注~~~