Diffusion 和Stable Diffusion的數(shù)學(xué)和工作原理詳細(xì)解釋
擴(kuò)散模型的興起可以被視為人工智能生成藝術(shù)領(lǐng)域最近取得突破的主要因素。而穩(wěn)定擴(kuò)散模型的發(fā)展使得我們可以通過(guò)一個(gè)文本提示輕松地創(chuàng)建美妙的藝術(shù)插圖。所以在本文中,我將解釋它們是如何工作的。
擴(kuò)散模型 Diffusion
擴(kuò)散模型的訓(xùn)練可以分為兩部分:
- 正向擴(kuò)散→在圖像中添加噪聲。
- 反向擴(kuò)散過(guò)程→去除圖像中的噪聲。
正向擴(kuò)散過(guò)程
正向擴(kuò)散過(guò)程逐步對(duì)輸入圖像 x? 加入高斯噪聲,一共有 T 步。該過(guò)程將產(chǎn)生一系列噪聲圖像樣本 x?, …, x_T。
當(dāng) T → ∞ 時(shí),最終的結(jié)果將變成一張完包含噪聲的圖像,就像從各向同性高斯分布中采樣一樣。
但是我們可以使用一個(gè)封閉形式的公式在特定的時(shí)間步長(zhǎng) t 直接對(duì)有噪聲的圖像進(jìn)行采樣,而不是設(shè)計(jì)一種算法來(lái)迭代地向圖像添加噪聲。
封閉公式
封閉形式的抽樣公式可以通過(guò)重新參數(shù)化技巧得到。
通過(guò)這個(gè)技巧,我們可以將采樣圖像x?表示為:
然后我們可以遞歸展開(kāi)它,最終得到閉式公式:
這里的ε 是 i.i.d. (獨(dú)立同分布)標(biāo)準(zhǔn)正態(tài)隨機(jī)變量。使用不同的符號(hào)和下標(biāo)區(qū)分它們很重要,因?yàn)樗鼈兪仟?dú)立的并且它們的值在采樣后可能不同。
但是,上面公式是如何從第4行跳到第5行呢?
有些人覺(jué)得這一步很難理解。下面我詳細(xì)介紹如何工作的:
讓我們用 X 和 Y 來(lái)表示這兩項(xiàng)。它們可以被視為來(lái)自?xún)蓚€(gè)不同正態(tài)分布的樣本。即 X ~ N(0, α?(1-α???)I) 和 Y ~ N(0, (1-α?)I)。
兩個(gè)正態(tài)分布(獨(dú)立)隨機(jī)變量的總和也是正態(tài)分布的。即如果 Z = X + Y,則 Z ~ N(0, σ2?+σ2?)。因此我們可以將它們合并在一起并以重新以參數(shù)化的形式表示合并后的正態(tài)分布。
重復(fù)這些步驟將為得到只與輸入圖像 x? 相關(guān)的公式:
現(xiàn)在我們可以使用這個(gè)公式在任何時(shí)間步驟直接對(duì)x?進(jìn)行采樣,這使得向前的過(guò)程更快。
反向擴(kuò)散過(guò)程
與正向過(guò)程不同,不能使用q(x???|x?)來(lái)反轉(zhuǎn)噪聲,因?yàn)樗请y以處理的(無(wú)法計(jì)算)。所以我們需要訓(xùn)練神經(jīng)網(wǎng)絡(luò)pθ(x???|x?)來(lái)近似q(x???|x?)。近似pθ(x???|x?)服從正態(tài)分布,其均值和方差設(shè)置如下:
損失函數(shù)
損失定義為負(fù)對(duì)數(shù)似然:
這個(gè)設(shè)置與VAE中的設(shè)置非常相似。我們可以?xún)?yōu)化變分的下界,而不是優(yōu)化損失函數(shù)本身。
通過(guò)優(yōu)化一個(gè)可計(jì)算的下界,我們可以間接優(yōu)化不可處理的損失函數(shù)。
通過(guò)展開(kāi),我們發(fā)現(xiàn)它可以用以下三項(xiàng)表示:
1、L_T:常數(shù)項(xiàng)
由于 q 沒(méi)有可學(xué)習(xí)的參數(shù),p 只是一個(gè)高斯噪聲概率,因此這一項(xiàng)在訓(xùn)練期間將是一個(gè)常數(shù),因此可以忽略。
2、L???:逐步去噪項(xiàng)
這一項(xiàng)是比較目標(biāo)去噪步驟 q 和近似去噪步驟 pθ。通過(guò)以 x? 為條件,q(x???|x?, x?) 變得易于處理。
經(jīng)過(guò)一系列推導(dǎo),上圖為q(x???|x?,x?)的平均值μ′?。為了近似目標(biāo)去噪步驟q,我們只需要使用神經(jīng)網(wǎng)絡(luò)近似其均值。所以我們將近似均值 μθ 設(shè)置為與目標(biāo)均值 μ?? 相同的形式(使用可學(xué)習(xí)的神經(jīng)網(wǎng)絡(luò) εθ):
目標(biāo)均值和近似值之間的比較可以使用均方誤差(MSE)進(jìn)行:
經(jīng)過(guò)實(shí)驗(yàn),通過(guò)忽略加權(quán)項(xiàng)并簡(jiǎn)單地將目標(biāo)噪聲和預(yù)測(cè)噪聲與 MSE 進(jìn)行比較,可以獲得更好的結(jié)果。所以為了逼近所需的去噪步驟 q,我們只需要使用神經(jīng)網(wǎng)絡(luò) εθ 來(lái)逼近噪聲 ε?。
3、L?:重構(gòu)項(xiàng)
這是最后一步去噪的重建損失,在訓(xùn)練過(guò)程中可以忽略,因?yàn)椋?/p>
- 可以使用 L??? 中的相同神經(jīng)網(wǎng)絡(luò)對(duì)其進(jìn)行近似。
- 忽略它會(huì)使樣本質(zhì)量更好,并更易于實(shí)施。
所以最終簡(jiǎn)化的訓(xùn)練目標(biāo)如下:
我們發(fā)現(xiàn)在真實(shí)變分界上訓(xùn)練我們的模型比在簡(jiǎn)化目標(biāo)上訓(xùn)練產(chǎn)生更好的碼長(zhǎng),正如預(yù)期的那樣,但后者產(chǎn)生了最好的樣本質(zhì)量。[2]
通過(guò)測(cè)試在變分邊界上訓(xùn)練模型比在簡(jiǎn)化目標(biāo)上訓(xùn)練會(huì)減少代碼的長(zhǎng)度,但后者產(chǎn)生最好的樣本質(zhì)量。[2]
U-Net模型
在每一個(gè)訓(xùn)練輪次
- 每個(gè)訓(xùn)練樣本(圖像)隨機(jī)選擇一個(gè)時(shí)間步長(zhǎng)t。
- 對(duì)每個(gè)圖像應(yīng)用高斯噪聲(對(duì)應(yīng)于t)。
- 將時(shí)間步長(zhǎng)轉(zhuǎn)換為嵌入(向量)。
訓(xùn)練過(guò)程的偽代碼
官方的訓(xùn)練算法如上所示,下圖是訓(xùn)練步驟如何工作的說(shuō)明:
反向擴(kuò)散
我們可以使用上述算法從噪聲中生成圖像。下面的圖表說(shuō)明了這一點(diǎn):
在最后一步中,只是輸出學(xué)習(xí)的平均值μθ(x?,1),而沒(méi)有添加噪聲。反向擴(kuò)散就是我們說(shuō)的采樣過(guò)程,也就是從高斯噪聲中繪制圖像的過(guò)程。
擴(kuò)散模型的速度問(wèn)題
擴(kuò)散(采樣)過(guò)程會(huì)迭代地向U-Net提供完整尺寸的圖像獲得最終結(jié)果。這使得純擴(kuò)散模型在總擴(kuò)散步數(shù)T和圖像大小較大時(shí)極其緩慢。
穩(wěn)定擴(kuò)散就是為了解決這一問(wèn)題而設(shè)計(jì)的。
穩(wěn)定擴(kuò)散 Stable Diffusion
穩(wěn)定擴(kuò)散模型的原名是潛擴(kuò)散模型(Latent Diffusion Model, LDM)。正如它的名字所指出的那樣,擴(kuò)散過(guò)程發(fā)生在潛在空間中。這就是為什么它比純擴(kuò)散模型更快。
潛在空間
首先訓(xùn)練一個(gè)自編碼器,學(xué)習(xí)將圖像數(shù)據(jù)壓縮為低維表示。
通過(guò)使用訓(xùn)練過(guò)的編碼器E,可以將全尺寸圖像編碼為低維潛在數(shù)據(jù)(壓縮數(shù)據(jù))。然后通過(guò)使用經(jīng)過(guò)訓(xùn)練的解碼器D,將潛在數(shù)據(jù)解碼回圖像。
潛在空間的擴(kuò)散
將圖像編碼后,在潛在空間中進(jìn)行正向擴(kuò)散和反向擴(kuò)散過(guò)程。
- 正向擴(kuò)散過(guò)程→向潛在數(shù)據(jù)中添加噪聲
- 反向擴(kuò)散過(guò)程→從潛在數(shù)據(jù)中去除噪聲
條件作用/調(diào)節(jié)
穩(wěn)定擴(kuò)散模型的真正強(qiáng)大之處在于它可以從文本提示生成圖像。這是通過(guò)修改內(nèi)部擴(kuò)散模型來(lái)接受條件輸入來(lái)完成的。
通過(guò)使用交叉注意機(jī)制增強(qiáng)其去噪 U-Net,將內(nèi)部擴(kuò)散模型轉(zhuǎn)變?yōu)闂l件圖像生成器。
上圖中的開(kāi)關(guān)用于在不同類(lèi)型的調(diào)節(jié)輸入之間進(jìn)行控制:
- 對(duì)于文本輸入,首先使用語(yǔ)言模型 ??θ(例如 BERT、CLIP)將它們轉(zhuǎn)換為嵌入(向量),然后通過(guò)(多頭)Attention(Q, K, V) 映射到 U-Net 層。
- 對(duì)于其他空間對(duì)齊的輸入(例如語(yǔ)義映射、圖像、修復(fù)),可以使用連接來(lái)完成調(diào)節(jié)。
訓(xùn)練
訓(xùn)練目標(biāo)(損失函數(shù))與純擴(kuò)散模型中的訓(xùn)練目標(biāo)非常相似。唯一的變化是:
- 輸入潛在數(shù)據(jù)z?而不是圖像x?。
- U-Net增加條件輸入??θ(y)。
采樣
由于潛在數(shù)據(jù)的大小比原始圖像小得多,所以去噪過(guò)程會(huì)快得多。
架構(gòu)的比較
比較純擴(kuò)散模型和穩(wěn)定擴(kuò)散模型(潛在擴(kuò)散模型)的整體架構(gòu)。
Diffusion Model
Stable Diffusion (Latent Diffusion Model)
快速總結(jié)一下:
- 擴(kuò)散模型分為正向擴(kuò)散和反向擴(kuò)散兩部分。
- 正擴(kuò)散可以用封閉形式的公式計(jì)算。
- 反向擴(kuò)散可以用訓(xùn)練好的神經(jīng)網(wǎng)絡(luò)來(lái)完成。
- 為了近似所需的去噪步驟q,我們只需要使用神經(jīng)網(wǎng)絡(luò)εθ近似噪聲ε?。
- 在簡(jiǎn)化損失函數(shù)上進(jìn)行訓(xùn)練可以獲得更好的樣本質(zhì)量。
- 穩(wěn)定擴(kuò)散(潛擴(kuò)散模型)是在潛空間中進(jìn)行擴(kuò)散過(guò)程,因此比純擴(kuò)散模型快得多。
- 純擴(kuò)散模型被修改為接受條件輸入,如文本、圖像、語(yǔ)義等。