在機器翻譯等通用任務場景下,如何最有效地引用對比學習?
對比學習和文本生成的結合并不是一個新話題。但是,之前的大多方法都局限于某些特定的任務場景。例如,在一個對話的場景中,可能需要利用對比學習,去區(qū)分說話者,或者說話的主題,達到更好的表示學習的效果。在摘要中,也有一些工作通過構造具有事實性錯誤的負樣本來使用對比學習,增強生成的摘要和原文的一致性。然而,對于比較通用的任務(以使用 Transformer的編碼器和解碼器為生成模型為例),在用于機器翻譯、摘要、數據到文本生成的各種任務下,如何去引用對比學習才最有效?為什么要去用對比學習?關于這方面的研究比較少,本文將就此進行討論。
今天的介紹會圍繞下面四點展開:
- 動機
- 方法
- 實驗
- 討論
01 動機
先來講一下為什么要使用對比學習。
1. 為什么在文本生成上應用對比學習
首先,對比學習是一種很好的表示學習的方式,尤其是在CV的場景下,對比學習更是非?;?,在文本生成任務場景下,如果可以去構造出對于這個任務有意義的、有價值的樣本,可以幫助模型通過不同樣本之間的比較,學到更好的意義和表示。
其次,最近有研究表明,對比學習是有助于緩解曝光偏差問題的一個新思路。所謂曝光偏差,就是指目前大多數的生成框架(大多基于最大似然估計進行訓練的)存在著測試和訓練的不一致性,這個不一致性將會損害模型的泛化性能。模型在訓練階段解碼器只曝光給了正確的輸入,而在測試階段模型不得不基于自己生成的字符來預測,由此形成了測試和訓練的偏差。之前已經有很多工作來解決這個問題,比較有名的就是scheduled-sampling:既然曝光偏差是由于訓練和測試的不一致導致的,那就讓模型在訓練的時候也以一定概率和測試采取同樣的機制。也就是說,以一定概率利用上一步預測的詞語指導下一步的生成。
除此之外,還有一些比較有名的方法,如基于強化學習,生成對抗網絡等。除了token-level監(jiān)督和最大似然的訓練目標以外,還讓模型去顯示的優(yōu)化一個難以微分的目標。但是這兩種技術,在實現中存在著一定的難度,如果不是一個富有經驗的研究者,可能訓練出來的基于強化學習或者生成對抗網絡的模型還不如一個純粹的MLE模型訓練的效果好。
2. 應用對比學習可以緩解自回歸模型的曝光偏差問題
?
對比學習是如何解決這個問題的??
首先回顧下對比學習的目的。對比學習就是在表示上把正例拉近,把負例拉遠。在生成的場景下,對正負樣本的一個非常直觀的定義就是,把比人寫的質量高的樣本當做正例。以翻譯任務為例,人翻譯的結果就是正例,然后再另外去找一些包含錯誤的翻譯結果就是負例。
如何緩解曝光偏差?就是將錯誤的樣本和正確的樣本在訓練階段同時曝光給解碼器,利用對比學習損失函數,讓模型學習到正確標簽的表示和錯誤標簽的表示。相比強化學習和GAN,對比學習的一個好處就是訓練過程沒有不穩(wěn)定的問題。
3. 一個簡單的方法
?
看一個如何應用對比學習的例子。最簡單的方式就是采用CV上SimCLR的方式,即正樣本是給定的人寫的目標語句(也稱為ground truth),將一個batch中其他的樣本當做是負樣本。錨點是生成中的source sequence輸入。
如右圖所示,是一個德英翻譯的例子。有一個德語輸入,目標是要把它翻譯成合適的英語輸出。圖中的綠色框就是人類所寫的標準的翻譯,紅色框是在訓練階段和它同一個batch里進行一個隨機采樣出來的結果。綠色的就是正樣本,其他的就是負樣本。對比學習損失函數可以采用比較常見的NCE loss:一個正樣本是一個分子,整個樣本集是一個分母。最終的訓練目標就是把原始的token-level的NLL損失加上新的對比學習損失。解碼階段采用普通的beam-search算法即可。
4. 其他構造正負樣本的方法
?
這個方法存在一個明顯的問題。在對比學習中,最重要的就是正負樣本是否對任務有意義,可以看出來,這個方法的負樣本的質量實在堪憂,這就導致正負樣本非常容易區(qū)分,使模型學不到更好的表示。右圖是對區(qū)分正負樣本難度的分析。Batch size越大,從中找出正例的概率越低。紅色這條線使用的是T5模型,表示學習效果更好,比Scratch的方法區(qū)分正負樣本的準確率高很多,甚至不需要做對比學習的微調就可以找出正負樣本。這意味著對比學習是沒有挑戰(zhàn)的。所以說直接從batch中選擇正負樣本的方法是不充分的。在實驗中也發(fā)現,這樣訓練損失函數下降的是很快的,很難捕捉到對這個任務比較好的特征。
現在也有相關研究者做出了一些改進。
- SSMBA:在離散空間添加擾動,如隨機mask一些詞,用masked language model 將那些詞預測回去生成新的正樣本。
- Dropout:使用dropout機制類似于SimCSE,將ground truth輸入進帶有dropout機制的decoder兩次,所得到的不同表示為一對正樣本。
- CLAPS: 在embedding空間對ground truth加擾動,通過和原來的序列語義變化的大小作為劃分正負樣本的依據。
5. 目前基于對比學習的文本生成方法仍然存在瓶頸
?
基于對比學習的文本生成方法,仍然存在一系列的瓶頸,還沒有發(fā)揮出其真正的優(yōu)勢。主要有以下三點:
- 正負例構建: 盡管之前的方法已經做出了一定的改進,但是對目標序列進行擾動并不能反映模型當前可能會出現的錯誤。
- 對比學習損失函數: 對比學習損失函數的選擇也存在問題。InfoNCELoss 只區(qū)分正負樣本,但會忽略負樣本之間的差異性。
- 解碼目標: 僅僅是簡單的使用普通的beam search算法意味著這里存在著訓練目標和解碼目標的不一致。
02 如何解決問題
1. 我們的改進
我們提出了一種新的對比學習的框架——CoNT,只做了三件事,就可以使之前的對比學習框架性能取得非常顯著的提升。
上圖是我們的模型概述。左邊的部分就是經典的生成框架,把原語句輸入給編碼器,目標語句輸入給解碼器進行訓練。Zx 和 Zy 分別是編碼器和解碼器輸出的向量表示。
- 第一個改進是使用模型預測的樣例,作為對比學習的樣例
如圖中的這個句子,首先讓模型自己進行推理,會生成一個句子,其概率約為0.48。同時,由于beam search算法,可以解碼出多個輸出,會產生另一個句子,其概率約為0.53。一般來說,只要返回這兩個輸出的句子就已經足夠了,但是在對比學習的場景下,還需要得到他們的表示。
- 第二個改進是使用三元組的對比損失函數
在這里,不同于NCE損失,只考慮一個正例樣本,其他的都是負樣本,而我們的做法是做一個相對的損失函數。比如,當前有一個結果是模型推理生成的,這個結果和人翻譯的結果相比就是負例,但相對于同batch的句子來說,這個結果就是正例。
- 為對比學習的目標所設計的解碼目標?
通過損失函數就可以看出,如果模型推理的結果和gold reference的結果比較接近,那么它和原始輸入的錨點是越相似的。從圖中可以看出,如果只考慮最大似然分數,那么概率為0.53的句子將作為最后的結果,但如果多做一個相似度打分,那么概率為0.48的句子會是最后的輸出,以人為判斷來看,這個結果明顯是更準確的。
這是一個直觀的例子,來自于IWSLT14德英翻譯的一個句子。主要是為了向大家展示來自于同一batch中句子的質量和自生成的樣本的質量的對比。
?
這是對剛才模型的數學表示。?
首先,y+ 和 y- 是正負樣本,都是來自于模型的分布。接下來,是三元組的對比損失函數。把所有的pair都加起來,對于每一pair,它的損失函數是MarginRankingLoss。其中,??是包含??個對比學習樣本的pair集合,大小為k(k-1)/2。對于每個(yi,yj) + 和 - 是由他們各自的 bleu score 決定的。分數高的在這個pair中就為正例,另外一個就為負例。最后,解碼目標是由一個序列相似度的損失加上一個語言模型的損失。在解碼的時候,為了統(tǒng)一性,引入平衡因子進行加權和。平衡因子一般設為0.5即可。
CoNT模型并不是一個完全割裂的設計,而是相互幫助,相互運作的框架。
首先,三元組對比損失函數可以建模樣本差異性,序列相似度可以在解碼時做全局打分,自生成的正負樣本可以反映模型當前的錯誤,都可以提升模型的性能。模型性能提高了以后,就會意味著正負樣本會更加的challenging,隨著模型性能越來越好,正負樣本也越來越來越難以區(qū)分,直到最后收斂。對于解碼的目標,在實驗中也證明了,三元組的對比損失函數,以及自生成的正負樣本,對于序列相似度的計算都是有幫助的。
03 實驗
1. 機器翻譯
?
首先看一下機器翻譯的實驗結果,使用的數據集是IWSLT14德英翻譯、WMT16俄英翻譯和WMT14英德翻譯數據集。第一個block是用純粹的MLE損失訓練的結果,第二個block是用NCE損失訓練的結果,第三個block就用構造的模型訓練的結果。Block2主要比較了不同的那個正負樣本構建方法所帶來差異性。Block2和block3反映的是用不同的損失建模對于學習所帶來的收益,可以看到我們的正負樣本的構建得到的效果顯著提高。橙色的框表示的是單看訓練所帶來的提升。
2. 文本摘要
這是摘要生成的實驗,使用的數據集是XSum和Multi-News。第一個block仍然是比較了不同的對比學習方法,可以看出CoNT的方法比MLE的方法高了三個多點,比之前最好的方法(CLAPS)也高了兩個點。同樣,在PEGASUS上面做了實驗,可以看到,也是取得了目前最好的結果。
3. 代碼注釋
?
這兩個實驗是在代碼注釋生成以及結構化的文本生成的上面做的實驗。?
左面這個block表示對于python和java這兩個數據集的結果。在不引入外部數據的前提下,最好結果是CodeT5+Dual-Gen,可以看到在加上CoNT之后的方法也是取得了一個新的SOTA。當然,在引入外部數據的情況下,可以取得更好結果。右面是比較經典的數據到文本生成的基準,叫WiKiBio,R2D2是之前的SOTA結果,在使用CoNT后,取得了最新的SOTA。
4. 數據到文本的生成—TOTTO
?
這是數據到文本生成的另一個比較有名的數據集TOTTO,相比較WiKiBio,它的數據更加干凈。上面給的就是一個例子。在測試集上,利用CoNT方法,使用T5-base模型是可以取得和T5-3B模型相近的結果。也就是說,使用CoNT方法,可以在保證模型的性能的情況下,用非常節(jié)能的方式和3B模型取得相近的結果,甚至在BLEURT和PARENT兩個指標上還可以取得小幅度的領先。
5. 常識生成—CommonGen
最后一個任務常識生成,即給定幾個關鍵詞,生成一句邏輯連貫且通順的句子。從表中可以看出,使用CoNT方法,比較之前的base的結果,取得了非常大的領先。和large相比也是取得了相近的結果,甚至在某些指標上還要高。
04 討論
1. 可視化表示
?
這是模型學習的表示的可視化結果。
藍色的點代表同一個batch中的樣例,橘色代表是從模型分布中采樣出來的,綠色表示ground truth,顏色越深代表和ground truth越相似。圖a是MLE模型的結果。圍繞綠點旁邊的,大多數都是模型自己推理出來的東西,但是它沒有一個很明顯的角色邊界。當用Na?ve CL的框架后,能夠學習到很明顯的決策邊界,但是對于比較細的粒度,如這個綠點旁邊圍繞的其實并不是一些高質量結果,還是比較錯亂的情況。但對于CoNT來說,也有一個明顯的角色決策邊界,而且在綠色的旁邊圍繞的大多數都是一些深色的橙點,即模型推理出的一些質量比較好的結果。
2. 序列相似度的權重
?
這里探究在解碼時引入相似度計算的影響。這里主要做兩個study,一個是使用不同的損失函數,另外一個是采用不同的正負樣例構建方法。當α等于零的時候,就意味著完全使用似然函數。α等于1的時候,就意味著完全依賴相似度分數??梢钥吹?,對于Pair-wise模型,在0-0.5時,分數是不斷上升的。但是當完全忽略掉似然函數時,性能也會有下降的趨勢。右邊這個圖主要反映了使用不同正負樣本構造方式對序列相似度打分的影響,可以看到,使用CoNT的方式對reanking的目標有比較大的幫助。
3. 如何在你的代碼中使用對比學習
?
這里講一些比較工程化的東西,即假設現有一個基于MLE訓練的模型,如何引入CoNT。由于我們的方法是不需要改變模型結構的,因此只需要把模型的checkpoint加載進來,然后調用你模型的推理階段的代碼,利用pair-wise計算損失函數,直到模型收斂。在推理部分,在beam search時返回每個beam對應的隱層的pooling操作后的向量表示,最后在預測結果的選擇時,利用平衡因子結合cosine距離和似然函數概率,選出最好的結果。
4. CoNT的優(yōu)缺點
在實際推理中,引入Contrastive learning幾乎不會帶來明顯的浮點數運算操作(FLOPs),因此不會造成更多能量的消耗(不費電),并且和MLE框架下訓練的模型推理時長幾乎是一模一樣的(不影響速度)。因此在實際部署中基于Contrastive learning訓練的模型可以容易地替換現有的使用MLE 訓練的模型,但是CoNT 的一個明顯的缺點是:犧牲了訓練的速度。CoNT的訓練速度慢主要有三個方面:
?
第一點,為了獲取足夠有效或者說足夠有意義的樣本,需要先對模型進行一次warmup,即先使用NLL損失微調模型,直到模型微調完成,才可以足夠合格的去產生所需要的正負樣本。
?
第二點,在訓練時候,要引入解碼,使用beam search。在自回歸場景下,這是不可并行的,也會增加模型的一個訓練時長。
?
第三點,在決定正負樣本時,需要計算和ground truth的相似度。這個過程其實是非常慢的,尤其是使用cpu來算,就會更慢,最后我們選擇利用矩陣乘法來近似的計算相似度,極大地降低了時間開銷。
5. 一些trade-off的方法
?
這里提供兩個trade-off思路:
- 減小樣本中來自模型分布的樣本數量,增大來自batch中的樣本數量。
- 在驗證集中對比學習的下降曲線在前1w步比較陡,可以考慮early stop。
6. 利用序列的相似度進行協(xié)助解碼
目前,對于對比學習目標的使用仍然不是最優(yōu)的。目前的生成過程是在beam search完成后加入的,相當于是reranking的作用。當然,這方面也是考慮到代碼的實現的難易度,包括和訓練一致性的問題。當然,使用這套方法非常有潛力去做一套協(xié)助解碼的工作。在beam search過程中,似然函數的打分可能是不可靠的。如圖中的例子,可以發(fā)現,在beam search過程中,由于貪心策略的存在,不可能遍歷所有的結果。一個解決方案是,能否考慮每多少步,引入一個序列相似度的計算。
05 問答環(huán)節(jié)
Q1:序列相似度是如何計算的?
A1:錨點的選擇是編碼器的輸出,該輸出是一個sequence * h的矩陣,沿著sequence緯度進行pooling就可以得到一個維度為h的向量,這就是一個編碼器輸出的源語句的表示。在beam search過程中我們可以得到那些不同的hypothesis的表示也是sequence * h的矩陣,這些序列長度是不同的,我們也沿著長度的維度進行pooling獲得編碼器輸出的hypothesis向量,然后通過這些輸出和源輸入的相似度就可以計算出序列相似度的得分。
Q2:CoNT有運用到對話任務上嗎?
A2:我們在實驗中沒有做對話任務,因為考慮到單輪的對話可能研究價值沒有那么大,但是多輪對話和我們整個框架在訓練和解碼過程中都稍微有一點不一致,所以說沒有去做對話的工作。所以可能也不能給你一個非常絕對的一個回答,歡迎后面進行實驗和討論。
Q3:請問Warmup是訓練到收斂還是訓練到一定效果就可以?
A3:我們在進行實驗時都是訓練到收斂的,當然,訓練到一定效果其實也是可以的。但是,由于CoNT在訓練的時候是需要進行推理的也就導致整個訓練速度會比只做MLE的速度慢很多,所以盡量是勸大家先把warmup訓收斂,因為如果先沒有收斂的話,雖然在后面的訓練過程中NLL依然會接著訓練,但是可能會為后續(xù)的帶對比損失的訓練造成更多的訓練時常開銷,當然,最終效果其實應該是不影響的。
Q4:可以把blue分數直接分類成soft label一樣的東西,放到對比學習的損失函數里面嗎?
A4:可以的。可以通過blue分數控制兩個樣本之間的margin,比如說一個blue分數比較高,一個blue分數比較低,那他們之間的margin就比較大,如果這兩個blue分數差不多,那他們的margin就比較小。當然,我不建議直接去對blue分數進行優(yōu)化,因為在生成上的RL確實在訓練中比較不穩(wěn)定。
Q5:有哪些數據集是驗證生成語言的常識準確性?
A5:我只做了剛剛我們做的常識生成的這個數據集CommonGen,也有一些其他數據集如CommonSense QA。把這個準確性理解成事實一致性的話,在我們的這個任務中其中評測準確性的指標是CIDER和SPICE。如果要自動評價一個常識的準確性,可能是需要人工評價,或者是用模型評價。用模型評價的話目前來說工作還不是很多,在摘要上有一個比較有名的FACTCC,翻譯上好像沒看到。
DataFunSummit