一階優(yōu)化算法啟發(fā),北大林宙辰團隊提出具有萬有逼近性質(zhì)的神經(jīng)網(wǎng)絡(luò)架構(gòu)的設(shè)計方法
以神經(jīng)網(wǎng)絡(luò)為基礎(chǔ)的深度學習技術(shù)已經(jīng)在諸多應(yīng)用領(lǐng)域取得了有效成果。在實踐中,網(wǎng)絡(luò)架構(gòu)可以顯著影響學習效率,一個好的神經(jīng)網(wǎng)絡(luò)架構(gòu)能夠融入問題的先驗知識,穩(wěn)定網(wǎng)絡(luò)訓練,提高計算效率。目前,經(jīng)典的網(wǎng)絡(luò)架構(gòu)設(shè)計方法包括人工設(shè)計、神經(jīng)網(wǎng)絡(luò)架構(gòu)搜索(NAS)[1]、以及基于優(yōu)化的網(wǎng)絡(luò)設(shè)計方法 [2]。人工設(shè)計的網(wǎng)絡(luò)架構(gòu)如 ResNet 等;神經(jīng)網(wǎng)絡(luò)架構(gòu)搜索則通過搜索或強化學習的方式在搜索空間中尋找最佳網(wǎng)絡(luò)結(jié)構(gòu);基于優(yōu)化的設(shè)計方法中的一種主流范式是算法展開(algorithm unrolling),該方法通常在有顯式目標函數(shù)的情況下,從優(yōu)化算法的角度設(shè)計網(wǎng)絡(luò)結(jié)構(gòu)。
然而,現(xiàn)有經(jīng)典神經(jīng)網(wǎng)絡(luò)架構(gòu)設(shè)計大多忽略了網(wǎng)絡(luò)的萬有逼近性質(zhì) —— 這是神經(jīng)網(wǎng)絡(luò)具備強大性能的關(guān)鍵因素之一。因此,這些設(shè)計方法在一定程度上失去了網(wǎng)絡(luò)的先驗性能保障。盡管兩層神經(jīng)網(wǎng)絡(luò)在寬度趨于無窮的時候就已具有萬有逼近性質(zhì) [3],在實際中,我們通常只能考慮有限寬的網(wǎng)絡(luò)結(jié)構(gòu),而這方面的表示分析的結(jié)果十分有限。實際上,無論是啟發(fā)性的人工設(shè)計,還是黑箱性質(zhì)的神經(jīng)網(wǎng)絡(luò)架構(gòu)搜索,都很難在網(wǎng)絡(luò)設(shè)計中考慮萬有逼近性質(zhì)?;趦?yōu)化的神經(jīng)網(wǎng)絡(luò)設(shè)計雖然相對更具解釋性,但其通常需要一個顯式的目標函數(shù),這使得設(shè)計的網(wǎng)絡(luò)結(jié)構(gòu)種類有限,限制了其應(yīng)用范圍。如何系統(tǒng)性地設(shè)計具有萬有逼近性質(zhì)的神經(jīng)網(wǎng)絡(luò)架構(gòu),仍是一個重要的問題。
為了解決這個問題,北京大學林宙辰教授團隊提出了一種易于操作的基于優(yōu)化算法設(shè)計具有萬有逼近性質(zhì)保障的神經(jīng)網(wǎng)絡(luò)架構(gòu)的方法,其通過將基于梯度的一階優(yōu)化算法的梯度項映射為具有一定性質(zhì)的神經(jīng)網(wǎng)絡(luò)模塊,再根據(jù)實際應(yīng)用問題對模塊結(jié)構(gòu)進行調(diào)整,就可以系統(tǒng)性地設(shè)計具有萬有逼近性質(zhì)的神經(jīng)網(wǎng)絡(luò)架構(gòu),并且可以與現(xiàn)有大多數(shù)基于模塊的網(wǎng)絡(luò)設(shè)計的方法無縫結(jié)合。論文還通過分析神經(jīng)網(wǎng)絡(luò)微分方程(NODE)的逼近性質(zhì)首次證明了具有一般跨層連接的神經(jīng)網(wǎng)絡(luò)的萬有逼近性質(zhì),并利用提出的框架設(shè)計了 ConvNext、ViT 的變種網(wǎng)絡(luò),取得了超越 baseline 的結(jié)果。論文被人工智能頂刊 TPAMI 接收。
- 論文:Designing Universally-Approximating Deep Neural Networks: A First-Order Optimization Approach
- 論文地址:https://ieeexplore.ieee.org/document/10477580
方法簡介
傳統(tǒng)的基于優(yōu)化的神經(jīng)網(wǎng)絡(luò)設(shè)計方法通常從一個具有顯式表示的目標函數(shù)出發(fā),采用特定的優(yōu)化算法進行求解,再將優(yōu)化迭代格式映射為神經(jīng)網(wǎng)絡(luò)架構(gòu),例如著名的 LISTA-NN 就是利用 LISTA 算法求解 LASSO 問題所得 [4],這種方法受限于目標函數(shù)的顯式表達式,可設(shè)計得到的網(wǎng)絡(luò)結(jié)構(gòu)有限。一些研究者嘗試通過自定義目標函數(shù),再利用算法展開等方法設(shè)計網(wǎng)絡(luò)結(jié)構(gòu),但他們也需要如權(quán)重綁定等與實際情況可能不符的假設(shè)。
論文提出的易于操作的網(wǎng)絡(luò)架構(gòu)設(shè)計方法從一階優(yōu)化算法的更新格式出發(fā),將梯度或鄰近點算法寫成如下的更新格式:
其中、
表示第 k 步更新時的(步長)系數(shù),再將梯度項替換為神經(jīng)網(wǎng)絡(luò)中的可學習模塊 T,即可得到 L 層神經(jīng)網(wǎng)絡(luò)的骨架:
整體方法框架見圖 1。
圖 1 網(wǎng)絡(luò)設(shè)計圖示
論文提出的方法可以啟發(fā)設(shè)計 ResNet、DenseNet 等經(jīng)典網(wǎng)絡(luò),并且解決了傳統(tǒng)基于優(yōu)化設(shè)計網(wǎng)絡(luò)架構(gòu)的方法局限于特定目標函數(shù)的問題。
模塊選取與架構(gòu)細節(jié)
該方法所設(shè)計的網(wǎng)絡(luò)模塊 T 只要求有包含兩層網(wǎng)絡(luò)結(jié)構(gòu),即,作為其子結(jié)構(gòu),即可保證所設(shè)計的網(wǎng)絡(luò)具有萬有逼近性質(zhì),其中所表達的層的寬度是有限的(即不隨逼近精度的提高而增長),整個網(wǎng)絡(luò)的萬有逼近性質(zhì)不是靠加寬
的層來獲得的。模塊 T 可以是 ResNet 中廣泛運用的 pre-activation 塊,也可以是 Transformer 中的注意力 + 前饋層的結(jié)構(gòu)。T 中的激活函數(shù)可以是 ReLU、GeLU、Sigmoid 等常用激活函數(shù)。還可以根據(jù)具體任務(wù)在中添加對應(yīng)的歸一化層。另外,
時,設(shè)計的網(wǎng)絡(luò)是隱式網(wǎng)絡(luò) [5],可以用不動點迭代的方法逼近隱格式,或采用隱式微分(implicit differentiation)的方法求解梯度進行更新。
通過等價表示設(shè)計更多網(wǎng)絡(luò)
該方法不要求同一種算法只能對應(yīng)一種結(jié)構(gòu),相反,該方法可以利用優(yōu)化問題的等價表示設(shè)計更多的網(wǎng)絡(luò)架構(gòu),體現(xiàn)其靈活性。例如,線性化交替方向乘子法通常用于求解約束優(yōu)化問題:通過令
即可得到一種可啟發(fā)網(wǎng)絡(luò)的更新迭代格式:
其啟發(fā)的網(wǎng)絡(luò)結(jié)構(gòu)可見圖 2。
圖 2 線性化交替方向乘子法啟發(fā)的網(wǎng)絡(luò)結(jié)構(gòu)
啟發(fā)的網(wǎng)絡(luò)具有萬有逼近性質(zhì)
對該方法設(shè)計的網(wǎng)絡(luò)架構(gòu),可以證明,在模塊滿足此前條件以及優(yōu)化算法(在一般情況下)穩(wěn)定、收斂的條件下,任意一階優(yōu)化算法啟發(fā)的神經(jīng)網(wǎng)絡(luò)在高維連續(xù)函數(shù)空間具有萬有逼近性質(zhì),并給出了逼近速度。論文首次在有限寬度設(shè)定下證明了具有一般跨層連接的神經(jīng)網(wǎng)絡(luò)的萬有逼近性質(zhì)(此前研究基本集中在 FCNN 和 ResNet,見表 1),論文主定理可簡略敘述如下:
主定理(簡略版):設(shè) A 是一個梯度型一階優(yōu)化算法。若算法 A 具有公式 (1) 中的更新格式,且滿足收斂性條件(優(yōu)化算法的常用步長選取均滿足收斂性條件。若在啟發(fā)網(wǎng)絡(luò)中均為可學習的,則可以不需要該條件),則由算法啟發(fā)的神經(jīng)網(wǎng)絡(luò):
在連續(xù)(向量值)函數(shù)空間以及范數(shù)
下具有萬有逼近性質(zhì),其中可學習模塊 T 只要有包含兩層形如
的結(jié)構(gòu)(σ 可以是常用的激活函數(shù))作為其子結(jié)構(gòu)都可以。
常用的 T 的結(jié)構(gòu)如:
1)卷積網(wǎng)絡(luò)中,pre-activation 塊:BN-ReLU-Conv-BN-ReLU-Conv (z),
2)Transformer 中:Attn (z) + MLP (z+Attn (z)).
主定理的證明利用了 NODE 的萬有逼近性質(zhì)以及線性多步方法的收斂性質(zhì),核心是證明優(yōu)化算法啟發(fā)設(shè)計的網(wǎng)絡(luò)結(jié)構(gòu)恰對應(yīng)一種收斂的線性多步方法對連續(xù)的 NODE 的離散化,從而啟發(fā)的網(wǎng)絡(luò) “繼承” 了 NODE 的逼近能力。在證明中,論文還給出了 NODE 逼近 d 維空間連續(xù)函數(shù)的逼近速度,解決了此前論文 [6] 的一個遺留問題。
表 1 此前萬有逼近性質(zhì)的研究基本集中在 FCNN 和 ResNet
實驗結(jié)果
論文利用所提出的網(wǎng)絡(luò)架構(gòu)設(shè)計框架設(shè)計了 8 種顯式網(wǎng)絡(luò)和 3 種隱式網(wǎng)絡(luò)(稱為 OptDNN),網(wǎng)絡(luò)信息見表 2,并在嵌套環(huán)分離、函數(shù)逼近和圖像分類等問題上進行了實驗。論文還以 ResNet, DenseNet, ConvNext 以及 ViT 為 baseline,利用所提出的方法設(shè)計了改進的 OptDNN,并在圖像分類的問題上進行實驗,考慮準確率和 FLOPs 兩個指標。
表 2 所設(shè)計網(wǎng)絡(luò)的有關(guān)信息
首先,OptDNN 在嵌套環(huán)分離和函數(shù)逼近兩個問題上進行實驗,以驗證其萬有逼近性質(zhì)。在函數(shù)逼近問題中,分別考慮了逼近 parity function 和 Talgarsky function,前者可表示為二分類問題,后者則是回歸問題,這兩個問題都是淺層網(wǎng)絡(luò)難以逼近的問題。OptDNN 在嵌套環(huán)分離的實驗結(jié)果如圖 3 所示,在函數(shù)逼近的實驗結(jié)果如圖 3 所示,OptDNN 不僅取得了很好的分離 / 逼近結(jié)果,而且比作為 baseline 的 ResNet 取得了更大的分類間隔和更小的回歸誤差,足以驗證 OptDNN 的萬有逼近性質(zhì)。
圖 3 OptNN 逼近 parity function
圖 4 OptNN 逼近 Talgarsky function
然后,OptDNN 分別在寬 - 淺和窄 - 深兩種設(shè)定下在 CIFAR 數(shù)據(jù)集上進行了圖像分類任務(wù)的實驗,結(jié)果見表 3 與 4。實驗均在較強的數(shù)據(jù)增強設(shè)定下進行,可以看出,一些 OptDNN 在相同甚至更小的 FLOPs 開銷下取得了比 ResNet 更小的錯誤率。論文還在 ResNet 和 DenseNet 設(shè)定下進行了實驗,也取得了類似的實驗結(jié)果。
表 3 OptDNN 在寬 - 淺設(shè)定下的實驗結(jié)果
表 4 OptDNN 在窄 - 深設(shè)定下的實驗結(jié)果
論文進一步選取了此前表現(xiàn)較好的 OptDNN-APG2 網(wǎng)絡(luò),進一步在 ConvNext 和 ViT 的設(shè)定下在 ImageNet 數(shù)據(jù)集上進行了實驗,OptDNN-APG2 的網(wǎng)絡(luò)結(jié)構(gòu)見圖 5,實驗結(jié)果表 5、6。OptDNN-APG2 取得了超過等寬 ConvNext、ViT 的準確率,進一步驗證了該架構(gòu)設(shè)計方法的可靠性。
圖 5 OptDNN-APG2 的網(wǎng)絡(luò)結(jié)構(gòu)
表 5 OptDNN-APG2 在 ImageNet 上的性能比較
表 6 OptDNN-APG2 與等寬(isotropic)的 ConvNeXt 和 ViT 的性能比較
最后,論文依照 Proximal Gradient Descent 和 FISTA 等算法設(shè)計了 3 個隱式網(wǎng)絡(luò),并在 CIFAR 數(shù)據(jù)集上和顯式的 ResNet 以及一些常用的隱式網(wǎng)絡(luò)進行了比較,實驗結(jié)果見表 7。三個隱式網(wǎng)絡(luò)均取得了與先進隱式網(wǎng)絡(luò)相當?shù)膶嶒灲Y(jié)果,也說明了方法的靈活性。
表 7 隱式網(wǎng)絡(luò)的性能比較
總結(jié)
神經(jīng)網(wǎng)絡(luò)架構(gòu)設(shè)計是深度學習中的核心問題之一。論文提出了一個利用一階優(yōu)化算法設(shè)計具有萬有逼近性質(zhì)保障的神經(jīng)網(wǎng)絡(luò)架構(gòu)的統(tǒng)一框架,拓展了基于優(yōu)化設(shè)計網(wǎng)絡(luò)架構(gòu)范式的方法。該方法可以與現(xiàn)有大部分聚焦網(wǎng)絡(luò)模塊的架構(gòu)設(shè)計方法相結(jié)合,可以在幾乎不增加計算量的情況下設(shè)計出高效的模型。在理論方面,論文證明了收斂的優(yōu)化算法誘導的網(wǎng)路架構(gòu)在溫和條件下即具有萬有逼近性質(zhì),并彌合了 NODE 和具有一般跨層連接網(wǎng)絡(luò)的表示能力。該方法還有望與 NAS、 SNN 架構(gòu)設(shè)計等領(lǐng)域結(jié)合,以設(shè)計更高效的網(wǎng)絡(luò)架構(gòu)。