改變LoRA的初始化方式,北大新方法PiSSA顯著提升微調(diào)效果
隨著大模型的參數(shù)量日益增長,微調(diào)整個模型的開銷逐漸變得難以接受。
為此,北京大學(xué)的研究團隊提出了一種名為 PiSSA 的參數(shù)高效微調(diào)方法,在主流數(shù)據(jù)集上都超過了目前廣泛使用的 LoRA 的微調(diào)效果。
- 論文: PiSSA: Principal Singular Values and Singular Vectors Adaptation of Large Language Models
- 論文鏈接: https://arxiv.org/pdf/2404.02948.pdf
- 代碼鏈接: https://github.com/GraphPKU/PiSSA
如圖 1 所示,PiSSA (圖 1c) 在模型架構(gòu)上和 LoRA [1] 完全一致 (圖 1b),只是初始化 Adapter 的方式不同。LoRA 使用高斯噪聲初始化 A,使用 0 初始化 B。而 PiSSA 使用主奇異值和奇異向量 (Principal Singular values and Singular vectors) 來初始化 Adapter 來初始化 A 和 B。
圖 1)從左到右依次為全參數(shù)微調(diào)、LoRA、以及 PiSSA。藍色代表凍結(jié)的參數(shù),橘黃色代表可訓(xùn)練參數(shù)及它們的初始化方式。相比全參數(shù)微調(diào),LoRA 和 PiSSA 都大幅節(jié)省了可訓(xùn)練參數(shù)量。對于相同輸入,這三種方法的初始輸出完全相等。然而,PiSSA 凍結(jié)模型的次要成分,直接微調(diào)主成分(前 r 個奇異值和奇異向量);而 LoRA 可看作凍結(jié)模型的主要部分,而去微調(diào) noise 部分。
在不同的任務(wù)上對比 PiSSA、LoRA 的微調(diào)效果
研究團隊使用 llama 2-7B、Mistral-7B 以及 Gemma-7B 作為基礎(chǔ)模型,通過微調(diào)提升它們的數(shù)學(xué)、代碼和對話能力。其中包括:在 MetaMathQA 上訓(xùn)練,在 GSM8K 和 MATH 數(shù)據(jù)集上驗證模型的數(shù)學(xué)能力;在 CodeFeedBack 上訓(xùn)練,在 HumanEval 和 MBPP 數(shù)據(jù)集上驗證模型的代碼能力;在 WizardLM-Evol-Instruct 上訓(xùn)練,在 MT-Bench 上驗證模型的對話能力。從下表的實驗結(jié)果可以看出,使用相同規(guī)模的可訓(xùn)練參數(shù),PiSSA 的微調(diào)效果顯著超越了 LoRA,甚至超越了全參數(shù)微調(diào)。
對比 PiSSA、LoRA 在不同的可訓(xùn)練參數(shù)量下微調(diào)的效果
研究團隊在數(shù)學(xué)任務(wù)上對模型的可訓(xùn)練參數(shù)量和效果之間的關(guān)系進行消融實驗。從圖 2.1 發(fā)現(xiàn)在訓(xùn)練初期,PiSSA 的訓(xùn)練 loss 下降特別快,而 LoRA 存在不下降,甚至略有上升的階段。此外,PiSSA 的訓(xùn)練 loss 全程低于 LoRA,說明對訓(xùn)練集擬合得更好;從圖 2.2、2.3、2.4 可以看出在每種 setting 下,PiSSA 的 loss 始終比 LoRA 低,準確率始終比 LoRA 高,PiSSA 能夠使用更少的可訓(xùn)練參數(shù)追趕上全參數(shù)微調(diào)的效果。
圖 2.1) 當(dāng)秩為 1 時 PiSSA、LoRA 在訓(xùn)練過程中的 loss。每幅圖的右上角是前 100 步迭代放大的曲線。其中 PiSSA 用橙色線表示,LoRA 用藍色線表示,全參數(shù)微調(diào)用綠線展示了最終的 loss 作為參考。秩為 [2,4,8,16,32,64,128] 時的現(xiàn)象與此一致,詳見文章附錄。
圖 2.2)使用秩為 [1,2,4,8,16,32,64,128] 的 PiSSA 和 LoRA 的最終 training loss。
圖 2.3)使用秩為 [1,2,4,8,16,32,64,128] 的 PiSSA 和 LoRA 微調(diào)的模型在 GSM8K 上的準確率。
圖 2.4)使用秩為 [1,2,4,8,16,32,64,128] 的 PiSSA 和 LoRA 微調(diào)的模型在 MATH 上的準確率。
PiSSA 方法詳解
受到 Intrinsic SAID [2]“預(yù)訓(xùn)練大模型參數(shù)具有低秩性” 的啟發(fā),PiSSA 對預(yù)訓(xùn)練模型的參數(shù)矩陣
進行奇異值分解,其中前 r 個奇異值和奇異向量用來初始化適配器 (adapter) 的兩個矩陣
和
,
;剩余的奇異值和奇異向量用來構(gòu)造殘差矩陣
,使得
。因此,適配器中的參數(shù)包含了模型的核心參數(shù),而殘差矩陣中的參數(shù)是修正參數(shù)。通過微調(diào)參數(shù)量較小的核心適配器 A、B,凍結(jié)參數(shù)量較大的殘差矩陣
,就達成了用很少的參數(shù)近似全參數(shù)微調(diào)的效果。
盡管同樣受到 Intrinsic SAID [1] 啟發(fā),PiSSA 和 LoRA 背后的原理卻截然不同。
LoRA 認為大模型微調(diào)前后矩陣的變化 △W 具有很低的本征秩 r,因此通過
和
相乘得到的低秩矩陣來模擬模型的變化 △W。初始階段,LoRA 使用高斯噪聲初始化 A,使用 0 初始化 B,因此
,以此保證模型初始能力沒有變化,并微調(diào) A 和 B 實現(xiàn)對 W 進行更新。與此相比,PiSSA 不關(guān)心 △W,而是認為 W 具有很低的本征秩 r。因此直接對 W 進行奇異值分解,分解成主成分 A、B,以及殘差項
,使得
。假設(shè) W 的奇異值分解為
,A、B 使用 SVD 分解后奇異值最大的 r 個奇異值、奇異向量進行初始化:
殘差矩陣使用其余的奇異值、奇異向量進行初始化:
PiSSA 直接對 W 的低秩主成分 A、B 進行微調(diào),凍結(jié)次要的修正項。相比 LoRA 用高斯噪聲以及 0 初始化適配器參數(shù)、凍結(jié)核心模型參數(shù),PiSSA 收斂更快、效果更好。
PiSSA 的發(fā)音類似 “披薩”(pizza)--- 如果把整個大模型類比為一個完整的披薩,PiSSA 切掉其中一角,而且是餡料最豐富的一角(主奇異值、奇異向量),重新烘焙(在下游任務(wù)上微調(diào))成喜歡的口味。
由于 PiSSA 采用了和 LoRA 完全相同的架構(gòu),其可以作為 LoRA 的一種可選初始化方式,在 peft 包中很方便的進行修改和調(diào)用 (如以下代碼所示)。相同的架構(gòu)也使得 PiSSA 繼承了大多數(shù) LoRA 的優(yōu)點,如:對殘差模型使用 4bit 量化 [3],減小訓(xùn)練開銷;微調(diào)完成后適配器能合并進殘差模型,不改變推理過程的模型架構(gòu);無需分享完整模型參數(shù),只需要分享參數(shù)量很少的 PiSSA 模塊,使用者直接加載 PiSSA 模塊就能自動進行奇異值分解以及賦值;一個模型可以同時使用多個 PiSSA 模塊等等。一些對 LoRA 方法的改進,也能與 PiSSA 進行結(jié)合:比如不固定每層的秩,通過學(xué)習(xí)找到最佳的秩 [4];用 PiSSA 指導(dǎo)的更新 [5],從而突破秩的限制等等。
# 在 peft 包中 LoRA 的初始化方式后面增加了一種 PiSSA 初始化選項:
if use_lora:
nn.init.normal_(self.lora_A.weight, std=1 /self.r)
nn.init.zeros_(self.lora_B.weight)
elif use_pissa:
Ur, Sr, Vr = svd_lowrank (self.base_layer.weight, self.r, niter=4)
# 注意:由于 self.base_layer.weight 的維度是 (out_channel,in_channel, 所以 AB 的順序相比圖示顛倒了一下)
self.lora_A.weight = torch.diag (torch.sqrt (Sr)) @ Vh.t ()
self.lora_B.weight = Ur @ torch.diag (torch.sqrt (Sr))
self.base_layer.weight = self.base_layer.weight - self.lora_B.weight @ self.lora_A.weight
對比高中低奇異值微調(diào)效果實驗
為了驗證使用不同大小奇異值、奇異向量初始化適配器對模型的影響,研究人員分別使用高、中、低奇異值初始化 LLaMA 2-7B、Mistral-7B-v0.1、Gemma-7B 的適配器,然后在 MetaMathQA 數(shù)據(jù)集上進行微調(diào),實驗結(jié)果展示在圖 3 中。從圖中可以看出,使用主要奇異值初始化的方法訓(xùn)練損失最小,在 GSM8K 和 MATH 驗證集上的準確率更高。這一現(xiàn)象驗證了微調(diào)主要奇異值、奇異向量的有效性。
圖 3)從左到右依次為訓(xùn)練 loss、在 GSM8K 上的準確率、在 MATH 上的準確率。其中藍色表示最大奇異值、橙色表示中等奇異值、綠色表示最小奇異值。
快速奇異值分解
PiSSA 繼承了 LoRA 的優(yōu)點,使用起來方便,效果超越 LoRA。代價是在初始化階段,需要對模型進行奇異值分解。雖然僅需要在初始化時分解一次,但是仍然可能需要幾分鐘甚至幾十分鐘的開銷。因此,研究人員使用一種快速奇異值分解 [6] 方法替代標準的 SVD 分解,通過下表的實驗可以看出,僅需幾秒鐘的時間,就能逼近標準 SVD 分解的訓(xùn)練集擬合效果。其中 Niter 表示迭代次數(shù),Niter 越大,時間越久但是誤差越小。Niter = ∞表示標準 SVD。表格中的平均誤差表示快速奇異值分解與標準 SVD 得到的 A、B 之間的平均 L_1 距離。
總結(jié)與展望
本工作對預(yù)訓(xùn)練模型的權(quán)重進行奇異值分解,通過將其中最重要的參數(shù)用于初始化一個名為 PiSSA 的適配器,微調(diào)這個適配器來近似微調(diào)完整模型的效果。實驗表明,PiSSA 比 LoRA 收斂更快,最終效果更好,唯一的代價僅是需要幾秒的 SVD 初始化過程。
那么,您愿意為了更好的訓(xùn)練效果,多花幾秒鐘時間,一鍵更改 LoRA 的初始化為 PiSSA 嗎?
本文轉(zhuǎn)自 機器之心 ,作者:機器之心
