剖析DeepMind神經(jīng)網(wǎng)絡(luò)記憶研究:模擬動物大腦實現(xiàn)連續(xù)學(xué)習(xí)
一、引言
我發(fā)現(xiàn)網(wǎng)絡(luò)上鋪天蓋地的人工智能相關(guān)信息中的絕大多數(shù)都可以分為兩類:一是向外行人士解釋進(jìn)展情況,二是向其他研究者解釋進(jìn)展。我還沒找到什么好資源能讓有技術(shù)背景但對不了解更前沿的進(jìn)展的人可以自己充電。我想要成為這中間的橋梁——通過為前沿研究提供(相對)簡單易懂的詳細(xì)解釋來實現(xiàn)。首先,讓我們從《Overcoming Catastrophic Forgetting in Neural Networks》這篇論文開始吧。
二、動機(jī)
實現(xiàn)通用人工智能的關(guān)鍵步驟是獲得連續(xù)學(xué)習(xí)的能力,也就是說,一個代理(agent)必須能在不遺忘舊任務(wù)的執(zhí)行方法的同時習(xí)得如何執(zhí)行新任務(wù)。然而,這種看似簡單的特性在歷史上卻一直未能實現(xiàn)。McCloskey 和 Cohen(1989)首先注意到了這種能力的缺失——他們首先訓(xùn)練一個神經(jīng)網(wǎng)絡(luò)學(xué)會了給一個數(shù)字加 1,然后又訓(xùn)練該神經(jīng)網(wǎng)絡(luò)學(xué)會了給數(shù)字加 2,但之后該網(wǎng)絡(luò)就不會給數(shù)字加 1 了。他們將這個問題稱為「災(zāi)難性遺忘(catastrophic forgetting)」,因為神經(jīng)網(wǎng)絡(luò)往往是通過快速覆寫來學(xué)習(xí)新任務(wù),而這樣就會失去執(zhí)行之前的任務(wù)所必需的參數(shù)。
克服災(zāi)難性遺忘方面的進(jìn)展一直收效甚微。之前曾有兩篇論文《Policy Distillation》和《Actor-Mimic: Deep Multitask and Transfer Reinforcement Learning》通過在訓(xùn)練過程中提供所有任務(wù)的數(shù)據(jù)而在混合任務(wù)上實現(xiàn)了很好的表現(xiàn)。但是,如果一個接一個地引入這些任務(wù),那么這種多任務(wù)學(xué)習(xí)范式就必須維持一個用于記錄和重放訓(xùn)練數(shù)據(jù)的情景記憶系統(tǒng)(episodic memory system)才能獲得良好的表現(xiàn)。這種方法被稱為系統(tǒng)級鞏固(system-level consolidation),該方法受限于對記憶系統(tǒng)的需求,且該系統(tǒng)在規(guī)模上必須和被存儲的總記憶量相當(dāng);而隨著任務(wù)的增長,這種記憶的量也會增長。
然而,你可能也直覺地想到了帶著一個巨大的記憶庫來進(jìn)行連續(xù)學(xué)習(xí)是錯誤的——畢竟,人類除了能學(xué)會走路,也能學(xué)會說話,而不需要維持關(guān)于學(xué)習(xí)如何走路的記憶。哺乳動物的大腦是如何實現(xiàn)這一能力的?Yang, Pan and Gan (2009) 說明學(xué)習(xí)是通過突觸后樹突棘(postsynaptic dendritic spines)隨時間而進(jìn)行的形成和消除而實現(xiàn)的。樹突棘是指「神經(jīng)元樹突上的突起,通常從突觸的單個軸突接收輸入」,如下所示:
具體而言,這些研究者檢查了小鼠學(xué)習(xí)過針對特定新任務(wù)的移動策略之后的大腦。當(dāng)這些小鼠學(xué)會了***的移動方式之后,研究者觀察到樹突棘形成出現(xiàn)了顯著增多。為了消除運動可能導(dǎo)致樹突棘形成的額外解釋,這些研究者還設(shè)置了一個進(jìn)行運動的對照組——這個組沒有觀察到樹突棘形成。然后這些研究者還注意到,盡管大多數(shù)新形成的樹突棘后面都消失了,但仍有一小部分保留了下來。2015 年的另一項研究表明當(dāng)特定的樹突棘被擦除時,其對應(yīng)的技能也會隨之消失。
Kirkpatrick et. 在論文《Overcoming catastrophic forgetting in neural networks》中提到:「這些實驗發(fā)現(xiàn)……說明在新大腦皮層中的連續(xù)學(xué)習(xí)依賴于特定于任務(wù)的突觸鞏固(synaptic consolidation),其中知識是通過使一部分突觸更少塑性而獲得持久編碼的,所以能長時間保持穩(wěn)定?!刮覀兡苁褂妙愃频姆椒?根據(jù)每個神經(jīng)元的重要性來改變單個神經(jīng)元的可塑性)來克服災(zāi)難性遺忘嗎?
這篇論文的剩余部分推導(dǎo)并演示了他們的初步答案:可以。
三、直覺
假設(shè)有兩個任務(wù) A 和 B,我們想要一個神經(jīng)網(wǎng)絡(luò)按順序?qū)W習(xí)它們。當(dāng)我們談到一個學(xué)習(xí)任務(wù)的神經(jīng)網(wǎng)絡(luò)時,它實際上意味著讓神經(jīng)網(wǎng)絡(luò)調(diào)整權(quán)重和偏置(統(tǒng)稱參數(shù)/θ)以使得神經(jīng)網(wǎng)絡(luò)在該任務(wù)上實現(xiàn)更好的表現(xiàn)。之前的研究表明對于大型網(wǎng)絡(luò)而言,許多不同的參數(shù)配置可以實現(xiàn)類似的表現(xiàn)。通常,這意味著網(wǎng)絡(luò)被過參數(shù)化了,但是我們可以利用這一點:過參數(shù)化(overparameterization)能使得任務(wù) B 的配置可能接近于任務(wù) A 的配置。作者提供了有用的圖形:
在圖中,θ∗A 代表在 A 任務(wù)中表現(xiàn)***的 θ 的配置,還存在多種參數(shù)配置可以接近這個表現(xiàn),灰色表示這一配置的集合;在這里使用橢圓來表示是因為有些參數(shù)的調(diào)整權(quán)重比其他參數(shù)更大。如果神經(jīng)網(wǎng)絡(luò)隨后被設(shè)置為學(xué)習(xí)任務(wù) B 而對記住任務(wù) A 沒有任何興趣(即遵循任務(wù) B 的誤差梯度),則該網(wǎng)絡(luò)將在藍(lán)色箭頭的方向上移動其參數(shù)。B 的***解也具有類似的誤差橢圓體,上面由白色橢圓表示。
然而,我們還想記住任務(wù) A。如果我們只是簡單使參數(shù)固化,就會按綠色箭頭發(fā)展,則處理任務(wù) A 和 B 的性能都將變得糟糕。***的辦法是根據(jù)參數(shù)對任務(wù)的重要程度來選擇其固化的程度;如果這樣的話,神經(jīng)網(wǎng)絡(luò)參數(shù)的變化方向?qū)⒆裱t色箭頭,它將試圖找到同時能夠很好執(zhí)行任務(wù) A 和 B 的配置。作者稱這種算法「彈性權(quán)重鞏固(EWC/Elastic Weight Consolidation)」。這個名稱來自于突觸鞏固(synaptic consolidation),結(jié)合「彈性的」錨定參數(shù)(對先前解決方案的約束限制參數(shù)是二次的,因此是彈性的)。
四、數(shù)學(xué)解釋
在這里存在兩個問題。***,為什么錨定函數(shù)是二次的?第二,如何判定哪個參數(shù)是「重要的」?
在回答這兩個問題之前,我們先要明白從概率的角度來理解神經(jīng)網(wǎng)絡(luò)的訓(xùn)練意味著什么。假設(shè)我們有一些數(shù)據(jù) D,我們希望找到***可能性的參數(shù),它被表示為 p(θ|D)。我們可以是用貝葉斯規(guī)則來計算這個條件概率。
如果我們應(yīng)用對數(shù)變換,則方程可以被重寫為:
假設(shè)數(shù)據(jù) D 由兩個獨立的(independent)部分構(gòu)成,用于任務(wù) A 的數(shù)據(jù) DA 和用于任務(wù) B 的數(shù)據(jù) DB。這個邏輯適用于多于兩個任務(wù),但在這里不用詳述。使用獨立性(independence)的定義,我們可以重寫這個方程:
看看(3)右邊的中間三個項。它們看起來很熟悉嗎?它們應(yīng)該。這三個項是方程(2)的右邊,但是 D 被 DA 代替了。簡單來說,這三個項等價于給定任務(wù) A 數(shù)據(jù)的網(wǎng)絡(luò)參數(shù)的條件概率的對數(shù)。這樣,我們得到了下面這個方程:
讓我們先解釋一下方程(4)。左側(cè)仍然告訴我們?nèi)绾斡嬎阏麄€數(shù)據(jù)集的 p(θ| D),但是當(dāng)求解任務(wù) A 時學(xué)習(xí)的所有信息都包含在條件概率 p(θ| DA)中。這個條件概率可以告訴我們哪些參數(shù)在解決任務(wù) A 中很重要。
下一步是不明確的:「真實的后驗概率是難以處理的,因此,根據(jù) Mackay (19) 對拉普拉斯近似的研究,我們將該后驗近似為一個高斯分布,其帶有由參數(shù)θ∗A 給定的均值和一個由 Fisher 信息矩陣 F 的對角線給出的對角精度?!?/p>
讓我們詳細(xì)解釋一下。首先,為什么真正的后驗概率難以處理?論文并沒有解釋,答案是:貝葉斯規(guī)則告訴我們
p(θ|DA) 取決于 p(DA)=∫p(DA|θ′)p(θ′)dθ′,其中θ′是參數(shù)空間中的參數(shù)的可能配置。通常,該積分沒有封閉形式的解,留下數(shù)值近似以作為替代。數(shù)值近似的時間復(fù)雜性相對于參數(shù)的數(shù)量呈指數(shù)級增長,因此對于具有數(shù)億或更多參數(shù)的深度神經(jīng)網(wǎng)絡(luò),數(shù)值近似是不實際的。
然后,Mackay 關(guān)于拉普拉斯近似的工作是什么,跟這里的研究有什么關(guān)系?我們使用θ*A 作為平均值,而非數(shù)值近似后驗分布,將其建模為多變量正態(tài)分布。方差呢?我們將把每個變量的方差指定為方差的倒數(shù)的精度。為了計算精度,我們將使用 Fisher 信息矩陣 F。Fisher 信息是「一種測量可觀察隨機(jī)變量 X 攜帶的關(guān)于 X 所依賴的概率的未知參數(shù)θ的信息的量的方法?!乖谖覀兊睦又?,我們感興趣的是測量來自 DA 的每個數(shù)據(jù)所攜帶的關(guān)于θ的信息的量。Fisher 信息矩陣比數(shù)值近似計算更可行,這使得它成為一個有用的工具。
因此,我們可以為我們的網(wǎng)絡(luò)在任務(wù) A 上訓(xùn)練后在任務(wù) B 上再定義一個新的損失函數(shù)。讓 LB(θ)僅作為任務(wù) B 的損失。如果我們用 i 索引我們的參數(shù),并且選擇標(biāo)量λ來影響任務(wù) A 對任務(wù) B 的重要性,則在 EWC 中最小化的函數(shù) L 是:
作者聲稱,EWC 具有相對于網(wǎng)絡(luò)參數(shù)的數(shù)量和訓(xùn)練示例的數(shù)量是線性的運行時間。
五、實驗和結(jié)果
1. 隨機(jī)模式
EWC 的***個測試是簡單地看它能否比梯度下降(GD)更長時間地記住簡單的模式。這些研究者訓(xùn)練了一個可將隨機(jī)二元模式和二元結(jié)果關(guān)聯(lián)起來的神經(jīng)網(wǎng)絡(luò)。如果該網(wǎng)絡(luò)看到了一個之前見過的二元模式,那么就通過觀察其信噪比是否超過了一個閾值來評價其是否已經(jīng)「記住」了該模式。使用這種簡單測試的原因是其具有一個分析解決方案。隨著模式數(shù)量的增加,EWC 和 GD 的表現(xiàn)都接近了它們的***答案。但是 EWC 能夠記憶比 GD 遠(yuǎn)遠(yuǎn)更多的模式,如下圖所示:
2. MNIST
這些研究者為 EWC 所進(jìn)行的第二個測試是一個修改過的 MNIST 版本。其沒有使用被給出的數(shù)據(jù),而是生成了三個隨機(jī)的排列(permutation),并將每個排列都應(yīng)用到該數(shù)據(jù)集中的每張圖像上。任務(wù) A 是對被***個排列轉(zhuǎn)換過的 MNIST 圖像中的數(shù)字進(jìn)行分類,任務(wù) B 是對被第二個排列轉(zhuǎn)換過的圖像中的數(shù)字進(jìn)行分類,任務(wù) C 類推。這些研究者構(gòu)建了一個全連接的深度神經(jīng)網(wǎng)絡(luò)并在任務(wù) A、B 和 C 上對該網(wǎng)絡(luò)進(jìn)行了訓(xùn)練,同時在任務(wù) A(在 A 上的訓(xùn)練完成后)、B(在 B 上的訓(xùn)練完成后)和 C(在 C 上的訓(xùn)練完成后)上測試了該網(wǎng)絡(luò)的表現(xiàn)。訓(xùn)練是分別使用隨機(jī)梯度下降(SGD)、使用 L2 正則化的均勻參數(shù)剛度(uniform parameter-rigidity using L2 regularization)、EWC 獨立完成的。下面是它們的結(jié)果:
如預(yù)期的一樣,SGD 出現(xiàn)了災(zāi)難性遺忘;在任務(wù) B 上訓(xùn)練后在任務(wù) A 上的表現(xiàn)出現(xiàn)了快速衰退,在任務(wù) C 上訓(xùn)練后更是進(jìn)一步衰退。使參數(shù)更剛性能維持在***個任務(wù)上的表現(xiàn),但卻不能學(xué)習(xí)后續(xù)的任務(wù)。而 EWC 能在成功學(xué)習(xí)新任務(wù)的同時記住如何執(zhí)行之前的任務(wù)。隨著任務(wù)數(shù)量的增加,EWC 也能維持相對較好的表現(xiàn),相對地,帶有 dropout 正則化的 SGD 的表現(xiàn)會持續(xù)下降,如下所示:
3. Atari 2600
DeepMind 曾在更早的一篇論文中表明:當(dāng)一次訓(xùn)練和測試一個游戲時,深度 Q 網(wǎng)絡(luò)(DQN)能夠在多種 Atari 2600 游戲上實現(xiàn)超人類的表現(xiàn)。為了了解 EWC 可以如何在這種更具挑戰(zhàn)性的強(qiáng)化學(xué)習(xí)環(huán)境中實現(xiàn)連續(xù)學(xué)習(xí),這些研究者對該 DQN 代理進(jìn)行了修改,使其能夠使用 EWC。但是,他們還需要做出一項額外的修改:在哺乳動物的連續(xù)學(xué)習(xí)中,為了確定一個代理當(dāng)前正在學(xué)習(xí)的任務(wù),必需要一個高層面的系統(tǒng),但該 DQN 代理完全不能做出這樣的確定。為了解決這個問題,研究者為其添加了一個基于 forget-me-not(FMN)過程的在線聚類算法,使得該 DQN 代理能夠為每一個推斷任務(wù)維持各自獨立的短期記憶緩存。
這就得到了一個能夠跨兩個時間尺度進(jìn)行學(xué)習(xí)的 DQN 代理。在短期內(nèi),DQN 代理可以使用 SGD 等優(yōu)化器(本案例中研究者使用了 RMSProp)來從經(jīng)歷重放(experience replay)機(jī)制中學(xué)習(xí)。在長期內(nèi),該 DQN 代理使用 EWC 來鞏固其從各種任務(wù)上學(xué)習(xí)到的知識。研究者從 DQN 實現(xiàn)了人類水平表現(xiàn)的 Atari 游戲(共 19 個)中隨機(jī)選出了 10 個,然后該代理在每個單獨的游戲上訓(xùn)練了一段時間,如下所示:
這些研究者對三個不同的 DQN 的代理進(jìn)行了比較。藍(lán)色代理沒有使用 EWC,紅色代理使用了 EWC 和 forget-me-not 任務(wù)標(biāo)識符。褐色代理則使用 EWC,并被提供了真實任務(wù)標(biāo)簽。在一個任務(wù)上實現(xiàn)人類水平的表現(xiàn)被規(guī)范化為分?jǐn)?shù) 1。如你所見,EWC 在這 10 個任務(wù)上都實現(xiàn)了接近人類水平的表現(xiàn),而非 EWC 代理沒能在一個以上的任務(wù)上做到這一點。該代理是否被給出了一個真實標(biāo)簽或是否必須對任務(wù)進(jìn)行推導(dǎo)對結(jié)果的影響不大,但我認(rèn)為這也表明了 FMN 過程的成果,而不僅僅是 EWC 的成功。
接下來的這個部分非常酷。如前所述,EWC 并沒有在這 10 個任務(wù)上實現(xiàn)人類水平的表現(xiàn)。為什么會這樣?一個可能的原因是 Fisher 信息矩陣可能對參數(shù)重要程度的估計不佳。為了實際驗證這一點,這些研究者研究了在僅一個游戲上訓(xùn)練的代理的權(quán)重。不管這個游戲是什么游戲,它們都表現(xiàn)出了以下模式:如果權(quán)重受到了一個均勻隨機(jī)擾動的影響,隨著該擾動的增加,該代理的表現(xiàn)(規(guī)范化為 1)會下降;而如果權(quán)重受到的擾動得到了 Fisher 信息的對角線的逆的影響,那么該分?jǐn)?shù)在面臨更大的擾動時也能保持穩(wěn)定。這說明 Fisher 信息在確定參數(shù)的真正重要性方面是很好的。
然后,研究者嘗試在 null 空間中進(jìn)行擾動。這本來應(yīng)該是無效的,但實際上研究者觀察到了與逆 Fisher 空間中的結(jié)果類似的結(jié)果。這說明使用 Fisher 信息矩陣會導(dǎo)致將一些重要參數(shù)標(biāo)記為不重要的情況——「因此很有可能當(dāng)前實現(xiàn)的主要限制是其低估了參數(shù)的不確定性?!?/p>
六、討論
1. 貝葉斯說明
作者們對 EWC 給出了非常好的貝葉斯解讀:「正式來說,當(dāng)需要學(xué)習(xí)新的任務(wù)時,網(wǎng)絡(luò)參數(shù)被先驗所調(diào)節(jié),也就是先前任務(wù)在給定數(shù)據(jù)參數(shù)上的后驗分布。這能實現(xiàn)先驗任務(wù)被約束的更快的參數(shù)學(xué)習(xí)率,并為重要的參數(shù)降低學(xué)習(xí)率?!?/p>
2. 重疊
在剛開始時,我提到神經(jīng)網(wǎng)絡(luò)的過參數(shù)化(overparameterization)讓 EWC 能實現(xiàn)優(yōu)異的表現(xiàn)。有一個很合理的問題就是:為了每個不同的任務(wù)將網(wǎng)絡(luò)劃分到特定部分,這些神經(jīng)網(wǎng)絡(luò)能否給出更好的表現(xiàn)?或者通過共享表征,這些網(wǎng)絡(luò)是否能高效地使用其能力?為了解答這個問題,作者們測量了任務(wù)對在 Fisher 信息矩陣上的重疊情況(Fisher Overlap)。對高度類似的任務(wù)(例如只有一點不同的兩個隨機(jī)排列)而言,F(xiàn)isher Overlap 相當(dāng)高。即使不相似的任務(wù),F(xiàn)isher Overlap 也高于 0。隨著網(wǎng)絡(luò)深度的增加,F(xiàn)isher Overlap 也會增加。下圖演示了該結(jié)果:
3. 突觸可塑性的理論
研究者還討論了 EWC 可能能為神經(jīng)可塑性方面的研究提供信息。級聯(lián)(Cascade)理論企圖構(gòu)建突觸狀態(tài)的模型來對可塑性和穩(wěn)定性建模。盡管 EWC 不能隨時間緩和參數(shù),也因此不能遺忘先前的信息,但 EWC 和級聯(lián)都能通過讓突觸更不可塑而延展記憶穩(wěn)定性。最近的一項研究提出除了存儲自身的實際權(quán)重之外,突觸也存儲當(dāng)前權(quán)重的不確定性。EWC 是該思路的延展:在 EWC 中,每個突觸存儲三個值:權(quán)重、均值和方差。
七、總結(jié)
- 沒有遺忘的連續(xù)學(xué)習(xí)任務(wù)對智能而言很有必要;
- 研究表明哺乳動物大腦中的突觸鞏固有助于實現(xiàn)連續(xù)學(xué)習(xí);
- EWC 通過讓重要參數(shù)變的更可塑,從而模擬人工神經(jīng)網(wǎng)絡(luò)中的突觸鞏固;
- EWC 應(yīng)用于神經(jīng)網(wǎng)絡(luò)時表現(xiàn)出了比 SGB 更好的表現(xiàn),而且在更大范圍上有一致的參數(shù)穩(wěn)定性;
- EWC 提供了說明權(quán)重鞏固是連續(xù)學(xué)習(xí)的基礎(chǔ)的線索證據(jù)。
原文:http://rylanschaeffer.github.io/content/research/overcoming_catastrophic_forgetting/main.html
【本文是51CTO專欄機(jī)構(gòu)機(jī)器之心的原創(chuàng)譯文,微信公眾號“機(jī)器之心( id: almosthuman2014)”】