Yann LeCun團(tuán)隊(duì)新研究成果:對(duì)自監(jiān)督學(xué)習(xí)逆向工程,原來聚類是這樣實(shí)現(xiàn)的
自監(jiān)督學(xué)習(xí)(SSL)在最近幾年取得了很大的進(jìn)展,在許多下游任務(wù)上幾乎已經(jīng)達(dá)到監(jiān)督學(xué)習(xí)方法的水平。但是,由于模型的復(fù)雜性以及缺乏有標(biāo)注訓(xùn)練數(shù)據(jù)集,我們還一直難以理解學(xué)習(xí)到的表征及其底層的工作機(jī)制。此外,自監(jiān)督學(xué)習(xí)中使用的 pretext 任務(wù)通常與特定下游任務(wù)的直接關(guān)系不大,這就進(jìn)一步增大了解釋所學(xué)習(xí)到的表征的復(fù)雜性。而在監(jiān)督式分類中,所學(xué)到的表征的結(jié)構(gòu)往往很簡(jiǎn)單。
相比于傳統(tǒng)的分類任務(wù)(目標(biāo)是準(zhǔn)確將樣本歸入特定類別),現(xiàn)代 SSL 算法的目標(biāo)通常是最小化包含兩大成分的損失函數(shù):一是對(duì)增強(qiáng)過的樣本進(jìn)行聚類(不變性約束),二是防止表征坍縮(正則化約束)。舉個(gè)例子,對(duì)于同一樣本經(jīng)過不同增強(qiáng)之后的數(shù)據(jù),對(duì)比式學(xué)習(xí)方法的目標(biāo)是讓這些樣本的分類結(jié)果一樣,同時(shí)又要能區(qū)分經(jīng)過增強(qiáng)之后的不同樣本。另一方面,非對(duì)比式方法要使用正則化器(regularizer)來避免表征坍縮。
自監(jiān)督學(xué)習(xí)可以利用輔助任務(wù)(pretext)無監(jiān)督數(shù)據(jù)中挖掘自身的監(jiān)督信息,通過這種構(gòu)造的監(jiān)督信息對(duì)網(wǎng)絡(luò)進(jìn)行訓(xùn)練,從而可以學(xué)習(xí)到對(duì)下游任務(wù)有價(jià)值的表征。近日,圖靈獎(jiǎng)得主 Yann LeCun 在內(nèi)的多位研究者發(fā)布了一項(xiàng)研究,宣稱對(duì)自監(jiān)督學(xué)習(xí)進(jìn)行了逆向工程,讓我們得以了解其訓(xùn)練過程的內(nèi)部行為。
論文地址:https://arxiv.org/abs/2305.15614v2
這篇論文通過一系列精心設(shè)計(jì)的實(shí)驗(yàn)對(duì)使用 SLL 的表征學(xué)習(xí)進(jìn)行了深度分析,幫助人們理解訓(xùn)練期間的聚類過程。具體來說,研究揭示出增強(qiáng)過的樣本會(huì)表現(xiàn)出高度聚類的行為,這會(huì)圍繞共享同一圖像的增強(qiáng)樣本的含義嵌入形成質(zhì)心。更出人意料的是,研究者觀察到:即便缺乏有關(guān)目標(biāo)任務(wù)的明確信息,樣本也會(huì)根據(jù)語義標(biāo)簽發(fā)生聚類。這表明 SSL 有能力根據(jù)語義相似性對(duì)樣本進(jìn)行分組。
問題設(shè)置
由于自監(jiān)督學(xué)習(xí)(SSL)通常用于預(yù)訓(xùn)練,讓模型做好準(zhǔn)備適應(yīng)下游任務(wù),這帶來了一個(gè)關(guān)鍵問題:SSL 訓(xùn)練會(huì)對(duì)所學(xué)到的表征產(chǎn)生什么影響?具體來說,訓(xùn)練期間 SSL 的底層工作機(jī)制是怎樣的,這些表征函數(shù)能學(xué)到什么類別?
為了調(diào)查這些問題,研究者在多種設(shè)置上訓(xùn)練了 SSL 網(wǎng)絡(luò)并使用不同的技術(shù)分析了它們的行為。
數(shù)據(jù)和增強(qiáng):本文提到的所有實(shí)驗(yàn)都使用了 CIFAR100 圖像分類數(shù)據(jù)集。為了訓(xùn)練模型,研究者使用了 SimCLR 中提出的圖像增強(qiáng)協(xié)議。每一個(gè) SSL 訓(xùn)練 session 都執(zhí)行 1000 epoch,使用了帶動(dòng)量的 SGD 優(yōu)化器。
骨干架構(gòu):所有的實(shí)驗(yàn)都使用了 RES-L-H 架構(gòu)作為骨干,再加上了兩層多層感知器(MLP)投射頭。
線性探測(cè)(linear probing):為了評(píng)估從表征函數(shù)中提取給定離散函數(shù)(例如類別)的有效性,這里使用的方法是線性探測(cè)。這需要基于該表征訓(xùn)練一個(gè)線性分類器(也稱為線性探針),這需要用到一些訓(xùn)練樣本。
樣本層面的分類:為了評(píng)估樣本層面的可分離性,研究者創(chuàng)建了一個(gè)專門的新數(shù)據(jù)集。
其中訓(xùn)練數(shù)據(jù)集包含來自 CIFAR-100 訓(xùn)練集的 500 張隨機(jī)圖像。每張圖像都代表一個(gè)特定類別并會(huì)進(jìn)行 100 種不同的增強(qiáng)。因此,訓(xùn)練數(shù)據(jù)集包含 500 個(gè)類別的共計(jì) 50000 個(gè)樣本。測(cè)試集依然是用這 500 張圖像,但要使用 20 種不同的增強(qiáng),這些增強(qiáng)都來自同一分布。因此,測(cè)試集中的結(jié)果由 10000 個(gè)樣本構(gòu)成。為了在樣本層面衡量給定表征函數(shù)的線性或 NCC(nearest class-center / 最近類別中心)準(zhǔn)確度,這里采用的方法是先使用訓(xùn)練數(shù)據(jù)計(jì)算出一個(gè)相關(guān)的分類器,然后再在相應(yīng)測(cè)試集上評(píng)估其準(zhǔn)確率。
揭示自監(jiān)督學(xué)習(xí)的聚類過程
在幫助分析深度學(xué)習(xí)模型方面,聚類過程一直以來都發(fā)揮著重要作用。為了直觀地理解 SSL 訓(xùn)練,圖 1 通過 UMAP 可視化展示了網(wǎng)絡(luò)的訓(xùn)練樣本的嵌入空間,其中包含訓(xùn)練前后的情況并分了不同層級(jí)。
圖 1:SSL 訓(xùn)練引起的語義聚類
正如預(yù)期的那樣,訓(xùn)練過程成功地在樣本層面上對(duì)樣本進(jìn)行了聚類,映射了同一圖像的不同增強(qiáng)(如第一行圖示)??紤]到目標(biāo)函數(shù)本身就會(huì)鼓勵(lì)這種行為(通過不變性損失項(xiàng)),因此這樣的結(jié)果倒是不意外。然而,更值得注意的是,該訓(xùn)練過程還會(huì)根據(jù)標(biāo)準(zhǔn) CIFAR-100 數(shù)據(jù)集的原始「語義類別」進(jìn)行聚類,即便該訓(xùn)練過程期間缺乏標(biāo)簽。有趣的是,更高的層級(jí)(超類別)也能被有效聚類。這個(gè)例子表明,盡管訓(xùn)練流程直接鼓勵(lì)的是樣本層面的聚類,但 SSL 訓(xùn)練的數(shù)據(jù)表征還會(huì)在不同層面上根據(jù)語義類別來進(jìn)行聚類。
為了進(jìn)一步量化這個(gè)聚類過程,研究者使用 VICReg 訓(xùn)練了一個(gè) RES-10-250。研究者衡量的是 NCC 訓(xùn)練準(zhǔn)確度,既有樣本層面的,也有基于原始類別的。值得注意的是,SSL 訓(xùn)練的表征在樣本層面上展現(xiàn)出了神經(jīng)坍縮(neural collapse,即 NCC 訓(xùn)練準(zhǔn)確度接近于 1.0),然而在語義類別方面的聚類也很顯著(在原始目標(biāo)上約為 0.41)。
如圖 2 左圖所示,涉及增強(qiáng)(網(wǎng)絡(luò)直接基于其訓(xùn)練的)的聚類過程大部分都發(fā)生在訓(xùn)練過程初期,然后陷入停滯;而在語義類別方面的聚類(訓(xùn)練目標(biāo)中并未指定)則會(huì)在訓(xùn)練過程中持續(xù)提升。
圖 2:SSL 算法根據(jù)語義目標(biāo)對(duì)對(duì)數(shù)據(jù)的聚類
之前有研究者觀察到,監(jiān)督式訓(xùn)練樣本的頂層嵌入會(huì)逐漸向一個(gè)類質(zhì)心的結(jié)構(gòu)收斂。為了更好地理解 SSL 訓(xùn)練的表征函數(shù)的聚類性質(zhì),研究者調(diào)查了 SSL 過程中的類似情況。其 NCC 分類器是一種線性分類器,其表現(xiàn)不會(huì)超過最佳的線性分類器。通過評(píng)估 NCC 分類器與同樣數(shù)據(jù)上訓(xùn)練的線性分類器的準(zhǔn)確度之比,能夠在不同粒度層級(jí)上研究數(shù)據(jù)聚類。圖 2 的中圖給出了樣本層面類別和原始目標(biāo)類別上的這一比值的變化情況,其值根據(jù)初始化的值進(jìn)行了歸一化。隨著 SSL 訓(xùn)練的進(jìn)行,NCC 準(zhǔn)確度和線性準(zhǔn)確度之間的差距會(huì)變小,這說明增強(qiáng)后的樣本會(huì)根據(jù)其樣本身份和語義屬性逐漸提升聚類水平。
此外,該圖還說明,樣本層面的比值起初會(huì)高一些,這說明增強(qiáng)后的樣本會(huì)根據(jù)它們的身份進(jìn)行聚類,直到收斂至質(zhì)心(NCC 準(zhǔn)確度和線性準(zhǔn)確度的比值在 100 epoch 時(shí) ≥ 0.9)。但是,隨著訓(xùn)練繼續(xù),樣本層面的比值會(huì)飽和,而類別層面的比值會(huì)繼續(xù)增長(zhǎng)并收斂至 0.75 左右。這說明增強(qiáng)后的樣本首先會(huì)根據(jù)樣本身份進(jìn)行聚類,實(shí)現(xiàn)之后,再根據(jù)高層面的語義類別進(jìn)行聚類。
SSL 訓(xùn)練中隱含的信息壓縮
如果能有效進(jìn)行壓縮,那么就能得到有益又有用的表征。但 SSL 訓(xùn)練過程中是否會(huì)出現(xiàn)那樣的壓縮卻仍是少有人研究的課題。
為了了解這一點(diǎn),研究者使用了互信息神經(jīng)估計(jì)(Mutual Information Neural Estimation/MINE),這種方法可以估計(jì)訓(xùn)練過程中輸入與其對(duì)應(yīng)嵌入表征之間的互信息。這個(gè)度量可用于有效衡量表征的復(fù)雜度水平,其做法是展現(xiàn)其編碼的信息量(比特?cái)?shù)量)。
圖 3 的中圖報(bào)告了在 5 個(gè)不同的 MINE 初始化種子上計(jì)算得到的平均互信息。如圖所示,訓(xùn)練過程會(huì)有顯著的壓縮,最終形成高度緊湊的訓(xùn)練表征。
圖 3:(左)一個(gè) SSL 訓(xùn)練的模型在訓(xùn)練期間的正則化和不變性損失以及原始目標(biāo)線性測(cè)試準(zhǔn)確度。(中)訓(xùn)練期間輸入和表征之間的互信息的壓縮。(右)SSL 訓(xùn)練學(xué)習(xí)聚類的表征。
正則化損失的作用
目標(biāo)函數(shù)包含兩項(xiàng):不變性和正則化。不變性項(xiàng)的主要功能是強(qiáng)化同一樣本的不同增強(qiáng)的表征之間的相似性。而正則化項(xiàng)的目標(biāo)是幫助防止表征坍縮。
為了探究這些分量對(duì)聚類過程的作用,研究者將目標(biāo)函數(shù)分解為了不變性項(xiàng)和正則化項(xiàng),并觀察它們?cè)谟?xùn)練過程中的行為。比較結(jié)果見圖 3 左圖,其中給出了原始語義目標(biāo)上的損失項(xiàng)的演變以及線性測(cè)試準(zhǔn)確度。不同于普遍流行的想法,不變性損失項(xiàng)在訓(xùn)練過程中并不會(huì)顯著改善。相反,損失(以及下游的語義準(zhǔn)確度)的改善是通過降低正則化損失實(shí)現(xiàn)的。
由此可以得出結(jié)論:SSL 的大部分訓(xùn)練過程都是為了提升語義準(zhǔn)確度和所學(xué)表征的聚類,而非樣本層面的分類準(zhǔn)確度和聚類。
從本質(zhì)上講,這里的發(fā)現(xiàn)表明:盡管自監(jiān)督學(xué)習(xí)的直接目標(biāo)是樣本層面的分類,但其實(shí)大部分訓(xùn)練時(shí)間都用于不同層級(jí)上基于語義類別的數(shù)據(jù)聚類。這一觀察結(jié)果表明 SSL 方法有能力通過聚類生成有語義含義的表征,這也讓我們得以了解其底層機(jī)制。
監(jiān)督學(xué)習(xí)和 SSL 聚類的比較
深度網(wǎng)絡(luò)分類器往往是基于訓(xùn)練樣本的類別將它們聚類到各個(gè)質(zhì)心。但學(xué)習(xí)得到的函數(shù)要能真正聚類,必須要求這一性質(zhì)對(duì)測(cè)試樣本依然有效;這是我們期望得到的效果,但效果會(huì)差一點(diǎn)。
這里有一個(gè)有趣的問題:相比于監(jiān)督學(xué)習(xí)的聚類,SSL 能在多大程度上根據(jù)樣本的語義類別來執(zhí)行聚類?圖 3 右圖報(bào)告了在不同場(chǎng)景(使用和不使用增強(qiáng)的監(jiān)督學(xué)習(xí)以及 SSL)的訓(xùn)練結(jié)束時(shí)的 NCC 訓(xùn)練和測(cè)試準(zhǔn)確度比率。
盡管監(jiān)督式分類器的 NCC 訓(xùn)練準(zhǔn)確度為 1.0,顯著高于 SSL 訓(xùn)練的模型的 NCC 訓(xùn)練準(zhǔn)確度,但 SSL 模型的 NCC 測(cè)試準(zhǔn)確度卻略高于監(jiān)督式模型的 NCC 測(cè)試準(zhǔn)確度。這說明兩種模型根據(jù)語義類別的聚類行為具有相似的程度。有意思的是,使用增強(qiáng)樣本訓(xùn)練監(jiān)督式模型會(huì)稍微降低 NCC 訓(xùn)練準(zhǔn)確度,卻會(huì)大幅提升 NCC 測(cè)試準(zhǔn)確度。
探索語義類別學(xué)習(xí)和隨機(jī)性的影響
語義類別是根據(jù)輸入的內(nèi)在模式來定義輸入和目標(biāo)的關(guān)系。另一方面,如果將輸入映射到隨機(jī)目標(biāo),則會(huì)看到缺乏可辨別的模式,這會(huì)導(dǎo)致輸入和目標(biāo)之間的連接看起來很任意。
研究者還探究了隨機(jī)性對(duì)模型學(xué)習(xí)所需目標(biāo)的熟練程度的影響。為此,他們構(gòu)建了一系列具有不同隨機(jī)度的目標(biāo)系統(tǒng),然后檢查了隨機(jī)度對(duì)所學(xué)表征的影響。他們?cè)谟糜诜诸惖耐粩?shù)據(jù)集上訓(xùn)練了一個(gè)神經(jīng)網(wǎng)絡(luò)分類器,然后使用其不同 epoch 的目標(biāo)預(yù)測(cè)作為具有不同隨機(jī)度的目標(biāo)。在 epoch 0 時(shí),網(wǎng)絡(luò)是完全隨機(jī)的,會(huì)得到確定的但看似任意的標(biāo)簽。隨著訓(xùn)練進(jìn)行,其函數(shù)的隨機(jī)性下降,最終得到與基本真值目標(biāo)對(duì)齊的目標(biāo)(可認(rèn)為是完全不隨機(jī))。這里將隨機(jī)度歸一化到 0(完全不隨機(jī),訓(xùn)練結(jié)束時(shí))到 1(完全隨機(jī),初始化時(shí))之間。
圖 4 左圖展示了不同隨機(jī)度目標(biāo)的線性測(cè)試準(zhǔn)確度。每條線都對(duì)應(yīng)于不同隨機(jī)度的 SSL 不同訓(xùn)練階段的準(zhǔn)確度??梢钥吹?,在訓(xùn)練過程中,模型會(huì)更高效地捕獲與「語義」目標(biāo)(更低隨機(jī)度)更接近的類別,同時(shí)在高隨機(jī)度的目標(biāo)上沒有表現(xiàn)出顯著的性能改進(jìn)。
圖 4:SSL 持續(xù)學(xué)習(xí)語義目標(biāo),而非隨機(jī)目標(biāo)
深度學(xué)習(xí)的一個(gè)關(guān)鍵問題是理解中間層對(duì)分類不同類型類別的作用和影響。比如,不同的層會(huì)學(xué)到不同類型的類別嗎?研究者也探索了這個(gè)問題,其做法是在訓(xùn)練結(jié)束時(shí)不同目標(biāo)隨機(jī)度下評(píng)估不同層表征的線性測(cè)試準(zhǔn)確度。如圖 4 中圖所示,隨著隨機(jī)度下降,線性測(cè)試準(zhǔn)確度持續(xù)提升,更深度的層在所有類別類型上都表現(xiàn)更優(yōu),而對(duì)于接近語義類別的分類,性能差距會(huì)更大。
研究者還使用了其它一些度量來評(píng)估聚類的質(zhì)量:NCC 準(zhǔn)確度、CDNV、平均每類方差、類別均值之間的平均平方距離。為了衡量表征隨訓(xùn)練進(jìn)行的改進(jìn)情況,研究者為語義目標(biāo)和隨機(jī)目標(biāo)計(jì)算了這些指標(biāo)的比率。圖 4 右圖展示了這些比率,結(jié)果表明相比于隨機(jī)目標(biāo),表征會(huì)更加偏向根據(jù)語義目標(biāo)來聚類數(shù)據(jù)。有趣的是,可以看到 CDNV(方差除以平方距離)會(huì)降低,其原因僅僅是平方距離的下降。方差比率在訓(xùn)練期間相當(dāng)穩(wěn)定。這會(huì)鼓勵(lì)聚類之間的間距拉大,這一現(xiàn)象已被證明能帶來性能提升。
了解類別層級(jí)結(jié)構(gòu)和中間層
之前的研究已經(jīng)證明,在監(jiān)督學(xué)習(xí)中,中間層會(huì)逐漸捕獲不同抽象層級(jí)的特征。初始的層傾向于低層級(jí)的特征,而更深的層會(huì)捕獲更抽象的特征。接下來,研究者探究了 SSL 網(wǎng)絡(luò)能否學(xué)習(xí)更高層面的層次屬性以及哪些層面與這些屬性的關(guān)聯(lián)性更好。
在實(shí)驗(yàn)中,他們計(jì)算了三個(gè)層級(jí)的線性測(cè)試準(zhǔn)確度:樣本層級(jí)、原始的 100 個(gè)類別、20 個(gè)超類別。圖 2 右圖給出了為這三個(gè)不同類別集計(jì)算的數(shù)量。可以觀察到,在訓(xùn)練過程中,相較于樣本層級(jí)的類別,在原始類別和超類別層級(jí)上的表現(xiàn)的提升更顯著。
接下來是 SSL 訓(xùn)練的模型的中間層的行為以及它們捕獲不同層級(jí)的目標(biāo)的能力。圖 5 左和中圖給出了不同訓(xùn)練階段在所有中間層上的線性測(cè)試準(zhǔn)確度,這里度量了原始目標(biāo)和超目標(biāo)。圖 5 右圖給出超類別和原始類別之間的比率。
圖 5:SSL 能在整體中間層中有效學(xué)習(xí)語義類別
研究者基于這些結(jié)果得到了幾個(gè)結(jié)論。首先,可以觀察到隨著層的深入,聚類效果會(huì)持續(xù)提升。此外,與監(jiān)督學(xué)習(xí)情況類似,研究者發(fā)現(xiàn)在 SSL 訓(xùn)練期間,網(wǎng)絡(luò)每一層的線性準(zhǔn)確度都有提升。值得注意的是,他們發(fā)現(xiàn)對(duì)于原始類別,最終層并不是最佳層。近期的一些 SSL 研究表明:下游任務(wù)能高度影響不同算法的性能。本文的研究拓展了這一觀察結(jié)果,并且表明網(wǎng)絡(luò)的不同部分可能適合不同的下游任務(wù)與任務(wù)層級(jí)。根據(jù)圖 5 右圖,可以看出,在網(wǎng)絡(luò)的更深層,超類別的準(zhǔn)確度的提升幅度超過原始類別。