從模型原理到代碼實(shí)踐,深入淺出上手 Transformer,叩開(kāi)大模型世界的大門
作者 | Plus
一、序言
作為非算法同學(xué),最近被Cursor、DeepSeek搞的有點(diǎn)焦慮,同時(shí)也非常好奇這里的原理,所以花了大量業(yè)余時(shí)間自學(xué)了Transformer并做了完整的工程實(shí)踐。希望自己心得和理解可以幫到大家~
如有錯(cuò)漏,歡迎指出~
本文都會(huì)以用Transformer做中英翻譯的具體實(shí)例進(jìn)行闡述。
二、從宏觀邏輯看Transformer
讓我們先從宏觀角度解釋一下這個(gè)架構(gòu)。
首先 Transformer也是一個(gè)神經(jīng)網(wǎng)絡(luò),神經(jīng)網(wǎng)絡(luò)的本質(zhì)是模擬人腦神經(jīng)元的思考過(guò)程,數(shù)學(xué)上是一種擬合,當(dāng)然,人腦內(nèi)部的信號(hào)處理是否連續(xù)或者可擬合我們不得而知,但Transformer在我的機(jī)器上實(shí)實(shí)在在地思考并輸出了正確的答案。
Transformer 主要是設(shè)計(jì)用來(lái)做翻譯的,分兩大塊,如上圖,左邊的編碼器和右邊的解碼器。
編碼器負(fù)責(zé)提取原文的特征, 解碼器負(fù)責(zé)提取當(dāng)前已有譯文序列的特征,并結(jié)合原文特征(編碼器解碼器的連線部分),給出下一個(gè)詞的預(yù)測(cè)。
GPT基本可以認(rèn)為就是Transformer的解碼器部分。
接下來(lái)我們分幾個(gè)部分逐步講解并附上代碼實(shí)踐,很長(zhǎng) 得慢慢看....
三、輸入和輸出
我認(rèn)為大部分文章都沒(méi)有把輸入和輸出講得很細(xì),其實(shí)理解了輸入輸出,你就基本可以理解Transformer的大半了。
1. 模型的輸入
以中英翻譯為例,Transformer的輸入分兩部分:
- 源文序列(中文) 即圖1左側(cè)部分,輸入到編碼器。 舉例: 我愛(ài)00700
- 目標(biāo)譯文(英文)即圖1右側(cè)部分,輸入到解碼器。 一開(kāi)始只有一個(gè),僅用于告訴模型開(kāi)始翻譯
2. 模型的輸出
很顯然,對(duì)于上述輸入,我們期待的輸出是 I love 00700
的作用是表示翻譯結(jié)束,因?yàn)榉g的英文不一定和輸入的中文等長(zhǎng),所以需要有一個(gè)結(jié)束符。
模型并不能一下子輸出完整的句子,他是一個(gè)詞一個(gè)詞( /token/ )吐出來(lái)的,并且每一個(gè)詞都需要作為下一詞的輸入,這也是為什么大模型都是打字機(jī)交互的原因。 具體例子:
第一個(gè)循環(huán)
編碼器輸入 我 愛(ài) 00700
解碼器輸入 <bos>
輸出 I
第二個(gè)循環(huán)
編碼器輸入 我 愛(ài) 00700
解碼器輸入 <bos> I
輸出 love
第三個(gè)循環(huán)
編碼器輸入 我 愛(ài) 00700
解碼器輸入 <bos> I love
輸出 00700
第四個(gè)循環(huán)
編碼器輸入 我 愛(ài) 00700
解碼器輸入 <bos> I love 00700
輸出 <eos>
// 輸出了結(jié)束符,翻譯完成
請(qǐng)注意上述是為了簡(jiǎn)化理解的一個(gè)陳述,實(shí)際上模型真正的輸出并不是一個(gè)詞,而是整個(gè) /詞表/ 內(nèi)任意一個(gè)詞可能的概率,也就是圖2所示的probabilities。
具體說(shuō),就是一個(gè)數(shù)組[0.2, 0.7, 0.1], 序號(hào)為0的詞的概率是0.2,序號(hào)為1的詞的概率是0.7,序號(hào)為2的詞的概率是0.1。 假設(shè)這是第二輪循環(huán),詞表是[I, love, 00700] 取最大概率0.7的詞,然后得到第二輪輸出是love, 和第一輪輸出拼起來(lái)就是 I love。
3. 詞表 & token
劃線的兩個(gè)詞 /token/ /詞表/ 你可能仍有疑惑。這正是我們真正弄清楚整個(gè)輸入輸出的關(guān)鍵。
顯然,計(jì)算機(jī)只能理解二進(jìn)制數(shù)據(jù),我們說(shuō)輸入"我愛(ài)00700"的時(shí)候,實(shí)際上輸入的是處理好的二進(jìn)制數(shù)據(jù)。
如何把句子轉(zhuǎn)化成模型可以理解的二進(jìn)制數(shù)據(jù)呢? 不妨先想想我們?cè)趺磳W(xué)英語(yǔ)的, 沒(méi)錯(cuò),背單詞?。?我們需要先認(rèn)識(shí)詞,然后理解句子,模型也是一樣的。
我們背的英文單詞表,和這里的模型 /詞表/ 其實(shí)是一個(gè)意思,當(dāng)然,形式略有不同。 詞表里的每一個(gè)詞,就是我們說(shuō)的 /token/ ,請(qǐng)注意 token 不一定是一個(gè)英語(yǔ)單詞。上述例子的英文詞表可能是: [I, lo, ve, 00, 7] 。 顯然,token不是按照空格分的。 原因是算力難以覆蓋,英文單詞有幾十萬(wàn)個(gè)(不權(quán)威),老黃聽(tīng)了都搖頭??。 而更低維度的分詞,可以壓縮詞表的數(shù)量,[I, lo, ve, 00, 7]是我瞎寫的,理解意思就行,讓DS給我們舉一個(gè)更好的例子:
假設(shè)語(yǔ)料庫(kù)包含以下高頻單詞:
- play (10次)
- player (8次)
- playing (6次)
- plays (5次)
- replay (7次)
- replaying (4次)
直接按空格分詞會(huì)生成獨(dú)立詞表:
空格分詞詞表大?。?
[play, player, playing, plays, replay, replaying]
# bpe分詞算法最終詞表(目標(biāo)大小=4)
# ?表示這個(gè)token只會(huì)出現(xiàn)在開(kāi)始,想一想 還原句子的時(shí)候你需要知道兩個(gè)token是拼起來(lái)還是插入空格
[?play, re, ing, er]
對(duì)比空格分詞, /bpe分詞算法/ (想了解的自行ds,限于篇幅不贅述)可以有效壓縮詞表大小,在大量語(yǔ)料的情況下會(huì)更明顯,40GB數(shù)據(jù)的GPT2也僅5w詞表大小
4. 嵌入(embedding)
現(xiàn)在,我們有了詞表,很容易就可以得到二進(jìn)制的序列了
還是老例子:
我 愛(ài) 00700 被編碼為 [1, 2, 3] (為了方便,這里假設(shè)分詞就這樣。)
我們可以直接輸入到網(wǎng)絡(luò)訓(xùn)練了嗎? 答案是否定的,不過(guò)我們也終于來(lái)到了有意思的地方。 在在訓(xùn)練之前需要做一個(gè) /嵌入/
來(lái)一個(gè)youtobe的動(dòng)圖看下 /嵌入/ 大概是啥。
再來(lái)一個(gè)DS的靈魂解釋:
嵌入(Embedding)的核心思想是 將復(fù)雜、高維的數(shù)據(jù)(如文字、圖像)轉(zhuǎn)化為低維、連續(xù)的數(shù)值向量,同時(shí)保留其內(nèi)在的語(yǔ)義或關(guān)系 。這種轉(zhuǎn)化讓計(jì)算機(jī)能像人類一樣“理解”數(shù)據(jù)的含義,并用數(shù)學(xué)方式計(jì)算相似性、分類或生成。
額... 有點(diǎn)抽象,似乎講了些什么,又似乎什么都沒(méi)講......
沒(méi)關(guān)系,讓我來(lái)舉一個(gè)形象的例子 觀察下面的句子:
紅色
當(dāng)紅明星
生意好紅火啊
注意"紅"這個(gè)字,發(fā)現(xiàn)了嗎,在不同語(yǔ)境(或者說(shuō)維度)它有不同的意思,如果我們直接輸入token編碼,這些信息是缺失的,模型就不太可能理解句子的意思。
那么,一個(gè)字或者說(shuō)一個(gè)token到底有多少維度的語(yǔ)義呢?不知道啊??,但是沒(méi)關(guān)系,猜一個(gè), 512。 沒(méi)錯(cuò),就是這么樸實(shí)無(wú)華!
于是我們就有了嵌入矩陣 (self.weight)
: 詞表大小(num_embeddings,如32000)
: 嵌入維度(embedding_dim,如512)
即E是一個(gè)32000x512的二維矩陣 (代碼里就是數(shù)組)
E[1,1] E[1,2] ...... E[1,512] 的值表示詞表中序號(hào)為1的詞在維度1、2、512的語(yǔ)義, 由此同樣一個(gè)“紅”的token,就可以有多種語(yǔ)義,這樣模型就可以理解詞和句子了。
當(dāng)我們輸入 我 愛(ài) 00700 被編碼為 [1, 2, 3] , 實(shí)際上我們應(yīng)該輸入:
請(qǐng)注意請(qǐng)注意,所謂一個(gè)字有不同維度的語(yǔ)義,是我現(xiàn)編的比喻,如果它幫到你理解這個(gè)東西那最最好,如果覺(jué)得比喻得不太對(duì),大可一笑了之。
繼續(xù)
E[1,1] E[1,2] ...... E[1,512] 我們可以把它看作是以原點(diǎn)為起點(diǎn)的 512維的向量 ,一般就叫做詞的嵌入向量,也就是平時(shí)可能聽(tīng)到的向量化,沒(méi)什么高大上的對(duì)吧。 向量化之后可以做什么呢?看圖
(以三維舉例,512維可以發(fā)揮想象力自行腦補(bǔ))
為了讓模型理解詞和詞的關(guān)系,我們用數(shù)學(xué)語(yǔ)言去描述這種關(guān)系,就是向量的內(nèi)積:
-
是向量a的長(zhǎng)度(范數(shù))
是向量b的長(zhǎng)度(范數(shù))
是兩個(gè)向量之間的夾角
就算忘記了向量?jī)?nèi)積怎么算也沒(méi)關(guān)系,我們只需要理解,夾角越小,向量?jī)?nèi)積越大, 兩個(gè)詞的關(guān)聯(lián)越緊密就行了。 實(shí)際上大部分復(fù)雜公式的計(jì)算都是封裝好了的。 這里的內(nèi)積,在自注意力的時(shí)候我們也會(huì)用到。
到了這里就可以解釋一下 的元素值是怎么來(lái)的了,我們定義了每一個(gè)詞有512個(gè)維度,那么每一個(gè)維度的值是什么呢?
答案是模型自己學(xué)習(xí)出來(lái)的,E是模型的參數(shù)的一部分,一個(gè)詞每個(gè)維度的值,初始化是一個(gè)隨機(jī)值,模型訓(xùn)練的時(shí)候會(huì)被更新。
怎么學(xué)出來(lái)的? 試想一下: 語(yǔ)料里面有無(wú)數(shù)的詞,但是他們是有一定的關(guān)聯(lián)的,比如 I love 這兩個(gè)詞經(jīng)常同時(shí)出現(xiàn),那么他們的內(nèi)積就應(yīng)該較大,反之也一樣。實(shí)際上有點(diǎn)像聚類,怎么理解都行,模型學(xué)得差不多的時(shí)候,token的分布或者說(shuō)嵌入向量的關(guān)系,一定是有規(guī)律的。
讓我們來(lái)用數(shù)學(xué)語(yǔ)言總結(jié)一下嵌入:
就是我們的查表或者說(shuō)嵌入操作:輸入張量
- B: 批次大小(batch_size)
- L: 序列長(zhǎng)度(sequence_length)
- D: 嵌入維度
嵌入查找操作上述例子 就是 ,L是句子長(zhǎng)度也就是3
為什么這里也會(huì)叫查表? 實(shí)際上對(duì)索引為1的token做嵌入,就是在嵌入矩陣?yán)锏秸业谝恍腥缓竽贸鰜?lái),就是 E[1,1] E[1,2] ...... E[1,512] 。這個(gè)過(guò)程不就是查表么。
5. Batch 處理
細(xì)心的你可能發(fā)現(xiàn)了問(wèn)題,這里多了一個(gè)維度B,輸入的結(jié)果 ,是三維的。這里解釋一下,我們的例子只有一句話,但是在模型的訓(xùn)練中,其實(shí)是一次性輸入多句話并行計(jì)算的,batch_size 的值就是你一次性輸入多少句子來(lái)訓(xùn)練,一般來(lái)說(shuō),這取決于你的顯存大小,自行嘗試調(diào)整就可以了。
假設(shè)batch 為2,一次輸入兩句:
兩句話的embedding矩陣合并,得到一個(gè) 2x3x512的三維矩陣 就是我們給編碼器的輸入了。 此時(shí) B=2 L=3 D=512
6. padding掩碼矩陣
還沒(méi)完,讓我們來(lái)一點(diǎn)刁鉆的問(wèn)題。 上述例子L=3,兩個(gè)句子長(zhǎng)度都一樣=3。 如果句子長(zhǎng)度不一樣呢?batch矩陣豈不是拼不出來(lái)? 當(dāng)然不會(huì),實(shí)際寫代碼的時(shí)候,L一般都是取一個(gè)較大值,比如語(yǔ)料中最長(zhǎng)句子的值。L=100。 不夠長(zhǎng)的句子怎么辦呢? 補(bǔ)padding,注意不是0,因?yàn)樯窠?jīng)網(wǎng)絡(luò)的某些中間計(jì)算如softmax輸入0也是有值的。 一般是用一個(gè)特殊的token作為padding。
在自注意力計(jì)算的時(shí)候(下文會(huì)講)會(huì)根據(jù)padding的位置,生成mask 矩陣,計(jì)算時(shí)候padding替換為一個(gè)極小值比如-1e9,就可以不影響計(jì)算了。
7. PositionalEncoding
現(xiàn)在這個(gè)帶padding的矩陣可以輸入模型開(kāi)始訓(xùn)練了嗎? 還是不行..... 我們還缺少一個(gè)重要的信息,位置。 舉一個(gè)例子:
[ [0.3, 0.5, 0.1, 0.4]
[貓,吃,魚] => Transformer => [0.1, -0.6, -0.2, 0.3],
[0.3, 0.5, 0.3, -0.1] ]
[ [0.3, 0.5, 0.1, 0.4]
[魚,吃,貓] => Transformer => [0.1, -0.6, -0.2, 0.3],
[0.3, 0.5, 0.3, -0.1] ]
很明顯,雖然只是交換了一個(gè)字的順序,但其實(shí)是兩個(gè)完全不同意義的句子, 而 Transformer輸出的分布卻是不變的,說(shuō)明網(wǎng)絡(luò)沒(méi)有辦法識(shí)別位置信息。
為了解決這個(gè)問(wèn)題,Transformer的論文里的方案是添加位置編碼,也就是PositionalEncoding。
其實(shí)沒(méi)有那么神秘,我們把 Transformer 看成函數(shù)T , 現(xiàn)在的問(wèn)題是:
我們可以加上一個(gè)位置編碼P去規(guī)避這個(gè)問(wèn)題:
: 當(dāng)前的位置序號(hào)
來(lái)一個(gè)簡(jiǎn)單粗暴易理解的 P(i) = i
貓 0
吃 1
魚 2
然后呢?就是簡(jiǎn)單的直接加上去
沒(méi)錯(cuò),這樣其實(shí)我們的輸入就包含了位置信息了,只要信息在那里,模型總能學(xué)明白
當(dāng)然,這里只是為了方便說(shuō)明位置編碼本身,所以用了最簡(jiǎn)單的方案,你說(shuō)它能不能跑,那肯定也是能跑的,就是會(huì)有不少問(wèn)題,實(shí)際上論文里實(shí)現(xiàn)是正弦torch.sin(position * div_term)和余弦函數(shù)torch.cos(position * div_term)來(lái)生成位置編碼, 解釋起來(lái)就篇幅太長(zhǎng),大家可以自己探索。
位置編碼不改變矩陣的維度,只是改變了數(shù)值, 我們給編碼器的最終輸入總結(jié)如下:
我愛(ài)我愛(ài)分詞嵌入
但是對(duì)于解碼器的輸入,還需要額外處理,繼續(xù)往下。
8. 并行計(jì)算 & Teacher Forcing
回顧下一開(kāi)始的例子
第一個(gè)循環(huán)
編碼器輸入 我 愛(ài) 00700
解碼器輸入 <bos>
輸出 I
第二個(gè)循環(huán)
編碼器輸入 我 愛(ài) 00700
解碼器輸入 <bos> I
輸出 love
第三個(gè)循環(huán)
編碼器輸入 我 愛(ài) 00700
解碼器輸入 <bos> I love
輸出 00700
第四個(gè)循環(huán)
編碼器輸入 我 愛(ài) 00700
解碼器輸入 <bos> I love 00700
輸出 <eos>
// 輸出了結(jié)束符,翻譯完成
我們給編碼器的輸入始終是整個(gè)句子序列的embedding。而給解碼器的輸入,則是當(dāng)前已經(jīng)預(yù)測(cè)出來(lái)的 n-1個(gè)詞的序列的embedding。
(1) 并行計(jì)算
當(dāng)你真的上手去寫代碼的時(shí)候,你就會(huì)發(fā)現(xiàn), 在訓(xùn)練階段,代碼里我們給解碼器直接輸入完整的句子I love 00700 ,且模型輸出直接就是整個(gè)句子。 只有一個(gè)循環(huán)。你現(xiàn)在肯定滿腦子問(wèn)號(hào),直接輸入正確答案了,還學(xué)什么?
別急,實(shí)際上我們輸入不是全部的ground truth,是 n-1序列的ground truth。怎么理解呢?真正的ground truth 是 I love 00700 eos 。 我們的輸入是I love 00700。 當(dāng)然這樣僅僅是少了最后一個(gè)token的答案而已。
實(shí)際上解碼的輸入(姑且叫trg_input)在做計(jì)算的時(shí)候還需要做一個(gè)causal mask,叫做因果掩碼,字面意思就是為了防止計(jì)算的時(shí)候知道未來(lái)的信息。
>>> trg_input # 假設(shè)我們的input長(zhǎng)這樣 1234代表 <bos> I love 00700
tensor([[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]])
# 因果掩碼矩陣是這樣的,對(duì)角線上移一位的上三角矩陣,上三角的元素是極小值
>>> mask = torch.triu(torch.ones(4, 4), diagonal=1)
>>> mask = mask.float().masked_fill(mask == 1, float(-1e9))
>>> mask
tensor([[0., -inf, -inf, -inf],
[0., 0., -inf, -inf],
[0., 0., 0., -inf],
[0., 0., 0., 0.]])
>>>
#執(zhí)行 causal mask
>>> trg_input + mask
tensor([[1., -inf, -inf, -inf],
[1., 2., -inf, -inf],
[1., 2., 3., -inf],
[1., 2., 3., 4.]])
看看causal mask結(jié)果的第一行,[1., -inf, -inf, -inf] 其實(shí)就是相當(dāng)于第一輪循環(huán)的輸入 bos, 因?yàn)闊o(wú)限小在計(jì)算的時(shí)候可以忽略。 后續(xù)的2、3、4行剛好也就是 對(duì)應(yīng)的2、3、4輪循環(huán)。 get到了嗎, 直接作為一個(gè)矩陣輸入,在模型里并行計(jì)算,4個(gè)循環(huán)優(yōu)化成了一個(gè)!能夠并行計(jì)算,這正是transformer一個(gè)重要特性。
(2) Teacher Forcing
想一想為什么我們可以實(shí)現(xiàn)并行計(jì)算? 因?yàn)橛?xùn)練的時(shí)候,我們提前知道了正確答案,在生產(chǎn)環(huán)境推理過(guò)程中,我們不可能知道第一個(gè)token應(yīng)該輸出I,第二個(gè)token應(yīng)該輸出love ...., 所以推理的時(shí)候,只能順序執(zhí)行循環(huán)。
再多想一點(diǎn),在訓(xùn)練的時(shí)候,模型還并不成熟,第一個(gè)字符輸出的未必是I,假設(shè)輸出了 He 呢? 然后第二輪循環(huán)我們的輸入就是 bos He, 這樣繼續(xù)循環(huán)下去只會(huì)越來(lái)越錯(cuò),非常不利于學(xué)習(xí)的收斂。 所以,我們第二輪循環(huán)不輸入He, 而是用正確答案 I 來(lái)繼續(xù)下一個(gè)token預(yù)測(cè),實(shí)驗(yàn)表明這樣學(xué)習(xí)的速度會(huì)更快。 這種訓(xùn)練方法,我們就稱之為Teacher Forcing。
關(guān)于并行計(jì)算和Teacher Forcing感覺(jué)大部分文章都是一筆帶過(guò),而我感覺(jué)這里很重要,包括理解架構(gòu)本身或者是去理解為什么transformer帶來(lái)了技術(shù)的爆發(fā),并行計(jì)算是一個(gè)很重要的因素。
9. 輸入輸出的代碼實(shí)現(xiàn)
作為一個(gè)嚴(yán)謹(jǐn)?shù)墓こ處?,最后我貼一下我自己的實(shí)現(xiàn),并做簡(jiǎn)單講解:
# 關(guān)于語(yǔ)料 我的數(shù)據(jù)集來(lái)自 wmt17 v13 zh-en 大約100m 30+w行長(zhǎng)句子
# 從csv讀出來(lái), 第一列是中文 第0列是英文
train_data = read_tsv_file(opt.train_tsv, 1, 0, opt.delimiter)
valid_data = read_tsv_file(opt.valid_tsv, 1, 0, opt.delimiter) if opt.valid_tsv else []
test_data = read_tsv_file(opt.test_tsv, 1, 0, opt.delimiter) if opt.test_tsv else []
# 分詞器初始化 我這里是bpe bytelevel 32000 詞表大小,具體實(shí)現(xiàn)太長(zhǎng)就不貼了
src_tokenizer_path = os.path.join(opt.data_dir, "tokenizer_zh.json")
src_tokenizer = train_tokenizer([src for src, _ in train_data], opt.vocab_size, src_tokenizer_path)
trg_tokenizer_path = os.path.join(opt.data_dir, "tokenizer_en.json")
trg_tokenizer = train_tokenizer([trg for _, trg in train_data], opt.vocab_size, trg_tokenizer_path)
# 有了分詞器之后,對(duì)訓(xùn)練數(shù)據(jù)句子分詞 打包成pkl格式,方便模型訓(xùn)練的時(shí)候讀取
processed_train = process_data(train_data, src_tokenizer, trg_tokenizer, opt.max_len)
# train階段
# 從pkl加載數(shù)據(jù) 略
# 訓(xùn)練輸入
for i, batch in enumerate(dataloader):
# 準(zhǔn)備數(shù)據(jù) 每次循環(huán)處理一個(gè)batch的數(shù)據(jù)
src = batch['src'].to(device) # 取出語(yǔ)料原文。我 愛(ài) 00700 (舉例,實(shí)際上batch是多句的數(shù)組,但是矩陣運(yùn)算的時(shí)候都是一起并行算的)
trg = batch['trg'].to(device) # 取出語(yǔ)料目標(biāo)譯文。bos I love 00700 eos
trg_input = trg[:, :-1] # 去掉最后一個(gè)token, bos I love 00700 n-1的Teacher Forcing序列
trg_output = trg[:, 1:] # 去掉第一個(gè)token, I love 00700 eos 真實(shí)的groud truth,用于計(jì)算損失
# 前向傳播 清空梯度
optimizer.zero_grad()
# 輸入模型計(jì)算一次
# teacher forcing: trg_input 實(shí)際上就是ground truth(正確的值)
# train 模式下,output是整個(gè)句子的預(yù)測(cè)值
output = model(src, trg_input)
......
# 嵌入
"""初始化嵌入矩陣"""
# 使用正態(tài)分布(高斯分布)初始化張量
# 均值為0,標(biāo)準(zhǔn)差為0.02
# 在數(shù)學(xué)上,512維的隨機(jī)向量,可以認(rèn)為是互相正交的(就是任意兩個(gè)token之間一點(diǎn)關(guān)系都沒(méi)有)
nn.init.normal_(self.weight, mean=0, std=0.02)
# 基本的嵌入查找
# weight 是模型參數(shù),自動(dòng)參與學(xué)習(xí)
embeddings = self.weight[x]
# mask
"""創(chuàng)建源序列和目標(biāo)序列的掩碼"""
src_padding_mask = (src == self.pad_idx) # src padding掩碼
trg_padding_mask = (trg == self.pad_idx) # trg padding掩碼
# 因果掩碼
seq_len = trg.size(1)
trg_mask = torch.triu(
torch.ones((seq_len, seq_len), device=src.device), diagonal=1
).bool()
return src_padding_mask, trg_padding_mask, trg_mask
# 應(yīng)用mask 我這里的mask是bool矩陣,true表示需要mask
# 如果mask是true,scores對(duì)應(yīng)位置的元素就填充為極小值 -1e9
if mask is not None:
scores = scores.masked_fill(mask == True, -1e9)
輸入輸出終于寫完 看到這里就真的不容易 感謝~!
四、自注意力機(jī)制
到這里我們已經(jīng)完成了輸入的所有前置處理,訓(xùn)練數(shù)據(jù)開(kāi)始進(jìn)入到網(wǎng)絡(luò)訓(xùn)練。 而理解Transformer網(wǎng)絡(luò)的關(guān)鍵就在于自注意力機(jī)制,如下圖所示的三個(gè)模塊,他們是整個(gè)網(wǎng)絡(luò)的核心。
1. 注意力機(jī)制
注意力其實(shí)非常好理解, 如下圖: 萬(wàn)綠叢中一點(diǎn)紅,你首先看到了紅,這就是注意力。
在訓(xùn)練網(wǎng)絡(luò)中,注意力的作用是什么呢?
老例子:“ 我 愛(ài) 00700 ” 翻譯為 “ I love () ” () 為當(dāng)前正在預(yù)測(cè)的詞。
如上圖, 前面我們介紹過(guò)Transformer預(yù)測(cè)下一個(gè)詞的時(shí)候,會(huì)結(jié)合整個(gè)上下文,即“我 愛(ài) 00700 I love”,假如沒(méi)有注意力機(jī)制,那么所有詞對(duì)預(yù)測(cè)結(jié)果的貢獻(xiàn)是一樣的! 假設(shè)語(yǔ)料的言情文偏多,那輸出很可能是 I love you, 很明顯不合理, 此時(shí)我們希望網(wǎng)絡(luò)把注意力集中在“00700”這個(gè)詞上,其他詞應(yīng)該忽略掉。
那么在數(shù)學(xué)上注意力是什么呢? 非常簡(jiǎn)單,就是權(quán)重系數(shù)寫出來(lái)就是:
- x ∈ R^(seq_len, d_model) seq_len是句子長(zhǎng)度,d_model=512是embdedding的緯度,忘了的話往回翻復(fù)習(xí)下吧
- W ∈ R^(512, 512) 是權(quán)重矩陣
??? 解釋了這么久 就這? 沒(méi)錯(cuò)是的,取個(gè)高大上的名字而已。
2. 自注意力機(jī)制
自注意力畢竟多了一個(gè)字,肯定還是有點(diǎn)不一樣的,不一樣在哪呢,他乘的不是權(quán)重矩陣,它乘它自己!
- 即權(quán)重
- x ∈ R^(seq_len, d_model) seq_len是句子長(zhǎng)度,d_model=512是embdedding的緯度,忘了的話往回翻復(fù)習(xí)下吧
重點(diǎn)看一下這個(gè)
如果你對(duì)矩陣乘法熟,你馬上就會(huì)意識(shí)到,這不就是句子的詞和詞之間的內(nèi)積嗎?
回憶一下嵌入的內(nèi)容,我們說(shuō)過(guò)詞向量的內(nèi)積反應(yīng)的是詞的關(guān)系遠(yuǎn)近。
本質(zhì)是,Transformer通過(guò)句子本身的詞和詞之間的關(guān)系計(jì)算注意力權(quán)重! 所以我們叫它 - 自注意力 !還算自洽吧~
當(dāng)然,直接相乘沒(méi)有參數(shù)可以訓(xùn)練啊,來(lái)個(gè)線性變換吧 ,于是
不好記,得起個(gè)名字, 第一個(gè)式子是原始請(qǐng)求 就叫Query吧,簡(jiǎn)寫為Q
感覺(jué)不用糾結(jié)這個(gè)命名,反正論文沒(méi)提命名的道理 自行想象吧!
第二個(gè)式子是被用來(lái)計(jì)算關(guān)系的,就叫Key吧,簡(jiǎn)寫為K
于是:
按照論文所提,實(shí)驗(yàn)發(fā)現(xiàn),QK點(diǎn)積可能會(huì)過(guò)大,數(shù)據(jù)方差太大,導(dǎo)致梯度不太穩(wěn)定,所以需要把分布均勻一下,于是:
我們需要的是權(quán)重,所以需要轉(zhuǎn)化為 0.0~1.0 這樣的值,這正是 /softmax/ 函數(shù)的能力 于是:
代入 , y改寫為Attention,更高大上
另外為了增加網(wǎng)路的復(fù)雜度,我們不直接乘x,老辦法,線性變換一下, 記為V
最終
最終我們得到了和論文一模一樣的公式~
3. 多頭注意力機(jī)制
先看一個(gè)圖,這是一個(gè)自注意力計(jì)算的示例:
問(wèn)題是: 我 - 我 愛(ài)-愛(ài) 00700-00700 的分?jǐn)?shù)最高, 自己最關(guān)注自己,聽(tīng)起來(lái)是合理的,但論文作者應(yīng)該是覺(jué)得這不利于收斂。 所以提出了多頭機(jī)制。
其實(shí)這里論文并沒(méi)有太多解釋,我理解是一種實(shí)驗(yàn)性的經(jīng)驗(yàn)。
那么什么是多頭? 其實(shí)也很簡(jiǎn)單,就是分塊計(jì)算注意力,好比原來(lái)是一個(gè)頭看整個(gè)矩陣, 現(xiàn)在是變身哪吒三頭六臂,一個(gè)頭看一小塊,然后信息合并。 畫圖說(shuō)明:
這種分頭其實(shí)本質(zhì)上計(jì)算量是一樣的,只是注意力分塊了,直接看公式就好了。
- Q, K, V 是輸入矩陣,形狀為 (batch_size, seq_len, d_model)
- W_i^Q ∈ R^(d_model × d_k)
- W_i^K ∈ R^(d_model × d_k)
- W_i^V ∈ R^(d_model × d_v)
- head_i 的形狀為 (batch_size, seq_len, d_v)
論文中 i=8, 8個(gè)頭, dk =dv = d_model/8 = 64
因果掩碼-多頭自注意力機(jī)制:
其實(shí)輸入輸出的時(shí)候就講了,就是解碼器做多頭計(jì)算的時(shí)候 使用teacher forcing,需要加上因果掩碼mask。 下面代碼說(shuō)明。
4. 交叉-多頭自注意力機(jī)制
解碼器如何結(jié)合編碼器的上下文就在這里了,如果是交叉-多頭自注意力機(jī)制,那么 Q = x, K=V=編碼器的輸出,直接看代碼最直觀。
# 1. 線性變換 這里的 query=key=value 就是x
# 如果是解碼器的 交叉-多頭自注意力機(jī)制,那么 query = x, key=value=編碼器的輸出
Q = self.q_linear(query) # 這里的linear就是線性變換 wx+b,下同
K = self.k_linear(key)
V = self.v_linear(value)
# 2. 分割成多頭 [batch_size, seq_len, d_model] -> [batch_size, seq_len, num_heads, d_k]
# 實(shí)際上 torch的張量(矩陣) 是用一緯數(shù)組存的,所謂分割多頭、轉(zhuǎn)置,最終就是改一下數(shù)組元素的順序
Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 3. 計(jì)算注意力
# K.transpose(-2, -1) 就是 k的轉(zhuǎn)置
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# 4. 應(yīng)用mask 因果掩碼 或者 padding 掩碼
if mask is not None:
scores = scores.masked_fill(mask == True, -1e9)
# 5. softmax獲取注意力權(quán)重
attn = F.softmax(scores, dim=-1)
# dropout是隨機(jī)丟棄一些值(寫0),為了增加隨機(jī)性,訓(xùn)練時(shí)才會(huì)開(kāi)啟,避免過(guò)擬合
attn = self.dropout(attn)
# 6. 注意力加權(quán)求和 就是乘v
out = torch.matmul(attn, V)
# 7. 重新拼起來(lái)多頭 把維度轉(zhuǎn)置回來(lái)
out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# 8. 最后又加了一個(gè)線性變換 wx+b, 不要為問(wèn)我為什么 ...... 實(shí)驗(yàn)性經(jīng)驗(yàn)
out = self.out_linear(out)
return out
五、前向傳播
到這里就很簡(jiǎn)單了,核心的部分都講完了。 還剩下幾個(gè)小塊簡(jiǎn)單講一下。
1. Add & Norm
Add 叫殘差,超級(jí)簡(jiǎn)單,直接上代碼:
# self_attn 就是算注意力,tgt_mask是因果掩碼。
# 本來(lái)應(yīng)該是 x = self.dropout1(self.self_attn(x2, x2, x2, tgt_mask))
# 殘差就是多加了 x, x = x + self.dropout1(self.self_attn(x2, x2, x2, tgt_mask))
# 簡(jiǎn)單解釋就是 加這個(gè)x,防止梯度下降太快到后面沒(méi)了。詳細(xì)就不展開(kāi)了,可自行g(shù)pt
x = x + self.dropout1(self.self_attn(x2, x2, x2, tgt_mask))
Norm 是歸一化,簡(jiǎn)單解釋一下,歸一化是讓數(shù)據(jù)分布更均衡,一般是消除一些離譜值、不同數(shù)據(jù)量綱的影響。 比如分析年齡和資產(chǎn)對(duì)相親成功率的影響,年齡只有0-100, 資產(chǎn)的范圍就很大,計(jì)算的時(shí)候就不好弄,需要把資產(chǎn)也縮放到一個(gè)合理的數(shù)值范圍 比如 0-100 萬(wàn)。
歸一化有很多種,這里用的是 /Layernorm/ ,這個(gè)公式還挺復(fù)雜,有興趣可以自己研究,這里我是偷懶的:
# 直接使用 pytorch的自帶公式
self.norm1 = nn.LayerNorm(d_model)
# 調(diào)包就完事
x2 = self.norm1(x)
2. FeedForward
前饋全連接層(feed-forward linear layer) ,好多文章試圖解釋這個(gè)意義,其實(shí)沒(méi)啥意義,就是普通的兩層網(wǎng)絡(luò), 一般來(lái)說(shuō)就是一層神經(jīng)網(wǎng)絡(luò)就是 線性變換+激活函數(shù)。
線性變換,無(wú)非就是升維和降維。舉例:升維是類似于你看圖片的時(shí)候,雙指放大,然后滑來(lái)滑去找到你需要的,是特征放大。降維么,就是截圖保存放大部分,然后輸出,是特征提取。
一升一降就是FeedForward啦,直接看代碼:
class FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
# 從dmodel升維到d_ff = 2048 論文的值
self.linear1 = nn.Linear(d_model, d_ff)
# 降回來(lái)提取特征
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 怎么找到關(guān)注特征的 那不就是激活函數(shù) F.relu ,很明顯,神經(jīng)元被激活的地方就是感興趣的特征啊。
return self.linear2(self.dropout(F.relu(self.linear1(x))))
3. 解碼器
到這里就沒(méi)啥了,每個(gè)塊我們都理解了,后續(xù)就是實(shí)現(xiàn)邏輯了。 直接上代碼,太長(zhǎng)了好像沒(méi)必要,文末尾放git地址把。
4. 編碼器
同上。
六、反向傳播
前向傳播計(jì)算出一次訓(xùn)練的結(jié)果,反向傳播就是根據(jù)結(jié)果的好or壞,更新參數(shù),循環(huán)往復(fù),最終得到一個(gè)滿意的模型。
自己實(shí)現(xiàn) 實(shí)在是有點(diǎn)累,調(diào)包只需要一句:
# 前向傳播
output = model(src, trg_input)
......
# 計(jì)算損失
loss = criterion(output_flat, trg_output_flat)
# 反向傳播
loss.backward()
pytorch框架會(huì)記錄你整個(gè)模型前向傳播的圖,然后根據(jù)損失,使用 /鏈?zhǔn)椒▌t / & /梯度下降算法/ 幫你回溯計(jì)算更新每一個(gè)參數(shù),太舒服了~