谷歌發(fā)現(xiàn)大模型「領(lǐng)悟」現(xiàn)象!訓(xùn)練久了突然不再死記硬背,多么痛的領(lǐng)悟
本文經(jīng)AI新媒體量子位(公眾號ID:QbitAI)授權(quán)轉(zhuǎn)載,轉(zhuǎn)載請聯(lián)系出處。
哪怕只有幾十個神經(jīng)元,AI也能出現(xiàn)泛化能力!
這是幾個谷歌科學(xué)家在搞正經(jīng)研究時,“不經(jīng)意間”發(fā)現(xiàn)的新成果。
他們給一些很簡單的AI模型“照了個X光”——將它們的訓(xùn)練過程可視化后,發(fā)現(xiàn)了有意思的現(xiàn)象:
隨著訓(xùn)練時間增加,一些AI會從“死記硬背”的狀態(tài)中脫離出來,進(jìn)化出“領(lǐng)悟力”(grokking),對沒見過的數(shù)據(jù)表現(xiàn)出概括能力。
這正是AI掌握泛化能力的關(guān)鍵。
基于此,幾位科學(xué)家專門寫了個博客,探討了其中的原理,并表示他們會繼續(xù)研究,試圖弄清楚大模型突然出現(xiàn)強(qiáng)理解力的真正原因。
一起來看看。
并非所有AI都能學(xué)會“領(lǐng)悟”
科學(xué)家們先探討了AI出現(xiàn)“領(lǐng)悟力”(grokking)的過程和契機(jī),發(fā)現(xiàn)了兩個現(xiàn)象:
- 一、雖然訓(xùn)練時loss會突然下降,但“領(lǐng)悟”并不是突然發(fā)生的,它是一個平滑的變化過程。
- 二、并非所有AI都能學(xué)會“領(lǐng)悟”。
先來看第一個結(jié)論。他們設(shè)計了一個單層MLP,訓(xùn)練它完成“數(shù)奇數(shù)”任務(wù)。
“數(shù)奇數(shù)”任務(wù),指識別一串長達(dá)30位“0”“1”序列中的前3位是否有奇數(shù)個“1”。例如,在000110010110001010111001001011中,前3位沒有奇數(shù)個1;010110010110001010111001001011中,前3位有奇數(shù)個1。
在訓(xùn)練前期階段,模型中各神經(jīng)元的權(quán)重(下圖中的熱圖)是雜亂無章的,因為AI不知道完成這一任務(wù)只需要看前3個數(shù)字。
但經(jīng)過一段時間的訓(xùn)練后,AI突然“領(lǐng)悟了”,學(xué)會了只看序列中的前3個數(shù)字。具體到模型中,表現(xiàn)為只剩下幾個權(quán)重會隨著輸入發(fā)生變化:
這個訓(xùn)練過程的目標(biāo)被稱之為最小化損失(提升模型輸出準(zhǔn)確率),采用的技術(shù)則被稱之為權(quán)重衰減(防止模型過擬合)。
訓(xùn)練過程中,有一些權(quán)重與任務(wù)的“干擾數(shù)字”(30位序列的后27位)相關(guān),下圖可視化為灰色;有一些則與完成任務(wù)的“前3位數(shù)字”有關(guān),下圖可視化為綠色。
當(dāng)最后一個灰色權(quán)重降到接近0,模型就會出現(xiàn)“領(lǐng)悟力”,顯然這個過程不是突然發(fā)生的。
再來看第二個結(jié)論。不是所有AI模型都能學(xué)會“領(lǐng)悟”。
科學(xué)家們訓(xùn)練了1125個模型,其中模型之間的超參數(shù)不同,每組超參數(shù)訓(xùn)練9個模型。
最后歸納出4類模型,只有2類模型會出現(xiàn)“領(lǐng)悟力”。
如下圖,“白色”和“灰色”代表學(xué)不會“領(lǐng)悟”的AI模型,“黃色”和“藍(lán)色”代表能“領(lǐng)悟”的AI模型。
總結(jié)概括規(guī)律就是,一旦權(quán)重衰減、模型大小、數(shù)據(jù)量和超參數(shù)的設(shè)置不合適,AI的“領(lǐng)悟力”就有可能消失——
以權(quán)重衰減為例。如果權(quán)重衰減太小,會導(dǎo)致模型過擬合;權(quán)重衰減太大,又會導(dǎo)致模型學(xué)不到任何東西。
嗯,調(diào)參是門技術(shù)活……
了解現(xiàn)象之后,還需要探明背后的原因。
接下來,科學(xué)家們又設(shè)計了兩個小AI模型,用它來探索模型出現(xiàn)“領(lǐng)悟力”、最終掌握泛化能力出現(xiàn)的機(jī)制。
更大的模型學(xué)會泛化的機(jī)制
科學(xué)家們分別設(shè)計了一個24個神經(jīng)元的單層MLP和一個5個神經(jīng)元的單層MLP,訓(xùn)練它們學(xué)會做模加法(modular addition)任務(wù)。
模加法,指(a + b) mod n。輸入整數(shù)a和b,用它們的和減去模數(shù)n,直到獲得一個比n小的整數(shù),確保輸出位于0~(n-1)之間。
顯然,這個任務(wù)的輸出是周期性的,答案一定位于0~66之間。
首先,給只有5個神經(jīng)元的單層MLP一點“提示”,設(shè)置權(quán)重時就加入周期性(sin、cos函數(shù))。
在人為幫助下,模型在訓(xùn)練時擬合得很好,很快學(xué)會了模加法。
然后,試著“從頭訓(xùn)練”具有24個神經(jīng)元的單層MLP,不特別設(shè)置任何權(quán)重。
可以看到,訓(xùn)練前期,這只MLP模型的權(quán)重(下面的熱圖)變化還是雜亂無章的:
然而到達(dá)某個訓(xùn)練階段后,模型權(quán)重變化會變得非常規(guī)律,甚至隨著輸入改變,呈現(xiàn)出某種周期性變化:
如果將單個神經(jīng)元的權(quán)重拎出來看,隨著訓(xùn)練步數(shù)的增加,這種變化更加明顯:
這也是AI從死記硬背轉(zhuǎn)變?yōu)榫哂蟹夯芰Φ年P(guān)鍵現(xiàn)象:神經(jīng)元權(quán)重隨著輸入出現(xiàn)周期性變化,意味著模型自己找到并學(xué)會了某種數(shù)學(xué)結(jié)構(gòu)(sin、cos函數(shù))。
這里面的頻率(freq)不是固定的一個值,而是有好幾個。
之所以會用到多個頻率(freq),是因為24個神經(jīng)元的單層MLP還自己學(xué)會了使用相長干涉(constructive interference),避免出現(xiàn)過擬合的情況。
不同的頻率組合,都能達(dá)到讓AI“領(lǐng)悟”的效果:
用離散傅里葉變換(DFT)對頻率進(jìn)行隔離,可以發(fā)現(xiàn)和“數(shù)奇數(shù)”類似的現(xiàn)象,核心只有幾個權(quán)重起作用:
總結(jié)來看,就像前面提到的“數(shù)奇數(shù)”任務(wù)一樣,“模加法”實驗表明,參數(shù)量更大的AI也能在這個任務(wù)中學(xué)會“領(lǐng)悟”,而這個過程同樣用到了權(quán)重衰減。
從5個神經(jīng)元到24個神經(jīng)元,科學(xué)家們成功探索了更大的AI能學(xué)習(xí)“領(lǐng)悟”的機(jī)制。
接下來,他們還計劃將這種思路套用到更大的模型中,以至于最后能歸納出大模型具備強(qiáng)理解力的原因。
不僅如此,這一成果還有助于自動發(fā)現(xiàn)神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)算法,最終讓AI自己設(shè)計AI。
團(tuán)隊介紹
撰寫博客的作者來自谷歌的People + AI Research(PAIR)團(tuán)隊。
這是谷歌的一個多學(xué)科團(tuán)隊,致力于通過基礎(chǔ)研究、構(gòu)建工具、創(chuàng)建框架等方法,來研究AI的公平性、可靠性等。
一句話總結(jié)就是,讓“AI更好地造福于人”。
博客地址:https://pair.withgoogle.com/explorables/grokking/