使用PyTorch實(shí)現(xiàn)去噪擴(kuò)散模型
在深入研究去噪擴(kuò)散概率模型(DDPM)如何工作的細(xì)節(jié)之前,讓我們先看看生成式人工智能的一些發(fā)展,也就是DDPM的一些基礎(chǔ)研究。
VAE
VAE 采用了編碼器、概率潛在空間和解碼器。在訓(xùn)練過程中,編碼器預(yù)測(cè)每個(gè)圖像的均值和方差。然后從高斯分布中對(duì)這些值進(jìn)行采樣,并將其傳遞到解碼器中,其中輸入的圖像預(yù)計(jì)與輸出的圖像相似。這個(gè)過程包括使用KL Divergence來計(jì)算損失。VAEs的一個(gè)顯著優(yōu)勢(shì)在于它們能夠生成各種各樣的圖像。在采樣階段簡(jiǎn)單地從高斯分布中采樣,解碼器創(chuàng)建一個(gè)新的圖像。
GAN
在變分自編碼器(VAEs)的短短一年之后,一個(gè)開創(chuàng)性的生成家族模型出現(xiàn)了——生成對(duì)抗網(wǎng)絡(luò)(GANs),標(biāo)志著一類新的生成模型的開始,其特征是兩個(gè)神經(jīng)網(wǎng)絡(luò)的協(xié)作:一個(gè)生成器和一個(gè)鑒別器,涉及對(duì)抗性訓(xùn)練過程。生成器的目標(biāo)是從隨機(jī)噪聲中生成真實(shí)的數(shù)據(jù),例如圖像,而鑒別器則努力區(qū)分真實(shí)數(shù)據(jù)和生成數(shù)據(jù)。在整個(gè)訓(xùn)練階段,生成器和鑒別器通過競(jìng)爭(zhēng)性學(xué)習(xí)過程不斷完善自己的能力。生成器生成越來越有說服力的數(shù)據(jù),從而比鑒別器更聰明,而鑒別器又提高了辨別真實(shí)樣本和生成樣本的能力。這種對(duì)抗性的相互作用在生成器生成高質(zhì)量、逼真的數(shù)據(jù)時(shí)達(dá)到頂峰。在采樣階段,經(jīng)過GAN訓(xùn)練后,生成器通過輸入隨機(jī)噪聲產(chǎn)生新的樣本。它將這些噪聲轉(zhuǎn)換為通常反映真實(shí)示例的數(shù)據(jù)。
為什么我們需要另一個(gè)模型架構(gòu)
兩種模型都有不同的問題,雖然GANs擅長(zhǎng)于生成與訓(xùn)練集中的圖像非常相似的逼真圖像,但VAEs擅長(zhǎng)于創(chuàng)建各種各樣的圖像,盡管有產(chǎn)生模糊圖像的傾向。但是現(xiàn)有的模型還沒有成功地將這兩種功能結(jié)合起來——?jiǎng)?chuàng)造出既高度逼真又多樣化的圖像。這一挑戰(zhàn)給研究人員帶來了一個(gè)需要解決的重大障礙。
在第一篇GAN論文發(fā)表六年后,在VAE論文發(fā)表七年后,一個(gè)開創(chuàng)性的模型出現(xiàn)了:去噪擴(kuò)散概率模型(DDPM)。DDPM結(jié)合了兩個(gè)世界的優(yōu)勢(shì),擅長(zhǎng)于創(chuàng)造多樣化和逼真的圖像。
在本文中,我們將深入研究DDPM的復(fù)雜性,涵蓋其訓(xùn)練過程,包括正向和逆向過程,并探索如何執(zhí)行采樣。在整個(gè)探索過程中,我們將使用PyTorch從頭開始構(gòu)建DDPM,并完成其完整的訓(xùn)練。
這里假設(shè)你已經(jīng)熟悉深度學(xué)習(xí)的基礎(chǔ)知識(shí),并且在深度計(jì)算機(jī)視覺方面有堅(jiān)實(shí)的基礎(chǔ)。我們不會(huì)介紹這些基本概念,我們的目標(biāo)是生成人類確信其真實(shí)性的圖像。
DDPM
去噪擴(kuò)散概率模型(DDPM)是生成模型領(lǐng)域的一種前沿方法。與依賴顯式似然函數(shù)的傳統(tǒng)模型不同,DDPM通過對(duì)擴(kuò)散過程進(jìn)行迭代去噪來運(yùn)行。這包括逐漸向圖像中添加噪聲并試圖去除該噪聲。其基本理論是基于這樣一種思想:通過一系列擴(kuò)散步驟轉(zhuǎn)換一個(gè)簡(jiǎn)單的分布,例如高斯分布,可以得到一個(gè)復(fù)雜而富有表現(xiàn)力的圖像數(shù)據(jù)分布?;蛘哒f通過將樣本從原始圖像分布轉(zhuǎn)移到高斯分布,我們可以創(chuàng)建一個(gè)模型來逆轉(zhuǎn)這個(gè)過程。這使我們能夠從全高斯分布開始,以圖像分布結(jié)束,有效地生成新圖像。
DDPM的訓(xùn)練包括兩個(gè)基本步驟:產(chǎn)生噪聲圖像這是固定和不可學(xué)習(xí)的正向過程,以及隨后的逆向過程。逆向過程的主要目標(biāo)是使用專門的機(jī)器學(xué)習(xí)模型對(duì)圖像進(jìn)行去噪。
正向擴(kuò)散過程
正向過程是一個(gè)固定且不可學(xué)習(xí)的步驟,但是它需要一些預(yù)定義的設(shè)置。在深入研究這些設(shè)置之前,讓我們先了解一下它是如何工作的。
這個(gè)過程的核心概念是從一個(gè)清晰的圖像開始。在用“T”表示的特定步長(zhǎng)上,少量噪聲按照高斯分布逐漸引入。
從圖像中可以看出,噪聲是在每一步遞增的,我們深入研究這種噪音的數(shù)學(xué)表示。
噪聲是從高斯分布中采樣的。為了在每一步引入少量的噪聲,我們使用馬爾可夫鏈。要生成當(dāng)前時(shí)間戳的圖像,我們只需要上次時(shí)間戳的圖像。馬爾可夫鏈的概念在這里是關(guān)鍵的,并對(duì)隨后的數(shù)學(xué)細(xì)節(jié)至關(guān)重要。
馬爾可夫鏈?zhǔn)且粋€(gè)隨機(jī)過程,其中過渡到任何特定狀態(tài)的概率僅取決于當(dāng)前狀態(tài)和經(jīng)過的時(shí)間,而不是之前的事件序列。這一特性簡(jiǎn)化了噪聲添加過程的建模,使其更易于數(shù)學(xué)分析。
用beta表示的方差參數(shù)被有意地設(shè)置為一個(gè)非常小的值,目的是在每個(gè)步驟中只引入最少量的噪聲。
步長(zhǎng)參數(shù)“T”決定了生成全噪聲圖像所需的步長(zhǎng)。在本文中,該參數(shù)被設(shè)置為1000,這可能顯得很大。我們真的需要為數(shù)據(jù)集中的每個(gè)原始圖像創(chuàng)建1000個(gè)噪聲圖像嗎?馬爾可夫鏈方面被證明有助于解決這個(gè)問題。由于我們只需要上一步的圖像來預(yù)測(cè)下一步,并且每一步添加的噪聲保持不變,因此我們可以通過生成特定時(shí)間戳的噪聲圖像來簡(jiǎn)化計(jì)算。采用對(duì)的再參數(shù)化技巧使我們能夠進(jìn)一步簡(jiǎn)化方程。
將式(3)中引入的新參數(shù)納入式(2)中,對(duì)式(2)進(jìn)行了發(fā)展,得到了結(jié)果。
逆向擴(kuò)散過程
我們已經(jīng)為圖像引入了噪聲下一步就是執(zhí)行逆操作了。除非我們知道初始條件,即t = 0時(shí)的未去噪圖像,否則無法從數(shù)學(xué)上實(shí)現(xiàn)對(duì)圖像進(jìn)行逆向處理去噪。我們的目標(biāo)是直接從噪聲中采樣以創(chuàng)建新圖像,這里缺乏關(guān)于結(jié)果的信息。所以我需要設(shè)計(jì)一種在不知道結(jié)果的情況下逐步去噪圖像的方法。所以就出現(xiàn)了使用深度學(xué)習(xí)模型來近似這個(gè)復(fù)雜的數(shù)學(xué)函數(shù)的解決方案。
有了一點(diǎn)數(shù)學(xué)背景,模型將近似于方程(5)。一個(gè)值得注意的細(xì)節(jié)是,我們將堅(jiān)持DDPM原始論文并保持固定的方差,盡管也有可能使模型學(xué)習(xí)它。
該模型的任務(wù)是預(yù)測(cè)當(dāng)前時(shí)間戳和前一個(gè)時(shí)間戳之間添加的噪聲的平均值。這樣做可以有效地去除噪音,達(dá)到預(yù)期的效果。但是如果我們的目標(biāo)是讓模型預(yù)測(cè)從“原始圖像”到最后一個(gè)時(shí)間戳添加的噪聲呢?
除非我們知道沒有噪聲的初始圖像,否則在數(shù)學(xué)上執(zhí)行逆向過程是具有挑戰(zhàn)性的,讓我們從定義后方差開始。
模型的任務(wù)是預(yù)測(cè)從初始圖像添加到時(shí)間戳t的圖像的噪聲。正向過程使我們能夠執(zhí)行這個(gè)操作,從一個(gè)清晰的圖像開始,并在時(shí)間戳t處進(jìn)展到一個(gè)有噪聲的圖像。
訓(xùn)練算法
我們假設(shè)用于進(jìn)行預(yù)測(cè)的模型體系結(jié)構(gòu)將是一個(gè)U-Net。訓(xùn)練階段的目標(biāo)是:對(duì)于數(shù)據(jù)集中的每個(gè)圖像,在[0,T]范圍內(nèi)隨機(jī)選擇一個(gè)時(shí)間戳,并計(jì)算正向擴(kuò)散過程。這產(chǎn)生了一個(gè)清晰的,有點(diǎn)噪聲的圖像,以及實(shí)際使用的噪聲。然后利用我們對(duì)逆向過程的理解,使用該模型來預(yù)測(cè)添加到圖像中的噪聲。有了真實(shí)的和預(yù)測(cè)的噪聲,我們似乎已經(jīng)進(jìn)入了一個(gè)有監(jiān)督的機(jī)器學(xué)習(xí)問題。
最主要的問題來了,我們應(yīng)該用哪個(gè)損失函數(shù)來訓(xùn)練我們的模型呢?由于處理的是概率潛在空間,Kullback-Leibler (KL)散度是一個(gè)合適的選擇。
KL散度衡量?jī)蓚€(gè)概率分布之間的差異,在我們的例子中,是模型預(yù)測(cè)的分布和期望分布。在損失函數(shù)中加入KL散度不僅可以指導(dǎo)模型產(chǎn)生準(zhǔn)確的預(yù)測(cè),還可以確保潛在空間表示符合期望的概率結(jié)構(gòu)。
KL散度可以近似為L(zhǎng)2損失函數(shù),所以可以得到以下?lián)p失函數(shù):
最終我們得到了論文中提出的訓(xùn)練算法。
采樣
逆向流程已經(jīng)解釋完成了,下面就是如何使用了。從時(shí)刻T的一個(gè)完全隨機(jī)的圖像開始,并使用逆向過程T次,最終到達(dá)時(shí)刻0。這構(gòu)成了本文中概述的第二種算法
參數(shù)
我們有很多不同的參數(shù)beta,beta_tildes,alpha, alpha_hat 等等。目前都不知道如何選擇這些參數(shù)。但是此時(shí)已知的唯一參數(shù)是T,它被設(shè)置為1000。
對(duì)于所有列出的參數(shù),它們的選擇取決于beta。從某種意義上說,Beta決定了我們要在每一步中添加的噪聲量。因此,為了確保算法的成功,仔細(xì)選擇beta是至關(guān)重要的。其他的參數(shù)因?yàn)樘?,?qǐng)參考論文。
在原始論文的實(shí)驗(yàn)階段探索了各種抽樣方法。最初的線性采樣方法圖像要么接收到的噪聲不足,要么變得過于嘈雜。為了解決這個(gè)問題,采用了另一種更常用的方法,即余弦采樣。余弦采樣提供了更平滑和更一致的噪聲添加。
模型Pytorch實(shí)現(xiàn)
我們將利用U-Net架構(gòu)進(jìn)行噪聲預(yù)測(cè),之所以選擇U-Net,是因?yàn)閁-Net是圖像處理、捕獲空間和特征地圖以及提供與輸入相同的輸出大小的理想架構(gòu)。
考慮到任務(wù)的復(fù)雜性和對(duì)每一步使用相同模型的要求(其中模型需要能夠以相同的權(quán)重去噪完全有噪聲的圖像和稍微有噪聲的圖像),調(diào)整模型是必不可少的。這包括合并更復(fù)雜的塊,并通過正弦嵌入步驟引入對(duì)所用時(shí)間戳的感知。這些增強(qiáng)的目的是使模型成為去噪任務(wù)的專家。在繼續(xù)構(gòu)建完整的模型之前,我們將介紹每個(gè)塊。
ConvNext塊
為了滿足提高模型復(fù)雜度的需要,卷積塊起著至關(guān)重要的作用。這里不能僅僅依賴于u-net論文中的基本塊,我們將結(jié)合ConvNext。
輸入由代表圖像的“x”和大小為“time_embedding_dim”的嵌入的時(shí)間戳可視化“t”組成。由于塊的復(fù)雜性以及與輸入和最后一層的殘差連接,在整個(gè)過程中,塊在學(xué)習(xí)空間和特征映射方面起著關(guān)鍵作用。
class ConvNextBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
mult=2,
time_embedding_dim=None,
norm=True,
group=8,
):
super().__init__()
self.mlp = (
nn.Sequential(nn.GELU(), nn.Linear(time_embedding_dim, in_channels))
if time_embedding_dim
else None
)
self.in_conv = nn.Conv2d(
in_channels, in_channels, 7, padding=3, groups=in_channels
)
self.block = nn.Sequential(
nn.GroupNorm(1, in_channels) if norm else nn.Identity(),
nn.Conv2d(in_channels, out_channels * mult, 3, padding=1),
nn.GELU(),
nn.GroupNorm(1, out_channels * mult),
nn.Conv2d(out_channels * mult, out_channels, 3, padding=1),
)
self.residual_conv = (
nn.Conv2d(in_channels, out_channels, 1)
if in_channels != out_channels
else nn.Identity()
)
def forward(self, x, time_embedding=None):
h = self.in_conv(x)
if self.mlp is not None and time_embedding is not None:
assert self.mlp is not None, "MLP is None"
h = h + rearrange(self.mlp(time_embedding), "b c -> b c 1 1")
h = self.block(h)
return h + self.residual_conv(x)
正弦時(shí)間戳嵌入
模型中的關(guān)鍵塊之一是正弦時(shí)間戳嵌入塊,它使給定時(shí)間戳的編碼能夠保留關(guān)于模型解碼所需的當(dāng)前時(shí)間的信息,因?yàn)樵撃P蛯⒂糜谒胁煌臅r(shí)間戳。
這是一個(gè)非常經(jīng)典的是實(shí)現(xiàn),并且應(yīng)用在各個(gè)地方,我們就直接貼代碼了
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim, theta=10000):
super().__init__()
self.dim = dim
self.theta = theta
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(self.theta) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
DownSample & UpSample
class DownSample(nn.Module):
def __init__(self, dim, dim_out=None):
super().__init__()
self.net = nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
nn.Conv2d(dim * 4, default(dim_out, dim), 1),
)
def forward(self, x):
return self.net(x)
class Upsample(nn.Module):
def __init__(self, dim, dim_out=None):
super().__init__()
self.net = nn.Sequential(
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(dim, dim_out or dim, kernel_size=3, padding=1),
)
def forward(self, x):
return self.net(x)
時(shí)間多層感知器
這個(gè)模塊利用它來基于給定的時(shí)間戳t創(chuàng)建時(shí)間表示。這個(gè)多層感知器(MLP)的輸出也將作為所有修改后的ConvNext塊的輸入“t”。
這里,“dim”是模型的超參數(shù),表示第一個(gè)塊所需的通道數(shù)。它作為后續(xù)塊中通道數(shù)量的基本計(jì)算。
sinu_pos_emb = SinusoidalPosEmb(dim, theta=10000)
time_dim = dim * 4
time_mlp = nn.Sequential(
sinu_pos_emb,
nn.Linear(dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim),
)
注意力
這是unet中使用的可選組件。注意力有助于增強(qiáng)剩余連接在學(xué)習(xí)中的作用。它通過殘差連接計(jì)算的注意機(jī)制和中低潛空間計(jì)算的特征映射,更多地關(guān)注從Unet左側(cè)獲得的重要空間信息。它來源于ACC-UNet論文。
gate 表示下塊的上采樣輸出,而x殘差表示在應(yīng)用注意的水平上的殘差連接。
class BlockAttention(nn.Module):
def __init__(self, gate_in_channel, residual_in_channel, scale_factor):
super().__init__()
self.gate_conv = nn.Conv2d(gate_in_channel, gate_in_channel, kernel_size=1, stride=1)
self.residual_conv = nn.Conv2d(residual_in_channel, gate_in_channel, kernel_size=1, stride=1)
self.in_conv = nn.Conv2d(gate_in_channel, 1, kernel_size=1, stride=1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
in_attention = self.relu(self.gate_conv(g) + self.residual_conv(x))
in_attention = self.in_conv(in_attention)
in_attention = self.sigmoid(in_attention)
return in_attention * x
最后整合
將前面討論的所有塊(不包括注意力塊)整合到一個(gè)Unet中。每個(gè)塊都包含兩個(gè)殘差連接,而不是一個(gè)。這個(gè)修改是為了解決潛在的過度擬合問題。
class TwoResUNet(nn.Module):
def __init__(
self,
dim,
init_dim=None,
out_dim=None,
dim_mults=(1, 2, 4, 8),
channels=3,
sinusoidal_pos_emb_theta=10000,
convnext_block_groups=8,
):
super().__init__()
self.channels = channels
input_channels = channels
self.init_dim = default(init_dim, dim)
self.init_conv = nn.Conv2d(input_channels, self.init_dim, 7, padding=3)
dims = [self.init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
sinu_pos_emb = SinusoidalPosEmb(dim, theta=sinusoidal_pos_emb_theta)
time_dim = dim * 4
self.time_mlp = nn.Sequential(
sinu_pos_emb,
nn.Linear(dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim),
)
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(
nn.ModuleList(
[
ConvNextBlock(
in_channels=dim_in,
out_channels=dim_in,
time_embedding_dim=time_dim,
group=convnext_block_groups,
),
ConvNextBlock(
in_channels=dim_in,
out_channels=dim_in,
time_embedding_dim=time_dim,
group=convnext_block_groups,
),
DownSample(dim_in, dim_out)
if not is_last
else nn.Conv2d(dim_in, dim_out, 3, padding=1),
]
)
)
mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, time_embedding_dim=time_dim)
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, time_embedding_dim=time_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
is_last = ind == (len(in_out) - 1)
is_first = ind == 0
self.ups.append(
nn.ModuleList(
[
ConvNextBlock(
in_channels=dim_out + dim_in,
out_channels=dim_out,
time_embedding_dim=time_dim,
group=convnext_block_groups,
),
ConvNextBlock(
in_channels=dim_out + dim_in,
out_channels=dim_out,
time_embedding_dim=time_dim,
group=convnext_block_groups,
),
Upsample(dim_out, dim_in)
if not is_last
else nn.Conv2d(dim_out, dim_in, 3, padding=1)
]
)
)
default_out_dim = channels
self.out_dim = default(out_dim, default_out_dim)
self.final_res_block = ConvNextBlock(dim * 2, dim, time_embedding_dim=time_dim)
self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
def forward(self, x, time):
b, _, h, w = x.shape
x = self.init_conv(x)
r = x.clone()
t = self.time_mlp(time)
unet_stack = []
for down1, down2, downsample in self.downs:
x = down1(x, t)
unet_stack.append(x)
x = down2(x, t)
unet_stack.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_block2(x, t)
for up1, up2, upsample in self.ups:
x = torch.cat((x, unet_stack.pop()), dim=1)
x = up1(x, t)
x = torch.cat((x, unet_stack.pop()), dim=1)
x = up2(x, t)
x = upsample(x)
x = torch.cat((x, r), dim=1)
x = self.final_res_block(x, t)
return self.final_conv(x)
擴(kuò)散的代碼實(shí)現(xiàn)
最后我們介紹一下擴(kuò)散是如何實(shí)現(xiàn)的。由于我們已經(jīng)介紹了用于正向、逆向和采樣過程的所有數(shù)學(xué)理論,所里這里將重點(diǎn)介紹代碼。
class DiffusionModel(nn.Module):
SCHEDULER_MAPPING = {
"linear": linear_beta_schedule,
"cosine": cosine_beta_schedule,
"sigmoid": sigmoid_beta_schedule,
}
def __init__(
self,
model: nn.Module,
image_size: int,
*,
beta_scheduler: str = "linear",
timesteps: int = 1000,
schedule_fn_kwargs: dict | None = None,
auto_normalize: bool = True,
) -> None:
super().__init__()
self.model = model
self.channels = self.model.channels
self.image_size = image_size
self.beta_scheduler_fn = self.SCHEDULER_MAPPING.get(beta_scheduler)
if self.beta_scheduler_fn is None:
raise ValueError(f"unknown beta schedule {beta_scheduler}")
if schedule_fn_kwargs is None:
schedule_fn_kwargs = {}
betas = self.beta_scheduler_fn(timesteps, **schedule_fn_kwargs)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
posterior_variance = (
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
)
register_buffer = lambda name, val: self.register_buffer(
name, val.to(torch.float32)
)
register_buffer("betas", betas)
register_buffer("alphas_cumprod", alphas_cumprod)
register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
register_buffer("sqrt_recip_alphas", torch.sqrt(1.0 / alphas))
register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
register_buffer(
"sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)
)
register_buffer("posterior_variance", posterior_variance)
timesteps, *_ = betas.shape
self.num_timesteps = int(timesteps)
self.sampling_timesteps = timesteps
self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity
self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity
@torch.inference_mode()
def p_sample(self, x: torch.Tensor, timestamp: int) -> torch.Tensor:
b, *_, device = *x.shape, x.device
batched_timestamps = torch.full(
(b,), timestamp, device=device, dtype=torch.long
)
preds = self.model(x, batched_timestamps)
betas_t = extract(self.betas, batched_timestamps, x.shape)
sqrt_recip_alphas_t = extract(
self.sqrt_recip_alphas, batched_timestamps, x.shape
)
sqrt_one_minus_alphas_cumprod_t = extract(
self.sqrt_one_minus_alphas_cumprod, batched_timestamps, x.shape
)
predicted_mean = sqrt_recip_alphas_t * (
x - betas_t * preds / sqrt_one_minus_alphas_cumprod_t
)
if timestamp == 0:
return predicted_mean
else:
posterior_variance = extract(
self.posterior_variance, batched_timestamps, x.shape
)
noise = torch.randn_like(x)
return predicted_mean + torch.sqrt(posterior_variance) * noise
@torch.inference_mode()
def p_sample_loop(
self, shape: tuple, return_all_timesteps: bool = False
) -> torch.Tensor:
batch, device = shape[0], "mps"
img = torch.randn(shape, device=device)
# This cause me a RunTimeError on MPS device due to MPS back out of memory
# No ideas how to resolve it at this point
# imgs = [img]
for t in tqdm(reversed(range(0, self.num_timesteps)), total=self.num_timesteps):
img = self.p_sample(img, t)
# imgs.append(img)
ret = img # if not return_all_timesteps else torch.stack(imgs, dim=1)
ret = self.unnormalize(ret)
return ret
def sample(
self, batch_size: int = 16, return_all_timesteps: bool = False
) -> torch.Tensor:
shape = (batch_size, self.channels, self.image_size, self.image_size)
return self.p_sample_loop(shape, return_all_timesteps=return_all_timesteps)
def q_sample(
self, x_start: torch.Tensor, t: int, noise: torch.Tensor = None
) -> torch.Tensor:
if noise is None:
noise = torch.randn_like(x_start)
sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
)
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
def p_loss(
self,
x_start: torch.Tensor,
t: int,
noise: torch.Tensor = None,
loss_type: str = "l2",
) -> torch.Tensor:
if noise is None:
noise = torch.randn_like(x_start)
x_noised = self.q_sample(x_start, t, noise=noise)
predicted_noise = self.model(x_noised, t)
if loss_type == "l2":
loss = F.mse_loss(noise, predicted_noise)
elif loss_type == "l1":
loss = F.l1_loss(noise, predicted_noise)
else:
raise ValueError(f"unknown loss type {loss_type}")
return loss
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, c, h, w, device, img_size = *x.shape, x.device, self.image_size
assert h == w == img_size, f"image size must be {img_size}"
timestamp = torch.randint(0, self.num_timesteps, (1,)).long().to(device)
x = self.normalize(x)
return self.p_loss(x, timestamp)
擴(kuò)散過程是訓(xùn)練部分的模型。它打開了一個(gè)采樣接口,允許我們使用已經(jīng)訓(xùn)練好的模型生成樣本。
訓(xùn)練的要點(diǎn)總結(jié)
對(duì)于訓(xùn)練部分,我們?cè)O(shè)置了37,000步的訓(xùn)練,每步16個(gè)批次。由于GPU內(nèi)存分配限制,圖像大小被限制為128x128。使用指數(shù)移動(dòng)平均(EMA)模型權(quán)重每1000步生成樣本以平滑采樣,并保存模型版本。
在最初的1000步訓(xùn)練中,模型開始捕捉一些特征,但仍然錯(cuò)過了某些區(qū)域。在10000步左右,這個(gè)模型開始產(chǎn)生有希望的結(jié)果,進(jìn)步變得更加明顯。在3萬步的最后,結(jié)果的質(zhì)量顯著提高,但仍然存在黑色圖像。這只是因?yàn)槟P蜎]有足夠的樣本種類,真實(shí)圖像的數(shù)據(jù)分布并沒有完全映射到高斯分布。
有了最終的模型權(quán)重,我們可以生成一些圖片。盡管由于128x128的尺寸限制,圖像質(zhì)量受到限制,但該模型的表現(xiàn)還是不錯(cuò)的。
注:本文使用的數(shù)據(jù)集是森林地形的衛(wèi)星圖片,具體獲取方式請(qǐng)參考源代碼中的ETL部分。
總結(jié)
我們已經(jīng)完整的介紹了有關(guān)擴(kuò)散模型的必要知識(shí),并且使用Pytorch進(jìn)行了完整的實(shí)現(xiàn),本文的代碼:
https://github.com/Camaltra/this-is-not-real-aerial-imagery/