The Annotated BERT注釋加量版,讀懂代碼才算讀懂了BERT 原創(chuàng)
前面我們從0實(shí)現(xiàn)了Transformer和GPT2的預(yù)訓(xùn)練過(guò)程,并且通過(guò)代碼注釋和打印數(shù)據(jù)維度使這個(gè)過(guò)程更容易理解,今天我將用同樣的方法繼續(xù)學(xué)習(xí)Bert。
原始Transformer是一個(gè)Encoder-Decoder架構(gòu),GPT是一種Decoder only模型,而B(niǎo)ert則是一種Encoder only模型,所以我們主要關(guān)注Transformer的左側(cè)部分。
后臺(tái)回復(fù)bert獲取訓(xùn)練數(shù)據(jù)集、代碼和論文下載鏈接
閱讀本文時(shí)請(qǐng)結(jié)合代碼
https://github.com/AIDajiangtang/annotated-transformer/blob/master/AnnotatedBert.ipynb
0.準(zhǔn)備訓(xùn)練數(shù)據(jù)
0.0下載數(shù)據(jù)
原始BERT使用BooksCorpus和English Wikipedia作為預(yù)訓(xùn)練數(shù)據(jù),但這個(gè)數(shù)據(jù)集太大了,我們本次使用IMDb網(wǎng)站的50,000條電影評(píng)論數(shù)據(jù)來(lái)預(yù)訓(xùn)練,它是一個(gè)包含兩列數(shù)據(jù)的csv文件,其中review列是電影評(píng)論,sentiment列是情感標(biāo)簽,即正面(positive)或負(fù)面(negative),我們本次只使用review列的電影評(píng)論。
(后臺(tái)回復(fù)bert獲取數(shù)據(jù)集下載鏈接)
下面打印出一條評(píng)論
One of the other reviewers has mentioned that after watching just 1 Oz episode you'll be hooked.
They are right, as this is exactly what happened with me.<br /><br />The first thing that struck me about Oz was its brutality and unflinching scenes of violence, which set in right from the word GO.
Trust me, this is not a show for the faint hearted or timid. This show pulls no punches with regards to drugs, sex or violence. Its is hardcore, in the classic use of the word.<br /><br />It is called OZ as that is the nickname given to the Oswald Maximum Security State Penitentary. It focuses mainly on Emerald City, an experimental section of the prison where all the cells have glass fronts and face inwards, so privacy is not high on the agenda. Em City is home to many..Aryans, Muslims, gangstas, Latinos, Christians, Italians, Irish and more....so scuffles, death stares, dodgy dealings and shady agreements are never far away.<br /><br />I would say the main appeal of the show is due to the fact that it goes where other shows wouldn't dare. Forget pretty pictures painted for mainstream audiences, forget charm, forget romance...OZ doesn't mess around. The first episode I ever saw struck me as so nasty it was surreal, I couldn't say I was ready for it, but as I watched more, I developed a taste for Oz, and got accustomed to the high levels of graphic violence. Not just violence, but injustice (crooked guards who'll be sold out for a nickel, inmates who'll kill on order and get away with it, well mannered, middle class inmates being turned into prison bitches due to their lack of street skills or prison experience) Watching Oz, you may become comfortable with what is uncomfortable viewing....thats if you can get in touch with your darker side.
ds = IMDBBertDataset(BASE_DIR.joinpath('data/imdb.csv'), ds_from=0, ds_to=1000)
為了加快訓(xùn)練,通過(guò)ds_from和ds_to參數(shù)設(shè)置只讀取前1000條評(píng)論。
0.1計(jì)算上下文長(zhǎng)度
上下文長(zhǎng)度是指輸入序列的最大長(zhǎng)度,再講Transformer和GPT2時(shí),是直接通過(guò)超參數(shù)設(shè)置的,今天我們將根據(jù)訓(xùn)練數(shù)據(jù)統(tǒng)計(jì)得出,通過(guò)pandas逐行讀取1000條數(shù)據(jù),將每條評(píng)論按'.'分割成句子,并將所有句子的長(zhǎng)度存儲(chǔ)到一個(gè)數(shù)組中。取句子長(zhǎng)度數(shù)組中第90百分位的值。
通過(guò)計(jì)算,找到最優(yōu)的句子長(zhǎng)度:27,如果樣本長(zhǎng)度大于27會(huì)被截?cái)?,小?7會(huì)用特殊字符填充。
舉個(gè)簡(jiǎn)單的例子,假設(shè)句子長(zhǎng)度數(shù)組為 [10, 20, 30, 40, 50, 60, 70, 80, 90, 100],那么第90百分位的值就是90。
0.2分詞
本次使用的是basic_english分詞方法,它是一種非常簡(jiǎn)單且直接的分詞方法,先將所有文本轉(zhuǎn)換為小寫(xiě),然后去除標(biāo)點(diǎn)符號(hào),最后按空格和標(biāo)點(diǎn)符號(hào)將文本拆分成單詞。
"Hello, world! This is an example sentence."
['hello', 'world', 'this', 'is', 'an', 'example', 'sentence']
接下來(lái)將拆分后的單詞轉(zhuǎn)換成一個(gè)數(shù)字id,這個(gè)過(guò)程需要根據(jù)訓(xùn)練數(shù)據(jù)構(gòu)造一個(gè)詞表,也就是找到訓(xùn)練數(shù)據(jù)中所有唯一單詞。
通過(guò)統(tǒng)計(jì)可知,這1000條數(shù)據(jù)包含詞匯數(shù):9626
然后將下面特殊字符加到詞表前面。
CLS = '[CLS]'
PAD = '[PAD]'
SEP = '[SEP]'
MASK = '[MASK]'
UNK = '[UNK]'
0.3構(gòu)造訓(xùn)練數(shù)據(jù)
BERT是一種Encoder only架構(gòu),每一個(gè)token會(huì)與其它所有token計(jì)算注意力,無(wú)論是它前面的還是后面的。這樣能充分吸收上下文信息,Encoder only的模型適合理解任務(wù)。
而Decoder只與它前面的token計(jì)算注意力。從這種意義上看,GPT只利用了上文,但這種自回歸的方式也有好處,就是適合生成任務(wù)。
為了學(xué)習(xí)雙向表示,除了模型結(jié)構(gòu),構(gòu)造訓(xùn)練數(shù)據(jù)方式也有所不同。
GPT是用當(dāng)前詞預(yù)測(cè)下一個(gè)詞,假設(shè)訓(xùn)練數(shù)據(jù)的token_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],context_length=4,stride=4,batch_size=2。
Input IDs: [tensor([1, 2, 3, 4]), tensor([5, 6, 7, 8])]
Target IDs: [tensor([2, 3, 4, 5]),tensor([6, 7, 8, 9])]
BERT采用兩種方式構(gòu)造預(yù)訓(xùn)練數(shù)據(jù):
MLM會(huì)隨機(jī)將一個(gè)樣本中的某些詞替換成[MASK],或者替換成詞表中的其它詞,在本例中,會(huì)替換15%的詞,其中80%替換成[MASK],20%替換成詞表中的其它詞。
NSP則是將相鄰的句子構(gòu)造成正樣本對(duì),將不相鄰的句子視為負(fù)樣本對(duì),兩個(gè)句子之間加一個(gè)[SEP]分割符。
BERT不善于生成任務(wù),那它如何完成問(wèn)答等下游任務(wù)?其實(shí),BERT會(huì)在每個(gè)樣本開(kāi)頭都會(huì)放一個(gè)[CLS] token,通過(guò)CLS輸出進(jìn)行二分類(lèi)。
知道方法后,接下來(lái)構(gòu)造訓(xùn)練數(shù)據(jù),首先遍歷這1000條電影評(píng)論文本。
以第一條評(píng)論為例
One of the other reviewers has mentioned that after watching just 1 Oz episode you'll be hooked.
They are right, as this is exactly what happened with me.<br /><br />The first thing that struck me about Oz was its brutality and unflinching scenes of violence, which set in right from the word GO.
Trust me, this is not a show for the faint hearted or timid. This show pulls no punches with regards to drugs, sex or violence. Its is hardcore, in the classic use of the word.<br /><br />It is called OZ as that is the nickname given to the Oswald Maximum Security State Penitentary. It focuses mainly on Emerald City, an experimental section of the prison where all the cells have glass fronts and face inwards, so privacy is not high on the agenda. Em City is home to many..Aryans, Muslims, gangstas, Latinos, Christians, Italians, Irish and more....so scuffles, death stares, dodgy dealings and shady agreements are never far away.<br /><br />I would say the main appeal of the show is due to the fact that it goes where other shows wouldn't dare. Forget pretty pictures painted for mainstream audiences, forget charm, forget romance...OZ doesn't mess around. The first episode I ever saw struck me as so nasty it was surreal, I couldn't say I was ready for it, but as I watched more, I developed a taste for Oz, and got accustomed to the high levels of graphic violence. Not just violence, but injustice (crooked guards who'll be sold out for a nickel, inmates who'll kill on order and get away with it, well mannered, middle class inmates being turned into prison bitches due to their lack of street skills or prison experience) Watching Oz, you may become comfortable with what is uncomfortable viewing....thats if you can get in touch with your darker side.
將該評(píng)論按照“.” 分割成句子,遍歷每個(gè)句子。
第一個(gè)句子:
One of the other reviewers has mentioned that after watching just 1 Oz episode you'll be hooked
第二個(gè)句子:
They are right, as this is exactly what happened with me.<br /><br />The first thing that struck me about Oz was its brutality and unflinching scenes of violence, which set in right from the word GO
第一個(gè)句子分詞:
['one', 'of', 'the', 'other', 'reviewers', 'has', 'mentioned', 'that', 'after', 'watching', 'just', '1', 'oz', 'episode', 'you', "'", 'll', 'be', 'hooked']
第二個(gè)句子分詞:
['they', 'are', 'right', ',', 'as', 'this', 'is', 'exactly', 'what', 'happened', 'with', 'me', '.', 'the', 'first', 'thing', 'that', 'struck', 'me', 'about', 'oz', 'was', 'its', 'brutality', 'and', 'unflinching', 'scenes', 'of', 'violence', ',', 'which', 'set', 'in', 'right', 'from', 'the', 'word', 'go']
將每個(gè)句子隨機(jī)選擇15%的單詞進(jìn)行隨機(jī)掩碼,開(kāi)頭加上[CLS],padding到上下文長(zhǎng)度27,然后將兩個(gè)句子拼接在一起,用[SEP]分割符分開(kāi)。
['[CLS]', 'one', 'of', 'the', 'other', 'reviewers', 'has', 'mentioned', '[MASK]', 'after', 'watching', 'just', '1', 'oz', 'episode', 'you', "'", '[MASK]', '[MASK]', 'hooked', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[SEP]', '[CLS]', 'they', 'are', 'right', ',', 'as', 'this', 'is', '[MASK]', 'what', 'happened', '[MASK]', 'me', '[MASK]', 'the', '[MASK]', 'financiers', 'that', 'struck', 'me', 'about', 'oz', 'was', 'its', 'brutality', 'and', 'unflinching']
根據(jù)上面掩碼句子構(gòu)造輸入掩碼,[MASK]的位置設(shè)置成Flase,其余為T(mén)rue。
[True, True, True, True, True, True, True, True, False, True, True, True, True, True, True, True, True, False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, True, True, False, True, False, True, False, False, True, True, True, True, True, True, True, True, True, True]
將帶掩碼的句子轉(zhuǎn)換成token ids,這個(gè)也是最終要輸入到模型中的X。
[0, 5, 6, 7, 8, 9, 10, 11, 2, 13, 14, 15, 16, 17, 18, 19, 20, 2, 2, 23, 1, 1, 1, 1, 1, 1, 1, 3, 0, 24, 25, 26, 27, 28, 29, 30, 2, 32, 33, 2, 35, 2, 7, 2, 32940, 12, 39, 35, 40, 17, 41, 42, 43, 44, 45]
將掩碼前的句子轉(zhuǎn)換成token ids,這個(gè)就是標(biāo)簽Y。
[0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 1, 1, 1, 1, 1, 1, 1, 3, 0, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 7, 37, 38, 12, 39, 35, 40, 17, 41, 42, 43, 44, 45]
通過(guò)模型輸出與標(biāo)簽Y計(jì)算MLM損失。
那NSP的損失呢?在構(gòu)造句子對(duì)時(shí),如果兩個(gè)句子是相鄰的,那么標(biāo)簽就是1,否則是0,最終通過(guò)[CLS]的輸出計(jì)算二分類(lèi)損失。
最終根據(jù)前1000行數(shù)據(jù)構(gòu)造了一個(gè)DataFrame,DataFrame中每一條是一個(gè)樣本,一共包含17122個(gè)樣本,每個(gè)樣本包含四列。
一個(gè)是輸入X,維度[1,55]
一個(gè)是標(biāo)簽Y,維度[1,55],
輸入掩碼,維度[1,55]
NSP分類(lèi)標(biāo)簽,0或者1。
55等于2兩個(gè)句子的長(zhǎng)度加上一個(gè)[SEP]分割符,每個(gè)句子長(zhǎng)度27。
1.預(yù)訓(xùn)練
超參數(shù)
EMB_SIZE = 64 #詞嵌入維度
HIDDEN_SIZE = 36 //
EPOCHS = 4
BATCH_SIZE = 12 #batch size
NUM_HEADS = 4 //頭的個(gè)數(shù)
根據(jù)超參數(shù)BATCH_SIZE = 12,也就是每個(gè)batch包含12個(gè)樣本,所以輸入X維度[12,55],標(biāo)簽Y維度[12,55]。
1.0詞嵌入
接下來(lái)將token ids轉(zhuǎn)換成embedding,在Bert中,每個(gè)token都涉及到三種嵌入,第一種是Token embedding,token id轉(zhuǎn)換成詞嵌入向量,第二種是位置編碼。還有一種是Segment embedding。用于表示哪個(gè)句子,0表示第一個(gè)句子,1表示第二個(gè)句子。
根據(jù)超參數(shù)EMB_SIZE = 64,所以詞嵌入維度64,Token embedding通過(guò)一個(gè)嵌入層[9626,64]將輸入[12,55]映射成[12,55,64]。
9626是詞表的大小,[9626,64]的嵌入層可以看作是有9626個(gè)位置索引的查找表,每個(gè)位置存儲(chǔ)64維向量。
位置編碼可以通過(guò)學(xué)習(xí)的方式獲得,也可以通過(guò)固定計(jì)算方式獲得,本次采用固定計(jì)算方式。
Segment embedding和輸入X大小一致,第一個(gè)句子對(duì)應(yīng)為0,第二個(gè)位置為1。
最后將三個(gè)embedding相加,然后將輸出的embedding[12,55,64]輸入到編碼器中。
1.1多頭注意力
編碼器的第一個(gè)操作是多頭注意力,與Transformer和GPT中不同的是,不計(jì)算[PAD]的注意力,會(huì)將[PAD]對(duì)應(yīng)位置的注意力分?jǐn)?shù)設(shè)置為一個(gè)非常小的值,使之經(jīng)過(guò)softmax后為0。
多頭注意力的輸出維度[12,55,64]。
1.2MLP
與Transformer和GPT中的一致,MLP的輸出維度[12,55,64]。
1.3輸出
編碼器的輸出[12,55,64],接下來(lái)通過(guò)與標(biāo)簽計(jì)算損失來(lái)更新參數(shù)。
MLM損失
將Encoder的輸出[12,55,64]通過(guò)一個(gè)線性層[64,9626]映射成概率分布[12,55,9626]。
因?yàn)橹恍枰?jì)算[MASK]對(duì)應(yīng)位置的損失,所以會(huì)通過(guò)一些技巧將標(biāo)簽和輸出中,非[MASK]位置設(shè)置為0。
最后與輸出標(biāo)簽Y計(jì)算多分類(lèi)交叉熵?fù)p失。
NSP損失
通過(guò)另一個(gè)線性層[64,2]將開(kāi)頭的[CLS]的輸出[12,64]映射成[12,2],表示屬于正負(fù)類(lèi)的概率,然后與標(biāo)簽計(jì)算交叉熵?fù)p失。
2.0推理
最簡(jiǎn)單的是完形填空,輸入一段文本[1,55],然后將某些詞替換成[MASK],將[MASK]的輸出通過(guò)一個(gè)輸出頭映射成[1,9626]。
因?yàn)槲覀冊(cè)陬A(yù)訓(xùn)練時(shí)使用了“next sentence prediction”(NSP),可以構(gòu)造一個(gè)閉集VQA,就是為一個(gè)問(wèn)題事先準(zhǔn)備幾個(gè)答案,分別將問(wèn)題和答案拼接在一起輸入到BERT,通過(guò)[CLS]的輸出去分類(lèi)。
或者去預(yù)測(cè)答案的起始和終止位置,這就涉及到下游任務(wù)的微調(diào)了。
總結(jié)
至此,我們已經(jīng)完成了GPT2和BERT的預(yù)訓(xùn)練過(guò)程,為了讓模型能跟隨人類(lèi)指令,后面還要對(duì)預(yù)訓(xùn)練模型進(jìn)行指令微調(diào)。
參考
??https://arxiv.org/pdf/1810.04805??
??https://github.com/coaxsoft/pytorch_bert??
??https://towardsdatascience.com/a-complete-guide-to-bert-with-code-9f87602e4a11??
??https://medium.com/data-and-beyond/complete-guide-to-building-bert-model-from-sratch-3e6562228891??
??https://coaxsoft.com/blog/building-bert-with-pytorch-from-scratch??
本文轉(zhuǎn)載自公眾號(hào)人工智能大講堂
