用AI取代SGD?無需訓(xùn)練ResNet-50,AI秒級(jí)預(yù)測(cè)全部2400萬(wàn)個(gè)參數(shù),準(zhǔn)確率60%
本文轉(zhuǎn)自雷鋒網(wǎng),如需轉(zhuǎn)載請(qǐng)至雷鋒網(wǎng)官網(wǎng)申請(qǐng)授權(quán)。
只需一次前向傳播,這個(gè)圖神經(jīng)網(wǎng)絡(luò),或者說元模型,便可預(yù)測(cè)一個(gè)圖像分類模型的所有參數(shù)。有了它,無需再苦苦等待梯度下降收斂!
來自圭爾夫大學(xué)的論文一作 Boris Knyazev 介紹道,該元模型可以預(yù)測(cè) ResNet-50 的所有2400萬(wàn)個(gè)參數(shù),并且這個(gè) ResNet-50 將在 CIFAR-10 上達(dá)到將近60%的準(zhǔn)確率,無需任何訓(xùn)練。特別是,該模型適用于幾乎任何神經(jīng)網(wǎng)絡(luò)。
基于這個(gè)結(jié)果,作者向我們發(fā)出了靈魂之問:以后還需要 SGD 或 Adam 來訓(xùn)練神經(jīng)網(wǎng)絡(luò)嗎?
“我們離用單一元模型取代手工設(shè)計(jì)的優(yōu)化器又近了一步,該元模型可以在一次前向傳播中預(yù)測(cè)幾乎任何神經(jīng)網(wǎng)絡(luò)的參數(shù)。”
令人驚訝的是,這個(gè)元模型在訓(xùn)練時(shí),沒有接收過任何類似 ResNet-50 的網(wǎng)絡(luò)(作為訓(xùn)練數(shù)據(jù))。
該元模型的適用性非常廣,不僅是ResNet-50,它還可以預(yù)測(cè) ResNet-101、ResNet-152、Wide-ResNets、Visual Transformers 的所有參數(shù),“應(yīng)有盡有”。不止是CIFAR-10,就連在ImageNet這樣更大規(guī)模的數(shù)據(jù)集上,它也能帶來不錯(cuò)的效果。
同時(shí),效率方面也很不錯(cuò)。該元模型可以在平均不到 1 秒的時(shí)間內(nèi)預(yù)測(cè)給定網(wǎng)絡(luò)的所有參數(shù),即使在 CPU 上,它的表現(xiàn)也是如此迅猛!
但天底下終究“沒有免費(fèi)的午餐”,因此當(dāng)該元模型預(yù)測(cè)其它不同類型的架構(gòu)時(shí),預(yù)測(cè)的參數(shù)不會(huì)很準(zhǔn)確(有時(shí)可能是隨機(jī)的)。一般來說,離訓(xùn)練分布越遠(yuǎn)(見圖中的綠框),預(yù)測(cè)的結(jié)果就越差。
但是,即使使用預(yù)測(cè)參數(shù)的網(wǎng)絡(luò)分類準(zhǔn)確率很差,也不要失望。
我們?nèi)匀豢梢詫⑵渥鳛榫哂辛己贸跏蓟瘏?shù)的模型,而不需要像過去那樣,使用隨機(jī)初始化,“我們可以在這種遷移學(xué)習(xí)中受益,尤其是在少樣本學(xué)習(xí)任務(wù)中。”
作者還表示,“作為圖神經(jīng)網(wǎng)絡(luò)的粉絲”,他們特地選用了GNN作為元模型。該模型是基于 Chris Zhang、Mengye Ren 和 Raquel Urtasun發(fā)表的ICLR 2019論文“Graph HyperNetworks for Neural Architecture Search”GHN提出的。
論文地址:https://arxiv.org/abs/1810.05749
在他們的基礎(chǔ)上,作者開發(fā)并訓(xùn)練了一個(gè)新的模型 GHN-2,它具有更好的泛化能力。
簡(jiǎn)而言之,在多個(gè)架構(gòu)上更新 GHN 參數(shù),并正確歸一化預(yù)測(cè)參數(shù)、改善圖中的遠(yuǎn)程交互以及改善收斂性至關(guān)重要。
為了訓(xùn)練 GHN-2,作者引入了一個(gè)神經(jīng)架構(gòu)數(shù)據(jù)集——DeepNets-1M。
這個(gè)數(shù)據(jù)集分為訓(xùn)練集、驗(yàn)證集和測(cè)試集三個(gè)部分。此外,他們還使用更廣、更深、更密集和無歸一化網(wǎng)絡(luò)來進(jìn)行分布外測(cè)試。
作者補(bǔ)充道,DeepNets-1M 可以作為一個(gè)很好的測(cè)試平臺(tái),用于對(duì)不同的圖神經(jīng)網(wǎng)絡(luò) (GNN) 進(jìn)行基準(zhǔn)測(cè)試。“使用我們的 PyTorch 代碼,插入任何 GNN(而不是我們的 Gated GNN )應(yīng)該都很簡(jiǎn)單。”
除了解決參數(shù)預(yù)測(cè)任務(wù)和用于網(wǎng)絡(luò)初始化之外, GHN-2 還可用于神經(jīng)架構(gòu)搜索,“GHN-2可以搜索最準(zhǔn)確、最魯棒(就高斯噪聲而言)、最有效和最容易訓(xùn)練的網(wǎng)絡(luò)。”
這篇論文已經(jīng)發(fā)表在了NeurIPS 2021上,研究人員分別來自圭爾夫大學(xué)、多倫多大學(xué)向量人工智能研究所、CIFAR、FAIR和麥吉爾大學(xué)。
論文地址:https://arxiv.org/pdf/2110.13100.pdf
項(xiàng)目也已經(jīng)開源,趕緊去膜拜這個(gè)神經(jīng)網(wǎng)絡(luò)優(yōu)化器吧!
項(xiàng)目地址:https://github.com/facebookresearch/ppuda
1、模型詳解
考慮在大型標(biāo)注數(shù)據(jù)集(如ImageNet)上訓(xùn)練深度神經(jīng)網(wǎng)絡(luò)的問題, 這個(gè)問題可以形式化為對(duì)給定的神經(jīng)網(wǎng)絡(luò) a 尋找最優(yōu)參數(shù)w。
損失函數(shù)通常通過迭代優(yōu)化算法(如SGD和Adam)來最小化,這些算法收斂于架構(gòu) a 的性能參數(shù)w_p。
盡管在提高訓(xùn)練速度和收斂性方面取得了進(jìn)展,但w_p的獲取仍然是大規(guī)模機(jī)器學(xué)習(xí)管道中的一個(gè)瓶頸。
例如,在 ImageNet 上訓(xùn)練 ResNet-50 可能需要花費(fèi)相當(dāng)多的 GPU 時(shí)間。
隨著網(wǎng)絡(luò)規(guī)模的不斷增長(zhǎng),以及重復(fù)訓(xùn)練網(wǎng)絡(luò)的必要性(如超參數(shù)或架構(gòu)搜索)的存在,獲得 w_p 的過程在計(jì)算上變得不可持續(xù)。
而對(duì)于一個(gè)新的參數(shù)預(yù)測(cè)任務(wù),在優(yōu)化新架構(gòu) a 的參數(shù)時(shí),典型的優(yōu)化器會(huì)忽略過去通過優(yōu)化其他網(wǎng)絡(luò)獲得的經(jīng)驗(yàn)。
然而,利用過去的經(jīng)驗(yàn)可能是減少對(duì)迭代優(yōu)化依賴的關(guān)鍵,從而減少高計(jì)算需求。
為了朝著這個(gè)方向前進(jìn),研究人員提出了一項(xiàng)新任務(wù),即使用超網(wǎng)絡(luò) HD 的單次前向傳播迭代優(yōu)化。
為了解決這一任務(wù),HD 會(huì)利用過去優(yōu)化其他網(wǎng)絡(luò)的知識(shí)。
例如,我們考慮 CIFAR-10 和 ImageNet 圖像分類數(shù)據(jù)集 D,其中測(cè)試集性能是測(cè)試圖像的分類準(zhǔn)確率。
讓 HD 知道如何優(yōu)化其他網(wǎng)絡(luò)的一個(gè)簡(jiǎn)單方法是,在[架構(gòu),參數(shù)]對(duì)的大型訓(xùn)練集上對(duì)其進(jìn)行訓(xùn)練,然而,這個(gè)過程的難度令人望而卻步。
因此,研究人員遵循元學(xué)習(xí)中常見的雙層優(yōu)化范式,即不需要迭代 M 個(gè)任務(wù),而是在單個(gè)任務(wù)(比如圖像分類)上迭代 M 個(gè)訓(xùn)練架構(gòu)。
圖 0:GHN原始架構(gòu)概覽。A:隨機(jī)采樣一個(gè)神經(jīng)網(wǎng)絡(luò)架構(gòu),生成一個(gè)GHN。B:經(jīng)過圖傳播后,GHN 中的每個(gè)節(jié)點(diǎn)都會(huì)生成自己的權(quán)重參數(shù)。C:通過訓(xùn)練GHN,最小化帶有生成權(quán)重的采樣網(wǎng)絡(luò)的訓(xùn)練損失。根據(jù)生成網(wǎng)絡(luò)的性能進(jìn)行排序。來源:https://arxiv.org/abs/1810.05749
通過優(yōu)化,超網(wǎng)絡(luò) HD 逐漸獲得了如何預(yù)測(cè)訓(xùn)練架構(gòu)的性能參數(shù)的知識(shí),然后它可以在測(cè)試時(shí)利用這些知識(shí)。
為此,需要設(shè)計(jì)架構(gòu)空間 F 和 HD。
對(duì)于 F,研究人員基于已有的神經(jīng)架構(gòu)設(shè)計(jì)空間,我們以兩種方式對(duì)其進(jìn)行了擴(kuò)展:對(duì)不同架構(gòu)進(jìn)行采樣的能力和包括多種架構(gòu)的擴(kuò)展設(shè)計(jì)空間,例如 ResNets 和 Visual Transformers。
這樣的架構(gòu)可以以計(jì)算圖的形式完整描述(圖 1)。
因此,為了設(shè)計(jì)超網(wǎng)絡(luò) HD,將依賴于圖結(jié)構(gòu)數(shù)據(jù)機(jī)器學(xué)習(xí)的最新進(jìn)展。
特別是,研究人員的方案建立在 Graph HyperNetworks (GHNs) 方法的基礎(chǔ)上。
通過設(shè)計(jì)多樣化的架構(gòu)空間 F 和改進(jìn) GHN,GHN-2在 CIFAR-10和 ImageNet上預(yù)測(cè)未見過架構(gòu)時(shí),圖像識(shí)別準(zhǔn)確率分別提高到77% (top-1)和48% (top-5)。
令人驚訝的是,GHN-2 顯示出良好的分布外泛化,比如對(duì)于相比訓(xùn)練集中更大和更深的架構(gòu),它也能預(yù)測(cè)出良好的參數(shù)。
例如,GHN-2可以在不到1秒的時(shí)間內(nèi)在 GPU 或 CPU 上預(yù)測(cè) ResNet-50 的所有 2400 萬(wàn)個(gè)參數(shù),在 CIFAR-10 上達(dá)到約 60%的準(zhǔn)確率,無需任何梯度更新(圖 1,(b))。
總的來說,該框架和結(jié)果為訓(xùn)練網(wǎng)絡(luò)開辟了一條新的、更有效的范式。
本論文的貢獻(xiàn)如下:
(a)引入了使用單個(gè)超網(wǎng)絡(luò)前向傳播預(yù)測(cè)不同前饋神經(jīng)網(wǎng)絡(luò)的性能參數(shù)的新任務(wù);
(b)引入了 DEEPNETS-1M數(shù)據(jù)集,這是一個(gè)標(biāo)準(zhǔn)化的基準(zhǔn)測(cè)試,具有分布內(nèi)和分布外數(shù)據(jù),用于跟蹤任務(wù)的進(jìn)展;
(c)定義了幾個(gè)基線,并提出了 GHN-2 模型,該模型在 CIFAR-10 和 ImageNet( 5.1 節(jié))上表現(xiàn)出奇的好;
(d)該元模型學(xué)習(xí)了神經(jīng)網(wǎng)絡(luò)架構(gòu)的良好表示,并且對(duì)于初始化神經(jīng)網(wǎng)絡(luò)是有用的。
上圖圖1(a)展示了GHN 模型概述(詳見第 4 節(jié)),基于給定圖像數(shù)據(jù)集和DEEPNETS-1M架構(gòu)數(shù)據(jù)集,通過反向傳播來訓(xùn)練GHN模型,以預(yù)測(cè)圖像分類模型的參數(shù)。
研究人員對(duì) vanilla GHN 的主要改進(jìn)包括Meta-batching、Virtual edges、Parameter normalization等。
其中,Meta-batching僅在訓(xùn)練 GHN 時(shí)使用,而Virtual edges、Parameter normalization用于訓(xùn)練和測(cè)試時(shí)。a1 的可視化計(jì)算圖如表 1 所示。
圖1(b)比較了由 GHN 預(yù)測(cè)ResNet-50 的所有參數(shù)的分類準(zhǔn)確率與使用 SGD 訓(xùn)練其參數(shù)時(shí)的分類準(zhǔn)確率。盡管自動(dòng)化預(yù)測(cè)參數(shù)得到的網(wǎng)絡(luò)準(zhǔn)確率仍遠(yuǎn)遠(yuǎn)低于人工訓(xùn)練的網(wǎng)絡(luò),但可以作為不錯(cuò)的初始化手段。
2、實(shí)驗(yàn):參數(shù)預(yù)測(cè)
盡管 GHN-2 從未觀察過測(cè)試架構(gòu),但 GHN-2 為它們預(yù)測(cè)了良好的參數(shù),使測(cè)試網(wǎng)絡(luò)在兩個(gè)圖像數(shù)據(jù)集上的表現(xiàn)都出奇的好(表 3 和表 4)。
表 3:GHN-2在DEEPNETS-1M 的未見過 ID 和 OOD 架構(gòu)的預(yù)測(cè)參數(shù)結(jié)果(CIFAR-10 )
GHN-2甚至在 ImageNet 上展示了良好的結(jié)果,其中對(duì)于某些架構(gòu),實(shí)現(xiàn)了高達(dá) 48.3% 的top-5準(zhǔn)確率。
雖然這些結(jié)果對(duì)于直接下游應(yīng)用來說很不夠,但由于三個(gè)主要原因,它們非常有意義。
首先,不依賴于通過 SGD 訓(xùn)練架構(gòu) F 的昂貴得令人望而卻步的過程。
其次,GHN 依靠單次前向傳播來預(yù)測(cè)所有參數(shù)。
第三,這些結(jié)果是針對(duì)未見過的架構(gòu)獲得的,包括 OOD 架構(gòu)。即使在嚴(yán)重的分布變化(例如 ResNet-506 )和代表性不足的網(wǎng)絡(luò)(例如 ViT7 )的情況下,GHN-2仍然可以預(yù)測(cè)比隨機(jī)參數(shù)表現(xiàn)更好的參數(shù)。
在 CIFAR-10 上,GHN-2 的泛化能力特別強(qiáng),在 ResNet-50 上的準(zhǔn)確率為 58.6%。
在這兩個(gè)圖像數(shù)據(jù)集上,GHN-2 在 DEEPNETS-1M 的所有測(cè)試子集上都顯著優(yōu)于 GHN-1,在某些情況下絕對(duì)增益超過 20%,例如BN-FREE 網(wǎng)絡(luò)上的 36.8% 與 13.7%(表 3)。
利用計(jì)算圖的結(jié)構(gòu)是 GHN 的一個(gè)關(guān)鍵特性,當(dāng)用 MLP 替換 GHN-2 的 GatedGNN 時(shí),在 ID(甚至在 OOD)架構(gòu)上的準(zhǔn)確率從 66.9% 下降到 42.2%。
與迭代優(yōu)化方法相比,GHN-2 預(yù)測(cè)參數(shù)的準(zhǔn)確率分別與 CIFAR-10 和 ImageNet 上 SGD 的 ∼2500 次和 ∼5000 次迭代相近。
相比之下,GHN-1 的性能分別與僅 ~500 次和 ~2000次(未在表 4 中展示)迭代相似。
消融實(shí)驗(yàn)(表 5)表明第 4 節(jié)中提出的所有三個(gè)組件都很重要。
表 5:在 CIFAR-10 上消融 GHN-2,在所有 ID 和 OOD 測(cè)試架構(gòu)中計(jì)算模型的平均排名