自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

基于PyTorch從零實(shí)現(xiàn)視覺(jué)轉(zhuǎn)換器(ViT)? 原創(chuàng)

發(fā)布于 2024-8-23 08:30
瀏覽
0收藏

譯者 | 朱先忠

審校 | 重樓

簡(jiǎn)介

視覺(jué)轉(zhuǎn)換器(Vision Transformer,通??s寫(xiě)為“ViT”)可以被視為計(jì)算機(jī)視覺(jué)領(lǐng)域的重大突破技術(shù)。當(dāng)涉及到與視覺(jué)相關(guān)的任務(wù)時(shí),人們通常使用基于CNN(卷積神經(jīng)網(wǎng)絡(luò))的模型來(lái)解決。到目前為止,這些模型的性能總是優(yōu)于任何其他類(lèi)型的神經(jīng)網(wǎng)絡(luò)。直到2020年,Dosovitskiy等人發(fā)表了一篇題為《一張圖頂16×16個(gè)單詞:大規(guī)模圖像識(shí)別的轉(zhuǎn)換器》的論文(參考文獻(xiàn)1),論文中強(qiáng)調(diào)這種轉(zhuǎn)換器能夠提供比傳統(tǒng)卷積神經(jīng)網(wǎng)絡(luò)更好的能力。

傳統(tǒng)卷積神經(jīng)網(wǎng)絡(luò)中的單個(gè)卷積層通過(guò)使用核提取特征來(lái)工作。由于內(nèi)核的大小與輸入圖像相比相對(duì)較小,因此它只能捕獲該小區(qū)域內(nèi)包含的信息。換句話說(shuō),它側(cè)重于提取局部特征。為了理解圖像的全局上下文,需要使用由多個(gè)卷積層組成的一個(gè)棧結(jié)構(gòu)。ViT解決了這個(gè)問(wèn)題,因?yàn)樗鼘?shí)現(xiàn)了直接從初始層捕獲全局信息。因此,在ViT中堆疊多個(gè)卷積層可以實(shí)現(xiàn)更全面的信息提取。

基于PyTorch從零實(shí)現(xiàn)視覺(jué)轉(zhuǎn)換器(ViT)?-AI.x社區(qū)

圖1:通過(guò)堆疊多個(gè)卷積層,CNN可以實(shí)現(xiàn)更大的感受野,這對(duì)于捕捉圖像的全局上下文至關(guān)重要(參考文獻(xiàn)2)

視覺(jué)轉(zhuǎn)換器架構(gòu)

如果你曾經(jīng)學(xué)習(xí)過(guò)轉(zhuǎn)換器,你應(yīng)該熟悉編碼器和解碼器這兩個(gè)術(shù)語(yǔ)。在NLP(自然語(yǔ)言處理)領(lǐng)域,特別是對(duì)于機(jī)器翻譯等任務(wù),編碼器負(fù)責(zé)捕獲輸入序列中標(biāo)記(即單詞)之間的關(guān)系,而解碼器負(fù)責(zé)生成輸出序列。在ViT的情況下,我們只需要編碼器部分,它將圖像的每個(gè)圖塊視為一個(gè)標(biāo)記?;谕瑯拥南敕?,編碼器能夠找到圖塊之間的關(guān)系。

整個(gè)視覺(jué)轉(zhuǎn)換器架構(gòu)如圖2所示。在我們?cè)敿?xì)討論有關(guān)代碼之前,我將先使用以下幾小節(jié)來(lái)解釋此架構(gòu)的每個(gè)組件。

基于PyTorch從零實(shí)現(xiàn)視覺(jué)轉(zhuǎn)換器(ViT)?-AI.x社區(qū)

圖2:視覺(jué)轉(zhuǎn)換器架構(gòu)(參考文獻(xiàn)1)

圖塊扁平化和線性投影

根據(jù)上圖,我們可以看到,要做的第一步是將圖像劃分為圖塊。所有這些圖塊排列成一個(gè)序列。然后,這些圖塊中的每一個(gè)都被扁平化,每個(gè)圖塊都形成一個(gè)一維陣列。然后,通過(guò)線性投影將這些標(biāo)記的序列投影到更高維的空間中。此時(shí),我們可以將投影結(jié)果視為NLP中的單詞嵌入,即表示單個(gè)單詞的向量。從技術(shù)上講,線性投影過(guò)程可以用簡(jiǎn)單的MLP(多層感知機(jī))或卷積層來(lái)完成。稍后,我將在具體的實(shí)施過(guò)程中對(duì)此進(jìn)行更多的解釋。

類(lèi)標(biāo)記和位置嵌入

由于我們正在處理分類(lèi)任務(wù),我們需要在投影的圖塊序列前添加一個(gè)新的標(biāo)記。這個(gè)標(biāo)記稱為類(lèi)標(biāo)記,它將通過(guò)為每個(gè)圖塊分配重要性權(quán)重來(lái)聚合其他圖塊的信息。值得注意的是,圖塊扁平化和線性投影會(huì)導(dǎo)致模型丟失空間信息。因此,為了解決這個(gè)問(wèn)題,所有標(biāo)記(包括類(lèi)標(biāo)記)都添加了位置嵌入,以便重新引入空間信息。

