極少數(shù)據(jù)就能微調(diào)大模型,一文詳解LoRA等方法的運(yùn)作原理
本文經(jīng)AI新媒體量子位(公眾號(hào)ID:QbitAI)授權(quán)轉(zhuǎn)載,轉(zhuǎn)載請(qǐng)聯(lián)系出處。
最近和大模型一起爆火的,還有大模型的微調(diào)方法。
這類(lèi)方法只用很少的數(shù)據(jù),就能讓大模型在原本表現(xiàn)沒(méi)那么好的下游任務(wù)中“脫穎而出”,成為這個(gè)任務(wù)的專(zhuān)家。
而其中最火的大模型微調(diào)方法,又要屬LoRA。
但包括LoRA在內(nèi),這類(lèi)方法的核心原理究竟是什么?它和大模型之間的關(guān)系又是什么?我們具體來(lái)看。
一、前言
先從最近大火的LoRA (《LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGEMODELS》)說(shuō)起。
圖片
該文章在ICLR2022中提出,說(shuō)的是利用低秩適配(low-rankadaptation)的方法,可以在使用大模型適配下游任務(wù)時(shí)只需要訓(xùn)練少量的參數(shù)即可達(dá)到一個(gè)很好的效果。
LoRA是怎么去微調(diào)適配下游任務(wù)的?
流程很簡(jiǎn)單,LoRA利用對(duì)應(yīng)下游任務(wù)的數(shù)據(jù),只通過(guò)訓(xùn)練新加部分參數(shù)來(lái)適配下游任務(wù)。
而當(dāng)訓(xùn)練好新的參數(shù)后,利用重參的方式,將新參數(shù)和老的模型參數(shù)合并,這樣既能在新任務(wù)上到達(dá)fine-tune整個(gè)模型的效果,又不會(huì)在推斷的時(shí)候增加推斷的耗時(shí)。
LoRA的示意圖如下:
圖片
圖中藍(lán)色部分為預(yù)訓(xùn)練好的模型參數(shù),LoRA在預(yù)訓(xùn)練好的模型結(jié)構(gòu)旁邊加入了A和B兩個(gè)結(jié)構(gòu),這兩個(gè)結(jié)構(gòu)的參數(shù)分別初始化為高斯分布和0,那么在訓(xùn)練剛開(kāi)始時(shí)附加的參數(shù)就是0。
A的輸入維度和B的輸出維度分別與原始模型的輸入輸出維度相同,而A的輸出維度和B的輸入維度是一個(gè)遠(yuǎn)小于原始模型輸入輸出維度的值,這也就是low-rank的體現(xiàn)(有點(diǎn)類(lèi)似Resnet的結(jié)構(gòu)),這樣做就可以極大地減少待訓(xùn)練的參數(shù)了。
在訓(xùn)練時(shí)只更新A、B的參數(shù),預(yù)訓(xùn)練好的模型參數(shù)是固定不變的。在推斷時(shí)可以利用重參數(shù)(reparametrization)思想,將AB與W合并,這樣就不會(huì)在推斷時(shí)引入額外的計(jì)算了。
而且對(duì)于不同的下游任務(wù),只需要在預(yù)訓(xùn)練模型基礎(chǔ)上重新訓(xùn)練AB就可以了,這樣也能加快大模型的訓(xùn)練節(jié)奏。
由于本文不具體介紹LoRA,所以詳細(xì)信息可以查看LoRA原文。我們只需要知道LoRA文章后續(xù)的實(shí)驗(yàn)已經(jīng)論證該方法的有效性。
那么進(jìn)一步思考,為什么LoRA的這種思路能work得不錯(cuò)呢?
答案就是接下來(lái)要講的本征維度 (Intrinsic dimension)了。
這點(diǎn)LoRA原文也提到過(guò),該文章靈感來(lái)源于下面兩篇文章:
1、MEASURING THE INTRINSIC DIMENSION OF OBJECTIVE LANDSCAPES,發(fā)表在ICLR2018,為了方便接下來(lái)該論文稱(chēng)為【論文1】
2、INTRINSIC DIMENSIONALITY EXPLAINS THE EFFECTIVENESS OF LANGUAGEMODEL FINE-TUNING,發(fā)表在ACL2021,為了方便接下來(lái)該論文稱(chēng)為【論文2】
二、本征維度是什么?
本征維度的概念在【論文1】中提出。
訓(xùn)練一個(gè)神經(jīng)網(wǎng)絡(luò)往往包含如下幾步:
1、對(duì)于一個(gè)給定的數(shù)據(jù)集,先設(shè)計(jì)網(wǎng)絡(luò)的結(jié)構(gòu)和選擇對(duì)應(yīng)的loss
2、對(duì)網(wǎng)絡(luò)中的參數(shù)進(jìn)行隨機(jī)的初始化
3、訓(xùn)練網(wǎng)絡(luò)使得loss越來(lái)越低
而訓(xùn)練階段可以認(rèn)為是在一個(gè)固定的目標(biāo)圖(objective landscape)上,尋找出有效的路徑。
這里解釋一下為什么是固定的目標(biāo)圖。因?yàn)樵跀?shù)據(jù)集和網(wǎng)絡(luò)結(jié)構(gòu)固定下來(lái)后,待優(yōu)化的問(wèn)題就已經(jīng)定義好了,所以目標(biāo)圖也就是確定的了。
如下圖所示:
那么對(duì)于一個(gè)參數(shù)量為D的模型,我們訓(xùn)練該模型,也就意味著在D維空間上尋找有效的解。文章認(rèn)為D可能是冗余的,可能實(shí)際上只需要優(yōu)化其中的d個(gè)參數(shù)就可以找到一個(gè)有效的解。
用公式表示如下:
其中表示D維的優(yōu)化參數(shù),
表示隨機(jī)初始化的一個(gè)參數(shù)并且在訓(xùn)練時(shí)是不進(jìn)行更新的,P是一個(gè)隨機(jī)初始化的D×d大小的矩陣且訓(xùn)練時(shí)也不進(jìn)行更新,
表示待優(yōu)化的d維參數(shù)。
也就是說(shuō)可以在訓(xùn)練網(wǎng)絡(luò)時(shí)只更新d維參數(shù),就可以達(dá)到該網(wǎng)絡(luò)應(yīng)有的效果。那么這個(gè)d就是所謂的該模型的本征維度。
這里講完可能還有點(diǎn)暈,我們看一下如下這張圖:
上圖中,藍(lán)色部分為初始化好的網(wǎng)絡(luò)參數(shù),綠色為
,紅色為
。網(wǎng)絡(luò)訓(xùn)練的時(shí)候只訓(xùn)練紅色部分,其它參數(shù)都是固定的。d就是本征維度。
上面講的只更新d維參數(shù),讓網(wǎng)絡(luò)達(dá)到應(yīng)有的效果,那么什么應(yīng)有的效果呢?文章定義,在只更新d維參數(shù)的情況下,網(wǎng)絡(luò)效果達(dá)到訓(xùn)練原始模型時(shí)效果的90%時(shí),那么就認(rèn)為達(dá)到了“應(yīng)有的效果”,并且d就為本征維度。
例如在做mnist這個(gè)數(shù)字分類(lèi)任務(wù)時(shí),如果原始模型精度能到0.9,那么在只更新d維參數(shù)的時(shí)候,精度能夠達(dá)到90%×0.9=0.81,就認(rèn)為這時(shí)候的d為本征維度記為。
三、使用本征維度思考大模型微調(diào)的有效性
【論文2】將之前提出的本征維度用來(lái)思考大模型微調(diào)的有效性,為什么現(xiàn)在用幾百或者幾千張圖片就可以對(duì)大模型進(jìn)行有效的微調(diào)?
根據(jù)【論文1】闡述,對(duì)于某一類(lèi)問(wèn)題,在一定精度上(比如達(dá)到90%的精度)有本征特征的存在。對(duì)于大模型而言,進(jìn)行本征維度的測(cè)試就能知道在解決某一類(lèi)下游問(wèn)題時(shí),需要調(diào)整多少參數(shù)就能近似的解決當(dāng)前的問(wèn)題。
如果真的有實(shí)驗(yàn)?zāi)茏C明僅僅調(diào)整少數(shù)的參數(shù)就能很好的解決下游問(wèn)題,那么也就能回答上述問(wèn)題,即對(duì)大模型做少量的微調(diào)(調(diào)整少量的參數(shù)),就能解決當(dāng)前的問(wèn)題。
下面無(wú)特殊說(shuō)明的話(huà),“文章”指的都是【論文2】
3.1 對(duì)于大模型而言,是否存在本征維度?
同【論文1】一樣,【論文2】也利用公式來(lái)進(jìn)行模型的訓(xùn)練,即訓(xùn)練時(shí)只調(diào)整d維參數(shù)。但與【論文1】的實(shí)驗(yàn)有點(diǎn)不同的是,【論文1】中是隨機(jī)初始化的,而【論文2】中是預(yù)訓(xùn)練好的參數(shù)。
【論文2】首先選擇BERT-Base\BERT-Large\RoBERTa-Base\RoBERTa-Large四個(gè)模型,并選擇GLUE benchmark中的MRPC和QQP兩個(gè)數(shù)據(jù)集(兩個(gè)數(shù)據(jù)集都是用來(lái)測(cè)試句子對(duì)是否相同意義的任務(wù))。
上下兩個(gè)子圖分別表示MRPC和QQP兩個(gè)任務(wù),每個(gè)子圖有四條實(shí)線(xiàn)表示四個(gè)模型的準(zhǔn)確率,四條虛線(xiàn)表示達(dá)到fine-tune整個(gè)模型90%的準(zhǔn)確率的值,橫坐標(biāo)表示訓(xùn)練d維的大小。從圖中可以看出兩個(gè)任務(wù),四個(gè)不同的模型,只需要訓(xùn)練較小的d維參數(shù)就可以達(dá)到90%的精度。本征維度這個(gè)概念在大模型中是成立的。
所以在訓(xùn)練某個(gè)下游任務(wù)時(shí),只需要訓(xùn)練少量參數(shù)就能達(dá)到不錯(cuò)的效果了。這時(shí)文章開(kāi)頭的問(wèn)題就已經(jīng)解決了。但是作者做了一些其他的實(shí)驗(yàn),發(fā)現(xiàn)了一些有意思的結(jié)論。
3.2 預(yù)訓(xùn)練的好壞與本征維度的關(guān)系
文章提出這樣一個(gè)假設(shè),預(yù)訓(xùn)練模型能夠隱式地降低模型在NLP各個(gè)任務(wù)的本征維度。
基于這個(gè)猜想,文章做了下面實(shí)驗(yàn),在預(yù)訓(xùn)練RoBERTa-base模型的時(shí)候,每隔10K保存下對(duì)應(yīng)的預(yù)訓(xùn)練模型,然后測(cè)試保存下來(lái)的預(yù)訓(xùn)練模型在MRPC、QQP、Yelp Polarity、SST-2、MNLI、ANLI六個(gè)數(shù)據(jù)集本征維度。
結(jié)果如下:
可以看出,在不同數(shù)據(jù)集上有相同的趨勢(shì),就是預(yù)訓(xùn)練次數(shù)越多,模型在各個(gè)任務(wù)上的本征維度越低。實(shí)驗(yàn)并沒(méi)有特意去優(yōu)化所謂的本征維度,只是預(yù)訓(xùn)練久一點(diǎn)而已。所以印證了預(yù)訓(xùn)練模型的表征能力越強(qiáng)(訓(xùn)練得越好),本征維度越小。
3.3 預(yù)訓(xùn)練模型參數(shù)與本征維度的關(guān)系
本來(lái)在做預(yù)訓(xùn)練參數(shù)與本征維度關(guān)系的時(shí)候,需要統(tǒng)一模型的結(jié)構(gòu),這樣更有說(shuō)服力。但是作者說(shuō),這樣要訓(xùn)練很多大模型的實(shí)驗(yàn),為了更方便的對(duì)比文章根據(jù)已有的結(jié)構(gòu)來(lái)做實(shí)驗(yàn)。從實(shí)驗(yàn)結(jié)果的趨勢(shì)來(lái)看,不同結(jié)構(gòu)也能得到有效的結(jié)論。
文章利用已有的預(yù)訓(xùn)練模型,在MRPC數(shù)據(jù)集上計(jì)算本征維度。
實(shí)驗(yàn)結(jié)果如下:
上圖中縱坐標(biāo)表示本征維度的值,橫坐標(biāo)表示模型的參數(shù)量。從圖中的趨勢(shì)可以明顯看出,模型越大本征維度越小,即越強(qiáng)的模型本征維度越低。
3.4 本征維度與泛化能力的關(guān)系
上面介紹了fine-tune(3.1)、預(yù)訓(xùn)練(3.2)和本征維度的關(guān)系,但本征維度與泛化能力的關(guān)系還沒(méi)有驗(yàn)證。即我們現(xiàn)在知道了讓本征維度小的方式,但是本征維度小了,泛化能力就能上去嗎?
文章又做了下面的實(shí)驗(yàn),把3.2保存下來(lái)的模型,在對(duì)應(yīng)的的本征維度上,進(jìn)行不同數(shù)據(jù)集的測(cè)試,結(jié)果如下:
圖片
可以看出本征維度低的模型,訓(xùn)練出來(lái)的模型準(zhǔn)確率是更高的。也就是說(shuō)本征維度越低,泛化性能越好。
回到引言的問(wèn)題:為什么LoRA思路能work?
因?yàn)榇竽P痛嬖诒菊骶S度的概念,只需要調(diào)整少量參數(shù)就能在下游任務(wù)上得到很好的效果。