大模型精準(zhǔn)反哺小模型,知識蒸餾助力提高 AI 算法性能
01 知識蒸餾誕生的背景
來,深度神經(jīng)網(wǎng)絡(luò)(DNN)在工業(yè)界和學(xué)術(shù)界都取得了巨大成功,尤其是在 計算機(jī)視覺任務(wù) 方面。深度學(xué)習(xí)的成功很大程度上歸功于其具有數(shù)十億參數(shù)的用于編碼數(shù)據(jù)的可擴(kuò)展性架構(gòu),其訓(xùn)練目標(biāo)是在已有的訓(xùn)練數(shù)據(jù)集上建模輸入和輸出之間的關(guān)系,其性能高度依賴于網(wǎng)絡(luò)的復(fù)雜程度及有標(biāo)注訓(xùn)練數(shù)據(jù)的數(shù)量和質(zhì)量。
相比于計算機(jī)視覺領(lǐng)域的傳統(tǒng)算法,大多數(shù)基于 DNN 的模型都因為 過參數(shù)化 而具備強(qiáng)大的 泛化能力 ,這種泛化能力體現(xiàn)在對于某個問題輸入的所有數(shù)據(jù)上,模型能給出較好的預(yù)測結(jié)果,無論是訓(xùn)練數(shù)據(jù)、測試數(shù)據(jù),還是屬于該問題的未知數(shù)據(jù)。
在當(dāng)前深度學(xué)習(xí)的背景下,算法工程師為了提升業(yè)務(wù)算法的預(yù)測效果,常常會有兩種方案:
使用過參數(shù)化的更復(fù)雜的網(wǎng)絡(luò),這類網(wǎng)絡(luò)學(xué)習(xí)能力非常強(qiáng),但需要大量的計算資源來訓(xùn)練,并且推理速度較慢。
集成模型,將許多效果弱一些的模型集成起來,通常包括參數(shù)的集成和結(jié)果的集成。
這兩種方案能顯著提升現(xiàn)有算法的效果,但都提升了模型的規(guī)模,產(chǎn)生了較大的計算負(fù)擔(dān),需要的計算和存儲資源很大。
在工作中,各種算法模型的最終目的都是要 服務(wù)于某個應(yīng)用 。就像在買賣中我們需要控制收入和支出一樣。在工業(yè)應(yīng)用中,除了要求模型要有好的預(yù)測以外, 計算資源的使用也要嚴(yán)格控制,不能只考慮結(jié)果不考慮效率。在輸入數(shù)據(jù)編碼量高的計算機(jī)視覺領(lǐng)域,計算資源更顯有限,控制算法的資源占用就更為重要。
通常來說,規(guī)模較大的模型預(yù)測效果更好,但訓(xùn)練時間長、推理速度慢的問題使得模型難以實時部署。尤其是在視頻監(jiān)控、自動駕駛汽車和高吞吐量云端環(huán)境等計算資源有限的設(shè)備上,響應(yīng)速度顯然不夠用。規(guī)模較小的模型雖然推理速度較快,但是因為參數(shù)量不足,推理效果和泛化性能可能就沒那么好。如何權(quán)衡大規(guī)模模型和小規(guī)模模型一直是一個熱門話題,當(dāng)前的解決方法大多是 根據(jù)部署環(huán)境的終端設(shè)備性能選擇合適規(guī)模的 DNN 模型。
如果我們希望有一個規(guī)模較小的模型,能在保持較快推理速度的前提下,達(dá)到和大模型相當(dāng)或接近的效果該如何做到呢?
在機(jī)器學(xué)習(xí)中,我們常常假定輸入到輸出有一個潛在的映射函數(shù)關(guān)系,從頭學(xué)習(xí)一個新模型就是輸入數(shù)據(jù)和對應(yīng)標(biāo)簽中一個 近似 未知的映射函數(shù)。在輸入數(shù)據(jù)不變的前提下,從頭訓(xùn)練一個小模型,從經(jīng)驗上來看很難接近大模型的效果。為了提升小模型算法的性能,一般來說最有效的方式是標(biāo)注更多的輸入數(shù)據(jù),也就是提供更多的監(jiān)督信息,這可以讓學(xué)習(xí)到的映射函數(shù)更魯棒,性能更好。舉兩個例子,在計算機(jī)視覺領(lǐng)域中,實例分割任務(wù)通過額外提供掩膜信息,可以提高目標(biāo)包圍框檢測的效果;遷移學(xué)習(xí)任務(wù)通過提供在更大數(shù)據(jù)集上的預(yù)訓(xùn)練模型,顯著提升新任務(wù)的預(yù)測效果。因此 提供更多的監(jiān)督信息 ,可能是縮短小規(guī)模模型和大規(guī)模模型差距的關(guān)鍵。
按照之前的說法,想要獲取更多的監(jiān)督信息意味著標(biāo)注更多的訓(xùn)練數(shù)據(jù),這往往需要巨大的成本,那么有沒有一種低成本又高效的監(jiān)督信息獲取方法呢?2006 年的文獻(xiàn)[1]中指出,可以讓新模型近似(approximate)原模型(模型即函數(shù))。因為原模型的函數(shù)是已知的,新模型訓(xùn)練時等于天然地增加了更多的監(jiān)督信息,這顯然要更可行。
進(jìn)一步思考,原模型帶來的監(jiān)督信息可能蘊(yùn)含著不同維度的知識,這些與眾不同的信息可能是新模型自己不能捕捉到的,在某種程度上來說,這對于新模型也是一種“跨域”的學(xué)習(xí)。
2015年Hinton在論文《Distilling the Knowledge in a Neural Network》[2] 中沿用近似的思想,率先提出“ 知識蒸餾 (Knowledge Distillation, KD)”的概念:可以先訓(xùn)練出一個大而強(qiáng)的模型,然后將其包含的知識轉(zhuǎn)移給小的模型,就實現(xiàn)了“保持小模型較快推理速度的同時,達(dá)到和大模型相當(dāng)或接近的效果”的目的。這其中先訓(xùn)練的大模型可以稱之為教師模型,后訓(xùn)練的小模型則被稱之為學(xué)生模型,整個訓(xùn)練過程可以形象地比喻為“師生學(xué)習(xí)”。隨后幾年,涌現(xiàn)了大量的知識蒸餾與師生學(xué)習(xí)的工作,為工業(yè)界提供了更多新的解決思路。目前,KD 已廣泛應(yīng)用于兩個不同的領(lǐng)域:模型壓縮和知識遷移[3]。
02 Knowledge Distillation
簡介
Knowledge Distillation 是一種基于“教師-學(xué)生網(wǎng)絡(luò)”思想的模型壓縮方法,由于簡單有效,在工業(yè)界被廣泛應(yīng)用。其目的是將已經(jīng)訓(xùn)練好的大模型包含的知識——蒸餾(Distill),提取到另一個小的模型中去。那怎么讓大模型的知識,或者說泛化能力轉(zhuǎn)移到小模型身上去呢?KD 論文把大模型對樣本輸出的概率向量作為軟目標(biāo)(soft targets)提供給小模型,讓小模型的輸出盡量去向這個軟目標(biāo)靠(原來是往 one-hot 編碼上靠),去近似學(xué)習(xí)大模型的行為。
在傳統(tǒng)的硬標(biāo)簽訓(xùn)練過程中,所有負(fù)標(biāo)簽都被統(tǒng)一對待,但這種方式把類別間的關(guān)系割裂開了。比如說識別手寫數(shù)字,同是標(biāo)簽為“3”的圖片,可能有的比較像“8”,有的比較像“2”,硬標(biāo)簽區(qū)分不出來這個信息,但是一個訓(xùn)練良好的大模型可以給出。大模型 softmax 層的輸出,除了正例之外,負(fù)標(biāo)簽也帶有大量的信息,比如某些負(fù)標(biāo)簽對應(yīng)的概率遠(yuǎn)遠(yuǎn)大于其他負(fù)標(biāo)簽。近似學(xué)習(xí)這一行為使得每個樣本給學(xué)生網(wǎng)絡(luò)帶來的信息量大于傳統(tǒng)的訓(xùn)練方式。
因此,作者在訓(xùn)練學(xué)生網(wǎng)絡(luò)時修改了一下?lián)p失函數(shù),讓小模型在擬合訓(xùn)練數(shù)據(jù)的真值(ground truth)標(biāo)簽的同時,也要擬合大模型輸出的概率分布。這個方法叫做知識 蒸餾訓(xùn)練 (Knowledge Distillation Training, KD Training)。知識蒸餾過程所用的訓(xùn)練樣本可以和訓(xùn)練大模型用的訓(xùn)練樣本一樣,或者另找一個獨立的 Transfer set。
方法詳解
具體來說,知識蒸餾使用的是 Teacher—Student 模型,其中 teacher 是“知識”的輸出者,student 是“知識”的接受者。知識蒸餾的過程分為 2 個階段:
- 教師模型訓(xùn)練:訓(xùn)練”Teacher 模型“, 簡稱為 Net-T,它的特點是模型相對復(fù)雜,也可以由多個分別訓(xùn)練的模型集成而成。對“Teacher模型”不作任何關(guān)于模型架構(gòu)、參數(shù)量、是否集成方面的限制,因為該模型不需要部署,唯一的要求就是,對于輸入 X, 其都能輸出 Y,其中 Y 經(jīng)過 softmax 的映射,輸出值對應(yīng)相應(yīng)類別的概率值。
- 學(xué)生模型訓(xùn)練:訓(xùn)練“Student 模型”, 簡稱為 Net-S,它是參數(shù)量較小、模型結(jié)構(gòu)相對簡單的單模型。同樣的,對于輸入 X,其都能輸出 Y,Y 經(jīng)過 softmax 映射后同樣能輸出對應(yīng)相應(yīng)類別的概率值。
由于使用 softmax 的網(wǎng)絡(luò)的結(jié)果很容易走向極端,即某一類的置信度超高,其他類的置信度都很低,此時學(xué)生模型關(guān)注到的正類信息可能還是僅屬于某一類。除此之外,因為不同類別的負(fù)類信息也有相對的重要性,所有負(fù)類分?jǐn)?shù)都差不多也不好,達(dá)不到知識蒸餾的目的。為了解決這個問題,引入溫度(Temperature)的概念,使用高溫將小概率值所攜帶的信息蒸餾出來。具體來說,在 logits 過 softmax 函數(shù)前除以溫度 T。
訓(xùn)練時首先將教師模型學(xué)習(xí)到的知識蒸餾給小模型,具體來說對樣本 X,大模型的倒數(shù)第二層先除以一個溫度 T,然后通過 softmax 預(yù)測一個軟目標(biāo) Soft target,小模型也一樣,倒數(shù)第二層除以同樣的溫度 T,然后通過 softmax 預(yù)測一個結(jié)果,再把這個結(jié)果和軟目標(biāo)的交叉熵作為訓(xùn)練的 total loss 的一部分。然后再將小模型正常的輸出和真值標(biāo)簽(hard target)的交叉熵作為訓(xùn)練的 total loss 的另一部分。Total loss 把這兩個損失加權(quán)合起來作為訓(xùn)練小模型的最終的 loss。
在小模型訓(xùn)練好了要預(yù)測時,就不需要再有溫度 T 了,直接按照常規(guī)的 softmax 輸出就可以了。
03 FitNet
簡介
FitNet 論文在蒸餾時引入了中間層隱藏映射(intermediate-level hints)來指導(dǎo)學(xué)生模型的訓(xùn)練。使用一個寬而淺的教師模型來訓(xùn)練一個窄而深的學(xué)生模型。在進(jìn)行 hint 引導(dǎo)時,提出使用一個層來匹配 hint 層和 guided 層的輸出 shape,這在后人的工作里面常被稱為 adaptation layer。
總的來說,相當(dāng)于是在做知識蒸餾時,不僅用到了教師模型的 logit 輸出,還用到了教師模型的中間層特征圖作為監(jiān)督信息。可以想到的是,直接讓小模型在輸出端模仿大模型,這個對于小模型來說太難了(模型越深越難訓(xùn),最后一層的監(jiān)督信號要傳到前面去還是挺累的),不如在中間加一些監(jiān)督信號,使得模型在訓(xùn)練時可以從逐層接受學(xué)習(xí)更難的映射函數(shù),而不是直接學(xué)習(xí)最難的映射函數(shù);除此之外,hint 引導(dǎo)加速了學(xué)生模型的收斂,在一個非凸問題上找到更好的局部最小值,使得學(xué)生網(wǎng)絡(luò)能更深的同時,還能訓(xùn)練得更快。這感覺就好像是,我們的目的是讓學(xué)生做高考題,那么就先把初中的題目給他教會了(先讓小模型用前半個模型學(xué)會提取圖像底層特征),然后再回到本來的目的、去學(xué)高考題(用 KD 調(diào)整小模型的全部參數(shù))。
這篇文章是提出蒸餾中間特征圖的始祖,提出的算法很簡單,但思路具有開創(chuàng)性。
方法詳解
FitNets 的具體做法是:
- 確定教師網(wǎng)絡(luò),并訓(xùn)練成熟,將教師網(wǎng)絡(luò)的中間特征層 hint 提取出來。
- 設(shè)定學(xué)生網(wǎng)絡(luò),該網(wǎng)絡(luò)一般較教師網(wǎng)絡(luò)更窄、更深。訓(xùn)練學(xué)生網(wǎng)絡(luò)使得學(xué)生網(wǎng)絡(luò)的中間特征層與教師模型的 hint 相匹配。由于學(xué)生網(wǎng)絡(luò)的中間特征層和與教師 hint 尺寸不同,因此需要在學(xué)生網(wǎng)絡(luò)中間特征層后添加回歸器用于特征升維,以匹配 hint 層尺寸。其中匹配教師網(wǎng)絡(luò)的 hint 層與回歸器轉(zhuǎn)化后的學(xué)生網(wǎng)絡(luò)的中間特征層的損失函數(shù)為均方差損失函數(shù)。
實際訓(xùn)練的時候往往和上一節(jié)的 KD Training 聯(lián)合使用,用兩階段法訓(xùn)練:先用 hint training 去 pretrain 小模型前半部分的參數(shù),再用 KD Training 去訓(xùn)練全體參數(shù)。由于蒸餾過程中使用了更多的監(jiān)督信息, 基于中間特征圖的蒸餾方法比基于結(jié)果 logits 的蒸餾方法效果要好 ,但是訓(xùn)練時間更久。
04 總結(jié)
知識蒸餾對于將知識從集成或從高度正則化的大型模型轉(zhuǎn)移到較小的模型中非常有效。即使在用于訓(xùn)練蒸餾模型的遷移數(shù)據(jù)集中缺少任何一個或多個類的數(shù)據(jù)時,蒸餾的效果也非常好。在經(jīng)典之作 KD 和 FitNet 提出之后,各種各樣的蒸餾方法如雨后春筍般涌現(xiàn)。未來我們也希望能在模型壓縮和知識遷移領(lǐng)域做出更進(jìn)一步的探索。
作者簡介
馬佳良,網(wǎng)易易盾高級計算機(jī)視覺算法工程師,主要負(fù)責(zé)計算機(jī)視覺算法在內(nèi)容安全領(lǐng)域的研發(fā)、優(yōu)化和創(chuàng)新。