轉(zhuǎn)換器編碼器和MLP頭

在這個(gè)階段,張量已經(jīng)準(zhǔn)備好,將被饋送到轉(zhuǎn)換器編碼器塊中,其詳細(xì)結(jié)構(gòu)可以在圖2的右側(cè)看到。該塊由四個(gè)部分組成:層規(guī)一化、多頭注意力、另一層規(guī)一化和MLP層。值得注意的是,這里實(shí)現(xiàn)了兩個(gè)殘差連接。轉(zhuǎn)換器編碼器塊左上角的L×表示將根據(jù)要構(gòu)建的模型大小重復(fù)L次。

最后,我們將把編碼器塊連接到MLP頭。請(qǐng)記住,要轉(zhuǎn)發(fā)的張量只是從類(lèi)標(biāo)記部分出來(lái)的張量。MLP頭部本身由一個(gè)完全連接的層和一個(gè)輸出層組成,其中輸出層中的每個(gè)神經(jīng)元代表數(shù)據(jù)集中一個(gè)可用的類(lèi)。

視覺(jué)轉(zhuǎn)換器變體

在原始論文中提出了三種ViT變體,即ViT-B、ViT-L和ViT-H,如圖3所示,其中:

  • Layers(L):轉(zhuǎn)換器編碼器的數(shù)量。
  • Hidden size(D):嵌入維度以表示單個(gè)圖塊。
  • MLP size:MLP隱藏層中的神經(jīng)元數(shù)量。
  • Heads:多頭注意力層中的注意力頭數(shù)。
  • Params:模型的參數(shù)數(shù)量。

基于PyTorch從零實(shí)現(xiàn)視覺(jué)轉(zhuǎn)換器(ViT)?-AI.x社區(qū)

圖3:三種視覺(jué)轉(zhuǎn)換器變體的詳細(xì)信息(參考文獻(xiàn)1)

在本文中,我想使用PyTorch框架從頭開(kāi)始實(shí)現(xiàn)一個(gè)ViT-Base架構(gòu)。順便說(shuō)一句,該模塊本身實(shí)際上還提供了幾個(gè)預(yù)訓(xùn)練的ViT模型(參考文獻(xiàn)3),即ViT_b_16、ViT_b_32、ViT_l_16、ViT_l_32和ViT_h_14,其中作為這些模型后綴的數(shù)字是指使用的圖塊大小。

從頭開(kāi)始實(shí)現(xiàn)一個(gè)ViT

現(xiàn)在,讓我們開(kāi)始真正有趣的部分。實(shí)現(xiàn)一個(gè)ViT編程首先要做的是導(dǎo)入模塊。在這種情況下,我們將只依賴PyTorch框架的功能來(lái)構(gòu)建ViT架構(gòu)。從torchinfo加載的summary()函數(shù)將幫助我們顯示模型的詳細(xì)信息。

# 代碼塊1
import torch
import torch.nn as nn
from torchinfo import summary

參數(shù)配置

在代碼塊2中,我們將初始化幾個(gè)變量來(lái)配置模型。在這里,我們假設(shè)單個(gè)批次中要處理的圖像數(shù)量?jī)H為1,其維度為3×224×224(標(biāo)記為#(1))。我們?cè)谶@里要使用的變體是ViT-Base,這意味著我們需要將圖塊大小設(shè)置為16,注意頭數(shù)量設(shè)置為12,編碼器數(shù)量設(shè)置為12,嵌入維度設(shè)置為768(#(2))。通過(guò)使用此配置,圖塊數(shù)量將為196(#(3))。這個(gè)數(shù)字是通過(guò)將大小為224×224的圖像劃分為16×16個(gè)圖塊而獲得的,其中它產(chǎn)生了14×14的網(wǎng)格。因此,一張圖像將有196個(gè)圖塊。

我們還將對(duì)dropout層使用0.1的速率(#(4))。值得注意的是,論文中沒(méi)有明確提及dropout層的使用。由于在構(gòu)建深度學(xué)習(xí)模型時(shí),使用這些層可以被視為一種標(biāo)準(zhǔn)做法,因此我無(wú)論如何都會(huì)實(shí)現(xiàn)它。我們假設(shè)數(shù)據(jù)集中有10個(gè)類(lèi),相應(yīng)地設(shè)置了NUM_classes變量。

# 代碼塊2
#(1)
BATCH_SIZE = 1
IMAGE_SIZE = 224
IN_CHANNELS = 3

#(2)
PATCH_SIZE = 16
NUM_HEADS = 12
NUM_ENCODERS = 12
EMBED_DIM = 768
MLP_SIZE = EMBED_DIM * 4 # 768*4 = 3072

#(3)
NUM_PATCHES = (IMAGE_SIZE//PATCH_SIZE) ** 2 # (224//16)**2 = 196

#(4)
DROPOUT_RATE = 0.1
NUM_CLASSES = 10

由于本文的重點(diǎn)是實(shí)現(xiàn)模型,因此我不會(huì)談?wù)撊绾斡?xùn)練它。但是,如果你想這樣做,你需要確保你的機(jī)器上安裝了GPU,因?yàn)樗梢允褂?xùn)練更快。下面的代碼塊3用于檢查PyTorch是否成功檢測(cè)到你的Nvidia GPU。

