模型知識蒸餾新SOTA!告別傳統(tǒng)散度蒸餾|騰訊優(yōu)圖&中科大出品
用大模型“蒸餾”小模型,有新招了!
甚至能在不同類型和架構(gòu)的LLMs(大語言模型)上達(dá)到新SOTA。
這就是來自中科大、騰訊優(yōu)圖實驗室提出的一種基于Sinkhorn距離的知識蒸餾方法,能把大的、復(fù)雜的教師模型的知識“蒸餾”到小的、簡單的學(xué)生模型中,從而讓小模型也能像大模型一樣工作。
之所以提出新方法,主要是現(xiàn)有的知識蒸餾(KD)方法都有各自的局限性:
當(dāng)兩個模型的輸出差異較大時,它們就不太管用了。
- KL散度:會導(dǎo)致學(xué)生模型的輸出變得過于平滑,失去了區(qū)分性;
- RKL散度:會讓學(xué)生的輸出變得太簡單,不能很好地模仿教師模型;
- JS散度:會讓學(xué)生模型低估稀有事件的概率;
而基于Sinkhorn距離的新方法能更準(zhǔn)確地衡量和縮小教師模型和學(xué)生模型之間的差異,從而提高了學(xué)生模型的性能。
此外,研究還提出了一種基于批量的重構(gòu)方法,從而在高維空間中捕捉跨樣本分布的幾何復(fù)雜性。
最終,通過在兩個流行的自然語言處理測試集(GLUE和SuperGLUE)上測試,新方法在編碼器、編碼器-解碼器以及解碼器等不同架構(gòu)的所有類型LLMs上均優(yōu)于當(dāng)前的最先進(jìn)方法。
研究背景
知識蒸餾的提出是為了通過對齊教師模型的軟目標(biāo)(例如輸出logits和中間層表示)來將教師模型內(nèi)在固有的知識傳遞給學(xué)生模型。
給定訓(xùn)練集中的一個樣本x_i及其真實標(biāo)簽???? ∈ ???,來自教師模型????和學(xué)生模型????的輸出logits ???? ∈ ???和???? ∈ ???可以由以下式子得到:
其中為softmax函數(shù), τ是溫度參數(shù), d是輸出logits的維度?;趌ogit的知識蒸餾的目標(biāo)是σΤ最小化測量散度J(????,????)以實現(xiàn)知識傳遞。
研究動機(jī)
現(xiàn)有研究已經(jīng)嘗試使用Kullback-Leibler(KL)散度、反Kullback-Leibler(RKL)散度和Jensen-Shannon(JS)散度。
所有這些度量都可以被視為f-散度度量的變體,而f-散度度量在量化缺乏實質(zhì)性交集的任何兩個分布時都存在明顯局限性。
此外,每種度量都有其自身的缺陷:
KL蒸餾會導(dǎo)致模式平均,使學(xué)生學(xué)習(xí)到一個過于平滑的分布,涵蓋了教師的整個支撐集;
RKL會引起模式塌陷,學(xué)生僅關(guān)注教師分布中高概率的顯著區(qū)域,而忽視了其余部分;
JS蒸餾會產(chǎn)生模式低估,由于懲罰不足,學(xué)生會低估稀有事件的概率。
為了解決傳統(tǒng)散度度量的問題,研究做出了以下貢獻(xiàn):
- 提出了一種知識蒸餾方法SinKD,采用Sinkhorn距離作為散度度量。它不僅解決了KL、RKL和JS散度在極端場景下的局限性,而且避免了計算Wasserstein距離的負(fù)擔(dān)。
- 深入探討了Sinkhorn距離的性質(zhì),并將SinKD重新reformulated為batch-wise OT,擴(kuò)展了它在NLP任務(wù)中的適用性。
- 通過大量的可比性、有效性和泛化性實驗證明了SinKD相較于目前最先進(jìn)的方法的優(yōu)越性。并為實際應(yīng)用提供了使用SinKD進(jìn)行蒸餾的實用指導(dǎo)方針。
傳統(tǒng)散度度量的缺陷
首先,KL散度是不對稱的,表現(xiàn)為JKL(????,????)≠ JKL(????,????),這一性質(zhì)違反了距離度量的對稱性特性,從而引入了一些不一致性。
其次,由于使用KL損失進(jìn)行優(yōu)化,學(xué)生模型試圖對教師模型的多模態(tài)分布進(jìn)行平均化,從而導(dǎo)致對這些模式的擬合不足。這被稱為“模式平均問題”(mode-averaging problem)。
因此,學(xué)生模型無法捕獲數(shù)據(jù)中的所有關(guān)鍵模式,最終影響模型性能。
第三,KL散度對應(yīng)的是一個非平滑函數(shù),這為優(yōu)化過程帶來了挑戰(zhàn)。
與KL散度一樣,具有內(nèi)在的不對稱性,從而導(dǎo)致在捕捉分布差異時出現(xiàn)不一致性。
此外,優(yōu)化的學(xué)生模型傾向于僅關(guān)注教師分布中概率較高的事件,這被稱為“模式崩塌問題”(mode-collapsing)。
如果教師對某個事件賦予零概率,學(xué)生模型也被迫做出相同的預(yù)測。
其中m?? = 1/2(????+????)受制于非平滑性,JS損失在優(yōu)化過程中面臨挑戰(zhàn)。
另外,由于JS損失在低概率區(qū)域的匹配上懲罰不足,學(xué)生模型可能會過度低估稀有事件的概率。
對于分布之間重疊較少甚至完全不重疊的情況退化為常數(shù)時,還存在梯度消失的風(fēng)險。
最優(yōu)傳輸距離的優(yōu)勢
Wasserstein距離通過求解兩個分布之間的最優(yōu)傳輸計劃來量化它們的差異。
直觀地看,它可以被認(rèn)為是將一個分布(即學(xué)生的logits分布)轉(zhuǎn)換為另一個分布(即教師的logits分布)所需的最小“代價”,其中“代價”可以定義為被移動的質(zhì)量與移動距離的乘積。
與傳統(tǒng)的散度度量相比,Wasserstein距離作為蒸餾的成本函數(shù)更為合理,因為它不依賴于對被測量分布的隱式假設(shè)。此外,它幾乎處處可微,從而便于優(yōu)化。
另外,現(xiàn)有的散度度量只能獨(dú)立處理每個樣本對,進(jìn)行逐一logit的匹配,對于一批樣本,這些方法無法定位來自同一樣本的教師和學(xué)生的logits對,從而無法實現(xiàn)整體距離的最小化。
由于計算Sinkhorn距離的過程可以實現(xiàn)來自同一樣本的兩個輸出之間的精確逐元素匹配,研究提出了“批量化”的SinKD方法(batchified SinKD)。
通過這種方式,即使通過低維觀測,也能夠捕捉復(fù)雜且隱式分布的幾何結(jié)構(gòu)。
方法介紹
這里簡要介紹SinKD的核心方法,詳細(xì)推導(dǎo)過程可以參閱原論文。
批量重構(gòu)的Sinkhorn距離
對于本問題,Wasserstein距離的定義如下:
其中,
Wasserstein距離本身在解析計算上存在困難,其計算成本對于蒸餾大型語言模型來說高得難以承受。
在這種情況下,研究使用Sinkhorn距離作為一種高效的近似方法。它不僅保留了Wasserstein距離的所有優(yōu)點(diǎn),同時也大大緩解了其在在線蒸餾中所面臨的成本問題。
Sinkhorn距離的定義如下:
逐樣本蒸餾將每個實例獨(dú)立處理,但忽略了一個批次樣本中的整體趨勢。
研究摒棄了僅在每對教師-學(xué)生樣本對上工作的逐樣本知識蒸餾方法,轉(zhuǎn)而在教師和學(xué)生樣本組上執(zhí)行知識蒸餾。
一個包含b個樣本的批次會整體參與散度度量。通過批量重構(gòu),這種方法有效地增加了“觀測”空間的維度,特別是在d遠(yuǎn)小于b的情況下表現(xiàn)尤為顯著。
對于常規(guī)分類任務(wù)的蒸餾,研究使用如下“batchified”代價函數(shù):
并初始化如下候選傳輸矩陣:
通過重構(gòu)和化簡,研究可以使用如下迭代式計算最優(yōu)傳輸矩陣(具體推導(dǎo)過程參見論文):
由此,可以算出最優(yōu)傳輸距離:
SinKD的變體
拓展到回歸任務(wù):對于回歸任務(wù),模型不會為每個選項生成概率,而是僅生成一個標(biāo)量(d=1)。對于一個包含b個樣本的批次,教師模型和學(xué)生模型的輸出分別表示為?? ∈ ?bx1和?? ∈ ?bx1。
為了計算教師和學(xué)生之間的批量化Sinkhorn距離,成本矩陣的元素由“批量化”回歸輸出之間的絕對差值確定:
拓展到獨(dú)熱標(biāo)簽微調(diào):SinKD方法也適用于僅有獨(dú)熱(one-hot)標(biāo)簽且無法獲取教師模型logits的模型微調(diào)。
在這種情況下,可以將單熱標(biāo)簽視為“假想”的單熱教師模型的logits。由于單熱logits中以零為主,傳統(tǒng)的散度度量(例如KL散度)在處理這種極端情況下的散度量化時顯得無能為力。
實驗與分析
(1)數(shù)值結(jié)果。與基線和SOTA方法對比,論文方法在大部分任務(wù)上均取得了更好的性能。
(2)消融實驗。得出的結(jié)論如下:
- Sinkhorn損失在所有損失中對學(xué)生模型的收益最大
- 批量化的SinKD優(yōu)于逐樣本的SinKD
- SinKD超越了基于f-散度變體的蒸餾方法
(3)生成式大語言模型實驗。SinKD可以推廣到生成式大語言模型,并在基于類GPT架構(gòu)的模型的蒸餾上取得不俗的成績表現(xiàn)。
但同時研究也觀察到,蒸餾效果的影響會隨著PROMPT模板的變化而改變。
這意味著,同樣的任務(wù)設(shè)置下,更加合理的PROMPT設(shè)計能夠更充分地利用教師模型的固有知識。
(4)可視化結(jié)果如下。
為了增強(qiáng)內(nèi)在評估,研究還進(jìn)行了以下附加分析:
- 隱藏狀態(tài)的表示
- 注意力機(jī)制的模式
- 層級性能分析
(5)拓展到獨(dú)熱標(biāo)簽微調(diào)。與現(xiàn)有的散度度量方法(例如KL散度)不同,SinKD方法還可以擴(kuò)展用于使用獨(dú)熱標(biāo)簽 (one-hot label) 微調(diào)語言模型。
(6)拓展到計算機(jī)視覺領(lǐng)域深度網(wǎng)絡(luò)。SinKD在所有測試的配置中均穩(wěn)定地超越了所有基線方法。
總結(jié)
研究引入了SinKD以解決現(xiàn)有蒸餾方法的局限性。此外,作者們提出了基于批次的重構(gòu)方法,以捕捉高維空間中樣本分布的幾何復(fù)雜性。最后,研究在各類任務(wù)、數(shù)據(jù)集和模型架構(gòu)上進(jìn)一步驗證SinKD的有效性。
更多細(xì)節(jié)歡迎查閱原論文。
COLING 2024會議論文:https://arxiv.org/abs/2402.17110
IEEE TNNLS期刊論文:https://hal.science/hal-04803835