破解ChatGPT驚人耗電!DeepMind新算法訓(xùn)練提效13倍,能耗暴降10倍
ChatGPT早已成為世界耗能大戶:一天用掉超50萬(wàn)度電,相當(dāng)于1.7萬(wàn)個(gè)美國(guó)家庭的用電量!
然而,大模型對(duì)能源的吞噬,遠(yuǎn)不僅如此。
國(guó)際能源署(IEA)預(yù)測(cè),從2022年到2026年,數(shù)據(jù)中心的用電量將翻一番。
隨著AI計(jì)算需求的膨脹,還需要用水來(lái)冷卻計(jì)算系統(tǒng)。研究稱,微軟用水量從2021年到22年飆升了34%,ChatGPT每處理5-50個(gè)提示就會(huì)消耗接近半升水。
針對(duì)這種現(xiàn)狀,我們有更好的解決策略嗎?
最近,谷歌DeepMind研究團(tuán)隊(duì)提出了一種加快AI訓(xùn)練的新方法——多模態(tài)對(duì)比學(xué)習(xí)與聯(lián)合示例選擇(JEST),大大減少了所需的計(jì)算資源和時(shí)間。
JEST以13倍更少的迭代次數(shù),以及10倍更少的計(jì)算量,超越了最先進(jìn)的模型!
論文地址:https://arxiv.org/pdf/2406.17711
預(yù)訓(xùn)練的參考模型,已經(jīng)學(xué)習(xí)了什么樣的數(shù)據(jù)是有「優(yōu)質(zhì)的」或「有用的」。然后通過(guò)模型,來(lái)引導(dǎo)數(shù)據(jù)選擇那些精心篩選過(guò)的小型數(shù)據(jù)集。
這一發(fā)現(xiàn)揭示了,數(shù)據(jù)篩選水平可以作為評(píng)判Scaling Law的一個(gè)新維度。
網(wǎng)友激動(dòng)表示,「我沒(méi)想到這么快就會(huì)發(fā)生。模型能夠自主選擇訓(xùn)練數(shù)據(jù)的能力是巨大的,因?yàn)樗褂?xùn)練變得顯著更容易,你不再需要猜測(cè)什么是高質(zhì)量的訓(xùn)練數(shù)據(jù),你有一個(gè)能夠『理解』什么樣的數(shù)據(jù)對(duì)自身學(xué)習(xí)最有價(jià)值的模型」。
前谷歌、蘋(píng)果軟件工程師稱贊道,這項(xiàng)研究非常令人印象深刻。
從「超級(jí)batch」中篩選數(shù)據(jù)
無(wú)論是語(yǔ)言、視覺(jué)還是多模態(tài)模型,數(shù)據(jù)質(zhì)量是預(yù)訓(xùn)練性能的重要驅(qū)動(dòng)因素。比如Phi-3、Gemma 2等模型的成功讓我們看到了,更少、更高質(zhì)量的數(shù)據(jù)有可能實(shí)現(xiàn)更強(qiáng)大的性能。
要篩選出高質(zhì)量的數(shù)據(jù),數(shù)據(jù)管道的建立就成為重要的工作。現(xiàn)有的方法大體可以分為兩種:1)手動(dòng)管理 2)基于模型的數(shù)據(jù)管理,用正在訓(xùn)練模型的特征選擇高質(zhì)量數(shù)據(jù)。
前者成本高昂且難以擴(kuò)展,后者則有望為多模態(tài)LLM實(shí)現(xiàn)Scaling Law。
然而,現(xiàn)有方法忽略了一個(gè)事實(shí)。
如果僅在單個(gè)數(shù)據(jù)點(diǎn)的層面進(jìn)行篩選,就沒(méi)有考慮到數(shù)據(jù)集以及batch的總體組成。畢竟,訓(xùn)練數(shù)據(jù)是以batch為單位,數(shù)據(jù)點(diǎn)之間的依賴性不可忽視。
許多計(jì)算機(jī)視覺(jué)的研究都曾表明,hard negatives(表達(dá)空間中相近但標(biāo)簽不同的樣本)相比可被平凡解的數(shù)據(jù)簇,能提供更有效的學(xué)習(xí)信號(hào)。
那么如何讓模型以batch為單位篩選數(shù)據(jù)呢?
論文提出的JEST算法正是要解決這個(gè)問(wèn)題,原理很好理解:就是直接從「超級(jí)batch」中篩選出「子batch」。
技術(shù)介紹
用數(shù)學(xué)語(yǔ)言來(lái)描述這個(gè)問(wèn)題,就是從大小為B的「超級(jí)batch」??中提取出與學(xué)習(xí)最相關(guān)的子batch ?={????,??∈[1,…,??]}???,過(guò)濾比率可以寫(xiě)作??=1???/??。
之前的優(yōu)先采樣(prioritized sampling)會(huì)使用基于模型的評(píng)分函數(shù)對(duì)每個(gè)數(shù)據(jù)點(diǎn)打分,再按比例采樣。JEST則直接對(duì)整個(gè)子batch評(píng)分,再按照batch級(jí)別的分?jǐn)?shù)采樣。
一種最直觀的啟發(fā)式方法就是在現(xiàn)有模型參數(shù) ?? : ??hard?(?|??)=??(?|??) 中,直接選擇損失值最高的batch,這種方法可被稱之為「硬學(xué)習(xí)」(hard learner)。
這種方法具有丟棄瑣碎數(shù)據(jù)的理想屬性,已被證明適用于小型、干凈的數(shù)據(jù)集;然而對(duì)于較大、較少管理的數(shù)據(jù)集往往弊大于利,因?yàn)樗琅f會(huì)采樣到噪聲數(shù)據(jù)。
另一種方法常用于多模態(tài),使用具有參數(shù) ???:??^easy?(?|???)=???(?|???) 的參考模型為預(yù)訓(xùn)練模型采樣數(shù)據(jù)。但作者依舊否定了這個(gè)方案,因?yàn)樗鼰o(wú)法直接反映模型當(dāng)前的狀態(tài),可能過(guò)度依賴參考模型的選擇,而且不易于擴(kuò)展。
最后,論文選擇借鑒ICML 2022年的一篇論文中提到的方法,將上述兩方面的評(píng)分結(jié)合起來(lái):??^learn?(?|??,???)=??hard?(?|??)+??^easy?(?|???)=??(?|??)???(?|???),并將這種啟發(fā)式方法稱為「可學(xué)習(xí)性評(píng)分」(learnability score)。
其中,batch上的損失值??(?|??)是各數(shù)據(jù)點(diǎn)之和,使用sigmoid對(duì)比損失函數(shù)計(jì)算(sigmoid-contrastive loss),因?yàn)橄啾萻oftmax對(duì)比損失而言,它的擴(kuò)展性更強(qiáng)。
由于batch上的對(duì)比損失可以分解為每個(gè)樣本的條件損失之和,因此可學(xué)習(xí)性評(píng)分可被分解為單個(gè)樣本可學(xué)習(xí)性評(píng)分???(??|??,???,?)之和,寫(xiě)作:
使用的順序采樣方法則受到了block Gibbs采樣的啟發(fā)。在第n次迭代、對(duì)第B_n個(gè)batch進(jìn)行采樣時(shí),依據(jù)如下概率公式對(duì)塊{X_k}進(jìn)行無(wú)替換采樣:
將X_k塊添加到B_n中來(lái)更新當(dāng)前采樣的batch,直至迭代數(shù)n=N時(shí)終止。算法的總體流程如下圖所示:
實(shí)驗(yàn)中發(fā)現(xiàn),使用迭代數(shù)N=16且每次迭代時(shí)獨(dú)立采樣b/N=2048個(gè)樣本時(shí),就足以恢復(fù)出學(xué)習(xí)性非常高的batch。
可學(xué)習(xí)性評(píng)分中涉及到使用參考模型為數(shù)據(jù)點(diǎn)打分,之前的方法慣常使用額外的小型模型,但這會(huì)增加每次迭代的計(jì)算成本,降低總體FLOP效率增益。
因此論文使用了在線模型近似的方法以及效率較高的FlexiViT架構(gòu),只使用降低分辨率的32×32的patch來(lái)評(píng)估「超級(jí)batch」,與全分辨率、patch大小為16×16的方法相比減少了72%的FLOP,以及67%的掛鐘時(shí)間(wall-clock time)。
此外,論文還提出了進(jìn)行多分辨率訓(xùn)練的技巧。將每個(gè)batch隨機(jī)分成兩半,使用不同分辨率編碼后再拼接起來(lái),提升了評(píng)分過(guò)程和訓(xùn)練的效率。
下圖詳細(xì)描述了全分辨率JEST和多分辨率Flexi-JEST方法的偽代碼實(shí)現(xiàn)。
所有JEST實(shí)驗(yàn)都在WebLI數(shù)據(jù)集上運(yùn)行,包含經(jīng)過(guò)寬松過(guò)濾的十億規(guī)模的英語(yǔ)圖像-文本對(duì),參考模型的訓(xùn)練則使用其中經(jīng)過(guò)高質(zhì)量過(guò)濾100M大小的子集(被稱為WebLI-curated)。
在WebLI的基礎(chǔ)上,作者還額外從網(wǎng)絡(luò)上抓取了6億個(gè)文本-圖像對(duì)并經(jīng)過(guò)同樣強(qiáng)度的過(guò)濾,組成WebLI-curated++數(shù)據(jù)集訓(xùn)練參考模型,拓展出JEST++/FlexiJEST++方法,來(lái)探索對(duì)數(shù)據(jù)管理的擴(kuò)展。
論文所報(bào)告的平均性能包括4個(gè)多模態(tài)規(guī)范基準(zhǔn):ImageNet 0-Shot和10-Shot 分類以及COCO圖像到文本和文本到圖像的top-1檢索。
實(shí)驗(yàn)結(jié)果
圖1中可以看到,使用JEST或FlexiJEST方法的最明顯優(yōu)勢(shì)就是效率提升。
左圖中,相比原有的SigLIP基線模型,JEST++可以在訓(xùn)練數(shù)據(jù)量減少13.1×的情況下達(dá)到相同準(zhǔn)確率。即使考慮到額外引入的打分成本,也有近10×的FLOP效率提升(中圖)。
右圖展現(xiàn)了JEST++/FlexiJEST++(綠色)與先前方法(灰色)的比較,相比CLIP、EVA-CLIP經(jīng)典模型實(shí)現(xiàn)了計(jì)算成本和性能的雙重提升。
左圖和中圖的平均準(zhǔn)確率由8個(gè)下游任務(wù)得出,右圖性能由ImageNet和COCO基準(zhǔn)測(cè)試得出
產(chǎn)生可學(xué)習(xí)batch
研究人員首先評(píng)估了JEST在選擇可學(xué)習(xí)batch方面的效果。
為了直觀地理解這一方法,作者們先將可學(xué)習(xí)性矩陣進(jìn)行可視化,即學(xué)習(xí)模型和參考模型之間,對(duì)batch中所有示例對(duì)的損失差異。
JEST就是按照示例子矩陣的可學(xué)習(xí)性總和比例進(jìn)行采樣。
由于矩陣明顯非對(duì)角關(guān)系(圖2,左),獨(dú)立選擇顯然是次優(yōu)的。
經(jīng)過(guò)少量迭代(對(duì)應(yīng)于用N=16個(gè)塊填充batch),作者發(fā)現(xiàn)子batch的可學(xué)習(xí)性快速增加,達(dá)到了需要數(shù)千次迭代的暴力吉布斯采樣(Gibbs sampling )所提取batch的可學(xué)習(xí)性(圖2,中)。
對(duì)于0.5、0.8和0.9的過(guò)濾比例,他們從大小分別為65,536、163,840和327,680的超級(jí)batch中選擇32,768個(gè)示例的子batch。
在圖2右側(cè),研究者還發(fā)現(xiàn)子batch的可學(xué)習(xí)性隨著更大的過(guò)濾比例而增加。
總之,JEST算法是在訓(xùn)練過(guò)程中選擇高度可學(xué)習(xí)batch的有效,且高效的方法。
加速多模態(tài)學(xué)習(xí)
接下來(lái),研究人員使用JEST算法選擇的可學(xué)習(xí)batch,檢驗(yàn)訓(xùn)練模型的效果。
所有實(shí)驗(yàn)都使用在WebLI-curated上訓(xùn)練的參考模型,這是一個(gè)ViT-B/16和Bert-B圖像-文本雙編碼器,30億訓(xùn)練樣本,采用sigmoid對(duì)比損失函數(shù)。
圖3(左)顯示了在訓(xùn)練過(guò)程中多個(gè)下游任務(wù)(ImageNet 0-Shot/10-Shot準(zhǔn)確率和COCO圖像到文本/文本到圖像檢索)的平均性能。
結(jié)果還發(fā)現(xiàn),JEST顯著加速了學(xué)習(xí)過(guò)程。
在使用50%、80%和90%的過(guò)濾比例時(shí),分別只需20億、10億和6.7億訓(xùn)練樣本就達(dá)到了30億均勻基準(zhǔn)的最終性能。
在更大的過(guò)濾比例下,坐著觀察到類似于更大batch size時(shí)的訓(xùn)練不穩(wěn)定性,需要修改Adam優(yōu)化器(β2 = 0.95)以穩(wěn)定訓(xùn)練,這表明JEST的數(shù)據(jù)篩選可以被視為增加了有效batch size。
在最終性能方面,當(dāng)過(guò)濾90%的數(shù)據(jù)時(shí),JEST也帶來(lái)了高達(dá)6%的顯著提升(圖3,中間,藍(lán)色曲線)。
值得注意的是,這種scaling行為這種性能提升在獨(dú)立樣本選擇方法中,并沒(méi)有觀察到。(圖3,中間,橙色曲線)。
最后,研究者還評(píng)估JEST是否也改善了,除可學(xué)習(xí)性之外的其他優(yōu)先標(biāo)準(zhǔn)。
圖3右側(cè)顯示了使用easy-reference優(yōu)先選擇的模型在不同過(guò)濾比例下的性能。
與基于可學(xué)習(xí)性的優(yōu)先選擇一致,JEST仍優(yōu)于獨(dú)立樣本選擇,特別是在高過(guò)濾比例下(在這種情況下,獨(dú)立樣本選擇導(dǎo)致性能下降)。
優(yōu)先選擇具有最高損失的數(shù)據(jù)產(chǎn)生了較小的收益,并且隨著過(guò)濾更多數(shù)據(jù)而更快地退化(圖10)。
由于基于可學(xué)習(xí)性的JEST產(chǎn)生了最佳的scaling行為,研究人員在后續(xù)實(shí)驗(yàn)中保留了這一標(biāo)準(zhǔn)。
多分辨率訓(xùn)練和在線batch選擇之間的協(xié)同效應(yīng)
隨著數(shù)據(jù)batch中被過(guò)濾的比例增加,基于可學(xué)習(xí)性評(píng)分的JEST變得更加高效。
然而,評(píng)分的成本會(huì)帶來(lái)顯著的提升:過(guò)濾超級(jí)batch 80%的數(shù)據(jù)會(huì)導(dǎo)致每次迭代的浮點(diǎn)運(yùn)算量是IID訓(xùn)練的4倍,或者在緩存參考模型得分時(shí)是2.3倍。
盡管JEST在訓(xùn)練迭代次數(shù)方面(以下簡(jiǎn)稱「訓(xùn)練效率」)顯著提高了效率,但額外的評(píng)分浮點(diǎn)運(yùn)算降低了其相對(duì)于IID基準(zhǔn)的計(jì)算效率(圖1,左vs右)。
因此,作者還研究了一種計(jì)算效率更高的變體,稱為Flexi-JEST,它使用多分辨率訓(xùn)練和低分辨率評(píng)分,將總開(kāi)銷降低到僅比基準(zhǔn)高10%(圖4,左)。
這些近似方法對(duì)性能有什么影響?
正如預(yù)期的那樣,F(xiàn)lexi-JEST的每次迭代性能相對(duì)于JEST有所下降,但仍然比IID有顯著的加速(圖1,左;圖4,中)。
然而,考慮到總浮點(diǎn)運(yùn)算量的減少,每次迭代性能的下降是非常有利的:最好的Flexi-JEST模型與40B Siglip運(yùn)行產(chǎn)生相同的平均性能,但浮點(diǎn)運(yùn)算量減少了9.9倍,比全分辨率JEST少2倍(圖1,右;圖4,中)。
這些實(shí)驗(yàn)表明了多分辨率訓(xùn)練和聯(lián)合示例選擇之間的協(xié)同效應(yīng),前者為加速后者提供了高效和準(zhǔn)確的評(píng)分能力。
實(shí)驗(yàn)結(jié)果,還指出了數(shù)據(jù)策劃策略的帕累托前沿(pareto front)。
如果以計(jì)算為代價(jià)來(lái)最大化訓(xùn)練速度或訓(xùn)練效率,全分辨率JEST方法相對(duì)于可比的IID訓(xùn)練運(yùn)行,可以產(chǎn)生高達(dá)13倍的加速。
實(shí)現(xiàn)強(qiáng)大數(shù)據(jù)質(zhì)量引導(dǎo)
可學(xué)習(xí)性評(píng)分的核心是,一個(gè)在人類選擇的小型、精心篩選的數(shù)據(jù)集上,訓(xùn)練的參考模型。
JEST的性能如何隨不同的篩選策略(在質(zhì)量和數(shù)量之間權(quán)衡)而變化?
此外,JEST訓(xùn)練的改進(jìn)是否與參考模型的性能相關(guān),還是這些指標(biāo)是分離的?
理解質(zhì)量與數(shù)量的權(quán)衡
研究人員探索了三種規(guī)模的數(shù)據(jù)篩選,每種都是原始WebLI數(shù)據(jù)集的一個(gè)子集:
- 弱篩選(十億級(jí)規(guī)模):使用圖像-文本對(duì)齊(ITA)過(guò)濾器。
- 中度篩選(3億級(jí)規(guī)模):使用ITA過(guò)濾器或文本質(zhì)量(TQ)過(guò)濾器。
- 強(qiáng)篩選(1億級(jí)規(guī)模):結(jié)合使用TQ、ITA和額外的圖像質(zhì)量(aesthetic)過(guò)濾器。
在整個(gè)過(guò)程中,作者將這個(gè)強(qiáng)篩選子集稱為「WebLI-curated」。
然后,他們?cè)谶@四個(gè)WebLI子集上,各訓(xùn)練10個(gè)epoch的標(biāo)準(zhǔn)SigLIP編碼器,并將它們用作在全WebLI數(shù)據(jù)集上進(jìn)行JEST訓(xùn)練的參考模型。
在不同的數(shù)據(jù)篩選方法中,參考模型的性能和JEST的性能似乎是解耦的(甚至可能是反相關(guān)的;圖5,左)。
雖然增加篩選(和減少數(shù)據(jù)集大?。?huì)產(chǎn)生較弱的模型,但當(dāng)它們被用作JEST預(yù)訓(xùn)練的參考模型時(shí),卻產(chǎn)生了相反的效果:
使用強(qiáng)篩選參考模型的JEST獲得了2.7%的改進(jìn),中度篩選獲得了1.5%的改進(jìn),弱篩選獲得了0.3%的改進(jìn)。
擴(kuò)展數(shù)據(jù)篩選
假設(shè)參考模型性能與JEST性能之間的普遍解耦,可能僅僅是由數(shù)據(jù)篩選所施加的數(shù)據(jù)集大小限制造成的。
為了理解這種效果,研究人員在WebLI-curated上訓(xùn)練了5個(gè)參考模型,同時(shí)改變所見(jiàn)的總樣本數(shù)(從2.5億到30億)。
在這種情況下,圖5(右)顯示了改進(jìn)的參考模型與更好的JEST預(yù)訓(xùn)練之間存在著顯著的相關(guān)性。
這表明「解耦」現(xiàn)象主要可以歸因于參考模型因篩選后數(shù)據(jù)集大小減少而導(dǎo)致的飽和。
此外,研究人員還注意到,當(dāng)數(shù)據(jù)集達(dá)到飽和時(shí),圖5(右)中的相關(guān)性開(kāi)始崩解,即在10個(gè)epoch或者看到10億個(gè)樣本之后。
這些結(jié)果表明,JEST可能會(huì)從進(jìn)一步擴(kuò)大參考數(shù)據(jù)集的數(shù)據(jù)篩選中獲益。
鑒于使用WebLI-curated++對(duì)數(shù)據(jù)進(jìn)行擴(kuò)展整理能顯著提高參考模型的性能,作者提出了是否有必要在原始WebLI數(shù)據(jù)集上進(jìn)行預(yù)訓(xùn)練的問(wèn)題。
然而,在評(píng)估參考模型在不同數(shù)據(jù)集上的性能時(shí),卻發(fā)現(xiàn):雖然它在2個(gè)下游任務(wù)上的性能優(yōu)于WebLI預(yù)訓(xùn)練,但在其他6個(gè)任務(wù)上的性能,以及平均性能都明顯低于WebLI預(yù)訓(xùn)練(表 5)。
與現(xiàn)有數(shù)據(jù)比較
最后,論文應(yīng)用JEST++在公開(kāi)的LAION-2B數(shù)據(jù)集上進(jìn)行預(yù)訓(xùn)練,刪除了其中不安全的圖像-文本對(duì),但沒(méi)有進(jìn)行其他的預(yù)先過(guò)濾。
這個(gè)數(shù)據(jù)規(guī)模相比的SOTA方法DBP減少了4×,但JEST++依舊遠(yuǎn)遠(yuǎn)超過(guò)了所有之前的離線數(shù)據(jù)管理方法。
簡(jiǎn)化數(shù)據(jù)管理
之前提到過(guò),用于預(yù)訓(xùn)練的WebLI-curated是原始數(shù)據(jù)集WebLI過(guò)濾后得到的,以求篩選出高質(zhì)量的圖像-文本對(duì)齊的數(shù)據(jù)。
如表3所示,這種離線數(shù)據(jù)管理流程對(duì)IID(獨(dú)立同分布)訓(xùn)練方法的性能至關(guān)重要,但JEST++則表現(xiàn)出了對(duì)預(yù)過(guò)濾流程的魯棒性。即使沒(méi)有過(guò)濾,JEST++的性能也沒(méi)有出現(xiàn)明顯下滑,降低了模型對(duì)基礎(chǔ)數(shù)據(jù)集的要求。
結(jié)論和局限性
總體來(lái)說(shuō),JEST方法展現(xiàn)出了「數(shù)據(jù)質(zhì)量引導(dǎo)」(data quality bootstrapping)方法的巨大潛力,即使用小規(guī)模精選數(shù)據(jù)集來(lái)指導(dǎo)對(duì)更大的、未經(jīng)管理的數(shù)據(jù)集的學(xué)習(xí)。
最近的研究表明,在下游任務(wù)未知時(shí),靜態(tài)數(shù)據(jù)集的過(guò)濾會(huì)限制模型性能。這篇論文的結(jié)果則表明,相比單獨(dú)選擇樣本的方法,在線構(gòu)建batch能提高預(yù)訓(xùn)練的效率。
無(wú)論是使用JEST參考模型對(duì)數(shù)據(jù)集進(jìn)行預(yù)評(píng)分,還是通過(guò)可學(xué)習(xí)性評(píng)分來(lái)根據(jù)模型需求進(jìn)行動(dòng)態(tài)調(diào)整,都可以成為通用基礎(chǔ)數(shù)據(jù)集的更有效率的替代方案。
論文的最后,作者也提出了該方法的局限性。雖然JEST同時(shí)實(shí)現(xiàn)了性能增益和訓(xùn)練成本降低,但依舊依賴于小型、精心管理的參考數(shù)據(jù)集,它指定了未經(jīng)管理的更大數(shù)據(jù)集中優(yōu)先考慮的分布。
因此,未來(lái)的工作可以探索一種方法,從指定的下游任務(wù)中如何推斷出參考數(shù)據(jù)集的組成和分布。