# 代碼塊3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# 代碼塊3 output
cuda

圖塊扁平化和線性投影實(shí)現(xiàn)

我之前提到過(guò),圖塊扁平化和線性投影操作可以通過(guò)使用簡(jiǎn)單的MLP或卷積層來(lái)完成。在這里,我將在PatcherUnfold()和PatcherConv()類(lèi)中實(shí)現(xiàn)它們。稍后,你可以選擇在主ViT類(lèi)中實(shí)現(xiàn)這兩個(gè)類(lèi)中的任何一個(gè)。

讓我們先從PatcherUnfold()開(kāi)始,詳細(xì)信息可以在代碼塊4中看到。在這里,我使用了一個(gè)nn.Unfold()層。在標(biāo)注有#(1)的行處可以看到,其kernel_size和步幅均為PATCH_SIZE(16)。通過(guò)這種配置,該層將對(duì)輸入圖像應(yīng)用一個(gè)不重疊的滑動(dòng)窗口。在每一步中,內(nèi)部的圖塊都會(huì)被壓平。請(qǐng)看下面的圖4,以查看此操作的圖形化展示。在該圖中,我們使用大小為2的核和步幅對(duì)大小為4×4的圖像應(yīng)用展開(kāi)操作。

# 代碼塊4
class PatcherUnfold(nn.Module):
 def __init__(self):
 super().__init__()
 self.unfold = nn.Unfold(kernel_size=PATCH_SIZE, stride=PATCH_SIZE) #(1)
 self.linear_projection = nn.Linear(in_features=IN_CHANNELS*PATCH_SIZE*PATCH_SIZE, 
 out_features=EMBED_DIM) #(2)

基于PyTorch從零實(shí)現(xiàn)視覺(jué)轉(zhuǎn)換器(ViT)?-AI.x社區(qū)

圖4:在4×4圖像上應(yīng)用具有核大小和步幅2的展開(kāi)操作

接下來(lái),使用一個(gè)標(biāo)準(zhǔn)的nn.Linear()層(#(2))進(jìn)行線性投影操作。為了使輸入與扁平化的圖塊匹配,我們需要使用In_CHANNELS*patch_SIZE*patch_SIZE作為In_features參數(shù),即16×16×3=768。然后,我使用設(shè)置大小為EMBED_DIM的out_features參數(shù)來(lái)確定投影結(jié)果維度(768)。值得注意的是,投影結(jié)果和扁平化的圖塊具有完全相同的尺寸,如ViT-B架構(gòu)所規(guī)定的。如果要實(shí)現(xiàn)ViT-L或ViT-H,則應(yīng)將投影結(jié)果維度分別更改為1024或1280,其大小可能不再與扁平化的圖塊相同。

因?yàn)閚n.Unfold()和nn.Linear()層已經(jīng)初始化,所以現(xiàn)在我們必須使用下面的forward()函數(shù)連接這些層。我們需要注意的一件事是,展開(kāi)張量的第一和第二軸需要使用permute() 方法進(jìn)行交換(#(1))。這是因?yàn)槲覀兿雽⒈馄降膱D塊視為一系列標(biāo)記,類(lèi)似于NLP模型中處理標(biāo)記的方式。我還打印出代碼塊中每個(gè)進(jìn)程的形狀,以幫助你跟蹤分析數(shù)組的維度。

# 代碼塊5
 def forward(self, x):
 print(f'original\t: {x.size()}')

 x = self.unfold(x)
 print(f'after unfold\t: {x.size()}')

 x = x.permute(0, 2, 1) #(1)
 print(f'after permute\t: {x.size()}')

 x = self.linear_projection(x)
 print(f'after lin proj\t: {x.size()}')

 return x

此時(shí),PatcherUnfold()類(lèi)已經(jīng)完成。為了檢查它是否正常工作,我們可以嘗試向它提供一個(gè)隨機(jī)值的張量,該張量模擬大小為224×224的單個(gè)RGB圖像。

# 代碼塊6
patcher_unfold = PatcherUnfold()
x = torch.randn(1, 3, 224, 224)
x = patcher_unfold(x)

你可以看到下面的輸出,我們的原始圖像已成功轉(zhuǎn)換為形狀1×196×768,其中1表示單批中的圖像數(shù)量,196表示序列長(zhǎng)度(圖塊數(shù)量),768是嵌入維度。

# 代碼塊6 輸出
original : torch.Size([1, 3, 224, 224])
after unfold : torch.Size([1, 768, 196])
after permute : torch.Size([1, 196, 768])
after lin proj : torch.Size([1, 196, 768])

這就是使用PatcherUnfold()類(lèi)實(shí)現(xiàn)圖塊扁平化展開(kāi)和線性投影的過(guò)程。我們實(shí)際上也可以使用PatcherConv()實(shí)現(xiàn)同樣的事情,代碼如下所示:

# 代碼塊7
class PatcherConv(nn.Module):
 def __init__(self):
 super().__init__()
 self.conv = nn.Conv2d(in_channels=IN_CHANNELS, 
 out_channels=EMBED_DIM, 
 kernel_size=PATCH_SIZE, 
 stride=PATCH_SIZE)

 self.flatten = nn.Flatten(start_dim=2)

 def forward(self, x):
 print(f'original\t\t: {x.size()}')

 x = self.conv(x) #(1)
 print(f'after conv\t\t: {x.size()}')

 x = self.flatten(x) #(2)
 print(f'after flatten\t\t: {x.size()}')

 x = x.permute(0, 2, 1) #(3)
 print(f'after permute\t\t: {x.size()}')

 return x

這種方法可能看起來(lái)不像前一種方法那么簡(jiǎn)單,因?yàn)樗鼘?shí)際上并沒(méi)有使圖塊變扁平。相反,它使用具有EMBED_DIM(768)個(gè)內(nèi)核的卷積層,從而產(chǎn)生具有768個(gè)通道的14×14圖像(#(1))。為了獲得與PatcherUnfold()相同的輸出維度,我們將空間維度展平(#(2)),并交換得到張量的第一和第二軸(#(3))。為此,你可以分析下面代碼塊8的輸出,并查看每一步后的詳細(xì)的張量形狀。

# 代碼塊8
patcher_conv = PatcherConv()
x = torch.randn(1, 3, 224, 224)
x = patcher_conv(x)

# 代碼塊8 output
original : torch.Size([1, 3, 224, 224])
after conv : torch.Size([1, 768, 14, 14])
after flatten : torch.Size([1, 768, 196])
after permute : torch.Size([1, 196, 768])

值得注意的是,在PatcherUnfold()中使用nn.Conv2d()實(shí)現(xiàn)單獨(dú)展開(kāi)和線性投影,相比于PatcherConv()更有效,因?yàn)樗鼘蓚€(gè)步驟組合成一個(gè)操作。

