譯者 | 朱先忠
審校 | 重樓
簡(jiǎn)介
在我最近發(fā)表的幾篇文章中,我談到了生成式深度學(xué)習(xí)算法,這些算法大多與文本生成任務(wù)有關(guān)。所以,我認(rèn)為現(xiàn)在轉(zhuǎn)向圖像生成的生成算法研究會(huì)很有趣。我們知道,如今已經(jīng)有很多專(zhuān)門(mén)用于生成圖像的深度學(xué)習(xí)模型,例如自動(dòng)編碼器、變分自動(dòng)編碼器(VAE)、生成對(duì)抗網(wǎng)絡(luò)(GAN)和神經(jīng)風(fēng)格遷移(NST)。
在本文中,我想討論所謂的擴(kuò)散模型 ——這是深度學(xué)習(xí)領(lǐng)域中對(duì)圖像生成影響最大的模型之一。該算法的想法最早是在2015年由Sohl-Dickstein等人撰寫(xiě)的題為《使用非平衡熱力學(xué)的深度無(wú)監(jiān)督學(xué)習(xí)》的論文中提出的【引文1】。他們的框架隨后由Ho等人在2020年的論文《去噪擴(kuò)散概率模型》中進(jìn)一步開(kāi)發(fā)【引文2】。DDPM后來(lái)被OpenAI和Google改編以開(kāi)發(fā)DALLE-2和Imagen,我們知道這些模型具有生成高質(zhì)量圖像的令人印象深刻的能力。
擴(kuò)散模型的工作原理
一般來(lái)說(shuō),擴(kuò)散模型的工作原理是從噪聲中生成圖像。我們可以把它想象成一個(gè)藝術(shù)家將畫(huà)布上的一抹顏料變成一幅美麗的藝術(shù)品。為了做到這一點(diǎn),首先需要訓(xùn)練擴(kuò)散模型。訓(xùn)練模型需要遵循兩個(gè)主要步驟,即前向擴(kuò)散和后向擴(kuò)散。
圖1.正向和反向擴(kuò)散過(guò)程【引文3】
如上圖所示,前向擴(kuò)散是一個(gè)將高斯噪聲迭代地應(yīng)用于原始圖像的過(guò)程。我們不斷添加噪聲,直到圖像完全無(wú)法識(shí)別,此時(shí)我們可以說(shuō)圖像現(xiàn)在位于潛在空間中。與自動(dòng)編碼器和GAN中的潛在空間通常比原始圖像的維度低不同,DDPM中的潛在空間保持與原始圖像完全相同的維度。這個(gè)噪聲過(guò)程遵循馬爾可夫鏈的原理,這意味著時(shí)間步t處的圖像僅受時(shí)間步t-1的影響。前向擴(kuò)散被認(rèn)為很容易,因?yàn)槲覀兓旧现皇且徊揭徊降靥砑右恍┰肼暋?/p>
第二個(gè)訓(xùn)練階段稱(chēng)為反向擴(kuò)散,我們的目標(biāo)是一點(diǎn)一點(diǎn)地去除噪聲,直到獲得清晰的圖像。這個(gè)過(guò)程遵循逆馬爾可夫鏈的原理,其中時(shí)間步長(zhǎng)t-1的圖像只能基于時(shí)間步長(zhǎng)t的圖像獲得。這樣的去噪過(guò)程非常困難,因?yàn)槲覀冃枰聹y(cè)哪些像素是噪聲,哪些像素屬于實(shí)際圖像內(nèi)容。因此,我們需要采用神經(jīng)網(wǎng)絡(luò)模型來(lái)實(shí)現(xiàn)這一點(diǎn)。
DDPM使用U-Net作為反向擴(kuò)散深度學(xué)習(xí)架構(gòu)的基礎(chǔ)。但是,我們不需要使用原始的U-Net模型【引文4】,而是需要對(duì)其進(jìn)行一些修改,以使其更適合我們的任務(wù)。稍后,我將在MNIST手寫(xiě)數(shù)字?jǐn)?shù)據(jù)集【引文5】上訓(xùn)練此模型,看看它是否可以生成類(lèi)似的圖像。
好了,以上就是目前你需要了解的有關(guān)擴(kuò)散模型的所有基本概念。在接下來(lái)的內(nèi)容中,我們將從頭開(kāi)始實(shí)現(xiàn)擴(kuò)散模型算法,從而使讀者更深入地了解有關(guān)細(xì)節(jié)。
PyTorch實(shí)現(xiàn)
我們將首先導(dǎo)入所需的模塊。如果你還不熟悉下面的導(dǎo)入,那么告訴你torch和torchvision都是我們用于準(zhǔn)備模型和數(shù)據(jù)集的庫(kù)。同時(shí),matplotlib和tqdm將幫助我們顯示圖像和進(jìn)度條。
# Codeblock 1
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
由于模塊已導(dǎo)入,接下來(lái)要做的是初始化一些配置參數(shù)。詳細(xì)信息請(qǐng)參閱下面的Codeblock 2。
# Codeblock 2
IMAGE_SIZE = 28 #(1)
NUM_CHANNELS = 1 #(2)
BATCH_SIZE = 2
NUM_EPOCHS = 10
LEARNING_RATE = 0.001
NUM_TIMESTEPS = 1000 #(3)
BETA_START = 0.0001 #(4)
BETA_END = 0.02 #(5)
TIME_EMBED_DIM = 32 #(6)
DEVICE = torch.device("cuda" if torch.cuda.is_available else "cpu") #(7)
DEVICE
# Codeblock 2 Output
device(type='cuda')
在代碼行#(1)和#(2)中,我設(shè)置IMAGE_SIZE和NUM_CHANNELS兩個(gè)常量分別為28和1,這些數(shù)字是從MNIST數(shù)據(jù)集中的圖像維度獲得的。變量BATCH_SIZE,NUM_EPOCHS和LEARNING_RATE非常簡(jiǎn)單,所以我不需要進(jìn)一步解釋它們。
在代碼行#(3),變量NUM_TIMESTEPS表示前向和后向擴(kuò)散過(guò)程中的迭代次數(shù)。時(shí)間步長(zhǎng)0是圖像處于原始狀態(tài)的情況(圖1中最左邊的圖像)。在這種情況下,由于我們將此參數(shù)設(shè)置為1000,時(shí)間步長(zhǎng)數(shù)999將是圖像完全無(wú)法識(shí)別的情況(上面圖1中最右邊的圖像)。重要的是要記住,時(shí)間步長(zhǎng)的選擇涉及模型準(zhǔn)確性和計(jì)算成本之間的權(quán)衡。如果我們?yōu)榉峙湟粋€(gè)較小的值NUM_TIMESTEPS,推理時(shí)間會(huì)更短,但由于模型在后向擴(kuò)散階段細(xì)化圖像的步驟較少,因此生成的圖像可能不是很好。另一方面,增加NUM_TIMESTEPS會(huì)減慢推理過(guò)程,但我們可以預(yù)期輸出圖像具有更好的質(zhì)量,這要?dú)w功于逐步的去噪過(guò)程,從而可以實(shí)現(xiàn)更精確的重建。
接下來(lái),BETA_START(代碼行#(4))和BETA_END(代碼行#(5))變量分別用于控制每個(gè)時(shí)間步添加的高斯噪聲量,而TIME_EMBED_DIM(代碼行#(6))用于確定用于存儲(chǔ)時(shí)間步信息的特征向量長(zhǎng)度。最后,在第#(7)行,如果Pytorch檢測(cè)到我們的機(jī)器上安裝了GPU,我將分配“cuda”給變量。我強(qiáng)烈建議你在GPU上運(yùn)行此項(xiàng)目,因?yàn)橛?xùn)練擴(kuò)散模型的計(jì)算成本很高。除了上述參數(shù)之外,為NUM_TIMESTEPS、BETA_START和BETA_END設(shè)定的值均直接采用自DDPM論文【引文2】。
完整的實(shí)現(xiàn)過(guò)程可分為幾個(gè)步驟:構(gòu)建U-Net模型、準(zhǔn)備數(shù)據(jù)集、定義擴(kuò)散過(guò)程的噪聲調(diào)度程序、訓(xùn)練和推理。我們將在以下小節(jié)中討論每個(gè)階段。
U-Net架構(gòu):時(shí)間嵌入
正如我之前提到的,擴(kuò)散模型的基礎(chǔ)是U-Net。之所以使用這種架構(gòu),是因?yàn)樗妮敵鰧舆m合表示圖像,這絕對(duì)是有道理的,因?yàn)樗畛跏菫閳D像分割任務(wù)引入的。下圖顯示了原始U-Net架構(gòu)的樣子。
圖2.【引文4】中提出的原始U-Net模型
然而,有必要修改一下此架構(gòu),以便它也能考慮時(shí)間步長(zhǎng)信息。不僅如此,由于我們只使用MNIST數(shù)據(jù)集,我們還需要使模型更小。只需記住深度學(xué)習(xí)中的慣例,即更簡(jiǎn)單的模型通常對(duì)簡(jiǎn)單任務(wù)更有效。
下圖中我展示了經(jīng)過(guò)修改的整個(gè)U-Net模型。在這里你可以看到時(shí)間嵌入張量在每個(gè)階段都被注入到模型中,稍后將通過(guò)逐元素求和來(lái)完成,從而使模型能夠捕獲時(shí)間步長(zhǎng)信息。接下來(lái),我們不會(huì)像原始U-Net那樣將每個(gè)下采樣和上采樣階段重復(fù)四次,而是在本例中只重復(fù)兩次。此外,值得注意的是,下采樣階段的堆棧也稱(chēng)為編碼器,而上采樣階段的堆棧通常稱(chēng)為解碼器。
圖3.針對(duì)我們的擴(kuò)散任務(wù)修改后的U-Net模型【引文3】
現(xiàn)在,讓我們開(kāi)始構(gòu)建架構(gòu),創(chuàng)建一個(gè)用于生成時(shí)間嵌入張量的類(lèi),其思想類(lèi)似于Transformer中的位置嵌入。有關(guān)詳細(xì)信息,請(qǐng)參閱下面的Codeblock 3。
# Codeblock 3
class TimeEmbedding(nn.Module):
def forward(self):
time = torch.arange(NUM_TIMESTEPS, device=DEVICE).reshape(NUM_TIMESTEPS, 1) #(1)
print(f"time\t\t: {time.shape}")
i = torch.arange(0, TIME_EMBED_DIM, 2, device=DEVICE)
denominator = torch.pow(10000, i/TIME_EMBED_DIM)
print(f"denominator\t: {denominator.shape}")
even_time_embed = torch.sin(time/denominator) #(1)
odd_time_embed = torch.cos(time/denominator) #(2)
print(f"even_time_embed\t: {even_time_embed.shape}")
print(f"odd_time_embed\t: {odd_time_embed.shape}")
stacked = torch.stack([even_time_embed, odd_time_embed], dim=2) #(3)
print(f"stacked\t\t: {stacked.shape}")
time_embed = torch.flatten(stacked, start_dim=1, end_dim=2) #(4)
print(f"time_embed\t: {time_embed.shape}")
return time_embed
上述代碼中,我們基本上要?jiǎng)?chuàng)建一個(gè)大小為NUM_TIMESTEPS×TIME_EMBED_DIM(1000×32)的張量,其中該張量的每一行都包含時(shí)間步長(zhǎng)信息。稍后,1000個(gè)時(shí)間步長(zhǎng)中的每一個(gè)都將由長(zhǎng)度為32的特征向量表示。張量本身的值是根據(jù)圖4中的兩個(gè)方程獲得的。在上面的Codeblock 3中,這兩個(gè)方程分別在第#(1)和#(2)行實(shí)現(xiàn),每個(gè)方程都形成一個(gè)大小為1000×16的張量。接下來(lái),使用第#(3)行和#(4)行的代碼將這些張量組合起來(lái)。
在這里,我還打印出了上述代碼塊中完成的每個(gè)步驟,以便你更好地理解TimeEmbedding類(lèi)中實(shí)際執(zhí)行的操作。如果你仍想了解有關(guān)上述代碼的更多解釋?zhuān)?qǐng)隨時(shí)閱讀我之前關(guān)于Transformer的博客文章,你可以通過(guò)本文末尾的鏈接訪問(wèn)該文章。單擊鏈接后,你可以一直向下滾動(dòng)到“Positional Encoding”部分。
圖4. Transformer論文【引文6】中的正弦位置編碼公式
現(xiàn)在,讓我們使用以下測(cè)試代碼檢查是否TimeEmbedding類(lèi)正常工作。結(jié)果輸出顯示它成功生成了一個(gè)大小為1000×32的張量,這正是我們之前所期望的。
# Codeblock 4
time_embed_test = TimeEmbedding()
out_test = time_embed_test()
# Codeblock 4 Output
time : torch.Size([1000, 1])
denominator : torch.Size([16])
even_time_embed : torch.Size([1000, 16])
odd_time_embed : torch.Size([1000, 16])
stacked : torch.Size([1000, 16, 2])
time_embed : torch.Size([1000, 32])
U-Net架構(gòu):DoubleConv
如果仔細(xì)觀察修改后的架構(gòu),你會(huì)發(fā)現(xiàn)我們實(shí)際上得到了許多重復(fù)的模型,例如下圖中黃色框中突出顯示的模型。
圖5.黃色框內(nèi)完成的流程將在DoubleConv課堂上實(shí)現(xiàn)【引文3】
這五個(gè)黃色框具有相同的結(jié)構(gòu),它們由兩個(gè)卷積層組成,在執(zhí)行第一個(gè)卷積操作后立即注入時(shí)間嵌入張量。因此,我們現(xiàn)在要做的是創(chuàng)建另一個(gè)名為DoubleConv的類(lèi)來(lái)重現(xiàn)此結(jié)構(gòu)。查看下面的代碼塊5a和5b以了解我如何做到這一點(diǎn)。
# Codeblock 5a
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels): #(1)
super().__init__()
self.conv_0 = nn.Conv2d(in_channels=in_channels, #(2)
out_channels=out_channels,
kernel_size=3,
bias=False,
padding=1)
self.bn_0 = nn.BatchNorm2d(num_features=out_channels) #(3)
self.time_embedding = TimeEmbedding() #(4)
self.linear = nn.Linear(in_features=TIME_EMBED_DIM, #(5)
out_features=out_channels)
self.conv_1 = nn.Conv2d(in_channels=out_channels, #(6)
out_channels=out_channels,
kernel_size=3,
bias=False,
padding=1)
self.bn_1 = nn.BatchNorm2d(num_features=out_channels) #(7)
self.relu = nn.ReLU(inplace=True) #(8)
上述方法__init__()中的兩個(gè)輸入?yún)?shù)使我們可以靈活地配置輸入和輸出通道的數(shù)量(#(1));這樣一來(lái),DoubleConv類(lèi)可以通過(guò)調(diào)整輸入?yún)?shù)來(lái)實(shí)例化所有五個(gè)黃色框。顧名思義,這里我們初始化了兩個(gè)卷積層(行#(2)和#(6)),每個(gè)卷積層后跟一個(gè)批量歸一化層和一個(gè)ReLU激活函數(shù)。請(qǐng)記住,兩個(gè)歸一化層需要單獨(dú)初始化(行#(3)和#(7)),因?yàn)樗鼈兠總€(gè)都有自己可訓(xùn)練的歸一化參數(shù)。同時(shí),ReLU激活函數(shù)應(yīng)該只初始化一次(#(8)),因?yàn)樗话魏螀?shù),因此可以在網(wǎng)絡(luò)的不同部分多次使用。在代碼行#(4),我們初始化之前創(chuàng)建的TimeEmbedding層,該層稍后將連接到標(biāo)準(zhǔn)線性層(#(5))。該線性層負(fù)責(zé)調(diào)整時(shí)間嵌入張量的維度,以便可以以逐元素的方式將得到的輸出與第一個(gè)卷積層的輸出相加。
現(xiàn)在,讓我們看一下下面的Codeblock 5b,以更好地理解該DoubleConv塊的運(yùn)行流程。在這部分代碼中,你可以看到forward()方法接受兩個(gè)輸入:原始圖像x和時(shí)間步長(zhǎng)信息t,如第#(1)行所示。我們首先使用第一個(gè)Conv-BN-ReLU序列處理圖像(#(2–4))。這種Conv-BN-ReLU結(jié)構(gòu)通常用于基于CNN的模型,即使插圖沒(méi)有明確顯示批量標(biāo)準(zhǔn)化和ReLU層。除了圖像之外,我們還從相應(yīng)圖像的嵌入張量(#(5))中獲取第t個(gè)時(shí)間步長(zhǎng)信息并將其傳遞給線性層(#(6))。我們?nèi)匀恍枰褂玫?(7)行的代碼擴(kuò)展結(jié)果張量的維度,然后在第#(8)行執(zhí)行逐元素求和。最后,我們使用第二個(gè)Conv-BN-ReLU序列(#(9–11))處理結(jié)果張量。
# Codeblock 5b
def forward(self, x, t): #(1)
print(f'images\t\t\t: {x.size()}')
print(f'timesteps\t\t: {t.size()}, {t}')
x = self.conv_0(x) #(2)
x = self.bn_0(x) #(3)
x = self.relu(x) #(4)
print(f'\nafter first conv\t: {x.size()}')
time_embed = self.time_embedding()[t] #(5)
print(f'\ntime_embed\t\t: {time_embed.size()}')
time_embed = self.linear(time_embed) #(6)
print(f'time_embed after linear\t: {time_embed.size()}')
time_embed = time_embed[:, :, None, None] #(7)
print(f'time_embed expanded\t: {time_embed.size()}')
x = x + time_embed #(8)
print(f'\nafter summation\t\t: {x.size()}')
x = self.conv_1(x) #(9)
x = self.bn_1(x) #(10)
x = self.relu(x) #(11)
print(f'after second conv\t: {x.size()}')
return x
為了查看我們的DoubleConv類(lèi)實(shí)現(xiàn)是否正常工作,我們將使用下面的Codeblock 6對(duì)其進(jìn)行測(cè)試。在這里,我想模擬此塊的第一個(gè)實(shí)例,它對(duì)應(yīng)于圖5中最左邊的黃色框。為此,我們需要將in_channels和out_channels參數(shù)分別設(shè)置為1和64(#(1))。接下來(lái),我們初始化兩個(gè)輸入張量,即x_test和t_test。其中,x_test張量的大小為2×1×28×28,表示一批大小為28×28的兩個(gè)灰度圖像(#(2))。請(qǐng)記住,這只是一個(gè)賦予了隨機(jī)數(shù)值的虛擬張量,它將在訓(xùn)練階段的后期被來(lái)自MNIST數(shù)據(jù)集的實(shí)際圖像替換。同時(shí),t_test是一個(gè)包含相應(yīng)圖像的時(shí)間步長(zhǎng)數(shù)的張量(#(3))。此張量的值在0到NUM_TIMESTEPS(1000)之間隨機(jī)選擇。請(qǐng)注意,此張量的數(shù)據(jù)類(lèi)型必須是整數(shù),因?yàn)檫@些數(shù)字將用于索引,如Codeblock 5b中的第#(5)行所示。最后,在代碼行#(4)我們將x_test和t_test張量傳遞給double_conv_test層。
順便說(shuō)一句,print()在運(yùn)行以下代碼之前,我重新運(yùn)行了前面的代碼塊并刪除了函數(shù),以便輸出結(jié)果看起來(lái)更整潔一些。
# Codeblock 6
double_conv_test = DoubleConv(in_channels=1, out_channels=64).to(DEVICE) #(1)
x_test = torch.randn((BATCH_SIZE, NUM_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)).to(DEVICE) #(2)
t_test = torch.randint(0, NUM_TIMESTEPS, (BATCH_SIZE,)).to(DEVICE) #(3)
out_test = double_conv_test(x_test, t_test) #(4)
# Codeblock 6 Output
images : torch.Size([2, 1, 28, 28]) #(1)
timesteps : torch.Size([2]), tensor([468, 304], device='cuda:0') #(2)
after first conv : torch.Size([2, 64, 28, 28]) #(3)
time_embed : torch.Size([2, 32]) #(4)
time_embed after linear : torch.Size([2, 64])
time_embed expanded : torch.Size([2, 64, 1, 1]) #(5)
after summation : torch.Size([2, 64, 28, 28]) #(6)
after second conv : torch.Size([2, 64, 28, 28]) #(7)
原始輸入張量的形狀可以在上面的輸出的第#(1)和#(2)行看到。具體來(lái)說(shuō),在第#(2)行,我還打印出了我們隨機(jī)選擇的兩個(gè)時(shí)間步。在這個(gè)例子中,我們假設(shè)x張量中的兩個(gè)圖像在輸入網(wǎng)絡(luò)之前已經(jīng)用來(lái)自第468和第304個(gè)時(shí)間步的噪聲級(jí)別進(jìn)行了噪聲處理。我們可以看到,在經(jīng)過(guò)第一個(gè)卷積層(#(3)行)后,圖像張量x的形狀變?yōu)?×64×28×28。同時(shí),我們的時(shí)間嵌入張量的大小變?yōu)?×32(#(4)行),這是通過(guò)從大小為1000×32的原始嵌入中提取第468行和第304行獲得的。為了能夠執(zhí)行元素級(jí)求和(#(6)行),我們需要將32維時(shí)間嵌入向量映射到64維并擴(kuò)展它們的軸,從而得到大小為2×64×1×1的張量(#(5)行),以便可以將其廣播到2×64×28×28張量。求和完成后,我們將張量傳遞到第二個(gè)卷積層,此時(shí)張量維度完全不會(huì)改變(#(7)行)。
U-Net架構(gòu):編碼器
到目前為止,既然我們已經(jīng)成功實(shí)現(xiàn)了DoubleConv塊,那么接下來(lái)要做的就是實(shí)現(xiàn)所謂的DownSample塊。在下面的圖6中,這對(duì)應(yīng)于紅色框內(nèi)的部分。
圖6.網(wǎng)絡(luò)中以紅色突出顯示的部分是所謂的DownSample塊【引文3】
DownSample塊的目的是減少圖像的空間維度,但需要注意的是,它同時(shí)增加了通道數(shù)。為了實(shí)現(xiàn)這一點(diǎn),我們可以簡(jiǎn)單地堆疊一個(gè)DoubleConv塊和一個(gè)最大池化操作。在這種情況下,池化使用2×2內(nèi)核大小,步長(zhǎng)為2,導(dǎo)致圖像的空間維度是輸入的兩倍。此塊的實(shí)現(xiàn)可以在下面的Codeblock 7中看到。
# Codeblock 7
class DownSample(nn.Module):
def __init__(self, in_channels, out_channels): #(1)
super().__init__()
self.double_conv = DoubleConv(in_channels=in_channels, #(2)
out_channels=out_channels)
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) #(3)
def forward(self, x, t): #(4)
print(f'original\t\t: {x.size()}')
print(f'timesteps\t\t: {t.size()}, {t}')
convolved = self.double_conv(x, t) #(5)
print(f'\nafter double conv\t: {convolved.size()}')
maxpooled = self.maxpool(convolved) #(6)
print(f'after pooling\t\t: {maxpooled.size()}')
return convolved, maxpooled #(7)
在這里,我將__init__()方法設(shè)置為采用輸入和輸出通道的數(shù)量,以便我們可以使用它來(lái)創(chuàng)建圖6中突出顯示的兩個(gè)DownSample塊,而無(wú)需將它們寫(xiě)入單獨(dú)的類(lèi)(#(1))。接下來(lái),DoubleConv和最大池化層分別在第#(2)和#(3)行初始化。請(qǐng)記住,由于DoubleConv塊接受圖像x和相應(yīng)的時(shí)間步t作為輸入,我們還需要設(shè)置此DownSample塊的forward()方法,以便使其也接受它們兩個(gè)參數(shù)(#(4))。然后,當(dāng)double_conv層處理兩個(gè)張量時(shí),x和t中包含的信息被組合在一起,輸出被存儲(chǔ)在名為convolved(#(5))的變量中。之后,我們現(xiàn)在實(shí)際上在第1行使用最大池化操作執(zhí)行下采樣(#(6)),產(chǎn)生一個(gè)名為maxpooled的張量。值得注意的是,convolved和maxpooled張量都將被返回,這基本上是因?yàn)槲覀兩院髸?huì)把maxpooled帶入下一個(gè)下采樣階段,而convolved張量將通過(guò)跳過(guò)連接直接傳輸?shù)浇獯a器中的上采樣階段。
現(xiàn)在,我們使用下面的Codeblock 8來(lái)測(cè)試該DownSample類(lèi)。這里使用的輸入張量與Codeblock 6中的完全相同。根據(jù)結(jié)果輸出,我們可以看到池化操作成功地將DoubleConv塊的輸出從2×64×28×28(#(1))轉(zhuǎn)換為2×64×14×14(#(2)),這表明我們的DownSample類(lèi)正常工作。
# Codeblock 8
down_sample_test = DownSample(in_channels=1, out_channels=64).to(DEVICE)
x_test = torch.randn((BATCH_SIZE, NUM_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)).to(DEVICE)
t_test = torch.randint(0, NUM_TIMESTEPS, (BATCH_SIZE,)).to(DEVICE)
out_test = down_sample_test(x_test, t_test)
# Codeblock 8 Output
original : torch.Size([2, 1, 28, 28])
timesteps : torch.Size([2]), tensor([468, 304], device='cuda:0')
after double conv : torch.Size([2, 64, 28, 28]) #(1)
after pooling : torch.Size([2, 64, 14, 14]) #(2)
U-Net架構(gòu):解碼器
我們需要在解碼器中引入所謂的UpSample塊,它負(fù)責(zé)將中間層中的張量恢復(fù)到原始圖像維度。為了保持對(duì)稱(chēng)結(jié)構(gòu),UpSample塊的數(shù)量必須與DownSample塊的數(shù)量相匹配。觀察下面的圖7,就可以看到兩個(gè)UpSample塊的位置。
圖7.藍(lán)色框內(nèi)的組件就是所謂的UpSample塊【引文3】
由于兩個(gè)UpSample塊的結(jié)構(gòu)相同,我們可以為它們初始化一個(gè)類(lèi),就像DownSample我們之前創(chuàng)建的類(lèi)一樣。請(qǐng)查看下面的Codeblock 9,了解我如何實(shí)現(xiàn)它。
# Codeblock 9
class UpSample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv_transpose = nn.ConvTranspose2d(in_channels=in_channels, #(1)
out_channels=out_channels,
kernel_size=2, stride=2) #(2)
self.double_conv = DoubleConv(in_channels=in_channels, #(3)
out_channels=out_channels)
def forward(self, x, t, connection): #(4)
print(f'original\t\t: {x.size()}')
print(f'timesteps\t\t: {t.size()}, {t}')
print(f'connection\t\t: {connection.size()}')
x = self.conv_transpose(x) #(5)
print(f'\nafter conv transpose\t: {x.size()}')
x = torch.cat([x, connection], dim=1) #(6)
print(f'after concat\t\t: {x.size()}')
x = self.double_conv(x, t) #(7)
print(f'after double conv\t: {x.size()}')
return x
在該__init__()方法中,我們使用nn.ConvTranspose2d對(duì)空間維度進(jìn)行上采樣(#(1))。內(nèi)核大小和步長(zhǎng)都設(shè)置為2,這樣輸出將是原來(lái)的兩倍(#(2))。接下來(lái),DoubleConv將使用塊來(lái)減少通道數(shù),同時(shí)結(jié)合來(lái)自時(shí)間嵌入張量的時(shí)間步長(zhǎng)信息(#(3))。
這個(gè)UpSample類(lèi)的流程比DownSample類(lèi)稍微復(fù)雜一些。如果我們仔細(xì)看看這個(gè)架構(gòu),我們會(huì)發(fā)現(xiàn)我們還有一個(gè)直接來(lái)自編碼器的跳過(guò)連接。因此,除了原始圖像x和時(shí)間步長(zhǎng)之外,我們還需要forward()方法接受另一個(gè)參數(shù)t,即殘差張量connection(#(4))。我們?cè)谶@個(gè)方法中做的第一件事是用轉(zhuǎn)置卷積層處理原始圖像(#(5))。事實(shí)上,這個(gè)層不僅對(duì)空間大小進(jìn)行了上采樣,而且還同時(shí)減少了通道數(shù)。然而,得到的張量隨后以通道方式直接連接(#(6)),使其看起來(lái)好像沒(méi)有執(zhí)行任何通道減少。重要的是要知道,此時(shí)這兩個(gè)張量只是連接在一起,這意味著來(lái)自?xún)烧叩男畔⑸形唇Y(jié)合。我們最終將這些連接的張量輸入到double_conv層(#(7)),允許它們通過(guò)卷積層內(nèi)的可學(xué)習(xí)參數(shù)相互共享信息。
下面的Codeblock 10展示了我如何測(cè)試該類(lèi)UpSample。需要傳入的張量的大小是根據(jù)第二個(gè)上采樣塊(即圖7中最右邊的藍(lán)色框)設(shè)置的。
# Codeblock 10
up_sample_test = UpSample(in_channels=128, out_channels=64).to(DEVICE)
x_test = torch.randn((BATCH_SIZE, 128, 14, 14)).to(DEVICE)
t_test = torch.randint(0, NUM_TIMESTEPS, (BATCH_SIZE,)).to(DEVICE)
connection_test = torch.randn((BATCH_SIZE, 64, 28, 28)).to(DEVICE)
out_test = up_sample_test(x_test, t_test, connection_test)
在下面的結(jié)果輸出中,如果我們將輸入張量(#(1))與最終張量形狀(#(2))進(jìn)行比較,我們可以清楚地看到通道數(shù)成功地從128減少到64,同時(shí)空間維度從14×14增加到28×28。這實(shí)際上意味著,我們的UpSample類(lèi)現(xiàn)在可以在主U-Net架構(gòu)中使用了。
original : torch.Size([2, 128, 14, 14]) #(1)
timesteps : torch.Size([2]), tensor([468, 304], device='cuda:0')
connection : torch.Size([2, 64, 28, 28])
after conv transpose : torch.Size([2, 64, 28, 28])
after concat : torch.Size([2, 128, 28, 28])
after double conv : torch.Size([2, 64, 28, 28]) #(2)
U-Net架構(gòu):將所有組件整合在一起
一旦創(chuàng)建了所有U-Net組件,我們接下來(lái)要做的就是將它們包裝成一個(gè)類(lèi)。請(qǐng)參閱下面的代碼塊11a和11b了解詳細(xì)信息。
# Codeblock 11a
class UNet(nn.Module):
def __init__(self):
super().__init__()
self.downsample_0 = DownSample(in_channels=NUM_CHANNELS, #(1)
out_channels=64)
self.downsample_1 = DownSample(in_channels=64, #(2)
out_channels=128)
self.bottleneck = DoubleConv(in_channels=128, #(3)
out_channels=256)
self.upsample_0 = UpSample(in_channels=256, #(4)
out_channels=128)
self.upsample_1 = UpSample(in_channels=128, #(5)
out_channels=64)
self.output = nn.Conv2d(in_channels=64, #(6)
out_channels=NUM_CHANNELS,
kernel_size=1)
你可以在上面的__init__()方法中看到,我們初始化兩個(gè)下采樣(#(1-2))和兩個(gè)上采樣(#(4-5))塊;其中,輸入和輸出通道的數(shù)量根據(jù)圖中所示的架構(gòu)設(shè)置。實(shí)際上,還有兩個(gè)我還沒(méi)有解釋的額外組件,即瓶頸層bottleneck(#(3))和輸出層output(#(6))。前者本質(zhì)上只是一個(gè)DoubleConv塊,它充當(dāng)編碼器和解碼器之間的主要連接。查看下面的圖8,以查看網(wǎng)絡(luò)的哪些組件屬于瓶頸層bottleneck。接下來(lái),輸出層是一個(gè)標(biāo)準(zhǔn)卷積層,負(fù)責(zé)將上一個(gè)UpSampling階段生成的64通道圖像轉(zhuǎn)換為僅1通道。此操作使用大小為1×1的內(nèi)核完成,這意味著,它結(jié)合所有通道的信息,同時(shí)在每個(gè)像素位置獨(dú)立操作。
圖8.瓶頸層(模型的下部)充當(dāng)U-Net編碼器和解碼器之間的主要橋梁【引文3】
我想,以下代碼塊中的整個(gè)U-Net的forward()方法應(yīng)該是非常簡(jiǎn)單的,因?yàn)槲覀冊(cè)谶@里所做的基本上是將張量從一層傳遞到另一層——只是不要忘記在下采樣和上采樣塊之間包含跳過(guò)連接。
# Codeblock 11b
def forward(self, x, t): #(1)
print(f'original\t\t: {x.size()}')
print(f'timesteps\t\t: {t.size()}, {t}')
convolved_0, maxpooled_0 = self.downsample_0(x, t)
print(f'\nmaxpooled_0\t\t: {maxpooled_0.size()}')
convolved_1, maxpooled_1 = self.downsample_1(maxpooled_0, t)
print(f'maxpooled_1\t\t: {maxpooled_1.size()}')
x = self.bottleneck(maxpooled_1, t)
print(f'after bottleneck\t: {x.size()}')
upsampled_0 = self.upsample_0(x, t, convolved_1)
print(f'upsampled_0\t\t: {upsampled_0.size()}')
upsampled_1 = self.upsample_1(upsampled_0, t, convolved_0)
print(f'upsampled_1\t\t: {upsampled_1.size()}')
x = self.output(upsampled_1)
print(f'final output\t\t: {x.size()}')
return x
現(xiàn)在,讓我們通過(guò)運(yùn)行以下測(cè)試代碼來(lái)看看我們是否正確構(gòu)建了上面的U-Net類(lèi)。
# Codeblock 12
unet_test = UNet().to(DEVICE)
x_test = torch.randn((BATCH_SIZE, NUM_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)).to(DEVICE)
t_test = torch.randint(0, NUM_TIMESTEPS, (BATCH_SIZE,)).to(DEVICE)
out_test = unet_test(x_test, t_test)
# Codeblock 12 Output
original : torch.Size([2, 1, 28, 28]) #(1)
timesteps : torch.Size([2]), tensor([468, 304], device='cuda:0')
maxpooled_0 : torch.Size([2, 64, 14, 14]) #(2)
maxpooled_1 : torch.Size([2, 128, 7, 7]) #(3)
after bottleneck : torch.Size([2, 256, 7, 7]) #(4)
upsampled_0 : torch.Size([2, 128, 14, 14])
upsampled_1 : torch.Size([2, 64, 28, 28])
final output : torch.Size([2, 1, 28, 28]) #(5)
在上面的輸出中我們可以看到,兩個(gè)下采樣階段成功地將原始大小為1×28×28(#(1))的張量分別轉(zhuǎn)換為64×14×14(#(2))和128×7×7(#(3))。然后,該張量通過(guò)瓶頸層,導(dǎo)致其通道數(shù)擴(kuò)展到256,而不會(huì)改變空間維度(#(4))。最后,我們對(duì)張量進(jìn)行兩次上采樣,最終將通道數(shù)縮小到1(#(5))。根據(jù)這個(gè)輸出,我們的模型似乎已經(jīng)可以運(yùn)行正常了。因此,它現(xiàn)在可以用于我們的擴(kuò)散任務(wù)了。
數(shù)據(jù)集準(zhǔn)備
由于我們已經(jīng)成功創(chuàng)建了整個(gè)U-Net架構(gòu),接下來(lái)要做的就是準(zhǔn)備MNIST手寫(xiě)數(shù)字?jǐn)?shù)據(jù)集。在實(shí)際加載它之前,我們需要首先使用Torchvision中的transforms.Compose()方法定義預(yù)處理步驟,如Codeblock13中#(1)行所示。這里我們做兩件事:將圖像轉(zhuǎn)換為PyTorch張量,這也會(huì)將像素值從0-255縮放到0-1(#(2)),并對(duì)其進(jìn)行規(guī)范化,以便最終像素值介于-1和1之間(#(3))。
接下來(lái),我們使用datasets.MNIST()下載數(shù)據(jù)集。在本例中,我們將從訓(xùn)練數(shù)據(jù)中獲取圖像,因此我們使用train=True(#(5))。不要忘記將我們之前初始化的變量transform傳遞給transform參數(shù)(transform=transform),以便它在加載圖像時(shí)自動(dòng)預(yù)處理圖像(#(6))。最后,我們需要使用DataLoader加載來(lái)自mnist_dataset的圖像(#(7))。注意,我用于輸入?yún)?shù)的參數(shù),為的是在每次迭代中從數(shù)據(jù)集中隨機(jī)挑選BATCH_SIZE(2)張圖像。
# Codeblock 13
transform = transforms.Compose([ #(1)
transforms.ToTensor(), #(2)
transforms.Normalize((0.5,), (0.5,)) #(3)
])
mnist_dataset = datasets.MNIST( #(4)
root='./data',
train=True, #(5)
download=True,
transform=transform #(6)
)
loader = DataLoader(mnist_dataset, #(7)
batch_size=BATCH_SIZE,
drop_last=True,
shuffle=True)
在下面的代碼塊中,我嘗試從數(shù)據(jù)集中加載一批圖像。在每次迭代中,loader都會(huì)提供圖像和相應(yīng)的標(biāo)簽,因此我們需要將它們存儲(chǔ)在兩個(gè)單獨(dú)的變量中:images和labels。
# Codeblock 14
images, labels = next(iter(loader))
print('images\t\t:', images.shape)
print('labels\t\t:', labels.shape)
print('min value\t:', images.min())
print('max value\t:', images.max())
我們可以在下面的結(jié)果輸出中看到,images張量的大小為2×1×28×28(#(1)),表示已成功加載兩張大小為28×28的灰度圖像。在這里我們還可以看到labels張量的長(zhǎng)度為2,與加載的圖像數(shù)量相匹配(#(2))。請(qǐng)注意,在這種情況下,標(biāo)簽將被完全忽略。我的計(jì)劃是,我只希望模型從整個(gè)訓(xùn)練數(shù)據(jù)集中生成它之前看到的任何數(shù)字,甚至不知道它實(shí)際上是什么數(shù)字。最后,此輸出還顯示預(yù)處理工作正常,因?yàn)橄袼刂惮F(xiàn)在介于-1和1之間。
# Codeblock 14 Output
images : torch.Size([2, 1, 28, 28]) #(1)
labels : torch.Size([2]) #(2)
min value : tensor(-1.)
max value : tensor(1.)
如果你想看看我們剛剛加載的圖像是什么樣子,請(qǐng)運(yùn)行以下代碼。
# Codeblock 15
plt.imshow(images[0].squeeze(), cmap='gray')
plt.show()
圖9.Codeblock 15的輸出【引文3】
噪音調(diào)度器
在本節(jié)中,我們將討論如何執(zhí)行前向和后向擴(kuò)散,該過(guò)程本質(zhì)上涉及在每個(gè)時(shí)間步長(zhǎng)上一點(diǎn)一點(diǎn)地添加或消除噪聲。有必要知道,我們基本上希望在所有時(shí)間步長(zhǎng)上噪聲量均勻;其中,在前向擴(kuò)散中,圖像應(yīng)該在時(shí)間步長(zhǎng)1000時(shí)完全充滿噪聲,而在后向擴(kuò)散中,我們必須在時(shí)間步長(zhǎng)0時(shí)獲得完全清晰的圖像。因此,我們需要一些東西來(lái)控制每個(gè)時(shí)間步長(zhǎng)的噪聲量。在本節(jié)后面,我將實(shí)現(xiàn)一個(gè)名為NoiseScheduler的類(lèi)來(lái)執(zhí)行此操作。這可能是本文中數(shù)學(xué)知識(shí)涉及最多的部分,因?yàn)槲覍⒃谶@里展示許多方程式。但不要擔(dān)心,因?yàn)槲覀儗?zhuān)注于實(shí)現(xiàn)這些方程式,而不是討論數(shù)學(xué)推導(dǎo)。
現(xiàn)在,讓我們看一下圖10中的方程式,我將在下面的NoiseScheduler類(lèi)的__init__()方法中實(shí)現(xiàn)它們。
圖10.我們需要在類(lèi)NoiseScheduler的__init__()方法中實(shí)現(xiàn)的方程【引文3】
# Codeblock 16a
class NoiseScheduler:
def __init__(self):
self.betas = torch.linspace(BETA_START, BETA_END, NUM_TIMESTEPS) #(1)
self.alphas = 1. - self.betas
self.alphas_cum_prod = torch.cumprod(self.alphas, dim=0)
self.sqrt_alphas_cum_prod = torch.sqrt(self.alphas_cum_prod)
self.sqrt_one_minus_alphas_cum_prod = torch.sqrt(1. - self.alphas_cum_prod)
上述代碼通過(guò)創(chuàng)建多個(gè)數(shù)字序列來(lái)工作,它們基本上都由BETA_START(0.0001)、BETA_END(0.02)和NUM_TIMESTEPS(1000)控制。我們需要實(shí)例化的第一個(gè)序列是betas本身,它是使用torch.linspace()完成的(#(1))。它本質(zhì)上的作用是生成一個(gè)長(zhǎng)度為1000的一維張量,從0.0001到0.02,其中此張量中的每個(gè)元素都對(duì)應(yīng)一個(gè)時(shí)間步。每個(gè)元素之間的間隔是均勻的,這使我們能夠在所有時(shí)間步長(zhǎng)中生成均勻的噪聲量。有了這個(gè)betas張量,我們?nèi)缓蟾鶕?jù)圖10中的四個(gè)方程計(jì)算alphas、alphas_cum_prod、sqrt_alphas_cum_prod和sqrt_one_minus_alphas_cum_prod。稍后,這些張量將作為在擴(kuò)散過(guò)程中如何產(chǎn)生或消除噪聲的基礎(chǔ)。
擴(kuò)散通常以順序方式進(jìn)行。然而,前向擴(kuò)散過(guò)程是確定性的,因此我們可以將原始方程推導(dǎo)為封閉形式,這樣我們就可以在特定的時(shí)間步長(zhǎng)中獲得噪聲,而不必從頭開(kāi)始迭代添加噪聲。下圖11顯示了前向擴(kuò)散的封閉形式,其中x0表示原始圖像,而epsilon(?)表示由隨機(jī)高斯噪聲組成的圖像。我們可以將此方程視為加權(quán)組合,其中我們根據(jù)時(shí)間步長(zhǎng)確定的權(quán)重將清晰圖像和噪聲組合在一起,從而得到具有特定噪聲量的圖像。
圖11.正向擴(kuò)散過(guò)程的封閉形式【引文3】
該方程的實(shí)現(xiàn)可以在Codeblock 16b中看到。在forward_diffusion()方法中,x0和?表示為original和noise。這里你需要記住這兩個(gè)輸入變量是圖像,而sqrt_alphas_cum_prod_t和sqrt_one_minus_alphas_cum_prod_t是標(biāo)量。因此,我們需要調(diào)整這兩個(gè)標(biāo)量的形狀(代碼行#(1)和#(2)),以便可以執(zhí)行#(3)行中的操作。變量noisy_image將是此函數(shù)的輸出,我想這個(gè)名字是不言自明的。
# Codeblock 16b
def forward_diffusion(self, original, noise, t):
sqrt_alphas_cum_prod_t = self.sqrt_alphas_cum_prod[t]
sqrt_alphas_cum_prod_t = sqrt_alphas_cum_prod_t.to(DEVICE).view(-1, 1, 1, 1) #(1)
sqrt_one_minus_alphas_cum_prod_t = self.sqrt_one_minus_alphas_cum_prod[t]
sqrt_one_minus_alphas_cum_prod_t = sqrt_one_minus_alphas_cum_prod_t.to(DEVICE).view(-1, 1, 1, 1) #(2)
noisy_image = sqrt_alphas_cum_prod_t * original + sqrt_one_minus_alphas_cum_prod_t * noise #(3)
return noisy_image
現(xiàn)在,我們來(lái)談?wù)劮聪驍U(kuò)散。事實(shí)上,這個(gè)比正向擴(kuò)散稍微復(fù)雜一些,因?yàn)槲覀冃枰齻€(gè)方程。在我向你展示這些方程之前,讓我先向你展示一下有關(guān)代碼實(shí)現(xiàn)。請(qǐng)參閱下面的代碼塊16c。
# Codeblock 16c
def backward_diffusion(self, current_image, predicted_noise, t): #(1)
denoised_image = (current_image - (self.sqrt_one_minus_alphas_cum_prod[t] * predicted_noise)) / self.sqrt_alphas_cum_prod[t] #(2)
denoised_image = 2 * (denoised_image - denoised_image.min()) / (denoised_image.max() - denoised_image.min()) - 1 #(3)
current_prediction = current_image - ((self.betas[t] * predicted_noise) / (self.sqrt_one_minus_alphas_cum_prod[t])) #(4)
current_prediction = current_prediction / torch.sqrt(self.alphas[t]) #(5)
if t == 0: #(6)
return current_prediction, denoised_image
else:
variance = (1 - self.alphas_cum_prod[t-1]) / (1. - self.alphas_cum_prod[t]) #(7)
variance = variance * self.betas[t] #(8)
sigma = variance ** 0.5
z = torch.randn(current_image.shape).to(DEVICE)
current_prediction = current_prediction + sigma*z
return current_prediction, denoised_image
在推理階段的后期,該backward_diffusion()方法將在一個(gè)循環(huán)中調(diào)用,該循環(huán)迭代NUM_TIMESTEPS(1000)次,從t=999開(kāi)始,繼續(xù)執(zhí)行t=998,一直到t=0為止。此函數(shù)負(fù)責(zé)根據(jù)current_image(前一個(gè)去噪步驟生成的圖像)、predicted_noise(U-Net在上一步中預(yù)測(cè)的噪聲)和時(shí)間步長(zhǎng)信息t三個(gè)參數(shù)迭代地從圖像中去除噪聲(#(1))。在每次迭代中,使用圖12中所示的公式進(jìn)行噪聲消除,在代碼塊16c中,這對(duì)應(yīng)于代碼行#(4-5)。
圖12.用于從圖像中去除噪聲的方程【引文3】
只要我們還沒(méi)有達(dá)到t=0,我們就會(huì)根據(jù)圖13中的方程計(jì)算方差(代碼行#(7-8))。然后,該方差將用于引入另一個(gè)受控噪聲,以模擬反向擴(kuò)散過(guò)程中的隨機(jī)性,因?yàn)閳D12中的噪聲消除方程是確定性近似。這本質(zhì)上也是我們?cè)谶_(dá)到t=0(代碼行#(6))后不再計(jì)算方差的原因,因?yàn)槲覀儾辉傩枰砑痈嘣肼?,因?yàn)閳D像已經(jīng)完全清晰了。
圖13.用于計(jì)算引入受控噪聲的方差的方程【引文3】
current_prediction與旨在估計(jì)前一個(gè)時(shí)間步長(zhǎng)的圖像(xt-1)不同,張量的目標(biāo)denoised_image是重建原始圖像(x0)。由于這些不同的目標(biāo),我們需要一個(gè)單獨(dú)的方程來(lái)計(jì)算denoised_image,如下圖14所示。方程本身的實(shí)現(xiàn)寫(xiě)在第#(2-3)行。
圖14.重建原始圖像的方程【引文3】
現(xiàn)在,讓我們測(cè)試一下上面創(chuàng)建的NoiseScheduler類(lèi)。在下面的代碼塊中,我實(shí)例化一個(gè)NoiseScheduler對(duì)象并打印出與其關(guān)聯(lián)的屬性,這些屬性都是使用圖10中的公式根據(jù)存儲(chǔ)在betas屬性中的值計(jì)算得出的。請(qǐng)記住,這些張量的實(shí)際長(zhǎng)度是NUM_TIMESTEPS(1000),但在這里我只打印出前6個(gè)元素。
# Codeblock 17
noise_scheduler = NoiseScheduler()
print(f'betas\t\t\t\t: {noise_scheduler.betas[:6]}')
print(f'alphas\t\t\t\t: {noise_scheduler.alphas[:6]}')
print(f'alphas_cum_prod\t\t\t: {noise_scheduler.alphas_cum_prod[:6]}')
print(f'sqrt_alphas_cum_prod\t\t: {noise_scheduler.sqrt_alphas_cum_prod[:6]}')
print(f'sqrt_one_minus_alphas_cum_prod\t: {noise_scheduler.sqrt_one_minus_alphas_cum_prod[:6]}')
# Codeblock 17 Output
betas : tensor([1.0000e-04, 1.1992e-04, 1.3984e-04, 1.5976e-04, 1.7968e-04, 1.9960e-04])
alphas : tensor([0.9999, 0.9999, 0.9999, 0.9998, 0.9998, 0.9998])
alphas_cum_prod : tensor([0.9999, 0.9998, 0.9996, 0.9995, 0.9993, 0.9991])
sqrt_alphas_cum_prod : tensor([0.9999, 0.9999, 0.9998, 0.9997, 0.9997, 0.9996])
sqrt_one_minus_alphas_cum_prod : tensor([0.0100, 0.0148, 0.0190, 0.0228, 0.0264, 0.0300])
上面的輸出表明我們的__init__()方法按預(yù)期工作。接下來(lái),我們將測(cè)試forward_diffusion()方法。如果回看一下圖16b,你將看到forward_diffusion()接受三個(gè)輸入:原始圖像、噪聲圖像和時(shí)間步長(zhǎng)數(shù)。我們只需使用我們之前加載的MNIST數(shù)據(jù)集中的圖像作為第一個(gè)輸入(#(1)),并使用大小完全相同的隨機(jī)高斯噪聲作為第二個(gè)輸入(#(2))。為此,可以運(yùn)行下面的Codeblock 18來(lái)查看這兩個(gè)圖像是什么樣子。
# Codeblock 18
image = images[0] #(1)
noise = torch.randn_like(image) #(2)
plt.imshow(image.squeeze(), cmap='gray')
plt.show()
plt.imshow(noise.squeeze(), cmap='gray')
plt.show()
圖15.用作原始圖像(左)和噪聲圖像(右)的兩幅圖像(其中,左側(cè)的圖像與我之前在圖9【引文3】中展示的圖像相同)
由于我們已經(jīng)準(zhǔn)備好了圖像和噪聲,接下來(lái)我們需要做的就是將它們傳遞給forward_diffusion()方法。我實(shí)際上嘗試多次運(yùn)行下面的Codeblock19:t=50、100、150等等,直到t=300。你可以在圖16中看到,隨著參數(shù)的增加,圖像變得不那么清晰。在這種情況下,當(dāng)t設(shè)置為999時(shí),圖像將完全被噪聲填充。
# Codeblock 19
noisy_image_test = noise_scheduler.forward_diffusion(image.to(DEVICE), noise.to(DEVICE), t=50)
plt.imshow(noisy_image_test[0].squeeze().cpu(), cmap='gray')
plt.show()
圖16. t=50、100、150等時(shí)刻正向擴(kuò)散過(guò)程的結(jié)果,直到t=300【引文3】
不幸的是,我們無(wú)法測(cè)試該backward_diffusion()方法,因?yàn)榇诉^(guò)程需要我們對(duì)U-Net模型進(jìn)行訓(xùn)練。所以,我們現(xiàn)在就跳過(guò)這部分。我將向你展示如何在稍后的推理階段實(shí)際使用此功能。
訓(xùn)練
由于U-Net模型、MNIST數(shù)據(jù)集和噪聲調(diào)度程序已準(zhǔn)備就緒,我們現(xiàn)在可以準(zhǔn)備一個(gè)函數(shù)進(jìn)行訓(xùn)練。在此之前,我在下面的Codeblock 20中實(shí)例化了模型和噪聲調(diào)度程序。
# Codeblock 20
model = UNet().to(DEVICE)
noise_scheduler = NoiseScheduler()
整個(gè)訓(xùn)練過(guò)程在Codeblock 21中顯示的train()函數(shù)中實(shí)現(xiàn)。在執(zhí)行任何操作之前,我們首先初始化優(yōu)化器和損失函數(shù),在本例中我們分別使用Adam和MSE(#(1-2))。我們基本上想要做的是訓(xùn)練模型,使其能夠預(yù)測(cè)輸入圖像中包含的噪聲,稍后,預(yù)測(cè)的噪聲將用作后向擴(kuò)散階段去噪過(guò)程的基礎(chǔ)。要實(shí)際訓(xùn)練模型,我們首先需要使用#(6)行處的代碼執(zhí)行前向擴(kuò)散。此噪聲化過(guò)程將使用#(4)行處生成的隨機(jī)噪聲在images量上完成(#(3))。接下來(lái),我們?yōu)閠(#(5))取0到NUM_TIMESTEPS(1000)之間的隨機(jī)數(shù),這主要是因?yàn)槲覀兿M覀兊哪P湍軌蚩吹讲煌肼曀降膱D像,以此作為提高泛化能力的一種方法。生成噪聲圖像后,我們將它與所選的t(#(7))一起傳遞給U-Net模型。此處的輸入對(duì)模型很有用,因?yàn)樗硎緢D像中的當(dāng)前噪聲水平。最后,我們之前初始化的損失函數(shù)負(fù)責(zé)計(jì)算實(shí)際噪聲與原始圖像的預(yù)測(cè)噪聲之間的差異(#(8))。因此,這次訓(xùn)練的目標(biāo)基本上是使預(yù)測(cè)噪聲盡可能與我們?cè)诘?(4)行生成的噪聲相似。
# Codeblock 21
def train():
optimizer = Adam(model.parameters(), lr=LEARNING_RATE) #(1)
loss_function = nn.MSELoss() #(2)
losses = []
for epoch in range(NUM_EPOCHS):
print(f'Epoch no {epoch}')
for images, _ in tqdm(loader):
optimizer.zero_grad()
images = images.float().to(DEVICE) #(3)
noise = torch.randn_like(images) #(4)
t = torch.randint(0, NUM_TIMESTEPS, (BATCH_SIZE,)) #(5)
noisy_images = noise_scheduler.forward_diffusion(images, noise, t).to(DEVICE) #(6)
predicted_noise = model(noisy_images, t) #(7)
loss = loss_function(predicted_noise, noise) #(8)
losses.append(loss.item())
loss.backward()
optimizer.step()
return losses
現(xiàn)在,讓我們使用下面的代碼塊運(yùn)行上述訓(xùn)練函數(shù)。坐下來(lái)放松一下,等待訓(xùn)練完成。就我而言,我使用了開(kāi)啟了Nvidia GPU P100的Kaggle Notebook,大約需要45分鐘才能完成。
# Codeblock 22
losses = train()
如果我們看一下?lián)p失圖,似乎我們的模型學(xué)習(xí)得很好,因?yàn)閮r(jià)值通常會(huì)隨著時(shí)間的推移而下降,早期階段下降迅速,后期階段趨勢(shì)更穩(wěn)定(但仍在下降)。所以,我認(rèn)為我們可以在推理階段的后期期待好的結(jié)果。
# Codeblock 23
plt.plot(losses)
圖17.損失值如何隨著訓(xùn)練的進(jìn)行而減少【引文3】
推理
此時(shí),我們已經(jīng)訓(xùn)練了模型,因此我們現(xiàn)在可以對(duì)其進(jìn)行推理。查看下面的Codeblock 24以了解我如何實(shí)現(xiàn)該inference()函數(shù)。
# Codeblock 24
def inference():
denoised_images = [] #(1)
with torch.no_grad(): #(2)
current_prediction = torch.randn((64, NUM_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)).to(DEVICE) #(3)
for i in tqdm(reversed(range(NUM_TIMESTEPS))): #(4)
predicted_noise = model(current_prediction, torch.as_tensor(i).unsqueeze(0)) #(5)
current_prediction, denoised_image = noise_scheduler.backward_diffusion(current_prediction, predicted_noise, torch.as_tensor(i)) #(6)
if i%100 == 0: #(7)
denoised_images.append(denoised_image)
return denoised_images
在標(biāo)有#(1)的行中,我初始化了一個(gè)空列表,該列表將用于每100個(gè)時(shí)間步存儲(chǔ)一次去噪結(jié)果(#(7))。這將使我們能夠稍后看到反向擴(kuò)散的進(jìn)行方式。實(shí)際的推理過(guò)程封裝在torch.no_grad()(#(2))內(nèi)。請(qǐng)記住,在擴(kuò)散模型中,我們從完全隨機(jī)的噪聲中生成圖像,我們假設(shè)這些圖像最初是在t=999時(shí)。為了實(shí)現(xiàn)這一點(diǎn),我們可以簡(jiǎn)單地使用torch.randn(),如#(3)行所示。在這里,我們初始化一個(gè)大小為64×1×28×28的張量,表示我們即將同時(shí)生成64張圖像。接下來(lái),我們編寫(xiě)一個(gè)for從999開(kāi)始向后迭代到0的循環(huán)(#(4))。在這個(gè)循環(huán)中,我們將當(dāng)前圖像和時(shí)間步作為訓(xùn)練后的U-Net的輸入,并讓它預(yù)測(cè)噪聲(#(5))。然后在#(6)行執(zhí)行實(shí)際的反向擴(kuò)散。在迭代結(jié)束時(shí),我們應(yīng)該得到類(lèi)似于我們數(shù)據(jù)集中的新圖像?,F(xiàn)在,讓我們?cè)谙旅娴拇a塊中調(diào)用該inference()函數(shù)。
# Codeblock 25
denoised_images = inference()
推理完成后,我們現(xiàn)在可以看到生成的圖像是什么樣子。下面的Codeblock 26用于顯示我們剛剛生成的前42張圖像。
# Codeblock 26
fig, axes = plt.subplots(ncols=7, nrows=6, figsize=(10, 8))
counter = 0
for i in range(6):
for j in range(7):
axes[i,j].imshow(denoised_images[-1][counter].squeeze().detach().cpu().numpy(), cmap='gray') #(1)
axes[i,j].get_xaxis().set_visible(False)
axes[i,j].get_yaxis().set_visible(False)
counter += 1
plt.show()
圖18.在MNIST手寫(xiě)數(shù)字?jǐn)?shù)據(jù)集【引文3】上訓(xùn)練的擴(kuò)散模型生成的圖像
如果我們看一下上面的代碼塊,你會(huì)看到#(1)行的索引器[-1]表示我們只顯示來(lái)自最后一次迭代(對(duì)應(yīng)于時(shí)間步長(zhǎng)0)的圖像。這就是你在圖18中看到的圖像都沒(méi)有噪音的原因。我承認(rèn)這可能不是最好的結(jié)果,因?yàn)椴⒎撬猩傻膱D像都是有效的數(shù)字?!?,這反而表明這些圖像不僅僅是原始數(shù)據(jù)集的重復(fù)。
這里我們還可以使用下面的Codeblock 27可視化反向擴(kuò)散過(guò)程。你可以在圖19中的結(jié)果輸出中看到,我們最初從完全隨機(jī)的噪聲開(kāi)始,隨著我們向右移動(dòng),噪聲逐漸消失。
# Codeblock 27
fig, axes = plt.subplots(ncols=10, figsize=(24, 8))
sample_no = 0
timestep_no = 0
for i in range(10):
axes[i].imshow(denoised_images[timestep_no][sample_no].squeeze().detach().cpu().numpy(), cmap='gray')
axes[i].get_xaxis().set_visible(False)
axes[i].get_yaxis().set_visible(False)
timestep_no += 1
plt.show()
圖19.時(shí)間步900、800、700等…直到時(shí)間步0時(shí)圖像的樣子【引文3】
結(jié)論
你可以基于本文提供的線路,進(jìn)一步進(jìn)行很多方面的研究。首先,如果你想要更好的結(jié)果,你可能需要調(diào)整Codeblock 2中的參數(shù)配置。其次,除了我們?cè)谙虏蓸雍蜕喜蓸与A段使用的卷積層堆棧之外,還可以通過(guò)實(shí)現(xiàn)注意層來(lái)改進(jìn)U-Net模型。這樣做并不能保證你獲得更好的結(jié)果,尤其是對(duì)于像本文所使用的這樣的簡(jiǎn)單數(shù)據(jù)集,但絕對(duì)值得一試。第三,如果你想挑戰(zhàn)自己,你還可以嘗試使用更復(fù)雜的數(shù)據(jù)集。
在實(shí)際應(yīng)用中,擴(kuò)散模型實(shí)際上可以做很多事情。最簡(jiǎn)單的一個(gè)可能是數(shù)據(jù)增強(qiáng)。使用擴(kuò)散模型,我們可以輕松地從特定的數(shù)據(jù)分布中生成新圖像。例如,假設(shè)我們正在進(jìn)行一個(gè)圖像分類(lèi)項(xiàng)目,但類(lèi)別中的圖像數(shù)量不平衡。為了解決這個(gè)問(wèn)題,我們可以從少數(shù)類(lèi)別中取出圖像并將它們輸入到擴(kuò)散模型中。通過(guò)這樣做,我們可以要求訓(xùn)練后的擴(kuò)散模型從該類(lèi)別中生成任意數(shù)量的樣本。
好了,以上就是有關(guān)擴(kuò)散模型的理論和實(shí)現(xiàn)的全部?jī)?nèi)容。感謝閱讀,希望你今天能學(xué)到一些新知識(shí)!
你可以通過(guò)此鏈接訪問(wèn)該項(xiàng)目中使用的代碼。這里還有我之前關(guān)于自動(dòng)編碼器、變分自動(dòng)編碼器(VAE)、神經(jīng)風(fēng)格遷移(NST)和Transformer的文章的鏈接。
參考
【1】Jascha Sohl-Dickstein等人,《利用非平衡熱力學(xué)進(jìn)行深度無(wú)監(jiān)督學(xué)習(xí)》,Arxiv。
【2】Jonathan Ho等人,《去噪擴(kuò)散概率模型》,Arxiv。
【3】Olaf Ronneberger等人,《U-Net:用于生物醫(yī)學(xué)圖像分割的卷積網(wǎng)絡(luò)》,Arxiv。
【4】Yann LeCun等人,《MNIST手寫(xiě)數(shù)字?jǐn)?shù)據(jù)庫(kù)》,Creative Commons Attribution-Share Alike 3.0許可證。
【5】Ashish Vaswani等人,《注意力就是你所需要的一切》,Arxiv。
譯者介紹
朱先忠,51CTO社區(qū)編輯,51CTO專(zhuān)家博客、講師,濰坊一所高校計(jì)算機(jī)教師,自由編程界老兵一枚。
原文標(biāo)題:The Art of Noise,作者:Muhammad Ardi