OpenAI攻克擴(kuò)散模型短板,清華校友路橙、宋飏合作最新論文
擴(kuò)散模型很成功,但也有一塊重大短板:采樣速度非常慢,生成一個(gè)樣本往往需要執(zhí)行成百上千步采樣。為此,研究社區(qū)已經(jīng)提出了多種擴(kuò)展蒸餾(diffusion distillation)技術(shù),包括直接蒸餾、對(duì)抗蒸餾、漸進(jìn)式蒸餾和變分分?jǐn)?shù)蒸餾(VSD)。但是,這些方法也有自己的問(wèn)題,包括成本高、復(fù)雜性高、多樣性有限等。
一致性模型(CM)在解決這些問(wèn)題方面具有巨大的優(yōu)勢(shì)。這又進(jìn)一步分為離散時(shí)間 CM 和連續(xù)時(shí)間 CM。其中離散時(shí)間 CM 會(huì)引入離散化誤差,并且需要仔細(xì)調(diào)度時(shí)間步長(zhǎng)網(wǎng)格,這可能會(huì)導(dǎo)致樣本質(zhì)量不佳。而連續(xù)時(shí)間 CM 雖可避免這些問(wèn)題,但也會(huì)有訓(xùn)練不穩(wěn)定的問(wèn)題。
近日,OpenAI 的研究科學(xué)家路橙(Cheng Lu)與戰(zhàn)略探索團(tuán)隊(duì)負(fù)責(zé)人宋飏(Yang Song)發(fā)布了一篇研究論文,提出了一些可簡(jiǎn)化、穩(wěn)定化和擴(kuò)展連續(xù)時(shí)間一致性模型的技術(shù)。值得一提的是,這兩位作者都是清華校友,師從朱軍教授,在擴(kuò)散概率模型領(lǐng)域做出過(guò)代表性工作。
- 論文標(biāo)題:Simplifying, Stabilizing & Scaling Continuous-Time Consistency Models
- 論文地址:https://arxiv.org/pdf/2410.11081v1
他們的貢獻(xiàn)包括:
- TrigFlow,一個(gè)將 EDM(arXiv:2206.00364)與流匹配(Flow Matching)統(tǒng)一起來(lái)的公式,其能極大簡(jiǎn)化擴(kuò)展模型、相關(guān)的概率流 ODE 和一致性模型(CM。
- 在此基礎(chǔ)上,他們分析了一致性模型訓(xùn)練不穩(wěn)定的根本原因,并提出了一種完整的緩解方案。他們的方法包括改進(jìn)網(wǎng)絡(luò)架構(gòu)中的時(shí)間調(diào)節(jié)和自適應(yīng)分組歸一化。
- 此外,他們還重新構(gòu)建了連續(xù)時(shí)間 CM 的訓(xùn)練目標(biāo),其中整合了關(guān)鍵項(xiàng)的自適應(yīng)加權(quán)和歸一化以及漸進(jìn)退火,以實(shí)現(xiàn)穩(wěn)定且可擴(kuò)展的訓(xùn)練。
簡(jiǎn)化連續(xù)時(shí)間一致性模型
作為前提,這里先給出離散時(shí)間和連續(xù)時(shí)間一致性模型的公式:
離散時(shí)間 CM:
連續(xù)時(shí)間 CM:
此前的一致性模型采用了 EDM 中的模型參數(shù)化和擴(kuò)散過(guò)程。具體來(lái)說(shuō),一致性模型會(huì)被參數(shù)化以下形式:
其中,F(xiàn) 是一個(gè)神經(jīng)網(wǎng)絡(luò),θ 是其參數(shù);c_skip、c_out、c_in 都是固定的系數(shù),用以確保在所有時(shí)間步驟上初始化時(shí)擴(kuò)散目標(biāo)的方差相等;c_noise 是對(duì) t 的一個(gè)變換運(yùn)算,以便更好地實(shí)現(xiàn)時(shí)間調(diào)節(jié)。
由于在 EDM 擴(kuò)散過(guò)程中,方差會(huì)爆炸式增長(zhǎng),也就意味著 x_t = x_0 + tz_t,基于此可以推導(dǎo)出下面三式:
雖然這些系數(shù)對(duì)于訓(xùn)練效率很重要,但由于它們與 t 和 σ_d 之間存在復(fù)雜的算術(shù)關(guān)系,因此會(huì)使得對(duì)一致性模型的理論分析變得復(fù)雜。
為了簡(jiǎn)化 EDM 及隨之的一致性模型,他們提出了 TrigFlow。這種擴(kuò)散模型形式保留了 EDM 性質(zhì),但滿足 c_skip (t) = cos (t)、c_out (t) = sin (t)、c_in (t) ≡ 1/σ_d。
TrigFlow 是流匹配(也稱為隨機(jī)插值或整流)和 v 預(yù)測(cè)參數(shù)化的一種特例。它與之前一些研究團(tuán)隊(duì)提出的三角插值非常相似,但經(jīng)過(guò)修改從而納入了對(duì)數(shù)據(jù)分布 p_d 的標(biāo)準(zhǔn)差 σ_d 的考量。
由于 TrigFlow 是流匹配的一個(gè)特例,同時(shí)滿足 EDM 原理,因此其集兩者之長(zhǎng),同時(shí)還讓擴(kuò)散過(guò)程、擴(kuò)散模型參數(shù)化、PF-ODE、擴(kuò)散訓(xùn)練目標(biāo)和一致性模型參數(shù)化全都變得更簡(jiǎn)單了。
讓連續(xù)時(shí)間一致性模型變得穩(wěn)定
連續(xù)時(shí)間 CM 的訓(xùn)練一直都高度不穩(wěn)定。因此,它們的表現(xiàn)一直不及之前研究中的離散時(shí)間 CM。
為了解決這個(gè)問(wèn)題,該團(tuán)隊(duì)在 TrigFlow 框架的基礎(chǔ)上,引入了幾項(xiàng)基于理論研究的改進(jìn)措施,其中重點(diǎn)關(guān)注的是參數(shù)化、網(wǎng)絡(luò)架構(gòu)和訓(xùn)練目標(biāo)。
參數(shù)化和網(wǎng)絡(luò)架構(gòu)
連續(xù)時(shí)間 CM 的訓(xùn)練的關(guān)鍵是 (2) 式,其取決于正切函數(shù),而在 TrigFlow 框架下,該正切函數(shù)由以下公式給出:
其中 表示 PF-ODE,其要么在一致性蒸餾中使用預(yù)訓(xùn)練的擴(kuò)散模型估計(jì)得出,要么就在一致性訓(xùn)練中使用從噪聲和干凈樣本計(jì)算得到的無(wú)偏估計(jì)器估計(jì)得出。
為了讓訓(xùn)練過(guò)程穩(wěn)定下來(lái),必須確保 (6) 式中的正切函數(shù)在不同時(shí)間步驟中保持穩(wěn)定。該團(tuán)隊(duì)在實(shí)踐中發(fā)現(xiàn) σ_dF_θ、PF-ODE 和噪聲樣本 x_t 能保持相對(duì)穩(wěn)定。進(jìn)一步分析后,他們發(fā)現(xiàn)通常經(jīng)過(guò)良好調(diào)節(jié)。因此不穩(wěn)定的來(lái)源是時(shí)間導(dǎo)數(shù)
,而其可被分解為:
其中,emb (?) 是指時(shí)間嵌入,通常以位置嵌入或傅里葉嵌入的形式出現(xiàn)在擴(kuò)散模型和 CM 的相關(guān)文獻(xiàn)中。
該團(tuán)隊(duì)穩(wěn)定 (7) 中每個(gè)元素的方法包括恒等時(shí)間變換(c_noise (t) = t)、位置時(shí)間嵌入和自適應(yīng)雙重歸一化。詳見(jiàn)原論文。
圖 4 可視化地展示了在 CIFAR-10 上訓(xùn)練 CM 時(shí)穩(wěn)定時(shí)間導(dǎo)數(shù)的情況。研究表明,這些改進(jìn)可在不損害擴(kuò)散模型訓(xùn)練的前提下穩(wěn)定 CM 的訓(xùn)練動(dòng)態(tài)。
訓(xùn)練目標(biāo)
使用 TrigFlow 和前述的優(yōu)化技術(shù),(2) 式中連續(xù)時(shí)間 CM 訓(xùn)練的梯度就會(huì)變?yōu)椋?/span>
之后,該團(tuán)隊(duì)又使用了另外一些技術(shù)來(lái)顯式地控制該梯度,以提升穩(wěn)定性,其中包括正切歸一化、自適應(yīng)加權(quán)、擴(kuò)散微調(diào)和正切預(yù)熱。詳見(jiàn)原論文。
有了這些技術(shù),離散時(shí)間和連續(xù)時(shí)間 CM 訓(xùn)練的穩(wěn)定性都能得到顯著改善。
該團(tuán)隊(duì)在相同的設(shè)置下訓(xùn)練了連續(xù)時(shí)間 CM 和離散時(shí)間 CM。如圖 5 (c) 所示,增加離散時(shí)間 CM 中的離散化步驟數(shù) N 可提高樣本質(zhì)量,原因是這樣做可減少離散化誤差來(lái);但一旦 N 變得太大(N > 1024 之后),樣本質(zhì)量就會(huì)降低,這是因?yàn)闀?huì)出現(xiàn)數(shù)值精度問(wèn)題。
相較之下,在所有 N 值上,連續(xù)時(shí)間 CM 的表現(xiàn)都顯著優(yōu)于離散時(shí)間 CM。這能為我們提供選擇連續(xù)時(shí)間 CM 的強(qiáng)有力依據(jù)。
該團(tuán)隊(duì)將他們的模型稱為 sCM,其中 s 代表 simple、stable、scalable,即簡(jiǎn)單、穩(wěn)定和可擴(kuò)展。下面是 sCM 訓(xùn)練的詳細(xì)偽代碼。
擴(kuò)展連續(xù)時(shí)間一致性模型
在這部分,研究者通過(guò)在各種具有挑戰(zhàn)性的數(shù)據(jù)集上訓(xùn)練大規(guī)模 sCM 來(lái)測(cè)試上述內(nèi)容中提出的所有改進(jìn)措施。
大規(guī)模模型中的正切計(jì)算
訓(xùn)練大規(guī)模擴(kuò)散模型的常見(jiàn)設(shè)置包括使用半精度(FP16)和 Flash Attention。由于訓(xùn)練連續(xù)時(shí)間 CM 需要精確計(jì)算正切 ,我們需要提高數(shù)值精度,同時(shí)支持高效記憶注意力計(jì)算,詳情如下。
計(jì)算時(shí),需要計(jì)算
,這可通過(guò)
與輸入向量
和正切向量
的雅可比向量積(JVP)高效獲得。然而根據(jù)經(jīng)驗(yàn),當(dāng) t 接近 0 或 π/2 時(shí),正切可能會(huì)在中間層溢出。為了提高數(shù)值精度,研究者認(rèn)為應(yīng)該重新安排正切的計(jì)算。
具體來(lái)說(shuō),由于公式 (8) 中的目標(biāo)函數(shù)包含,而
與
成正比,因此可以這樣計(jì)算 JVP:
即的 JVP,輸入為
,正切為
這種重新排列大大緩解了中間層的溢出問(wèn)題,使 FP16 的訓(xùn)練更加穩(wěn)定。
Flash Attention 被廣泛用于大規(guī)模模型訓(xùn)練中的注意力計(jì)算,既能節(jié)省 GPU 內(nèi)存,又能加快訓(xùn)練速度。然而,F(xiàn)lash Attention 并不計(jì)算雅可比向量積(JVP)。為了填補(bǔ)這一空白,研究者提出了一種類似的算法,它能以 Flash Attention 的風(fēng)格在一次前向傳遞中高效計(jì)算 softmax 自注意力及其 JVP,從而顯著減少注意力層中 JVP 計(jì)算所需的 GPU 內(nèi)存用量。
實(shí)驗(yàn)
sCM 的訓(xùn)練計(jì)算。研究者在所有數(shù)據(jù)集上使用與教師擴(kuò)散模型相同的批大小。sCD 每次訓(xùn)練迭代的有效計(jì)算量大約是教師模型的兩倍。他們觀察到,sCD 的兩步采樣質(zhì)量收斂很快,只用了不到教師模型 20% 的訓(xùn)練計(jì)算量,就獲得了與教師擴(kuò)散模型相當(dāng)?shù)慕Y(jié)果。在實(shí)踐中,只需使用 sCD 進(jìn)行 20k 次微調(diào)迭代,就能獲得高質(zhì)量的樣本。
基準(zhǔn)。在表 1 和表 2 中,研究者通過(guò) FID 和函數(shù)評(píng)估次數(shù)(NFE)的基準(zhǔn),將本文結(jié)果與之前的方法進(jìn)行了比較。首先,sCM 優(yōu)于之前所有不依賴與其他網(wǎng)絡(luò)聯(lián)合訓(xùn)練的幾步式方法,與對(duì)抗訓(xùn)練取得的最佳結(jié)果相當(dāng),甚至超越。值得注意的是,sCD-XXL 在 ImageNet 512×512 上的一步 FID 超過(guò)了 StyleGAN-XL 和 VAR。此外,sCD-XXL 的兩步 FID 性能優(yōu)于除擴(kuò)散模型外的所有生成模型,可與需要 63 個(gè)連續(xù)步驟的最佳擴(kuò)散模型相媲美。其次,兩步式 sCM 模型將與教師擴(kuò)散模型的 FID 差距顯著縮小到 10% 以內(nèi)。此外,sCT 在較小的擴(kuò)展上更有效,但在較大擴(kuò)展上的方差會(huì)增大,而 sCD 在小型擴(kuò)展和大型擴(kuò)展上都表現(xiàn)出一致的性能。
Scaling 研究。如圖 6 所示,首先,隨著模型 FLOPs 的增加,sCT 和 sCD 的樣本質(zhì)量都有所提高,這表明這兩種方法都能從 Scaling 中獲益。其次,與 sCD 相比,sCT 在較小分辨率下的計(jì)算效率更高,但在較大分辨率下的效率較低。第三,對(duì)于給定的數(shù)據(jù)集,sCD 的 Scaling 是可預(yù)測(cè)的,在不同大小的模型中,F(xiàn)ID 的相對(duì)差異保持一致。這表明,sCD 的 FID 下降速度與教師擴(kuò)散模型相同,因此,sCD 與教師擴(kuò)散模型一樣具有可擴(kuò)展性。隨著教師擴(kuò)散模型的 FID 隨規(guī)模的擴(kuò)大而減小,sCD 與教師模型之間 FID 的絕對(duì)差異也隨之減小。最后,F(xiàn)ID 的相對(duì)差異隨著采樣步驟的增加而減小,兩步式 sCD 的采樣質(zhì)量與教師擴(kuò)散模型相當(dāng)。
與 VSD 的對(duì)比。如圖 7 所示,研究者對(duì)比了 sCD、VSD、sCD 和 VSD 的組合等(通過(guò)簡(jiǎn)單地將兩種損失相加)并觀察到,VSD 具有與擴(kuò)散模型中應(yīng)用大 guidance scale 類似的人工效應(yīng):它提高了保真度(表現(xiàn)為更高的精確度分?jǐn)?shù)),同時(shí)降低了多樣性(表現(xiàn)為更低的召回分?jǐn)?shù))。這種效應(yīng)隨著 guidance scale 的增加而變得更加明顯,最終導(dǎo)致嚴(yán)重的模式崩潰。相比之下,兩步式 sCD 的精確度和召回分?jǐn)?shù)與教師擴(kuò)散模型相當(dāng),因此 FID 分?jǐn)?shù)比 VSD 更高。
更多研究細(xì)節(jié),可參考原論文。