類(lèi)標(biāo)記和位置嵌入實(shí)現(xiàn)

在將所有圖塊投影到嵌入維度并排列成序列后,下一步是將類(lèi)標(biāo)記放在序列中的第一個(gè)圖塊標(biāo)記之前。此過(guò)程與PosEmbedding()類(lèi)中的位置嵌入實(shí)現(xiàn)打包在一起,如代碼塊9所示:

# 代碼塊9
class PosEmbedding(nn.Module):
 def __init__(self):
 super().__init__()
 self.class_token = nn.Parameter(torch.randn(size=(BATCH_SIZE, 1, EMBED_DIM)), 
 requires_grad=True) #(1)
 self.pos_embedding = nn.Parameter(torch.randn(size=(BATCH_SIZE, NUM_PATCHES+1, EMBED_DIM)), 
 requires_grad=True) #(2)
 self.dropout = nn.Dropout(p=DROPOUT_RATE) #(3)

類(lèi)標(biāo)記本身使用nn.Parameter()初始化。本質(zhì)上,nn.Parameter()是一個(gè)權(quán)重張量(#(1))。此張量的大小需要與嵌入維度和批大小相匹配,以便它可以與現(xiàn)有的標(biāo)記序列連接。這個(gè)張量最初包含隨機(jī)值,這些值將在訓(xùn)練過(guò)程中更新。為了允許更新它,我們需要將requires_grad參數(shù)設(shè)置為T(mén)rue。同樣,我們也需要使用nn.Parameter()來(lái)創(chuàng)建位置嵌入(#(2)),但形狀不同。在這種情況下,我們將序列維度設(shè)置為比原始序列長(zhǎng)一個(gè)標(biāo)記,以容納我們剛剛創(chuàng)建的類(lèi)標(biāo)記。不僅如此,在這里,我還使用我們之前指定的速率(#(3))初始化了一個(gè)dropout層。

之后,我將用下面代碼塊10中的forward()函數(shù)連接這些層。此函數(shù)接受的張量將使用torch.cat()與class_token連接,如#(1)標(biāo)記的行所示。接下來(lái),我們將在結(jié)果輸出和位置嵌入張量(#(2))之間執(zhí)行元素相加,然后再將其傳遞到dropout層(#(3))。

# 代碼塊10
 def forward(self, x):

 class_token = self.class_token
 print(f'class_token dim\t\t: {class_token.size()}')

 print(f'before concat\t\t: {x.size()}')
 x = torch.cat([class_token, x], dim=1) #(1)
 print(f'after concat\t\t: {x.size()}')

 x = self.pos_embedding + x #(2)
 print(f'after pos_embedding\t: {x.size()}')

 x = self.dropout(x) #(3)
 print(f'after dropout\t\t: {x.size()}')

 return x

像往常一樣,讓我們嘗試通過(guò)這個(gè)網(wǎng)絡(luò)向前傳播一個(gè)張量,看看它是否按預(yù)期工作。請(qǐng)記住,pos_embedding模型的輸入本質(zhì)上是PatcherUnfold()或PatcherConv()產(chǎn)生的張量。

# 代碼塊11
pos_embedding = PosEmbedding()
x = pos_embedding(x)

如果我們仔細(xì)看看每一步的張量維數(shù),我們可以觀察到張量x的大小最初是1×196×768。在類(lèi)標(biāo)記之前添加后,維度變?yōu)?×197×768。

