反向傳播、前向傳播都不要,這種無(wú)梯度學(xué)習(xí)方法是Hinton想要的嗎?
「我們應(yīng)該拋棄反向傳播并重新開(kāi)始?!乖缭趲啄昵?,使反向傳播成為深度學(xué)習(xí)核心技術(shù)之一的 Geoffrey Hinton 就發(fā)表過(guò)這樣一個(gè)觀點(diǎn)。
而一直對(duì)反向傳播持懷疑態(tài)度的也是 Hinton。因?yàn)檫@種方法既不符合生物學(xué)機(jī)理,與大規(guī)模模型的并行性也不兼容。所以,Hinton 等人一直在尋找替代反向傳播的新方法,比如 2022 年的前向 - 前向算法。但由于性能、可泛化性等方面仍然存在問(wèn)題,這一方向的探索一直沒(méi)有太大起色。
最近,來(lái)自牛津大學(xué)和 Mila 實(shí)驗(yàn)室的研究者向這一問(wèn)題發(fā)起了挑戰(zhàn)。他們開(kāi)發(fā)了一種名為 NoProp 的新型學(xué)習(xí)方法,該方法既不依賴前向傳播也不依賴反向傳播。相反,NoProp 從擴(kuò)散和流匹配(flow matching)方法中汲取靈感,每一層獨(dú)立地學(xué)習(xí)對(duì)噪聲目標(biāo)進(jìn)行去噪。
- 論文標(biāo)題:NOPROP: TRAINING NEURAL NETWORKS WITHOUT BACK-PROPAGATION OR FORWARD-PROPAGATION
- 論文鏈接:https://arxiv.org/pdf/2503.24322v1
研究人員認(rèn)為這項(xiàng)工作邁出了引入一種新型無(wú)梯度學(xué)習(xí)方法的第一步。這種方法不學(xué)習(xí)分層表示 —— 至少不是通常意義上的分層表示。NoProp 需要預(yù)先將每一層的表示固定為目標(biāo)的帶噪聲版本,學(xué)習(xí)一個(gè)局部去噪過(guò)程,然后可以在推理時(shí)利用這一過(guò)程。
他們?cè)?MNIST、CIFAR-10 和 CIFAR-100 圖像分類(lèi)基準(zhǔn)測(cè)試上展示了該方法的有效性。研究結(jié)果表明,NoProp 是一種可行的學(xué)習(xí)算法,與其他現(xiàn)有的無(wú)反向傳播方法相比,它實(shí)現(xiàn)了更高的準(zhǔn)確率,更易于使用且計(jì)算效率更高。通過(guò)擺脫傳統(tǒng)的基于梯度的學(xué)習(xí)范式,NoProp 改變了網(wǎng)絡(luò)內(nèi)部的貢獻(xiàn)分配(credit assignment)方式,實(shí)現(xiàn)了更高效的分布式學(xué)習(xí),并可能影響學(xué)習(xí)過(guò)程的其他特性。
在看了論文之后,有人表示,「NoProp 用獨(dú)立的、無(wú)梯度的、基于去噪的層訓(xùn)練取代了傳統(tǒng)的反向傳播,以實(shí)現(xiàn)高效且非層次化的貢獻(xiàn)分配。這是一項(xiàng)具有開(kāi)創(chuàng)性意義的工作,可能會(huì)對(duì)分布式學(xué)習(xí)系統(tǒng)產(chǎn)生重大影響,因?yàn)樗鼜母旧细淖兞素暙I(xiàn)分配機(jī)制。
其數(shù)學(xué)公式中涉及每層特定的噪聲模型和優(yōu)化目標(biāo),這使得無(wú)需梯度鏈即可進(jìn)行獨(dú)立學(xué)習(xí)。其優(yōu)勢(shì)在于通過(guò)讓每一層獨(dú)立地對(duì)一個(gè)固定的噪聲目標(biāo)進(jìn)行去噪,從而繞過(guò)了反向傳播中基于順序梯度的貢獻(xiàn)分配方式。這種方式能夠?qū)崿F(xiàn)更高效、可并行化的更新,避免了梯度消失等問(wèn)題,盡管它并未構(gòu)建傳統(tǒng)的層次化表示。」
還有人表示,「我在查看擴(kuò)散模型架構(gòu)時(shí)也產(chǎn)生過(guò)這樣的想法…… 然而,我認(rèn)為這可能是一種非最優(yōu)的方法,所以它現(xiàn)在表現(xiàn)得如此出色讓我感到很神秘。顯而易見(jiàn)的是其并行化優(yōu)勢(shì)?!?/span>
為什么要尋找反向傳播的替代方案?
反向傳播雖是訓(xùn)練神經(jīng)網(wǎng)絡(luò)的主流方法,但研究人員一直在尋找替代方案,原因有三:
- 生物學(xué)合理性不足:反向傳播需要前向傳遞和后向傳遞嚴(yán)格交替,與生物神經(jīng)系統(tǒng)運(yùn)作方式不符。
- 內(nèi)存消耗大:必須存儲(chǔ)中間激活值以計(jì)算梯度,造成顯著內(nèi)存開(kāi)銷(xiāo)。
- 并行計(jì)算受限:梯度的順序傳播限制了并行處理能力,影響大規(guī)模分布式學(xué)習(xí),并導(dǎo)致學(xué)習(xí)過(guò)程中的干擾和災(zāi)難性遺忘問(wèn)題。
目前為止,反向傳播的替代優(yōu)化方法包括:
- 無(wú)梯度方法:如直接搜索方法和基于模型的方法
- 零階梯度方法:使用有限差分近似梯度
- 進(jìn)化策略
- 基于局部損失的方法:如差異目標(biāo)傳播(difference target propagation)和前向 - 前向算法
但這些方法因在準(zhǔn)確性、計(jì)算效率、可靠性和可擴(kuò)展性方面的限制,尚未在神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)中廣泛應(yīng)用。
方法解析
NoProp
設(shè) x 和 y 是分類(lèi)數(shù)據(jù)集中的一個(gè)輸入 - 標(biāo)簽樣本對(duì),假設(shè)從數(shù)據(jù)分布 q?(x,y) 中抽取,z?,z?,...,z? ∈ R? 是神經(jīng)網(wǎng)絡(luò)中 T 個(gè)模塊的對(duì)應(yīng)隨機(jī)中間激活值,目標(biāo)是訓(xùn)練該網(wǎng)絡(luò)以估計(jì) q?(y|x)。
定義兩個(gè)分布 p 和 q,按以下方式分解:
p 分布可以被解釋為一個(gè)隨機(jī)前向傳播過(guò)程,它迭代地計(jì)算下一個(gè)激活值 z?,給定前一個(gè)激活值 z??? 和輸入 x。實(shí)際上,可以看到它可以被明確表示為一個(gè)添加了高斯噪聲的殘差網(wǎng)絡(luò):
其中 N?(?|0,1) 是一個(gè) d 維高斯密度函數(shù),均值向量為 0,協(xié)方差矩陣為單位矩陣,a?,b?,c? 是標(biāo)量(如下所示),b?z??? 是一個(gè)加權(quán)跳躍連接,而 ?θ?(z???,x) 是由參數(shù) θ? 參數(shù)化的殘差塊。注意,這種計(jì)算結(jié)構(gòu)不同于標(biāo)準(zhǔn)深度神經(jīng)網(wǎng)絡(luò),后者沒(méi)有從輸入 x 到每個(gè)模塊的直接連接。遵循變分?jǐn)U散模型方法,也可以將 p 解釋為給定 x 條件下 y 的條件隱變量模型,其中 z? 是一系列隱變量??梢允褂米兎止綄W(xué)習(xí)前向過(guò)程 p,其中 q 分布作為變分后驗(yàn)。關(guān)注的目標(biāo)是 ELBO,這是對(duì)數(shù)似然 log p (y|x)(即證據(jù))的下界:
遵循 Sohl-Dickstein 和 Kingma 等人的方法,將變分后驗(yàn) q 固定為一個(gè)易于處理的高斯分布。在這里使用方差保持的 Ornstein-Uhlenbeck 過(guò)程:
其中 u? 是類(lèi)別標(biāo)簽 y 在 R? 中的嵌入,由可訓(xùn)練的嵌入矩陣 W (Embed) ∈ R??? 定義,m 是類(lèi)別數(shù)量。嵌入由 u? = {W (Embed)}? 給出。利用高斯分布的標(biāo)準(zhǔn)性質(zhì),我們可以得到:
其中 ?? = ∏????α?,μ?(z???,u?) = a?u? + b?z???,a? = √(??(1-α???))/(1-????),b? = √(α???(1-??))/(1-????),以及 c? = (1-??)(1-α???)/(1-????)。為了優(yōu)化 ELBO,將 p 參數(shù)化以匹配 q 的形式:
其中 p (z?) 被選為 Ornstein-Uhlenbeck 過(guò)程的平穩(wěn)分布,?θ?(z???,x) 是由參數(shù) θ? 參數(shù)化的神經(jīng)網(wǎng)絡(luò)模塊。給定 z??? 和 x 對(duì) z? 進(jìn)行采樣的結(jié)果計(jì)算如殘差架構(gòu)(方程 3)所示,其中 a?,b?,c? 如上所述。最后,將此參數(shù)化代入 ELBO(方程 4)并簡(jiǎn)化,得到 NoProp 目標(biāo)函數(shù):
其中 SNR (t) = ??/(1-??) 是信噪比,η 是一個(gè)超參數(shù),U {1,T} 是在整數(shù) 1,...,T 上的均勻分布。我們看到每個(gè) ?θ?(z???,x) 都被訓(xùn)練為直接預(yù)測(cè) u?,給定 z??? 和 x,使用 L2 損失,而 p?θout (y|z?) 被訓(xùn)練為最小化交叉熵?fù)p失。每個(gè)模塊 ?θ?(z???,x) 都是獨(dú)立訓(xùn)練的,這是在沒(méi)有通過(guò)網(wǎng)絡(luò)進(jìn)行前向或反向傳播的情況下實(shí)現(xiàn)的。
實(shí)現(xiàn)細(xì)節(jié)
NoProp 架構(gòu)如圖 1 所示。
在推理階段,NoProp 架構(gòu)從高斯噪聲 z?開(kāi)始,通過(guò)一系列擴(kuò)散步驟轉(zhuǎn)換潛變量。每個(gè)步驟中,潛變量 z?通過(guò)擴(kuò)散動(dòng)態(tài)塊 u?演化,形成序列 z?→z?→...→z?,其中每個(gè) u?都以前一狀態(tài) z???和輸入圖像 x 為條件。最終,z?通過(guò)線性層和 softmax 函數(shù)映射為預(yù)測(cè)標(biāo)簽?。
訓(xùn)練時(shí),各時(shí)間步驟被采樣,每個(gè)擴(kuò)散塊 u?獨(dú)立訓(xùn)練,同時(shí)線性層和嵌入矩陣與擴(kuò)散塊共同優(yōu)化以防止類(lèi)別嵌入崩潰。對(duì)于流匹配變體,u?表示 ODE 動(dòng)態(tài),標(biāo)簽預(yù)測(cè)通過(guò)尋找與 z?在歐幾里得距離上最接近的類(lèi)別嵌入獲得。
訓(xùn)練所用的模型如圖 6 所示,其中左邊為離散時(shí)間情況的模型,右邊為連續(xù)時(shí)間情況的模型。
作者在三種情況下構(gòu)建了相似但有區(qū)別的神經(jīng)網(wǎng)絡(luò)模型:
- 離散時(shí)間擴(kuò)散:神經(jīng)網(wǎng)絡(luò) ?θt 將圖像 x 和潛變量 zt?1 通過(guò)不同嵌入路徑處理后合并。圖像用卷積模塊處理,潛變量根據(jù)維度匹配情況用卷積或全連接網(wǎng)絡(luò)處理。合并后的表示通過(guò)全連接層產(chǎn)生 logits,應(yīng)用 softmax 后得到類(lèi)別嵌入上的概率分布,最終輸出為類(lèi)別嵌入的加權(quán)和。
- 連續(xù)時(shí)間擴(kuò)散:在離散模型基礎(chǔ)上增加時(shí)間戳 t 作為輸入,使用位置嵌入編碼并與其他特征合并,整體結(jié)構(gòu)與離散情況相似。
- 流匹配:架構(gòu)與連續(xù)時(shí)間擴(kuò)散相同,但不應(yīng)用 softmax 限制,允許 v?θ 表示嵌入空間中的任意方向,而非僅限于類(lèi)別嵌入的凸組合。
所有模型均使用線性層加 softmax 來(lái)參數(shù)化相應(yīng)方程中的條件概率分布。
對(duì)于離散時(shí)間擴(kuò)散,作者使用固定余弦噪聲調(diào)度。對(duì)于連續(xù)時(shí)間擴(kuò)散,作者將噪聲調(diào)度與模型共同訓(xùn)練。
實(shí)驗(yàn)結(jié)果
作者對(duì) NoProp 方法進(jìn)行了評(píng)估,分別在離散時(shí)間設(shè)置下與反向傳播方法進(jìn)行比較,在連續(xù)時(shí)間設(shè)置下與伴隨敏感性方法(adjoint sensitivity method)進(jìn)行比較,場(chǎng)景是圖像分類(lèi)任務(wù)。
結(jié)果如表 1 所示,表明 NoProp-DT 在離散時(shí)間設(shè)置下在 MNIST、CIFAR-10 和 CIFAR-100 數(shù)據(jù)集上的性能與反向傳播方法相當(dāng),甚至更好。此外,NoProp-DT 在性能上優(yōu)于以往的無(wú)反向傳播方法,包括 Forward-Forward 算法、Difference Target 傳播以及一種稱為 Local Greedy Forward Gradient Activity-Perturbed 的前向梯度方法。雖然這些方法使用了不同的架構(gòu),并且不像 NoProp 那樣顯式地對(duì)圖像輸入進(jìn)行條件約束 —— 這使得直接比較變得困難 —— 但 NoProp 具有不依賴前向傳播的獨(dú)特優(yōu)勢(shì)。
此外,如表 2 所示,NoProp 在訓(xùn)練過(guò)程中減少了 GPU 內(nèi)存消耗。
為了說(shuō)明學(xué)習(xí)到的類(lèi)別嵌入,圖 2 可視化了 CIFAR-10 數(shù)據(jù)集中類(lèi)別嵌入的初始化和最終學(xué)習(xí)結(jié)果,其中嵌入維度與圖像維度匹配。
在連續(xù)時(shí)間設(shè)置下,NoProp-CT 和 NoProp-FM 的準(zhǔn)確率低于 NoProp-DT,這可能是由于它們對(duì)時(shí)間變量 t 的額外條件約束。然而,它們?cè)?CIFAR-10 和 CIFAR-100 數(shù)據(jù)集上通常優(yōu)于伴隨敏感性方法,無(wú)論是在準(zhǔn)確率還是計(jì)算效率方面。雖然伴隨方法在 MNIST 數(shù)據(jù)集上達(dá)到了與 NoProp-CT 和 NoProp-FM 相似的準(zhǔn)確率,但其訓(xùn)練速度明顯較慢,如圖 3 所示。
對(duì)于 CIFAR-100 數(shù)據(jù)集,當(dāng)使用 one-hot 編碼時(shí),NoProp-FM 無(wú)法有效學(xué)習(xí),導(dǎo)致準(zhǔn)確率提升非常緩慢。相比之下,NoProp-CT 仍然優(yōu)于伴隨方法。然而,一旦類(lèi)別嵌入與模型聯(lián)合學(xué)習(xí),NoProp-FM 的性能顯著提高。
作者還對(duì)類(lèi)別概率的參數(shù)化和類(lèi)別嵌入矩陣 W_Embed 的初始化進(jìn)行了消融研究,結(jié)果分別如圖 4 和圖 5 所示。消融結(jié)果表明,類(lèi)別概率的參數(shù)化方法之間沒(méi)有一致的優(yōu)勢(shì),性能因數(shù)據(jù)集而異。對(duì)于類(lèi)別嵌入的初始化,正交初始化和原型初始化通常與隨機(jī)初始化相當(dāng),甚至優(yōu)于隨機(jī)初始化。
更多詳細(xì)內(nèi)容請(qǐng)參見(jiàn)原論文。