Transformers基本原理—Decoder如何進(jìn)行解碼?
一、Transformers整體架構(gòu)概述
Transformers 是一種基于自注意力機(jī)制的架構(gòu),最初在2017年由Vaswani等人在論文《Attention Is All You Need》中提出。這種架構(gòu)徹底改變了自然語(yǔ)言處理(NLP)領(lǐng)域,因?yàn)樗軌蛴行У靥幚硇蛄袛?shù)據(jù),并且能夠捕捉長(zhǎng)距離依賴關(guān)系。
Transformers整體架構(gòu)如下:
主要架構(gòu)由左側(cè)的編碼器(Encoder)和右側(cè)的解碼器(Decoder)構(gòu)成。
對(duì)于Decoder主要做了一件什么事兒,可參考之前的文章。
本次我們主要來(lái)看解碼器如何工作。
二、編碼器的輸出
假設(shè)有這樣一個(gè)任務(wù),我們要將德文翻譯為英文,這個(gè)就會(huì)使用到Encoder-Decoder整個(gè)結(jié)構(gòu)。
德文:ich mochte ein bier
英文:i want a beer
為了更好地模擬真實(shí)情況(每一句話長(zhǎng)度不可能相同),我們需要將德文進(jìn)行Padding,填充1個(gè)P符號(hào);英文目標(biāo)增加一個(gè)End,表示這句話翻譯完了,用E符號(hào)表示。那么這兩句話就變?yōu)?/span>
德文:ich mochte ein bier P
英文:i want a beer E
這個(gè)Encoder結(jié)束后,會(huì)給我們一個(gè)輸出向量,這個(gè)向量將會(huì)在Decoder結(jié)構(gòu)中會(huì)形成K和V進(jìn)行使用。
Encoder的輸出可以當(dāng)作真實(shí)序列的高級(jí)表達(dá),或者說(shuō)是對(duì)真實(shí)序列的高級(jí)拆分。
三、解碼器如何工作
在Transformer模型中,解碼器(Decoder)的主要作用是通過(guò)自回歸模型生成目標(biāo)序列。
自回歸的含義是這樣的:
在預(yù)測(cè) “東方紅,太陽(yáng)升?!?的時(shí)候:
會(huì)優(yōu)先傳入 “東”,然后預(yù)測(cè)出來(lái)“方”,
接著傳入“東方”, 然后預(yù)測(cè)出來(lái)“東方紅”,
接著傳入“東方紅”,然后預(yù)測(cè)出來(lái)“,”,
接著再傳入“東方紅,”,然后預(yù)測(cè)出來(lái)“太”
...
直到最后預(yù)測(cè)出來(lái)整句話:“東方紅,太陽(yáng)升?!?。
所以,自回歸就是基于已出現(xiàn)的詞預(yù)測(cè)未來(lái)的詞
接下來(lái)我們先來(lái)搞明白解碼器具體做了一件什么事兒,忽略如何做的細(xì)節(jié)。
四、解碼器輸入
4.1 (Shifted)Outputs
解碼器的輸入基本上與編碼器輸入類似,如果忘記了可以去查看:Transformers基本原理—Encoder如何進(jìn)行編碼(1)
4.1.1 如何理解(Shifted)Outputs
(Shifted)Outputs表示輸入序列詞的偏移。
對(duì)于Transformer結(jié)構(gòu)來(lái)說(shuō),我們的1組數(shù)據(jù)應(yīng)包含3個(gè):Encoder輸入是1個(gè),Decoder輸入是1個(gè),目標(biāo)是1個(gè)。
比如:
Encoder輸入enc_inputs為德文 “ich mochte ein bier P”
Decoder輸入dec_inputs為英文 “S i want a beer”
優(yōu)化目標(biāo)輸入target_inputs為英文 “i want a beer E”
對(duì)于Decoder來(lái)說(shuō),就是要通過(guò)“S”預(yù)測(cè)目標(biāo)“i”,通過(guò)“S i”預(yù)測(cè)目標(biāo)“want”,通過(guò)“S i want”預(yù)測(cè)“a”,通過(guò)“S i want a”預(yù)測(cè)“beer”,通過(guò)“S i want a beer”預(yù)測(cè)“E”。
S表示給序列一個(gè)開(kāi)始預(yù)測(cè)的提示,E表示給序列一個(gè)停止預(yù)測(cè)的提示。
4.1.2 如何完成(Shifted)Outputs
這里的處理方法與Encoder是一樣,但是一定要注意輸入的是帶偏移的序列。
我們定義1個(gè)大詞表,它包含了輸入所有的詞的位置,如下所示:
tgt_vocab = {'P': 0, 'i': 1, 'want': 2, 'a': 3, 'beer': 4, 'S': 5, 'E': 6}
那對(duì)應(yīng)于 英文 “S i want a beer” ,其編碼端的輸入就是:
dec_inputs: tensor([[5, 1, 2, 3, 4]])
接著我們將這5個(gè)詞通過(guò)nn.Embedding給映射到高維空間,形成詞向量,為了方便展示,我們映射到6維空間(一般映射到2的n次方維度,最高512維)。
此處的計(jì)算邏輯與Encoder一模一樣,不再贅述,直接附結(jié)果。
這樣我們就完成了輸入文本的詞嵌入。
4.2 位置嵌入
位置嵌入原理與Encoder中一模一樣,此處不再贅述,可參考:Transformers基本原理—Encoder如何進(jìn)行編碼(1)?
這樣我們就完成了輸入文本的位置嵌入。
4.3 融合詞嵌入和位置編碼
由于兩個(gè)向量shape都是一樣的,因此直接相加即可,結(jié)果如下:
五、帶掩碼的多頭注意力機(jī)制
在Decoder的多頭注意力機(jī)制中,QKV的計(jì)算方式與Encoder中一模一樣,重點(diǎn)只需要關(guān)注掩碼的方式即可。
由于Encoder自回歸的原因,所以,此處的掩碼應(yīng)遵循預(yù)測(cè)時(shí)不看到未來(lái)信息且不關(guān)注Padding符號(hào)的原則。
5.1 不看未來(lái)信息
以“S i want a beer”為例:
在通過(guò)“S”預(yù)測(cè)時(shí),不應(yīng)該注意后面的“i want a beer”;在通過(guò)“S i”預(yù)測(cè)時(shí),不應(yīng)該注意后面的“want a beer”。
因此這個(gè)注意力矩陣就應(yīng)該遵循這個(gè)規(guī)律:將看不到的詞的注意力設(shè)置為0。
而這個(gè)我們恰巧可以通過(guò)上三角矩陣來(lái)完成,具體代碼如下:
def get_attn_subsequent_mask(seq):
# 取出傳進(jìn)來(lái)的向量的三個(gè)維度,分別是batch_size,seq_length,d_model
attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
# 通過(guò)np.triu形成一個(gè)上三角矩陣,填充為1
subsequent_mask = np.triu(np.ones(attn_shape), k=1)
# 由array轉(zhuǎn)為torch的tensor格式
subsequent_mask = torch.from_numpy(subsequent_mask).byte()
return subsequent_mask
5.2 不看Padding的詞
這里的處理方式與Encoder一致,主要為了讓對(duì)Padding的注意力降為0.
以“S i want P P”為例,具體代碼如下:
# 需要判斷哪些位置是填充的
def get_attn_pad_mask(seq_q, seq_k): # enc_inputs, enc_inputs 告訴后面的句子后面的層 那些是被pad符號(hào)填充的 pad的目的是讓batch里面的每一行長(zhǎng)度一致
# 獲取seq_q和seq_k的batch_size和長(zhǎng)度
batch_size, len_q = seq_q.size()
batch_size, len_k = seq_k.size()
# 判斷輸入的位置編碼是否是0并升維度,0表明這個(gè)并不是有效的詞,只是填充或者標(biāo)點(diǎn)等,后續(xù)計(jì)算自注意力時(shí)不關(guān)注這些詞
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)
# batch_size x 1 x len_k(=len_q), one is maskingtensor([[[False, False, False, False, True]]]) True代表為pad符號(hào)
# 擴(kuò)展mask矩陣,使其維度為batch_size x len_q x len_k
# (1,1,5) ->> (1,5,5)
return pad_attn_mask.expand(batch_size, len_q, len_k)
如果序列為“S i want a beer”為例,則掩碼矩陣如下:
5.3 合并的masked矩陣
顯然,我們?cè)趯?duì)序列進(jìn)行編碼的時(shí)候,肯定希望這個(gè)序列既不能偷看未來(lái)信息,也不能關(guān)注Padding信息,因此,需要將兩個(gè)掩碼進(jìn)行融),得到self-attention的mask。
具體做法為:兩個(gè)mask矩陣直接求和,大于0的地方說(shuō)明要么是未來(lái)信息,要么是Padding信息。
“S i want P P”的合并mask矩陣就如下所示:
“S i want a beer”的合并mask矩陣就如下所示:
這樣,我們就得到了最終的mask矩陣,只需要將融合的mask矩陣的在計(jì)算注意力時(shí)填充為-1e10即可,這樣經(jīng)過(guò)Softmax之后得到的注意力就約等于0,不會(huì)去關(guān)注這部分信息。
5.4 多頭注意力機(jī)制
將【4.3 融合詞嵌入和位置編碼】的結(jié)果當(dāng)作X輸入Decoder的多頭進(jìn)行計(jì)算即可,只需要在計(jì)算自注意力的時(shí)候根據(jù)合并的masked矩陣進(jìn)行掩碼填充,防止關(guān)注不需要關(guān)注的信息即可。
第1步:QKV的計(jì)算
X 為我們輸入的詞嵌入,假設(shè)此處的shape為(2,4)。
WQ、WK、WV矩陣為優(yōu)化目標(biāo),WQ、WK矩陣的維度一定相同,此處均為(4,3);WV矩陣維度可以相同也可以不同,此處為(4,3)。
Q、K、V經(jīng)過(guò)矩陣運(yùn)算后,shape均為(2,3)。
第2步:計(jì)算最終輸出
1、Q矩陣與K矩陣的轉(zhuǎn)置相乘,得到詞與詞之間的注意力分?jǐn)?shù)矩陣,位于(i,j)位置表示第i個(gè)詞與第j個(gè)詞之間的注意力分?jǐn)?shù)。
矩陣運(yùn)算:(2,3) * (3,2) ->> (2,2)
2、歸一化注意力分?jǐn)?shù)后,矩陣shape不會(huì)改變,只改變內(nèi)部數(shù)值大小?!靖鶕?jù)mask的填充就發(fā)生在Softmax前,確保不關(guān)注未來(lái)信息與padding信息】
矩陣大?。?2,2)
3、與V矩陣相乘,得到最終輸出Z矩陣。
矩陣運(yùn)算:(2,2) * (2,3) ->> (2,3)
Z矩陣就是原始輸入經(jīng)過(guò)變換后的輸出,且整個(gè)過(guò)程中只需要優(yōu)化WQ、WK、WV矩陣,計(jì)算量大大減少。
這樣,我們就完成了Decoder的第一個(gè)多頭計(jì)算。
六、殘差連接和層歸一化
6.1 殘差連接
為了解決深層網(wǎng)絡(luò)退化問(wèn)題,殘差連接是經(jīng)常被使用的,在代碼內(nèi)也就一行:
dec_outputs = output + residual
dec_outputs:最終輸出,判斷原始輸出好還是經(jīng)過(guò)多頭后的輸出好
output:經(jīng)過(guò)多頭后的輸出
residual:原始輸入
6.2 層歸一化
層歸一化只需要將最終輸出做一下LayerNorm即可,也是一行代碼即可完成。
dec_outputs = nn.LayerNorm(output + residual)
這樣,我們就完成了殘差連接與層歸一化,如圖所示:
需要注意的就是本層的輸出將會(huì)作為Q,傳遞給下一個(gè)多頭。
七、第二層多頭自注意力機(jī)制
這一層的數(shù)據(jù)來(lái)自于兩個(gè)位置:一個(gè)是Encoder的輸出,一個(gè)是Decoder的第一層多頭輸出。
7.1 Encoder的輸出
在編碼端,會(huì)產(chǎn)生一個(gè)對(duì)原始輸入的高級(jí)表達(dá),假設(shè)變量名為:enc_outputs。
這個(gè)編碼端的輸出將會(huì)被當(dāng)作K矩陣和V矩陣使用。
其中,enc_outputs既是K又是V,兩者完全一樣。
同時(shí),還需要將Encoder的原始輸入傳遞給第二層多頭,需要根據(jù)這個(gè)原始輸入形成掩碼矩陣,讓解碼時(shí)關(guān)注K,以確保解碼器在生成每個(gè)詞時(shí),只利用編碼器輸出中有效的、相關(guān)的信息。
7.2 Decoder第一層多頭輸出
Decoder第一層多頭輸出的結(jié)果就是對(duì)原始解碼輸入的抽象表達(dá),會(huì)被當(dāng)作Q輸入第二層多頭內(nèi)。
7.3 Encoder—Decoder邏輯
當(dāng)Q給多頭后,會(huì)通過(guò)與Encoder的K進(jìn)行點(diǎn)乘,得到與要查詢(Q)的注意力,判斷哪些內(nèi)容與Q類似,最后再與V計(jì)算,得到最終的輸出。
一個(gè)通俗的理解就是你買了一個(gè)變形金剛(transformer是變形金剛)的玩具,拿到的東西會(huì)有:各個(gè)零件、組裝說(shuō)明書(shū)。
Encoder的結(jié)果就是提供了各個(gè)零件的組裝說(shuō)明書(shū);Decoder就是各個(gè)零件的查詢。
當(dāng)我們想要將其組裝好時(shí),會(huì)在Decoder內(nèi)查詢(Q)每個(gè)零件的用法,通過(guò)組裝說(shuō)明書(shū)(Encoder的結(jié)果),尋找與當(dāng)前零件最相似的說(shuō)明(尋找高注意力的零件:Q · K^T),最后根據(jù)零件用法進(jìn)行組裝(輸出最后結(jié)果)成一個(gè)變形金剛。
當(dāng)然,這里的QKV的流程是人類賦予的,只是與現(xiàn)實(shí)中的數(shù)據(jù)庫(kù)查表過(guò)程類似,與組裝變形金剛類似,讓我們更加容易接受的一種說(shuō)法。
個(gè)人對(duì)Encoder—Decoder的理解:
Encoder就是對(duì)原始輸入的高階抽象表達(dá)(低級(jí)表達(dá)的高級(jí)抽象),比如不同語(yǔ)言雖然表面上不一樣(各種各樣的寫(xiě)法),但是在更高維的向量空間是相似的,存在一套大一統(tǒng)高階向量,只是我們無(wú)法用語(yǔ)言來(lái)描述。
就像有些數(shù)據(jù)在二維不可分,但是落在高維空間就可分了一樣,我們無(wú)法觀察更高維的空間,但它們是存在的。
Decoder就是當(dāng)我們?cè)诟呔S空間查詢(Q)時(shí),就能找到相似(注意力強(qiáng)弱)的高級(jí)抽象表達(dá)(Q · K^T),然后通過(guò)V輸出為低級(jí)表達(dá),讓我們能看得懂。
正如作為三維生物的我們無(wú)法理解更高維空間,需要一個(gè)中介將高維空間轉(zhuǎn)化為低維空間表達(dá);或者我們作為三維生物,如果二維生物可溝通,我們亦可當(dāng)作中介將高維空間翻譯給低維生物,我們就是這個(gè)多頭機(jī)制,我們就是這個(gè)中介。
所以,Encoder—Decoder機(jī)制更像是一個(gè)大的翻譯,大的中介,形成了高維與低維通信的方法。
7.4 Decoder第二層多頭的掩碼機(jī)制
由于Q代表當(dāng)前要生成的詞,而K代表編碼器中的詞,我們要去零件庫(kù)尋找與Q相似的零件,那么我們希望的肯定是這個(gè)零件是有效的,以確保解碼器在生成每個(gè)詞時(shí),只利用編碼器輸出中有效的、相關(guān)的信息。
因此,我們?cè)谛纬裳诖a矩陣時(shí),主要根據(jù)Encoder的輸入序列來(lái)形成掩碼矩陣,而不是Decoder輸入的序列,當(dāng)然,Q肯定也有Padding,但這里我們不必過(guò)度關(guān)注。
所以,在第二層多頭掩碼時(shí),我們更關(guān)注K的Padding掩碼,確保解碼得到的信息也是有效的。
具體代碼如下:
def get_attn_pad_mask(seq_q, seq_k):
batch_size, len_q = seq_q.size()
batch_size, len_k = seq_k.size()
# eq(zero) is PAD token
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # batch_size x 1 x len_k(=len_q), one is masking
return pad_attn_mask.expand(batch_size, len_q, len_k) # batch_size x len_q x len_k
比如,輸入序列為:[1,2,3,4,0],最后一位是Padding,那么掩碼矩陣如下圖示例:
7.5 Decoder第二層多頭的輸出
至此,我們就將Q進(jìn)行了解碼輸出,得到了與Q最為相似的表達(dá)。
如圖所示步驟:
八、殘差連接和層歸一化
接著就是再次殘差連接與層歸一化,與上述步驟完全一樣。(為了保持連貫性,還是再寫(xiě)一遍)
8.1 殘差連接
為了解決深層網(wǎng)絡(luò)退化問(wèn)題,殘差連接是經(jīng)常被使用的,在代碼內(nèi)也就一行:
dec_outputs = output + residual
dec_outputs:最終輸出,判斷原始輸出好還是經(jīng)過(guò)多頭后的輸出好
output:經(jīng)過(guò)多頭后的輸出
residual:原始輸入
8.2 層歸一化
層歸一化只需要將最終輸出做一下LayerNorm即可,也是一行代碼即可完成。
dec_outputs = nn.LayerNorm(output + residual)
這樣,我們就完成了殘差連接與層歸一化,如圖所示:
九、前饋神經(jīng)網(wǎng)絡(luò)
該前饋神經(jīng)網(wǎng)絡(luò)接受來(lái)自歸一化后的輸出,shape為(batch_size, seq_length, dmodel)。
batch_size:批處理大小
seq_length:序列長(zhǎng)度,比如64個(gè)Token
dmodel:詞嵌入映射維度
在Transformer中,由于我們輸入的是序列數(shù)據(jù),因此Conv1d就被優(yōu)先考慮使用了,但是這里需要一定的前置知識(shí)(點(diǎn)擊可看Conv1d原理),才能懂得Transformer中FNN的工作原理。
9.1 Conv1d連接
由于在Pytorch內(nèi)Conv1d要求維度優(yōu)先,因此需要對(duì)輸入數(shù)據(jù)進(jìn)行維度轉(zhuǎn)換。
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1,bias=False)
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1,bias=False)
# 維度優(yōu)先由6維升維到d_ff維
output = nn.ReLU()(self.conv1(inputs.transpose(1, 2))) # 這里因?yàn)橐痪S卷積需要維度再前所以要交換特征位置
# 再降維到d_model維度,一般為詞嵌入維度
output = self.conv2(output).transpose(1, 2)
假設(shè)輸入的數(shù)據(jù)為(1, 5 , 6),d_ff為2048,那么整個(gè)過(guò)程的數(shù)據(jù)shape流動(dòng)就為:
(1,5,6) ->轉(zhuǎn)置>> (1,6,5)->卷積>> (1,2048,5)->卷積>>(1,6,5)->轉(zhuǎn)置>>(1,5,6)
經(jīng)過(guò)調(diào)整提取之后,輸出shape仍為(1,5,6)。
至此,前饋神經(jīng)網(wǎng)絡(luò)完成。
十、殘差連接和層歸一化
接著就是再次殘差連接與層歸一化,與上述步驟完全一樣。(為了保持連貫性,還是再寫(xiě)一遍)
10.1 殘差連接
為了解決深層網(wǎng)絡(luò)退化問(wèn)題,殘差連接是經(jīng)常被使用的,在代碼內(nèi)也就一行:
dec_outputs = output + residual
dec_outputs:最終輸出,判斷原始輸出好還是經(jīng)過(guò)多頭后的輸出好
output:經(jīng)過(guò)多頭后的輸出
residual:原始輸入
10.2 層歸一化
層歸一化只需要將最終輸出做一下LayerNorm即可,也是一行代碼即可完成。
dec_outputs = nn.LayerNorm(output + residual)
這樣,我們就完成了殘差連接與層歸一化,如圖所示:
十一、線性層及Softmax
在得到歸一化的輸出之后,經(jīng)過(guò)線性層將輸出維度映射為指定類別,比如,你的目標(biāo)詞表有1000個(gè)詞,那么就會(huì)通過(guò)softmax預(yù)測(cè)輸出的這1000個(gè)詞哪個(gè)詞概率最高,最高的即是目標(biāo)輸出。
代碼如下:
self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False)
dec_logits = self.projection(dec_outputs).view(-1, dec_logits.size(-1))
可以看到,經(jīng)過(guò)50個(gè)epoch,是能夠?qū)⒌挛某晒Ψg為英文的,至此,我們已經(jīng)將Transformer剖析完成。
圖片