# 代碼塊11輸出
class_token dim : torch.Size([1, 1, 768])
before concat : torch.Size([1, 196, 768])
after concat : torch.Size([1, 197, 768])
after pos_embedding : torch.Size([1, 197, 768])
after dropout : torch.Size([1, 197, 768])

轉(zhuǎn)換器編碼器實(shí)現(xiàn)

如果我們回顧一下圖2,可以看到轉(zhuǎn)換器編碼器塊由四個(gè)組件組成。我們將在下面顯示的TransformerEncoder()類(lèi)中定義所有這些組件。

# 代碼塊12
class TransformerEncoder(nn.Module):
 def __init__(self):
 super().__init__()

 self.norm_0 = nn.LayerNorm(EMBED_DIM) #(1)

 self.multihead_attention = nn.MultiheadAttention(EMBED_DIM, #(2) 
 num_heads=NUM_HEADS, 
 batch_first=True, 
 dropout=DROPOUT_RATE)

 self.norm_1 = nn.LayerNorm(EMBED_DIM) #(3)

 self.mlp = nn.Sequential( #(4)
 nn.Linear(in_features=EMBED_DIM, out_features=MLP_SIZE), #(5)
 nn.GELU(), 
 nn.Dropout(p=DROPOUT_RATE), 
 nn.Linear(in_features=MLP_SIZE, out_features=EMBED_DIM), #(6) 
 nn.Dropout(p=DROPOUT_RATE)
 )

標(biāo)記為#(1)和#(3)的行處的兩個(gè)歸一化步驟是使用nn.LayerNorm()實(shí)現(xiàn)的。請(qǐng)記住,我們?cè)谶@里使用的層規(guī)一化不同于我們?cè)贑NN中常見(jiàn)的批規(guī)一化。批歸一化是通過(guò)對(duì)批中所有樣本中單個(gè)特征內(nèi)的值進(jìn)行歸一化來(lái)實(shí)現(xiàn)的。同時(shí),在層歸一化中,單個(gè)樣本中的所有特征都將被歸一化。請(qǐng)看下圖5,以更好地說(shuō)明這一概念。在這個(gè)例子中,我們假設(shè)每一行代表一個(gè)樣本,而每一列都是一個(gè)特征。相同顏色的單元格表示它們的值一起歸一化。

基于PyTorch從零實(shí)現(xiàn)視覺(jué)轉(zhuǎn)換器(ViT)?-AI.x社區(qū)


圖5:批次歸一化和層歸一化之間差異展示(批規(guī)一化在批維度上進(jìn)行規(guī)一化,而層規(guī)一化在特征維度上進(jìn)行標(biāo)準(zhǔn)化)

隨后,我們初始化一個(gè)nn.Multihead Attention()層,在代碼塊12中標(biāo)記為#(2)的行處輸入大小為EMBED_DIM(768)。batch_first參數(shù)設(shè)置為T(mén)rue,表示批處理維度位于輸入張量的第0軸。一般來(lái)說(shuō),多頭注意力本身允許模型同時(shí)捕捉圖像塊之間的各種關(guān)系。多頭注意力中的每一個(gè)頭都集中在這些關(guān)系的不同方面。稍后,該層接受三個(gè)輸入:查詢、鍵和值,這些都是計(jì)算所謂的注意力權(quán)重所必需的。通過(guò)這樣做,這一層可以了解每個(gè)圖塊應(yīng)該在多大程度上關(guān)注其他圖塊。換句話說(shuō),這種機(jī)制允許該層捕獲兩個(gè)或多個(gè)圖塊之間的關(guān)系。ViT中采用的注意力機(jī)制可以被視為整個(gè)模型的核心,因?yàn)檫@個(gè)組件本質(zhì)上是允許ViT在圖像識(shí)別任務(wù)中超越CNN性能的組件。

