針對(duì)深度學(xué)習(xí)的“失憶癥”,科學(xué)家提出基于相似性加權(quán)交錯(cuò)學(xué)習(xí),登上PNAS
與人類不同,人工神經(jīng)網(wǎng)絡(luò)在學(xué)習(xí)新事物時(shí)會(huì)迅速遺忘先前學(xué)到的信息,必須通過新舊信息的交錯(cuò)來重新訓(xùn)練;但是,交錯(cuò)全部舊信息非常耗時(shí),并且可能沒有必要。只交錯(cuò)與新信息有實(shí)質(zhì)相似性的舊信息可能就足夠了。
近日,美國科學(xué)院院報(bào)(PNAS)刊登了一篇論文,“Learning in deep neural networks and brains with similarity-weighted interleaved learning”,由加拿大皇家學(xué)會(huì)會(huì)士、知名神經(jīng)科學(xué)家 Bruce McNaughton 的團(tuán)隊(duì)發(fā)表。他們的工作發(fā)現(xiàn),通過將舊信息與新信息進(jìn)行相似性加權(quán)交錯(cuò)訓(xùn)練,深度網(wǎng)絡(luò)可以快速學(xué)習(xí)新事物,不僅降低了遺忘率,而且使用的數(shù)據(jù)量大幅減少。
論文作者還作出一個(gè)假設(shè):通過跟蹤最近活躍的神經(jīng)元和神經(jīng)動(dòng)力學(xué)吸引子(attractor dynamics)的持續(xù)興奮性軌跡,可以在大腦中實(shí)現(xiàn)相似性加權(quán)交錯(cuò)。這些發(fā)現(xiàn)可能會(huì)促進(jìn)神經(jīng)科學(xué)和機(jī)器學(xué)習(xí)的進(jìn)一步發(fā)展。
1 研究背景
了解大腦如何終身學(xué)習(xí)仍然是一項(xiàng)長期挑戰(zhàn)。
在人工神經(jīng)網(wǎng)絡(luò)(ANN)中,過快地整合新信息會(huì)產(chǎn)生災(zāi)難性干擾,即先前獲得的知識(shí)突然丟失?;パa(bǔ)學(xué)習(xí)系統(tǒng)理論 (Complementary Learning Systems Theory,CLST) 表明,通過將新記憶與現(xiàn)有知識(shí)交錯(cuò),新記憶可以逐漸融入新皮質(zhì)。
CLST指出,大腦依賴于互補(bǔ)的學(xué)習(xí)系統(tǒng):海馬體 (HC) 用于快速獲取新記憶,新皮層 (NC) 用于將新數(shù)據(jù)逐漸整合到與上下文無關(guān)的結(jié)構(gòu)化知識(shí)中。在“離線期間”,例如睡眠和安靜的清醒休息期間,HC觸發(fā)回放最近在NC中的經(jīng)歷,而NC自發(fā)地檢索和交錯(cuò)現(xiàn)有類別的表征。交錯(cuò)回放允許以梯度下降的方式逐步調(diào)整NC突觸權(quán)重,以創(chuàng)建與上下文無關(guān)的類別表征,從而優(yōu)雅地整合新記憶并克服災(zāi)難性干擾。許多研究已經(jīng)成功地使用交錯(cuò)回放實(shí)現(xiàn)了神經(jīng)網(wǎng)絡(luò)的終身學(xué)習(xí)。
然而,在實(shí)踐中應(yīng)用CLST時(shí),有兩個(gè)重要問題亟待解決。首先,當(dāng)大腦無法訪問所有舊數(shù)據(jù)時(shí),如何進(jìn)行全面的信息交錯(cuò)呢?一種可能的解決方案是“偽排練”,其中隨機(jī)輸入可以引發(fā)內(nèi)部表征的生成式回放,而無需顯式訪問先前學(xué)習(xí)的示例。類吸引子動(dòng)力學(xué)可能使大腦完成“偽排練”,但“偽排練”的內(nèi)容尚未明確。因此,第二個(gè)問題是,每進(jìn)行新的學(xué)習(xí)活動(dòng)之后,大腦是否有充足的時(shí)間交織所有先前學(xué)習(xí)的信息。
相似性加權(quán)交錯(cuò)學(xué)習(xí)(Similarity-Weighted Interleaved Learning,SWIL)算法被認(rèn)為是第二個(gè)問題的解決方案,這表明僅交錯(cuò)與新信息具有實(shí)質(zhì)表征相似性的舊信息可能就足夠了。實(shí)證行為研究表明,高度一致的新信息可以快速整合到NC結(jié)構(gòu)化知識(shí)中,幾乎沒有干擾。這表明整合新信息的速度取決于其與先驗(yàn)知識(shí)的一致性。受此行為結(jié)果的啟發(fā),并通過重新檢查先前獲得的類別之間的災(zāi)難性干擾分布,McClelland等人證明SWIL可以在具有兩個(gè)上義詞類別(例如,“水果”是“蘋果”和“香蕉”的上義詞)的簡單數(shù)據(jù)集中,每個(gè)epoch使用少于2.5倍的數(shù)據(jù)量學(xué)習(xí)新信息,實(shí)現(xiàn)了與在全部數(shù)據(jù)上訓(xùn)練網(wǎng)絡(luò)相同的性能。然而,研究人員在使用更復(fù)雜的數(shù)據(jù)集時(shí)并沒有發(fā)現(xiàn)類似的效果,這引發(fā)了對(duì)該算法可擴(kuò)展性的擔(dān)憂。
實(shí)驗(yàn)表明,深度非線性人工神經(jīng)網(wǎng)絡(luò)可以通過僅交錯(cuò)與新信息共享大量表征相似性的舊信息子集來學(xué)習(xí)新信息。通過使用SWIL算法,ANN能夠以相似的精度水平和最小的干擾快速學(xué)習(xí)新信息,同時(shí)使用的每個(gè)時(shí)期呈現(xiàn)的舊信息量少之又少,這意味著數(shù)據(jù)利用率高且可以快速學(xué)習(xí)。
同時(shí),SWIL也可應(yīng)用于序列學(xué)習(xí)框架。此外,學(xué)習(xí)一種新類別可以極大地提高數(shù)據(jù)利用率 。如果舊信息與之前學(xué)習(xí)過的類別有著非常少的相似性,那么呈現(xiàn)的舊信息數(shù)量就會(huì)少得多,這很可能是人類學(xué)習(xí)的實(shí)際情況。最后,作者提出了一個(gè)關(guān)于SWIL如何在大腦中實(shí)現(xiàn)的理論模型,其興奮性偏差與新信息的重疊成正比。
2 應(yīng)用于圖像分類數(shù)據(jù)集的
DNN動(dòng)力學(xué)模型
McClelland等人的實(shí)驗(yàn)表明,在具有一個(gè)隱藏層的深度線性網(wǎng)絡(luò)中,SWIL可以學(xué)習(xí)一個(gè)新類別,類似于完全交錯(cuò)學(xué)習(xí) (Fully Interleaved Learning,F(xiàn)IL),即將整個(gè)舊類別與新類別交錯(cuò),但使用的數(shù)據(jù)量減少了40%。
然而,網(wǎng)絡(luò)是在一個(gè)非常簡單的數(shù)據(jù)集上訓(xùn)練的,只有兩個(gè)上義詞類別,這就對(duì)算法的可擴(kuò)展性提出了疑問。
首先針對(duì)更復(fù)雜的數(shù)據(jù)集(如Fashion-MNIST),探索不同類別的學(xué)習(xí)在具有一個(gè)隱藏層的深度線性神經(jīng)網(wǎng)絡(luò)中如何演變。移出了“boot”(“靴子”)和“bag”(“紙袋”)類別后,該模型在剩余的8個(gè)類別上的測試準(zhǔn)確率達(dá)到了87%。然后作者團(tuán)隊(duì)重新訓(xùn)練模型,在兩種不同的條件下學(xué)習(xí)(新的)“boot”類,每個(gè)條件重復(fù)10次:
1)集中學(xué)習(xí)(Focused Learning ,F(xiàn)oL),即僅呈現(xiàn)新的“boot”類;
2)完全交錯(cuò)學(xué)習(xí) (FIL),即所有類別(新類別+以前學(xué)過的類別)以相等的概率呈現(xiàn)。在這兩種情況下,每個(gè)epoch總共呈現(xiàn)180張圖像,每個(gè)epoch中的圖像相同。
該網(wǎng)絡(luò)在總共9000張從未見過的圖像上進(jìn)行了測試,其中測試數(shù)據(jù)集由每類1000張圖像組成,不包括“bag”類別。當(dāng)網(wǎng)絡(luò)的性能達(dá)到漸近線時(shí),訓(xùn)練停止。
不出所料,F(xiàn)oL對(duì)舊類別造成了干擾,而FIL克服了這一點(diǎn)(圖1第2列)。如上所述,F(xiàn)oL對(duì)舊數(shù)據(jù)的干擾因類別而異,這是SWIL最初靈感的一部分,并表明新“boot”類別和舊類別之間存在分級(jí)相似關(guān)系。例如,“sneaker”(“運(yùn)動(dòng)鞋”)和“sandals”(“涼鞋”)的召回率比“trouser”(“褲子”)下降得更快(圖1第2列),可能是因?yàn)檎闲碌摹癰oot”類會(huì)選擇性地改變代表“sneaker”和“sandals”類的突觸權(quán)重,從而造成更多的干擾。
圖1:預(yù)訓(xùn)練網(wǎng)絡(luò)在兩種情況下學(xué)習(xí)新“boot”類的性能對(duì)比分析:FoL(上)和 FIL(下)。從左到右依次為預(yù)測新“boot”類別的召回率(橄欖色)、現(xiàn)有類別的召回率(用不同顏色繪制)、總準(zhǔn)確度(高分意味著低誤差)和交叉熵?fù)p失(總誤差的度量)曲線,是保留的測試數(shù)據(jù)集上與epoch數(shù)有關(guān)的函數(shù)。
3 計(jì)算不同類別之間的相似度
FoL在學(xué)習(xí)新類別的時(shí)候,在相似的舊類別上的分類性能會(huì)大幅下降。
之前已經(jīng)探討了多類別屬性相似度和學(xué)習(xí)之間的關(guān)系,并且表明深度線性網(wǎng)絡(luò)可以快速獲取已知的一致屬性。相比之下,在現(xiàn)有類別層次結(jié)構(gòu)中添加新分支的不一致屬性,需要緩慢、漸進(jìn)、交錯(cuò)的學(xué)習(xí)。
在當(dāng)前的工作中,作者團(tuán)隊(duì)使用已提出的方法在特征級(jí)別計(jì)算相似度。簡言之,計(jì)算目標(biāo)隱藏層(通常是倒數(shù)第二層)現(xiàn)有類別和新類別的平均每類激活向量之間的余弦相似度。圖2A顯示了基于Fashion MNIST數(shù)據(jù)集的新“boot”類別和舊類別,作者團(tuán)隊(duì)根據(jù)預(yù)訓(xùn)練網(wǎng)絡(luò)的倒數(shù)第二層激活函數(shù)計(jì)算的相似度矩陣。
類別之間的相似性與我們對(duì)物體的視覺感知一致。例如,在層次聚類圖(圖2B)中,我們可以觀察到“boot”類與“sneaker”和“sandal”類之間、以及“shirt”(“襯衫”)和“t-shirt”(“T恤”)類之間具有較高的相似性。相似度矩陣(圖2A)與混淆矩陣(圖2C)完全對(duì)應(yīng)。相似度越高,越容易混淆,例如,“襯衫”類與“T恤”、“套頭衫”和“外套”類圖像容易混淆,這表明相似性度量預(yù)測了神經(jīng)網(wǎng)絡(luò)的學(xué)習(xí)動(dòng)態(tài)。
在上一節(jié)的FoL結(jié)果圖(圖1)中,舊類別的召回率曲線中存在相近的類相似度曲線。與不同的舊類別(“trouser”等)相比,F(xiàn)oL學(xué)習(xí)新“boot”類的時(shí)候會(huì)快速遺忘相似的舊類別(“sneaker” 和 “sandal”)。
圖2:( A ) 作者團(tuán)隊(duì)根據(jù)預(yù)訓(xùn)練網(wǎng)絡(luò)的倒數(shù)第二層激活函數(shù),計(jì)算的現(xiàn)有類別和新“boot”類的相似度矩陣,其中對(duì)角線值(同一類別的相似性繪制為白色)被刪除。( B ) 對(duì)A中的相似矩陣進(jìn)行層次聚類。( C ) FIL算法在訓(xùn)練學(xué)習(xí)“boot”類后生成的混淆矩陣。為了縮放清晰,刪除了對(duì)角線值。
4 深度線性神經(jīng)網(wǎng)絡(luò)實(shí)現(xiàn)快速和
高效學(xué)習(xí)新事物
接下來在前兩個(gè)條件基礎(chǔ)上增加了3種新條件,研究了新的分類學(xué)習(xí)動(dòng)態(tài),其中每個(gè)條件重復(fù)10次:
1)FoL(共計(jì)n=6000張圖像/epoch);
2) FIL(共計(jì)n=54000張圖像/epoch,6000張圖像/類);
3) 部分交錯(cuò)學(xué)習(xí) (Partial Interleaved Learning,PIL)使用了很小的圖像子集(共計(jì)n=350張圖像/epoch,大約39張圖像/類),每一類別(新類別+現(xiàn)有類別)的圖像以相等的概率呈現(xiàn);
4) SWIL,每個(gè)epoch使用與PIL 相同的圖像總數(shù)進(jìn)行重新訓(xùn)練,但根據(jù)與(新)“boot”類別的相似性對(duì)現(xiàn)有類別圖像進(jìn)行加權(quán);
5)等權(quán)交錯(cuò)學(xué)習(xí)(Equally Weighted Interleaved Learning,EqWIL),使用與SWIL相同數(shù)量的“boot”類圖像重新訓(xùn)練,但現(xiàn)有類別圖像的權(quán)重相同(圖3A)。
作者團(tuán)隊(duì)使用了上述相同的測試數(shù)據(jù)集(共有n=9000張圖像)。當(dāng)在每種條件下神經(jīng)網(wǎng)絡(luò)的性能都達(dá)到漸近線時(shí),停止訓(xùn)練。盡管每個(gè)epoch使用的訓(xùn)練數(shù)據(jù)較少,預(yù)測新“boot”類的準(zhǔn)確率需要更長的時(shí)間達(dá)到漸近線,與FIL(H=7.27,P<0.05)相比,PIL的召回率更低(圖3B第1列和表1“New class”列)。
對(duì)于SWIL,相似度計(jì)算用于確定要交錯(cuò)的現(xiàn)有舊類別圖像的比例。在此基礎(chǔ)上,作者團(tuán)隊(duì)從每個(gè)舊類別中隨機(jī)抽取具有加權(quán)概率的輸入圖像。與其他類別相比,“sneaker”和“sandal”類最相似,從而導(dǎo)致被交錯(cuò)的比例更高(圖3A)。
根據(jù)樹狀圖(圖2B),作者團(tuán)隊(duì)將“sneaker”和“sandal”類稱為相似的舊類,其余則稱為不同的舊類。與PIL(H=5.44,P<0.05)相比,使用SWIL時(shí),模型學(xué)習(xí)新“boot”類的速度更快,對(duì)現(xiàn)有類別的干擾也相近。此外,SWIL(H=0.056,P>0.05)的新類別召回率(圖3B第1列和表1“New class”列)、總準(zhǔn)確率和損失與FIL相當(dāng)。EqWIL(H=10.99,P<0.05)中新“boot”類的學(xué)習(xí)與SWIL相同,但對(duì)相近的舊類別有更大程度的干擾(圖3B第2列和表1“Similar old class”列)。
作者團(tuán)隊(duì)使用以下兩種方法比較SWIL和FIL:
1) 內(nèi)存比,即FIL和SWIL中存儲(chǔ)的圖像數(shù)量之比,表示存儲(chǔ)的數(shù)據(jù)量減少;
2) 加速比,即在FIL和SWIL中呈現(xiàn)的內(nèi)容總數(shù)的比率,以達(dá)到新類別回憶的飽和精度,表明學(xué)習(xí)新類別所需的時(shí)間減少。
SWIL可以在數(shù)據(jù)需求減少的情況下學(xué)習(xí)新內(nèi)容,內(nèi)存比=154.3x (54000/350),并且速度更快,加速比=77.1x (54000/(350×2))。即使和新內(nèi)容有關(guān)的圖像數(shù)量較少,該模型也可以通過使用SWIL,利用模型先驗(yàn)知識(shí)的層次結(jié)構(gòu)實(shí)現(xiàn)相同的性能。SWIL在PIL和EqWIL之間提供了一個(gè)中間緩沖區(qū),允許集成一個(gè)新類別,并將對(duì)現(xiàn)有類別的干擾降到最低。
圖3 ( A ) 作者團(tuán)隊(duì)在五種不同的學(xué)習(xí)條件下預(yù)訓(xùn)練神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)新的“boot”類(橄欖綠),直到性能平穩(wěn):1)FoL(共計(jì)n=6000張圖像/epoch);2)FIL(共計(jì)n=54000張圖像/epoch);3) PIL(共計(jì)n=350張圖像/epoch);4) SWIL(共計(jì)n=350張圖像/epoch)和 5) EqWIL(共計(jì)n=350張圖像/epoch)。(B)FoL(黑色)、FIL(藍(lán)色)、PIL(棕色)、SWIL(洋紅色)和 EqWIL(金色)預(yù)測新類別、相似舊類別(“sneaker”和“sandals”)和不同舊類別的召回率,預(yù)測所有類別的總準(zhǔn)確率,以及在測試數(shù)據(jù)集上的交叉熵?fù)p失,其中橫坐標(biāo)都是epoch數(shù)。
5 基于CIFAR10使用SWIL在CNN中學(xué)習(xí)新類別
接下來,為了測試SWIL是否可以在更復(fù)雜的環(huán)境中工作,作者團(tuán)隊(duì)訓(xùn)練了一個(gè)具有全連接輸出層的6層非線性CNN(圖4A),以識(shí)別CIFAR10數(shù)據(jù)集中剩余8個(gè)不同類別(“cat”和“car”除外)的圖像。他們還對(duì)模型進(jìn)行了重新訓(xùn)練,在之前定義的5種不同訓(xùn)練條件(FoL、FIL、PIL、SWIL和EqWIL)下學(xué)習(xí)“cat”(“貓”)類。圖4C顯示了5種情況下每類圖像的分布。對(duì)于SWIL、PIL和EqWIL條件,每個(gè)epoch的總圖像數(shù)為2400,而對(duì)于FIL和FoL,每個(gè)epoch的總圖像數(shù)分別為45000和5000。作者團(tuán)隊(duì)針對(duì)每種情況對(duì)網(wǎng)絡(luò)分別進(jìn)行訓(xùn)練,直到性能趨于穩(wěn)定。
他們?cè)谥拔匆娺^的總共9000張圖像(1000張圖像/類,不包括“car”(“轎車”)類)上對(duì)該模型進(jìn)行了測試。圖4B是作者團(tuán)隊(duì)基于CIFAR10數(shù)據(jù)集計(jì)算的相似性矩陣?!癱at”類和“dog”(“狗”)類更類似,而其他動(dòng)物類屬于同一分支(圖4B左)。
根據(jù)樹狀圖(圖4B),將“truck” (“貨車”)、“ship”(“輪船”) 和 “plane”(“飛機(jī)”) 類別稱為不同的舊類別,除“cat”類外其余的動(dòng)物類別稱為相似的舊類別。對(duì)于FoL,模型學(xué)習(xí)了新的“cat”類,但遺忘了舊類別。與Fashion-MNIST數(shù)據(jù)集結(jié)果類似,“dog”類(與“cat”類相似性最大)和“truck”類(與“cat”類相似性最?。┚嬖诟蓴_梯度,其中“dog”類的遺忘率最高,而“truck”類遺忘率最低。
如圖4D所示,F(xiàn)IL算法學(xué)習(xí)新的“cat”類時(shí)克服了災(zāi)難性的干擾。對(duì)于PIL算法,模型在每個(gè)epoch使用18.75倍的數(shù)據(jù)量學(xué)習(xí)新的“cat”類,但“cat”類的召回率比FIL(H=5.72,P<0.05)低。對(duì)于SWIL,在新類別、相似和不同舊類別上的召回率、總準(zhǔn)確率和損失與FIL相當(dāng)(H=0.42,P>0.05;見表2和圖4D)。SWIL對(duì)新“cat”類的召回率高于PIL(H=7.89,P<0.05)。使用EqWIL算法時(shí),新“cat”類的學(xué)習(xí)情況與SWIL和FIL相似,但對(duì)相似舊類別的干擾較大(H=24.77,P<0.05;見表2)。
FIL、PIL、SWIL和EqWIL這4種算法預(yù)測不同舊類別的性能相當(dāng)(H=0.6,P>0.05)。SWI比PIL更好地融合了新的“cat”類,并有助于克服EqWIL中的觀測干擾。與FIL相比,使用SWIL學(xué)習(xí)新類別速度更快,加速比=31.25x (45000×10/(2400×6)),同時(shí)使用更少的數(shù)據(jù)量 (內(nèi)存比=18.75x)。這些結(jié)果證明,即使在非線性CNN和更真實(shí)的數(shù)據(jù)集上,SWIL也可以有效學(xué)習(xí)新類別事物。
圖4:( A ) 作者團(tuán)隊(duì)使用具有全連接輸出層的6層非線性CNN學(xué)習(xí)CIFAR10數(shù)據(jù)集中的8類事物。( B ) 相似度矩陣 (右)是在呈現(xiàn)新的“cat”類之后,作者團(tuán)隊(duì)根據(jù)最后一個(gè)卷積層的激活函數(shù)計(jì)算獲得。對(duì)相似矩陣應(yīng)用層次聚類(左),在樹狀圖中顯示動(dòng)物(橄欖綠)和交通工具(藍(lán)色)兩個(gè)上義詞類別的分組情況。( C ) 作者團(tuán)隊(duì)在5種不同的條件下預(yù)訓(xùn)練CNN學(xué)習(xí)新的“cat”類(橄欖綠),直到性能平穩(wěn):1)FoL(共計(jì)n=5000張圖像/epoch);2)FIL(共計(jì)n=45000張圖像/epoch);3) PIL(共計(jì)n=2400張圖像/epoch);4) SWIL(共計(jì)n=2400張圖像/epoch);5) EqWIL(共計(jì)n=2400張圖像/epoch)。每個(gè)條件重復(fù)10次。(D)FoL(黑色)、FIL(藍(lán)色)、PIL(棕色)、SWIL(洋紅色)和 EqWIL(金色)預(yù)測新類別、相似舊類別(CIFAR10數(shù)據(jù)集中的其他動(dòng)物類)和不同舊類別(“plane” 、“ship” 和 “truck”)的召回率,預(yù)測所有類別的總準(zhǔn)確率,以及在測試數(shù)據(jù)集上的交叉熵?fù)p失,其中橫坐標(biāo)都是epoch數(shù)。
6 新內(nèi)容與舊類別的一致性對(duì)學(xué)習(xí)時(shí)間和所需數(shù)據(jù)的影響
如果一項(xiàng)新內(nèi)容可以添加到先前學(xué)習(xí)過的類別中,而不需要對(duì)網(wǎng)絡(luò)進(jìn)行較大更改,則稱二者具有一致性?;诖丝蚣?,與干擾多個(gè)現(xiàn)有類別(低一致性)的新類別相比,學(xué)習(xí)干擾更少現(xiàn)有類別(高一致性)的新類別可以更容易地集成到網(wǎng)絡(luò)中。
為了測試上述推斷,作者團(tuán)隊(duì)使用上一節(jié)中經(jīng)過預(yù)訓(xùn)練的CNN,在前面描述的所有5種學(xué)習(xí)條件下,學(xué)習(xí)了一個(gè)新的“car”類別。圖5A顯示了“car”類別的相似性矩陣,與其他現(xiàn)有類別相比,“car”和“truck”、“ship”和“plane”在同一層次節(jié)點(diǎn)下,說明它們更相似。為了進(jìn)一步確認(rèn),作者團(tuán)隊(duì)在用于相似性計(jì)算的激活層上進(jìn)行了t-SNE降維可視化分析(圖5B)。研究發(fā)現(xiàn)“car”類與其他交通工具類(“truck”、“ship”和“plane”)有顯著重疊,而“cat”類與其他動(dòng)物類(“dog”、 “frog”(“青蛙”)、“horse”(“馬”)、“bird”(“鳥”)和“deer”(“鹿”))有重疊。
和作者團(tuán)隊(duì)預(yù)期相符,F(xiàn)oL學(xué)習(xí)“car”類別時(shí)會(huì)產(chǎn)生災(zāi)難性干擾,對(duì)相近的舊類別干擾性更強(qiáng),而使用FIL克服了這一點(diǎn)(圖5D)。對(duì)于PIL、SWIL和EqWIL,每個(gè)epoch總共有n=2000張圖像(圖5C)。使用SWIL算法,模型學(xué)習(xí)新的“car”類別可以達(dá)到和FIL(H=0.79,P>0.05)相近的精度,而對(duì)現(xiàn)有類別(包括相似和不同類別)的干擾最小。如圖5D第2列所示,使用EqWIL,模型學(xué)習(xí)新“car”類的方式與SWIL相同,但對(duì)其他相似類別(例如“truck”)的干擾程度更高(H=53.81,P<0.05)。
與FIL相比,SWIL可以更快地學(xué)習(xí)新內(nèi)容,加速比=48.75x(45000×12/(2000×6)),內(nèi)存需求減少,內(nèi)存比=22.5x。與“cat”(48.75x vs.31.25x)相比,“car”可以通過交錯(cuò)更少的類(如“truck”、“ship”和“plane”)更快地學(xué)習(xí),而“cat”與更多的類別(如“dog” 、“frog” 、“horse” 、“frog” 和“deer”)重疊。這些仿真實(shí)驗(yàn)表明,交叉和加速學(xué)習(xí)新類別所需的舊類別數(shù)據(jù)量,取決于新信息與先驗(yàn)知識(shí)的一致性。
圖 5:( A ) 作者團(tuán)隊(duì)根據(jù)倒數(shù)第二層激活函數(shù)計(jì)算獲得相似度矩陣(左),以及呈現(xiàn)新的“car”類別后對(duì)相似度矩陣進(jìn)行層次聚類后的結(jié)果圖(右)。( B ) 模型分別學(xué)習(xí)新的“car”類別和“cat”類別,經(jīng)過最后一個(gè)卷積層過激活函數(shù)后,作者團(tuán)隊(duì)進(jìn)行t-SNE降維可視化的結(jié)果圖。( C ) 作者團(tuán)隊(duì)在5種不同的條件下預(yù)訓(xùn)練CNN學(xué)習(xí)新的“car”類(橄欖綠),直到性能平穩(wěn):1)FoL(共計(jì)n=5000張圖像/epoch);2)FIL(共計(jì)n=45000張圖像/epoch);3) PIL(共計(jì)n=2000張圖像/epoch);4) SWIL(共計(jì)n=2000張圖像/epoch);5) EqWIL(共計(jì)n=2000張圖像/epoch)。(D)FoL(黑色)、FIL(藍(lán)色)、PIL(棕色)、SWIL(洋紅色)和 EqWIL(金色)預(yù)測新類別、相似舊類別(“plane” 、“ship” 和 “truck”)和不同舊類別(CIFAR10數(shù)據(jù)集中的其他動(dòng)物類)的召回率,預(yù)測所有類別的總準(zhǔn)確率,以及在測試數(shù)據(jù)集上的交叉熵?fù)p失,其中橫坐標(biāo)都是epoch數(shù)。每張圖顯示的是重復(fù)10次后的平均值,陰影區(qū)域?yàn)椤? SEM。
7 利用SWIL進(jìn)行序列學(xué)習(xí)
接下來,作者團(tuán)隊(duì)測試是否可以使用SWIL學(xué)習(xí)序列化形式呈現(xiàn)的新內(nèi)容(序列學(xué)習(xí)框架)。為此他們采用了圖4中經(jīng)過訓(xùn)練的CNN模型,在FIL和SWIL條件下學(xué)習(xí)CIFAR10數(shù)據(jù)集中的“cat”類(任務(wù)1),只在CIFAR10的剩余9個(gè)類別上訓(xùn)練,然后在每個(gè)條件下訓(xùn)練模型學(xué)習(xí)新的“car”類(任務(wù)2)。圖6第1列顯示了SWIL條件下學(xué)習(xí)“car”類別時(shí),其他各項(xiàng)類別的圖像數(shù)量分布情況(共計(jì)n=2500張圖像/epoch)。需要注意的是,預(yù)測“cat”類時(shí)也交叉學(xué)習(xí)新的“car”類。由于在FIL條件下模型性能最佳,SWIL僅與FIL進(jìn)行了結(jié)果比較。
如圖6所示,SWIL預(yù)測新、舊類別的能力與FIL相當(dāng)(H=14.3,P>0.05)。模型使用SWIL算法可以更快地學(xué)習(xí)新的“car”類別,加速比為45x(50000×20/(2500×8)),每個(gè)epoch的內(nèi)存占用比FIL少20倍。模型學(xué)習(xí)“cat”和“car”類別時(shí),在SWIL條件下每個(gè)epoch使用的圖像數(shù)量(內(nèi)存比和加速比分別為18.75x 和 20x),少于在FIL條件下每個(gè)epoch使用的整個(gè)數(shù)據(jù)集(內(nèi)存比和加速比分別為31.25x 和45x),并且仍然可以快速學(xué)習(xí)新類別。擴(kuò)展這一思想,隨著學(xué)過的類別數(shù)目不斷增加,作者團(tuán)隊(duì)預(yù)期模型的學(xué)習(xí)時(shí)間和數(shù)據(jù)存儲(chǔ)會(huì)成倍減少,從而更高效地學(xué)習(xí)新類別,這或許反映了人類大腦實(shí)際學(xué)習(xí)時(shí)的情況。
實(shí)驗(yàn)結(jié)果表明,SWIL可在序列學(xué)習(xí)框架中集成多個(gè)新類,使神經(jīng)網(wǎng)絡(luò)能夠在不受干擾的情況下持續(xù)學(xué)習(xí)。
圖6:作者團(tuán)隊(duì)訓(xùn)練6層CNN學(xué)習(xí)新的“cat”類(任務(wù)1),然后學(xué)習(xí)“car”類(任務(wù)2),直到性能在以下兩種情況下趨于穩(wěn)定:1)FIL:包含所有舊類別(以不同顏色繪制)和以相同概率呈現(xiàn)的新類別(“cat”/“car”)圖像;2) SWIL:根據(jù)與新類別(“cat”/“car”)的相似性進(jìn)行加權(quán)并按比例使用舊類別示例。同時(shí)將任務(wù)1中學(xué)習(xí)的“cat”類包括在內(nèi),并根據(jù)任務(wù)2中學(xué)習(xí)“car”類的相似性進(jìn)行加權(quán)。第1張子圖表示每個(gè)epoch使用的圖像數(shù)量分布情況,其余各子圖分別表示FIL(藍(lán)色)和SWIL(洋紅色)預(yù)測新類別、相似舊類別和不同舊類別的召回率,預(yù)測所有類別的總準(zhǔn)確率,以及在測試數(shù)據(jù)集上的交叉熵?fù)p失,其中橫坐標(biāo)都是epoch數(shù)。
8 利用SWIL擴(kuò)大類別間的距離,減少學(xué)習(xí)時(shí)間和數(shù)據(jù)量
作者團(tuán)隊(duì)最后測試了SWIL算法的泛化性,驗(yàn)證其是否可以學(xué)習(xí)包括更多類別的數(shù)據(jù)集,以及是否適用于更復(fù)雜的網(wǎng)絡(luò)架構(gòu)。
他們?cè)贑IFAR100數(shù)據(jù)集(訓(xùn)練集500張圖像/類,測試集100張圖像/類)上訓(xùn)練了一個(gè)復(fù)雜的CNN模型-VGG19(共有19層),學(xué)習(xí)了其中的90個(gè)類別。然后對(duì)網(wǎng)絡(luò)進(jìn)行再訓(xùn)練,學(xué)習(xí)新類別。圖7A顯示了基于CIFAR100數(shù)據(jù)集,作者團(tuán)隊(duì)根據(jù)倒數(shù)第二層的激活函數(shù)計(jì)算的相似性矩陣。如圖7B所示,新“train”(“火車”)類與許多現(xiàn)有的交通工具類別(如“bus” (“公共汽車”)、“streetcar” (“有軌電車”)和“tractor”(“拖拉機(jī)”)等)很相似。
與FIL相比,SWIL可以更快地學(xué)習(xí)新事物(加速比=95.45x (45500×6/(1430×2)))并且使用的數(shù)據(jù)量 (內(nèi)存比=31.8x) 顯著減少,而性能基本相同(H=8.21, P>0.05) 。如圖7C所示,在PIL(H=10.34,P<0.05)和EqWIL(H=24.77,P<0.05)條件下,模型預(yù)測新類別的召回率較低并且產(chǎn)生的干擾較大,而SWIL克服了上述不足。
同時(shí),為了探索不同類別表征之間的較大距離是否構(gòu)成了加速模型學(xué)習(xí)的基本條件,作者團(tuán)隊(duì)另外訓(xùn)練了兩種神經(jīng)網(wǎng)絡(luò)模型:
1)6層CNN(與基于CIFAR10的圖4和圖5相同);
2)VGG11(11層)學(xué)習(xí)CIFAR100數(shù)據(jù)集中的90個(gè)類別,僅在FIL和SWIL兩個(gè)條件下對(duì)新的“train”類進(jìn)行訓(xùn)練。
如圖7B所示,對(duì)于上述兩種網(wǎng)絡(luò)模型,新的“train”類和交通工具類別之間的重疊度更高,但與VGG19模型相比,各類別的分離度較低。與FIL相比,SWIL學(xué)習(xí)新事物的速度與層數(shù)的增加大致呈線性關(guān)系(斜率=0.84)。該結(jié)果表明,類別間表征距離的增加可以加速學(xué)習(xí)并減少內(nèi)存負(fù)載。
圖7:( A ) VGG19學(xué)習(xí)新的“train”類后,作者團(tuán)隊(duì)根據(jù)倒數(shù)第二層激活函數(shù)計(jì)算的相似性矩陣?!皌ruck” 、“streetcar” 、“bus” 、“house” 和 “tractor”5種類別與“train”的相似性最大。從相似度矩陣中排除對(duì)角元素(相似度 =1)。(B,左)作者團(tuán)隊(duì)針對(duì)6層CNN、VGG11和VGG19網(wǎng)絡(luò),經(jīng)過倒數(shù)第二層激活函數(shù)后,進(jìn)行t-SNE降維可視化的結(jié)果圖。(B,右)縱軸表示加速比(FIL/SWIL),橫軸表示3個(gè)不同網(wǎng)絡(luò)的層數(shù)相對(duì)于6層CNN的比率。黑色虛線、紅色虛線和藍(lán)色實(shí)線分別代表斜率 =1的標(biāo)準(zhǔn)線、最佳擬合線和仿真結(jié)果。( C ) VGG19模型的學(xué)習(xí)情況:FoL(黑色)、FIL(藍(lán)色)、PIL(棕色)、SWIL(洋紅色)和 EqWIL(金色)預(yù)測新“train”類、相似舊類別(交通工具類別)和不同舊類別(除了交通工具類別)的召回率,預(yù)測所有類別的總準(zhǔn)確率,以及在測試數(shù)據(jù)集上的交叉熵?fù)p失,其中橫坐標(biāo)都是epoch數(shù)。每張圖顯示的是重復(fù)10次后的平均值,陰影區(qū)域?yàn)椤? SEM。( D ) 從左到右依次表示模型預(yù)測Fashion-MNIST“boot”類(圖3)、CIFAR10“cat”類(圖4)、CIFAR10“car”類(圖5)和CIFAR100“train”類的召回率,是SWIL(洋紅色)和FIL(藍(lán)色)使用的圖像總數(shù)(對(duì)數(shù)比例)的函數(shù)。“N”表示每種學(xué)習(xí)條件下每個(gè)epoch使用的圖像總數(shù)(包括新、舊類別)。
如果在更多非重疊類上訓(xùn)練網(wǎng)絡(luò),并且各表征之間的距離更大,速度是否會(huì)進(jìn)一步提升?
為此,作者團(tuán)隊(duì)采用了一個(gè)深度線性網(wǎng)絡(luò)(用于圖1-3中的Fashion-MNIST示例),并對(duì)其進(jìn)行訓(xùn)練,以學(xué)習(xí)由8個(gè)Fashion-MNIST類別(不包括“bags”和“boot”類)和10個(gè)Digit-MNIST類別形成的組合數(shù)據(jù)集,然后訓(xùn)練網(wǎng)絡(luò)學(xué)習(xí)新的“boot”類別。
和作者團(tuán)隊(duì)的預(yù)期相符,“boot”與舊類別“sandals”和“sneaker”相似度更高,其次是其余的Fashion-MNIST類(主要包括服飾類圖像),最后Digit-MNIST類(主要包括數(shù)字類圖像)。
基于此,作者團(tuán)隊(duì)首先交織了更多相似的舊類別樣本,再交織Fashion-MNIST和Digit-MNIST類樣本(共計(jì)n=350張圖像/epoch)。實(shí)驗(yàn)結(jié)果表明,與FIL類似,SWIL可以快速學(xué)習(xí)新類別內(nèi)容而不受干擾,但使用的數(shù)據(jù)子集要小得多,內(nèi)存比為325.7x (114000/350) ,加速比為162.85x (228000/1400)。作者團(tuán)隊(duì)在當(dāng)前結(jié)果中觀察到的加速比為2.1x (162.85/77.1),與Fashion-MNIST數(shù)據(jù)集相比,類別數(shù)目增加了 2.25倍 (18/8)。
本節(jié)的實(shí)驗(yàn)結(jié)果有助于確定SWIL可以適用于更復(fù)雜的數(shù)據(jù)集 (CIFAR100) 和神經(jīng)網(wǎng)絡(luò)模型(VGG19),證明了該算法的泛化性。同時(shí)證明了擴(kuò)大類別之間的內(nèi)部距離或增加非重疊類別的數(shù)量,可能會(huì)進(jìn)一步提高學(xué)習(xí)速度并降低內(nèi)存負(fù)載。
9 總結(jié)
人工神經(jīng)網(wǎng)絡(luò)在持續(xù)學(xué)習(xí)方面面臨重大挑戰(zhàn),通常表現(xiàn)出災(zāi)難性干擾。
為了克服此問題,許多研究都使用了完全交錯(cuò)學(xué)習(xí)(FIL),即新舊內(nèi)容交叉學(xué)習(xí),聯(lián)合訓(xùn)練網(wǎng)絡(luò)。FIL需要在每次學(xué)新信息時(shí)交織所有現(xiàn)有信息,使其成為一個(gè)生物學(xué)意義上不可信且耗時(shí)的過程。最近,有研究表明FIL可能并非必需,僅交錯(cuò)與新內(nèi)容具有實(shí)質(zhì)表征相似性的舊內(nèi)容,即采用相似性加權(quán)交錯(cuò)學(xué)習(xí)(SWIL)的方法可以達(dá)到相同的學(xué)習(xí)效果。然而,有人對(duì)SWIL的可擴(kuò)展性表示了擔(dān)憂。本文擴(kuò)展了SWIL算法,并基于不同的數(shù)據(jù)集(Fashion-MNIST、CIFAR10 和 CIFAR100)和神經(jīng)網(wǎng)絡(luò)模型(深度線性網(wǎng)絡(luò)和CNN)對(duì)其進(jìn)行了測試。在所有條件下,與部分交錯(cuò)學(xué)習(xí)(PIL)相比,相似性加權(quán)交錯(cuò)學(xué)習(xí)(SWIL)和等權(quán)交錯(cuò)學(xué)習(xí)(EqWIL)在學(xué)習(xí)新類別方面的表現(xiàn)更好。這和作者團(tuán)隊(duì)的預(yù)期相符,因?yàn)榕c舊類別相比,SWIL和EqWIL增加了新類別的相對(duì)頻率。
本文同時(shí)還證明,與同等子抽樣現(xiàn)有類別(即EqWIL方法)相比,仔細(xì)選擇和交織相似內(nèi)容減少了對(duì)相近舊類別的災(zāi)難性干擾。在預(yù)測新類別和現(xiàn)有類別方面,SWIL的性能與FIL類似,卻顯著加快了學(xué)習(xí)新內(nèi)容的速度(圖7D),同時(shí)大大減少了所需的訓(xùn)練數(shù)據(jù)。SWIL可以在序列學(xué)習(xí)框架中學(xué)習(xí)新類別,進(jìn)一步證明了其泛化能力。
最后,與許多舊類別具有相似性的新類別相比,如果其與之前學(xué)過的類別重疊更少(距離更大),可以縮短集成時(shí)間,并且數(shù)據(jù)效率更高??傮w來說,實(shí)驗(yàn)結(jié)果提供了一種可能的見解,即大腦事實(shí)上通過減少不切實(shí)際的訓(xùn)練時(shí)間,克服了原始CLST模型的一項(xiàng)主要弱點(diǎn)。?