45倍加速+最新SOTA!VAE與擴(kuò)散模型迎來端到端聯(lián)合訓(xùn)練:REPA-E讓VAE自我進(jìn)化!
文章鏈接: https://arxiv.org/pdf/2504.10483
項目鏈接:https://end2end-diffusion.github.io/
Git鏈接:https://github.com/End2End-Diffusion/REPA-E
模型鏈接:https://huggingface.co/REPA-E
亮點直擊
- 端到端聯(lián)合優(yōu)化的突破首次實現(xiàn)VAE與擴(kuò)散模型的端到端聯(lián)合訓(xùn)練,通過REPA Loss替代傳統(tǒng)擴(kuò)散損失,解決兩階段訓(xùn)練目標(biāo)不一致問題,使隱空間與生成任務(wù)高度適配。
- 訓(xùn)練效率革命性提升REPA-E僅需傳統(tǒng)方法1/45的訓(xùn)練步數(shù)即可收斂,且生成質(zhì)量顯著超越現(xiàn)有方法(如FID從5.9降至4.07),大幅降低計算成本。
- 雙向性能增益不僅提升擴(kuò)散模型性能,還通過反向傳播優(yōu)化VAE的隱空間結(jié)構(gòu),使其成為通用型模塊,可遷移至其他任務(wù)(如替換SD-VAE后下游任務(wù)性能提升)。
總結(jié)速覽
解決的問題
現(xiàn)有隱空間擴(kuò)散模型(LDM)采用兩階段訓(xùn)練(先訓(xùn)練VAE,再固定VAE訓(xùn)練擴(kuò)散模型),導(dǎo)致兩個階段的優(yōu)化目標(biāo)不一致,限制了生成性能。直接端到端聯(lián)合訓(xùn)練VAE和擴(kuò)散模型時,傳統(tǒng)擴(kuò)散損失(Diffusion Loss)失效,甚至導(dǎo)致性能下降。
提出的方案
提出REPA-E訓(xùn)練框架,通過表示對齊損失(REPA Loss)實現(xiàn)VAE與擴(kuò)散模型的端到端聯(lián)合優(yōu)化。REPA Loss通過對齊隱空間表示的結(jié)構(gòu),協(xié)調(diào)兩者的訓(xùn)練目標(biāo),替代傳統(tǒng)擴(kuò)散損失的直接優(yōu)化。
應(yīng)用的技術(shù)
- 表示對齊損失(REPA Loss):在擴(kuò)散模型的去噪過程中,對齊隱空間表示的分布,確保VAE生成的隱空間編碼更適配擴(kuò)散模型的訓(xùn)練目標(biāo)。
- 端到端梯度傳播:通過聯(lián)合優(yōu)化框架,將擴(kuò)散模型的梯度反向傳播至VAE,動態(tài)調(diào)整其隱空間結(jié)構(gòu)。
- 自適應(yīng)隱空間優(yōu)化:根據(jù)擴(kuò)散模型的訓(xùn)練需求,自動平衡VAE的重建能力與隱空間的可學(xué)習(xí)性。
達(dá)到的效果
- 訓(xùn)練加速:相比傳統(tǒng)兩階段訓(xùn)練(4M步),REPA-E僅需400K步即達(dá)到更優(yōu)性能,訓(xùn)練速度提升45倍;相比單階段REPA優(yōu)化(17倍加速)。
- 生成質(zhì)量SOTA:在ImageNet 256×256上,F(xiàn)ID達(dá)到1.26(使用分類器引導(dǎo))和1.83(無引導(dǎo)),刷新當(dāng)前最佳記錄。
- 隱空間改善:對不同初始結(jié)構(gòu)的VAE(如高頻噪聲的SD-VAE、過平滑的IN-VAE),端到端訓(xùn)練自動優(yōu)化其隱空間,提升生成細(xì)節(jié)
REPA-E:解鎖VAE的聯(lián)合訓(xùn)練
概述。 給定一個變分自編碼器(VAE)和隱空間diffusion transformer (例如SiT),本文希望以端到端的方式聯(lián)合優(yōu)化VAE的隱空間表示和擴(kuò)散模型的特征,以實現(xiàn)最佳的最終生成性能。首先提出三個關(guān)鍵見解:1)樸素的端到端調(diào)優(yōu)——直接反向傳播擴(kuò)散損失到VAE是無效的。擴(kuò)散損失鼓勵學(xué)習(xí)更簡單的隱空間結(jié)構(gòu)(下圖3a),這雖然更容易最小化去噪目標(biāo),但會降低最終生成性能。接著分析了最近提出的表示對齊損失,發(fā)現(xiàn):2)更高的表示對齊分?jǐn)?shù)與改進(jìn)的生成性能相關(guān)(圖3b)。這為使用表示對齊分?jǐn)?shù)作為代理來提升最終生成性能提供了另一種途徑。3)使用樸素REPA方法可達(dá)到的最大對齊分?jǐn)?shù)受限于VAE隱空間特征的瓶頸。進(jìn)一步表明,在訓(xùn)練過程中將REPA損失反向傳播到VAE有助于解決這一限制,顯著提高最終的表示對齊分?jǐn)?shù)(圖3c)。
基于上述見解本文提出了REPA-E;一種用于聯(lián)合優(yōu)化VAE和LDM特征的端到端調(diào)優(yōu)方法。我們的核心思想很簡單:不直接使用擴(kuò)散損失進(jìn)行端到端調(diào)優(yōu),而是使用表示對齊分?jǐn)?shù)作為最終生成性能的代理。這促使本文提出最終方法,即不使用擴(kuò)散損失,而是使用表示對齊損失進(jìn)行端到端訓(xùn)練。通過REPA損失的端到端訓(xùn)練有助于更好地提高最終的表示對齊分?jǐn)?shù)(圖3b),從而提升最終生成性能。
用REPA推動端到端訓(xùn)練的動機(jī)
單純端到端訓(xùn)練對擴(kuò)散損失的影響。
更高的表示對齊分?jǐn)?shù)與更好的生成性能相關(guān)。 本文還使用CKNNA分?jǐn)?shù)在不同模型規(guī)模和訓(xùn)練迭代中測量表示對齊。如前圖3b所示,訓(xùn)練過程中更高的表示對齊分?jǐn)?shù)會帶來更好的生成性能。這表明可以通過使用表示對齊目標(biāo)(而非擴(kuò)散損失)進(jìn)行端到端訓(xùn)練來提升生成性能。
表示對齊受限于VAE特征。 圖3c顯示,雖然樸素應(yīng)用REPA損失可以提高表示對齊(CKNNA)分?jǐn)?shù),但可達(dá)到的最大對齊分?jǐn)?shù)仍受限于VAE特征,飽和值約為0.4(最大值為1)。此外,我們發(fā)現(xiàn)將表示對齊損失反向傳播到VAE有助于解決這一限制;允許端到端優(yōu)化VAE特征以最好地支持表示對齊目標(biāo)。
基于REPA的端到端訓(xùn)練
REPA-E——一種用于聯(lián)合訓(xùn)練VAE和LDM特征的端到端調(diào)優(yōu)方案。建議使用表示對齊損失而非直接使用擴(kuò)散損失來進(jìn)行端到端訓(xùn)練。通過REPA損失實現(xiàn)的端到端訓(xùn)練能夠更好地提升最終表示對齊分?jǐn)?shù)(圖3c),從而改善最終生成性能。
VAE隱空間歸一化的批歸一化層
為了實現(xiàn)端到端訓(xùn)練,我們首先在VAE和隱空間擴(kuò)散模型之間引入批歸一化層。典型的LDM訓(xùn)練需要使用預(yù)計算的隱空間統(tǒng)計量(例如SD-VAE的std=1/0.1825)對VAE特征進(jìn)行歸一化。這有助于將VAE隱空間輸出歸一化為零均值和單位方差,從而提高擴(kuò)散模型的訓(xùn)練效率。然而,在端到端訓(xùn)練中,每當(dāng)VAE模型更新時都需要重新計算統(tǒng)計量——這代價高昂。為解決這個問題,我們提出使用批歸一化層,該層使用指數(shù)移動平均(EMA)均值和方差作為數(shù)據(jù)集級統(tǒng)計量的替代。因此,批歸一化層充當(dāng)了可微分歸一化算子,無需在每次優(yōu)化步驟后重新計算數(shù)據(jù)集級統(tǒng)計量。
端到端表示對齊損失
整體訓(xùn)練
最終的整體訓(xùn)練以端到端方式使用以下?lián)p失函數(shù)進(jìn)行:
其中θ、φ、ω分別表示潛擴(kuò)散模型(LDM)、變分自編碼器(VAE)以及可訓(xùn)練的REPA投影層的參數(shù)。
實驗
本文通過廣泛實驗驗證REPA-E的性能及所提組件的影響,主要探究三個關(guān)鍵問題:
- REPA-E能否顯著提升生成性能與訓(xùn)練速度?(下表2、前圖1、下圖4)
- REPA-E是否適用于不同訓(xùn)練設(shè)置(模型規(guī)模、架構(gòu)、REPA編碼器等)?(下表3-8)
- 分析端到端調(diào)優(yōu)(REPA-E)對VAE隱空間結(jié)構(gòu)及下游生成性能的影響。(圖6、表9-10)
實驗設(shè)置
實現(xiàn)細(xì)節(jié)
評估指標(biāo)
圖像生成評估嚴(yán)格遵循ADM:在5萬張生成圖像上報告Fréchet Inception距離(gFID)、結(jié)構(gòu)FID(sFID)、Inception分?jǐn)?shù)(IS)、精確度(Prec.)與召回率(Rec.)。采樣采用SiT和REPA的SDE Euler-Maruyama方法(250步)。VAE評估在ImageNet驗證集5萬張256×256圖像上測量重建FID(rFID)。
訓(xùn)練性能與速度的影響
首先分析REPA-E對隱空間diffusion transformer 訓(xùn)練性能與速度的提升。
定量評估
下表2比較了不同隱空間擴(kuò)散模型(LDM)基線。在ImageNet 256×256生成任務(wù)中評估相似參數(shù)量(~675M)的模型,結(jié)果均未使用分類器無關(guān)引導(dǎo)。關(guān)鍵發(fā)現(xiàn):
- 端到端調(diào)優(yōu)加速訓(xùn)練:相比REPA,gFID從19.40→12.83(20輪)、11.10→7.17(40輪)、7.90→4.07(80輪)持續(xù)提升;
- 端到端訓(xùn)練提升最終性能:REPA-E在80輪時gFID=4.07,優(yōu)于FasterDiT(400輪,gFID=7.91)、MaskDiT、DiT和SiT(均訓(xùn)練1400輪以上)。REPA-E僅需40萬步即超越REPA 400萬步的結(jié)果(gFID=5.9。
定性評估
下圖4對比REPA 與REPA-E在5萬、10萬、40萬步的生成效果。REPA-E在訓(xùn)練早期即生成結(jié)構(gòu)更合理的圖像,且整體質(zhì)量更優(yōu)。
REPA-E的泛化性與可擴(kuò)展性
進(jìn)一步分析REPA-E在不同訓(xùn)練設(shè)置下的適應(yīng)性(模型規(guī)模、分詞器架構(gòu)、表示編碼器、對齊深度等)。默認(rèn)使用SiT-L 為生成模型,SD-VAE為VAE,DINOv2-B為REPA損失的預(yù)訓(xùn)練視覺模型,對齊深度為8。各變體訓(xùn)練10萬步,結(jié)果均未使用分類器無關(guān)引導(dǎo)。
模型規(guī)模的影響
下表3比較SiT-B、SiT-L和SiT-XL:
- REPA-E在所有配置下均優(yōu)于REPA基線,gFID從49.5→34.8(SiT-B)、24.1→16.3(SiT-L)、19.4→12.8(SiT-XL);
- 增益隨模型規(guī)模增大:SiT-B提升29.6%,SiT-L提升32.3%,SiT-XL提升34.0%,表明REPA-E對大模型更具擴(kuò)展性。
表示編碼器的選擇
表4顯示不同感知編碼器(CLIP-L、I-JEPA-H、DINOv2-B/DINOv2-L)下REPA-E均一致提升性能。例如DINOv2-B下gFID從24.1→16.3,DINOv2-L下從23.3→16.0。
VAE的變體。 下表5評估了不同VAE對REPA-E性能的影響。我們報告了使用三種不同VAE的結(jié)果:1)SD-VAE,2)VA-VAE,以及3)IN-VAE(一個在ImageNet上訓(xùn)練的16倍下采樣、32通道的VAE,使用[39]中的官方訓(xùn)練代碼)。在所有變體中,REPA-E始終在性能上優(yōu)于REPA基線。REPA-E將gFID從24.1降至16.3(SD-VAE),從22.7降至12.7(IN-VAE),從12.8降至11.1(VA-VAE)。結(jié)果表明,REPA-E在VAE的架構(gòu)、預(yù)訓(xùn)練數(shù)據(jù)集和訓(xùn)練設(shè)置多樣性的情況下,始終能穩(wěn)健提升生成質(zhì)量。
對齊深度的變體。 下表6研究了在擴(kuò)散模型不同層應(yīng)用對齊損失的效果。觀察到REPA-E在不同對齊深度選擇下均能持續(xù)提升生成質(zhì)量,相較REPA基線,gFID分別從23.0降至16.4(第6層)、24.1降至16.3(第8層)、23.7降至16.2(第10層)。
組件設(shè)計的消融實驗。 本文還進(jìn)行了消融研究,分析各組件的重要性,結(jié)果如表7所示。觀察到每個組件對REPA-E的最終性能都起到了關(guān)鍵作用。特別地,觀察到對擴(kuò)散損失應(yīng)用stop-grad操作有助于防止隱空間結(jié)構(gòu)的退化。同樣地,批歸一化(batch norm)通過自適應(yīng)地規(guī)范化潛變量統(tǒng)計信息,提升了gFID從18.09至16.3。同樣地,正則化損失對保持微調(diào)后VAE的重建性能起到了關(guān)鍵作用,從而將gFID從19.07提升至16.3。
端到端從頭訓(xùn)練。 分析了VAE初始化對端到端訓(xùn)練的影響。如表8所示,我們發(fā)現(xiàn)雖然使用預(yù)訓(xùn)練權(quán)重初始化VAE能略微提升性能,但REPA-E也可以在從頭訓(xùn)練VAE和LDM的情況下使用,并仍然在性能上優(yōu)于REPA,后者在技術(shù)上需要一個VAE訓(xùn)練階段以及LDM訓(xùn)練階段。例如,REPA在400萬次迭代后達(dá)到FID 5.90,而REPA-E在完全從頭訓(xùn)練的情況下(同時訓(xùn)練VAE和LDM)在僅40萬次迭代內(nèi)就達(dá)到了更快更優(yōu)的生成FID 4.34。
端到端微調(diào)對VAE的影響
接下來分析了端到端微調(diào)對VAE的影響。首先展示端到端微調(diào)能改善隱空間結(jié)構(gòu)(下圖6)。然后展示一旦使用REPA-E進(jìn)行微調(diào),微調(diào)后的VAE可以作為原始VAE的直接替代品,顯著提升生成性能。
端到端訓(xùn)練提升隱空間結(jié)構(gòu)。 結(jié)果如圖6所示。本文使用主成分分析(PCA)將隱空間結(jié)構(gòu)可視化為RGB著色的三個通道。我們考慮三種不同的VAE:1)SD-VAE,2)IN-VAE(一個在ImageNet上訓(xùn)練的16倍下采樣、32通道的VAE),3)最新VA-VAE。觀察到使用REPA-E進(jìn)行端到端微調(diào)自動改善了原始VAE的隱空間結(jié)構(gòu)。例如,與同時期工作的發(fā)現(xiàn)一致,我們觀察到SD-VAE的隱空間存在高噪聲成分。應(yīng)用端到端訓(xùn)練后可自動幫助調(diào)整隱空間以減少噪聲。相比之下,其他VAE如最新提出的VA-VAE的隱空間則表現(xiàn)為過度平滑。使用REPA-E進(jìn)行端到端微調(diào)可自動學(xué)習(xí)更具細(xì)節(jié)的隱空間結(jié)構(gòu),以更好地支持生成性能。
端到端訓(xùn)練提升VAE性能。 接下來評估端到端微調(diào)對VAE下游生成性能的影響。首先對最近提出的VA-VAE進(jìn)行端到端微調(diào)。然后我們使用該微調(diào)后的VAE(命名為E2E-VAE),將其下游生成性能與當(dāng)前最先進(jìn)的VAE進(jìn)行比較,包括SD-VAE和VA-VAE。本文進(jìn)行了傳統(tǒng)的潛擴(kuò)散模型訓(xùn)練(不使用REPA-E),即僅更新生成器網(wǎng)絡(luò),同時保持VAE凍結(jié)。表9展示了在不同訓(xùn)練設(shè)置下的VAE下游生成性能比較。端到端微調(diào)后的VAE在各種LDM架構(gòu)和訓(xùn)練設(shè)置下的下游生成任務(wù)中始終優(yōu)于其原始版本。有趣的是,觀察到使用SiT-XL進(jìn)行微調(diào)的VAE即使在使用不同LDM架構(gòu)(如DiT-XL)時仍能帶來性能提升,進(jìn)一步展示了本文方法的穩(wěn)健性。
結(jié)論
本文探討了一個基本問題:“我們是否能夠?qū)崿F(xiàn)基于隱空間擴(kuò)散 Transformer 的端到端訓(xùn)練,從而釋放 VAE 的潛力?”具體來說,觀察到,直接將擴(kuò)散損失反向傳播到 VAE 是無效的,甚至?xí)档妥罱K的生成性能。盡管擴(kuò)散損失無效,但可以使用最近提出的表示對齊損失進(jìn)行端到端訓(xùn)練。所提出的端到端訓(xùn)練方案(REPA-E)顯著改善了隱空間結(jié)構(gòu),并展現(xiàn)出卓越的性能:相較于 REPA 和傳統(tǒng)訓(xùn)練方案,擴(kuò)散模型訓(xùn)練速度分別提升了超過 17× 和 45×。
REPA-E 不僅在不同訓(xùn)練設(shè)置下表現(xiàn)出一致的改進(jìn),還改善了多種 VAE 架構(gòu)下原始的隱空間結(jié)構(gòu)??傮w而言,本文的方法達(dá)到了新的SOTA水平,在使用和不使用 classifier-free guidance 的情況下,分別取得了 1.26 和 1.83 的生成 FID 分?jǐn)?shù)。希望本工作能夠推動進(jìn)一步的研究,推動隱空間擴(kuò)散 Transformer 的端到端訓(xùn)練發(fā)展。
本文轉(zhuǎn)自AI生成未來 ,作者:AI生成未來