轉(zhuǎn)換器編碼器內(nèi)的MLP組件是使用nn.Sequential()構(gòu)造的(#(4))。在這里,我們實(shí)現(xiàn)了兩個(gè)連續(xù)的線性層,每個(gè)層后面都有一個(gè)dropout層。我們還需要將GELU激活函數(shù)放在第一個(gè)線性層之后。第二個(gè)線性層不使用激活函數(shù),因?yàn)樗哪康闹皇菍埩客队盎卦记度刖S度。

現(xiàn)在,是時(shí)候使用下面的代碼塊連接我們剛剛初始化的所有層了。

# 代碼塊13
 def forward(self, x):

 residual = x #(1)
 print(f'residual dim\t\t: {residual.size()}')

 x = self.norm_0(x) #(2)
 print(f'after norm\t\t: {x.size()}')

 x = self.multihead_attention(x, x, x)[0] #(3)
 print(f'after attention\t\t: {x.size()}')

 x = x + residual #(4)
 print(f'after addition\t\t: {x.size()}')

 residual = x #(5)
 print(f'residual dim\t\t: {residual.size()}')

 x = self.norm_1(x) #(6)
 print(f'after norm\t\t: {x.size()}')

 x = self.mlp(x) #(7)
 print(f'after mlp\t\t: {x.size()}')

 x = x + residual #(8)
 print(f'after addition\t\t: {x.size()}')

 return x

在上述forward()函數(shù)中,我們首先將輸入張量x存儲(chǔ)到殘差變量(#(1))中,在該變量中,它用于創(chuàng)建殘差連接。接下來(lái),我們?cè)趯⑤斎霃埩浚?(2))輸入到多頭注意力層(#(3))之前對(duì)其進(jìn)行歸一化。正如我之前提到的,這一層將查詢、鍵和值作為輸入。在這種情況下,張量x將被用作三個(gè)參數(shù)的參數(shù)。請(qǐng)注意,我在代碼的同一行也寫(xiě)了[0]。這主要是因?yàn)橐粋€(gè)nn.MultiheadAttention()對(duì)象返回兩個(gè)值:注意力輸出和注意力權(quán)重;在這種情況下,我們只需要前者。接下來(lái),在標(biāo)記為#(4)的行處,我們?cè)诙囝^注意力層的輸出和原始輸入張量之間執(zhí)行元素相加。然后,在執(zhí)行第一次殘差運(yùn)算后,我們直接用當(dāng)前張量x(#(5))更新殘差變量。在將張量饋送到MLP塊(#(7))并執(zhí)行另一個(gè)元素相加操作(#(8))之前,在第#(6)行完成第二次歸一化操作。

我們可以使用下面的代碼塊14檢查我們的轉(zhuǎn)換器編碼器塊實(shí)現(xiàn)是否正確。請(qǐng)記住,transformer_encoder模型的輸入必須是PosEmbedding()產(chǎn)生的輸出。

# 代碼塊14
transformer_encoder = TransformerEncoder()
x = transformer_encoder(x)

# 代碼塊14 output
residual dim : torch.Size([1, 197, 768])
after norm : torch.Size([1, 197, 768])
after attention : torch.Size([1, 197, 768])
after addition : torch.Size([1, 197, 768])
residual dim : torch.Size([1, 197, 768])
after norm : torch.Size([1, 197, 768])
after mlp : torch.Size([1, 197, 768])
after addition : torch.Size([1, 197, 768])

從上面的輸出中可以看出,每一步之后張量維度都沒(méi)有變化。但是,如果你仔細(xì)看看MLP塊是如何在代碼塊12中構(gòu)造的,你會(huì)發(fā)現(xiàn)它的隱藏層在#(5)標(biāo)記的行處擴(kuò)展為MLP_SIZE(3072)。然后,我們直接將其投影回其原始尺寸,即第6行的EMBED_DIM(768)。

實(shí)現(xiàn)MLP頭編程

我們要實(shí)現(xiàn)的最后一個(gè)類(lèi)是MLPHead()。就像轉(zhuǎn)換器編碼器塊內(nèi)的MLP層一樣,MLPHead()也包括一些全連接層、GELU激活函數(shù)和層規(guī)一化。這個(gè)類(lèi)的完整的實(shí)現(xiàn)代碼如下所示:

# 代碼塊15
class MLPHead(nn.Module):
 def __init__(self):
 super().__init__()

 self.norm = nn.LayerNorm(EMBED_DIM)
 self.linear_0 = nn.Linear(in_features=EMBED_DIM, 
 out_features=EMBED_DIM)
 self.gelu = nn.GELU()
 self.linear_1 = nn.Linear(in_features=EMBED_DIM, 
 out_features=NUM_CLASSES) #(1)

 def forward(self, x):
 print(f'original\t\t: {x.size()}')

 x = self.norm(x)
 print(f'after norm\t\t: {x.size()}')

 x = self.linear_0(x)
 print(f'after layer_0 mlp\t: {x.size()}')

 x = self.gelu(x)
 print(f'after gelu\t\t: {x.size()}')

 x = self.linear_1(x)
 print(f'after layer_1 mlp\t: {x.size()}')

 return x

在上面實(shí)現(xiàn)代碼中,需要注意的一點(diǎn)是,第二個(gè)全連接層基本上是整個(gè)ViT架構(gòu)的輸出(#(1))。因此,我們需要確保神經(jīng)元的數(shù)量與我們要訓(xùn)練模型的數(shù)據(jù)集中可用的種類(lèi)的數(shù)量相匹配。在這種情況下,我假設(shè)我們有EMBED_DIM(10)個(gè)類(lèi)。此外,值得注意的是,我最后沒(méi)有使用softmax層,因?yàn)樗呀?jīng)在nn網(wǎng)絡(luò)中實(shí)現(xiàn)了。如果你想真正訓(xùn)練這個(gè)模型,可以使用一下CrossEntropyLoss()。

為了測(cè)試MLPHead()模型,我們首先需要對(duì)轉(zhuǎn)換器編碼器塊產(chǎn)生的張量進(jìn)行切片,如代碼塊16中的第#(1)行所示。這是因?yàn)槲覀兿氆@取符號(hào)序列中的第0個(gè)元素,它對(duì)應(yīng)于我們之前在圖塊符號(hào)序列前面添加的類(lèi)標(biāo)記。

# 代碼塊16
x = x[:, 0] #(1)
mlp_head = MLPHead()
x = mlp_head(x)

# 代碼塊16 output
original : torch.Size([1, 768])
after norm : torch.Size([1, 768])
after layer_0 mlp : torch.Size([1, 768])
after gelu : torch.Size([1, 768])
after layer_1 mlp : torch.Size([1, 10])

當(dāng)運(yùn)行上述測(cè)試代碼時(shí),我們可以看到最終的張量形狀是1×10,這正是我們所期望的。

整個(gè)ViT架構(gòu)

此時(shí),所有ViT組件都已成功創(chuàng)建。因此,我們現(xiàn)在可以使用它們來(lái)構(gòu)建整個(gè)視覺(jué)轉(zhuǎn)換器架構(gòu)了。請(qǐng)分析一下下面的代碼塊,看看我是怎么做到的。

# 代碼塊17
class ViT(nn.Module):
 def __init__(self):
 super().__init__()

 #self.patcher = PatcherUnfold()
 self.patcher = PatcherConv() #(1) 
 self.pos_embedding = PosEmbedding()
 self.transformer_encoders = nn.Sequential(
 *[TransformerEncoder() for _ in range(NUM_ENCODERS)] #(2)
 )
 self.mlp_head = MLPHead()

 def forward(self, x):

 x = self.patcher(x)
 x = self.pos_embedding(x)
 x = self.transformer_encoders(x)
 x = x[:, 0] #(3)
 x = self.mlp_head(x)

 return x

關(guān)于上述代碼,我想強(qiáng)調(diào)幾點(diǎn)。首先,在第1行,我們可以使用PatcherUnfold()或PatcherConv(),因?yàn)樗鼈兌加邢嗤淖饔?,即?zhí)行圖塊展平和線性投影步驟。在這種情況下,我選用了后者。其次,轉(zhuǎn)換器編碼器塊將重復(fù)NUM_Encoder(12)次(#(2)),因?yàn)槲覀儗?shí)現(xiàn)如圖3所示的ViT-Base。最后,不要忘記對(duì)轉(zhuǎn)換器編碼器輸出的張量進(jìn)行切片,因?yàn)槲覀兊腗LP頭只會(huì)處理輸出的類(lèi)標(biāo)記部分(#(3))。

我們可以使用以下代碼測(cè)試ViT模型是否正常工作。

# 代碼塊18
vit = ViT().to(device)
x = torch.randn(1, 3, 224, 224).to(device)
print(vit(x).size())

你可以在這里看到,維度為1×3×224×224的輸入已轉(zhuǎn)換為1×10,這表明我們的模型按預(yù)期工作。

注意:你需要注釋掉所有打印內(nèi)容,使輸出結(jié)果看起來(lái)更簡(jiǎn)潔一些。

# 代碼塊18 輸出
torch.Size([1, 10])

此外,我們還可以使用我們?cè)诖a開(kāi)頭導(dǎo)入的summary()函數(shù)查看網(wǎng)絡(luò)的詳細(xì)結(jié)構(gòu)。你可以觀察到,參數(shù)的總數(shù)約為8600萬(wàn),與圖3中所示的數(shù)字相匹配。

# 代碼塊19
summary(vit, input_size=(1,3,224,224))

# 代碼塊19 輸出
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
ViT [1, 10] --
├─PatcherConv: 1-1 [1, 196, 768] --
│ └─Conv2d: 2-1 [1, 768, 14, 14] 590,592
│ └─Flatten: 2-2 [1, 768, 196] --
├─PosEmbedding: 1-2 [1, 197, 768] 152,064
│ └─Dropout: 2-3 [1, 197, 768] --
├─Sequential: 1-3 [1, 197, 768] --
│ └─TransformerEncoder: 2-4 [1, 197, 768] --
│ │ └─LayerNorm: 3-1 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-2 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-3 [1, 197, 768] 1,536
│ │ └─Sequential: 3-4 [1, 197, 768] 4,722,432
│ └─TransformerEncoder: 2-5 [1, 197, 768] --
│ │ └─LayerNorm: 3-5 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-6 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-7 [1, 197, 768] 1,536
│ │ └─Sequential: 3-8 [1, 197, 768] 4,722,432
│ └─TransformerEncoder: 2-6 [1, 197, 768] --
│ │ └─LayerNorm: 3-9 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-10 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-11 [1, 197, 768] 1,536
│ │ └─Sequential: 3-12 [1, 197, 768] 4,722,432
│ └─TransformerEncoder: 2-7 [1, 197, 768] --
│ │ └─LayerNorm: 3-13 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-14 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-15 [1, 197, 768] 1,536
│ │ └─Sequential: 3-16 [1, 197, 768] 4,722,432
│ └─TransformerEncoder: 2-8 [1, 197, 768] --
│ │ └─LayerNorm: 3-17 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-18 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-19 [1, 197, 768] 1,536
│ │ └─Sequential: 3-20 [1, 197, 768] 4,722,432
│ └─TransformerEncoder: 2-9 [1, 197, 768] --
│ │ └─LayerNorm: 3-21 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-22 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-23 [1, 197, 768] 1,536
│ │ └─Sequential: 3-24 [1, 197, 768] 4,722,432
│ └─TransformerEncoder: 2-10 [1, 197, 768] --
│ │ └─LayerNorm: 3-25 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-26 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-27 [1, 197, 768] 1,536
│ │ └─Sequential: 3-28 [1, 197, 768] 4,722,432
│ └─TransformerEncoder: 2-11 [1, 197, 768] --
│ │ └─LayerNorm: 3-29 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-30 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-31 [1, 197, 768] 1,536
│ │ └─Sequential: 3-32 [1, 197, 768] 4,722,432
│ └─TransformerEncoder: 2-12 [1, 197, 768] --
│ │ └─LayerNorm: 3-33 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-34 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-35 [1, 197, 768] 1,536
│ │ └─Sequential: 3-36 [1, 197, 768] 4,722,432
│ └─TransformerEncoder: 2-13 [1, 197, 768] --
│ │ └─LayerNorm: 3-37 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-38 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-39 [1, 197, 768] 1,536
│ │ └─Sequential: 3-40 [1, 197, 768] 4,722,432
│ └─TransformerEncoder: 2-14 [1, 197, 768] --
│ │ └─LayerNorm: 3-41 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-42 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-43 [1, 197, 768] 1,536
│ │ └─Sequential: 3-44 [1, 197, 768] 4,722,432
│ └─TransformerEncoder: 2-15 [1, 197, 768] --
│ │ └─LayerNorm: 3-45 [1, 197, 768] 1,536
│ │ └─MultiheadAttention: 3-46 [1, 197, 768] 2,362,368
│ │ └─LayerNorm: 3-47 [1, 197, 768] 1,536
│ │ └─Sequential: 3-48 [1, 197, 768] 4,722,432
├─MLPHead: 1-4 [1, 10] --
│ └─LayerNorm: 2-16 [1, 768] 1,536
│ └─Linear: 2-17 [1, 768] 590,592
│ └─GELU: 2-18 [1, 768] --
│ └─Linear: 2-19 [1, 10] 7,690
==========================================================================================
Total params: 86,396,938
Trainable params: 86,396,938
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 173.06
==========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 102.89
Params size (MB): 231.59
Estimated Total Size (MB): 335.08
==========================================================================================

總結(jié)

上面所有這些內(nèi)容幾乎都與視覺(jué)轉(zhuǎn)換器架構(gòu)有關(guān)。如果你發(fā)現(xiàn)代碼中存在任何錯(cuò)誤,歡迎隨時(shí)發(fā)表評(píng)論。

本文中使用的所有代碼也可以在我的GitHub存儲(chǔ)庫(kù)中找到。此代碼的鏈接地址是https://github.com/MuhammadArdiPutra/medium_articles/blob/main/Paper%20Walkthrough%20-%20Vision%20Transformer%20(ViT).ipynb。

參考資料

【1】Alexey Dosovitskiy等人?!禔n Image is Worth 16×16 Words: Transformers for Image Recognition at Scale》(一張圖頂16×16個(gè)單詞:用于大規(guī)模圖像識(shí)別的轉(zhuǎn)換器)。Arxiv,https://arxiv.org/pdf/2010.11929。

【2】林浩寧等?!禡aritime Semantic Labeling of Optical Remote Sensing Images with Multi-Scale Fully Convolutional Network》(基于多尺度全卷積網(wǎng)絡(luò)的光學(xué)遙感圖像海洋語(yǔ)義標(biāo)注)。Research Gate,https://www.researchgate.net/publication/316950618_Maritime_Semantic_Labeling_of_Optical_Remote_Sensing_Images_with_Multi-Scale_Fully_Convolutional_Network。

【3】《Vision Transformer. PyTorch》(基于PyTorch框架的視覺(jué)轉(zhuǎn)換器實(shí)現(xiàn))。
???Https://pytorch.org/vision/main/models/vision_transformer.html。??

譯者介紹

朱先忠,51CTO社區(qū)編輯,51CTO專家博客、講師,濰坊一所高校計(jì)算機(jī)教師,自由編程界老兵一枚。

原文標(biāo)題:??Paper Walkthrough: Vision Transformer (ViT)??,作者:Muhammad Ardi

?著作權(quán)歸作者所有,如需轉(zhuǎn)載,請(qǐng)注明出處,否則將追究法律責(zé)任
已于2024-8-23 10:00:49修改
收藏
回復(fù)
舉報(bào)
回復(fù)
相關(guān)推薦