擴(kuò)散模型訓(xùn)練方法一直錯(cuò)了!謝賽寧:Representation matters
是什么讓紐約大學(xué)著名研究者謝賽寧三連呼喊「Representation matters」?他表示:「我們可能一直都在用錯(cuò)誤的方法訓(xùn)練擴(kuò)散模型。」即使對(duì)生成模型而言,表征也依然有用?;诖?,他們提出了 REPA,即表征對(duì)齊技術(shù),其能讓「訓(xùn)練擴(kuò)散 Transformer 變得比你想象的更簡(jiǎn)單?!?/span>
Yann LeCun 也對(duì)他們的研究表示了認(rèn)可:「我們知道,當(dāng)使用自監(jiān)督學(xué)習(xí)訓(xùn)練視覺編碼器時(shí),使用具有重構(gòu)損失的解碼器的效果遠(yuǎn)不如使用具有特征預(yù)測(cè)損失和崩潰預(yù)防機(jī)制的聯(lián)合嵌入架構(gòu)。這篇來自紐約大學(xué) @sainingxie 的論文表明,即使你只對(duì)生成像素感興趣(例如使用擴(kuò)散 Transformer 生成漂亮圖片),也應(yīng)該包含特征預(yù)測(cè)損失,以便解碼器的內(nèi)部表征可以根據(jù)預(yù)訓(xùn)練的視覺編碼器(例如 DINOv2)預(yù)測(cè)特征。」
我們知道,在生成高維視覺數(shù)據(jù)方面,基于去噪的生成模型(如擴(kuò)展模型和基于流的模型)的表現(xiàn)非常好,已經(jīng)得到了廣泛應(yīng)用。近段時(shí)間,也有研究開始探索將擴(kuò)展模型用作表征學(xué)習(xí)器,因?yàn)檫@些模型的隱藏狀態(tài)可以捕獲有意義的判別式特征。
而謝賽寧指導(dǎo)的這個(gè)團(tuán)隊(duì)發(fā)現(xiàn)(另一位指導(dǎo)者是 KAIST 的 Jinwoo Shin),訓(xùn)練擴(kuò)散模型的主要挑戰(zhàn)源于需要學(xué)習(xí)高質(zhì)量的內(nèi)部表征。他們的研究表明:「當(dāng)生成式擴(kuò)散模型得到來自另一個(gè)模型(例如自監(jiān)督視覺編碼器)的外部高質(zhì)量表征的支持時(shí),其性能可以得到大幅提升?!?/span>
REPresentation Alignment(REPA),即表征對(duì)齊技術(shù),便基于此而誕生了。這是一個(gè)基于近期的擴(kuò)散 Transformer(DiT)架構(gòu)的簡(jiǎn)單正則化技術(shù)。
- 論文標(biāo)題:Representation Alignment for Generation: Training Diffusion Transformers Is Easier Than You Think
- 論文地址:https://arxiv.org/pdf/2410.06940
- 項(xiàng)目地址:https://sihyun.me/REPA/
- 代碼地址:https://github.com/sihyun-yu/REPA
本質(zhì)上講,REPA 就是將一張清晰圖像的預(yù)訓(xùn)練自監(jiān)督視覺表征蒸餾成一個(gè)有噪聲輸入的擴(kuò)展 Transformer 表征。這種正則化可以更好地將擴(kuò)展模型表征與目標(biāo)自監(jiān)督表征對(duì)齊。
方法看起來很簡(jiǎn)單,但 REPA 的效果卻很好!據(jù)介紹,REPA 能大幅提升模型訓(xùn)練的效率和效果。相比于原生模型,REPA 能將收斂速度提升 17.5 倍以上。在生成質(zhì)量方面,在使用帶引導(dǎo)間隔(guidance interval)的無(wú)分類器引導(dǎo)時(shí),新方法取得了 FID=1.42 的當(dāng)前最佳結(jié)果。
REPA:用于表征對(duì)齊的正則化
REPresentation Alignment(REPA)是一種簡(jiǎn)單的正則化方法,其使用了近期的擴(kuò)展 Transformer 架構(gòu)。簡(jiǎn)單來說,該技術(shù)就是一種將預(yù)訓(xùn)練的自監(jiān)督視覺表征蒸餾到擴(kuò)展 Transformer 的簡(jiǎn)單又有效的方法。這讓擴(kuò)散模型可以利用這些語(yǔ)義豐富的外部表征進(jìn)行生成,從而大幅提高性能。
觀察
REPA 的誕生基于該團(tuán)隊(duì)得到的幾項(xiàng)重要觀察。
他們研究了在 ImageNet 上預(yù)訓(xùn)練得到的 SiT(可擴(kuò)展插值 Transformer)模型的逐層行為,該模型使用了線性插值和速度預(yù)測(cè)(velocity prediction)進(jìn)行訓(xùn)練。他們研究的重點(diǎn)是擴(kuò)散 Transformer 和當(dāng)前領(lǐng)先的監(jiān)督式 DINOv2 模型之間的表征差距。他們從三個(gè)角度進(jìn)行了研究:語(yǔ)義差距、特征對(duì)齊進(jìn)展以及最終的特征對(duì)齊。
對(duì)于語(yǔ)義差距,他們比較了使用 DINOv2 特征的線性探測(cè)結(jié)果與來自 SiT 模型(訓(xùn)練了 700 萬(wàn)次迭代)的線性探測(cè)結(jié)果,采用的協(xié)議涉及到對(duì)擴(kuò)散 Transformer 的全局池化的隱藏狀態(tài)進(jìn)行線性探測(cè)。
接下來,為了測(cè)量特征對(duì)齊,他們使用了 CKNNA;這是一種與 CKA 相關(guān)的核對(duì)齊(kernel alignment)指標(biāo),但卻是基于相互最近鄰。這樣一來,便能以量化方式評(píng)估對(duì)齊效果了。圖 2 總結(jié)了其結(jié)果。
擴(kuò)散 Transformer 與先進(jìn)視覺編碼器之間的語(yǔ)義差距明顯。如圖 2a 所示,可以觀察到,預(yù)訓(xùn)練擴(kuò)散 Transformer 的隱藏狀態(tài)表征在第 20 層能得到相當(dāng)高的線性探測(cè)峰值。但是,其性能仍遠(yuǎn)低于 DINOv2,表明這兩種表征之間存在相當(dāng)大的語(yǔ)義差距。此外,他們還發(fā)現(xiàn),在此峰值之后,線性探測(cè)性能會(huì)迅速下降,這表明擴(kuò)散 Transformer 必定從重點(diǎn)學(xué)習(xí)語(yǔ)義豐富的表征轉(zhuǎn)向了生成具有高頻細(xì)節(jié)的圖像。
擴(kuò)散表征已經(jīng)與其它視覺表征(細(xì)微地)對(duì)齊了。圖 2b 使用 CKNNA 展示了 SiT 與 DINOv2 之間的表征對(duì)齊情況。可以看到,SiT 模型表征的對(duì)齊已經(jīng)優(yōu)于 MAE,而后者也是一種基于掩碼圖塊重建的自監(jiān)督學(xué)習(xí)方法。但是,相比于其它自監(jiān)督學(xué)習(xí)方法之間的對(duì)齊分?jǐn)?shù),其絕對(duì)對(duì)齊分?jǐn)?shù)依然較低。這些結(jié)果表明,盡管擴(kuò)散 Transformer 表征與自監(jiān)督視覺表征存在一定的對(duì)齊,但對(duì)齊程度不高。
當(dāng)模型增大、訓(xùn)練變多時(shí),對(duì)齊效果會(huì)更好。該團(tuán)隊(duì)還測(cè)量了不同模型大小和訓(xùn)練迭代次數(shù)的 CKNNA 值。圖 2c 表明更大模型和更多訓(xùn)練有助于對(duì)齊。同樣地,相比于其它自監(jiān)督視覺編碼器之間的對(duì)齊,擴(kuò)散表征的絕對(duì)對(duì)齊分?jǐn)?shù)依然較低。
這些發(fā)現(xiàn)并非 SiT 模型所獨(dú)有,其它基于去噪的生成式 Transformer 也能觀察到。該團(tuán)隊(duì)也在 DiT 模型上觀察到了類似的結(jié)果 —— 其使用 DDPM 目標(biāo)在 ImageNet 上完成了預(yù)訓(xùn)練。
與自監(jiān)督表征的表征對(duì)齊
REPA 將模型隱藏狀態(tài)的 patch-wise 投影與預(yù)訓(xùn)練自監(jiān)督視覺表征對(duì)齊。具體來說,該研究使用干凈的(clean)圖像表征作為目標(biāo)并探討其影響。這種正則化的目的是讓擴(kuò)散 transformer 的隱藏狀態(tài)從包含有用語(yǔ)義信息的噪聲輸入中預(yù)測(cè)噪聲不變、干凈的視覺表征。這能為后續(xù)層重建目標(biāo)提供有意義的引導(dǎo)。
形式上,令 ?? 為預(yù)訓(xùn)練編碼器,x* 為干凈圖像。令 y*=??(x*) ∈ ?^{N×D} 為編碼器輸出,其中 N、D > 0 分別是 patch 的數(shù)量和 ?? 的嵌入維度。
REPA 是將與 y* 對(duì)齊,其中
是擴(kuò)散 transformer 編碼器輸出
通過可訓(xùn)練投影頭 h_? 得到的投影。實(shí)踐中 h_? 的參數(shù)化是簡(jiǎn)單地使用多層感知器(MLP)完成的。
特別地,REPA 通過最大化預(yù)訓(xùn)練表征 y* 和隱藏狀態(tài) h_t 之間的 patch-wise 相似性來實(shí)現(xiàn)對(duì)齊,其中 n 是 patch 索引,sim (?,?) 是預(yù)定義的相似度函數(shù)。
在實(shí)踐中,是基于一個(gè)系數(shù) λ 將該項(xiàng)添加到基于擴(kuò)散的原始目標(biāo)中。例如,對(duì)于速度模型的訓(xùn)練,其目標(biāo)變?yōu)椋?/span>
其中 λ > 0 是一個(gè)超參數(shù),用于控制去噪和表示對(duì)齊之間的權(quán)衡。該團(tuán)隊(duì)主要研究這種正則化對(duì)兩個(gè)常用目標(biāo)的影響:DiT 中使用的改進(jìn)版 DDPM 和 SiT 中使用的線性隨機(jī)插值,盡管也可以考慮其他目標(biāo)。
結(jié)果
REPA 改善視覺擴(kuò)展
該研究首先比較兩個(gè) SiT-XL/2 模型在前 400K 次迭代期間生成的圖像,其中一個(gè)模型應(yīng)用 REPA。兩種模型共享相同的噪聲、采樣器和采樣步驟數(shù),并且都不使用無(wú)分類器引導(dǎo)。使用 REPA 訓(xùn)練的模型表現(xiàn)更好。
REPA 在各個(gè)方面都展現(xiàn)出強(qiáng)大的可擴(kuò)展性
該研究通過改變預(yù)訓(xùn)練編碼器和擴(kuò)散 transformer 模型大小來檢查 REPA 的可擴(kuò)展性,結(jié)果表明:與更好的視覺表征相結(jié)合可以改善生成和線性探測(cè)結(jié)果。
REPA 還在大型模型中提供了更顯著的加速,與普通模型相比,實(shí)現(xiàn)了更快的 FID-50K 改進(jìn)。此外,增加模型大小可以在生成和線性評(píng)估方面帶來更快的增益。
REPA 顯著提高訓(xùn)練效率和生成質(zhì)量
最后,該研究比較了普通 DiT 或 SiT 模型與使用 REPA 訓(xùn)練的模型的 FID 值。
在沒有無(wú)分類器引導(dǎo)的情況下,REPA 在 400K 次迭代時(shí)實(shí)現(xiàn)了 FID=7.9,優(yōu)于普通模型在 700 萬(wàn)次迭代時(shí)的性能。
使用無(wú)分類器引導(dǎo),帶有 REPA 的 SiT-XL/2 的性能優(yōu)于最新的擴(kuò)散模型,迭代次數(shù)減少為 1/7,并通過額外的引導(dǎo)調(diào)度實(shí)現(xiàn)了 SOTA FID=1.42。
該團(tuán)隊(duì)也執(zhí)行了消融研究,探索了不同時(shí)間步數(shù)、不同視覺編碼器和不同 λ 值(正則化系數(shù))的影響。詳見原論文。