用任務(wù)向量做模型編輯為何有效?這篇ICLR 2025 Oral論文給出了理論分析
本文作者李宏康,博士畢業(yè)于美國倫斯勒理工大學(xué),本科畢業(yè)于中國科學(xué)技術(shù)大學(xué),并即將前往賓夕法尼亞大學(xué)擔(dān)任博士后研究員。研究方向包括深度學(xué)習(xí)理論、大語言模型理論等等。本文的通訊作者為倫斯勒理工大學(xué)的汪孟教授。
任務(wù)向量(task vector)方法近來在許多視覺和語言任務(wù)中表現(xiàn)出了在效率與可遷移性方面的優(yōu)勢。但是由于人們尚未深入理解任務(wù)向量的理論機制,其在更廣泛與更大規(guī)模的應(yīng)用中面臨挑戰(zhàn)。
近期,一個來自美國倫斯勒理工大學(xué)、密歇根州立大學(xué) OPTML 實驗室、和 IBM 研究院的研究團(tuán)隊從神經(jīng)網(wǎng)絡(luò)的優(yōu)化和泛化理論的角度分析了任務(wù)向量在模型編輯中的有效性。該工作已經(jīng)被 ICLR 2025 錄取,并被選為前 1.8% 的 Oral 論文。
- 論文標(biāo)題:When is Task Vector Provably Effective for Model Editing? A Generalization Analysis of Nonlinear Transformers
- 論文地址:https://openreview.net/pdf?id=vRvVVb0NAz
背景介紹
任務(wù)向量(task vector)是指微調(diào)得到的模型與預(yù)訓(xùn)練模型之間的權(quán)重差值。人們發(fā)現(xiàn),將不同的任務(wù)向量進(jìn)行線性算術(shù)運算后疊加在一個預(yù)訓(xùn)練模型上可以直接賦予此模型多種全新的能力,例如多任務(wù)學(xué)習(xí)(multi-task learning)、機器遺忘(machine unlearning)、以及分布外泛化(out-of-domain generalization),其優(yōu)勢是無需使用下游任務(wù)的訓(xùn)練數(shù)據(jù)對模型進(jìn)行微調(diào)。
這種基于任務(wù)向量的直接運算對模型進(jìn)行編輯從而做下游任務(wù)預(yù)測的方法被稱為任務(wù)運算(task arithmetic)。
由于缺乏對該方法的理論研究,本文重點探索任務(wù)向量方法能夠被有效且高效使用的深層原因。我們的貢獻(xiàn)如下:
- 我們?yōu)槿蝿?wù)加法和減法運算的有效性提供了一個特征學(xué)習(xí)的理論分析框架。
- 我們給出了任務(wù)運算在分布外泛化的理論保證。
- 解釋了任務(wù)向量的低秩近似和模型剪枝的理論機制。
初步觀察
我們從一個簡單的問題出發(fā):組合多個任務(wù)向量的系數(shù)會受到哪些因素的影響?
直覺告訴我們,任務(wù)間的關(guān)系可能是一個關(guān)鍵因素。比如說,在多任務(wù)學(xué)習(xí)中,讓一個模型具備兩個相似任務(wù)的能力,理應(yīng)是更容易的。
為了論證這一點,我們用 Colored-MNIST 數(shù)據(jù)集構(gòu)建了一組二分類實驗。其中,分類的標(biāo)準(zhǔn)是數(shù)字的奇偶性。我們通過調(diào)整數(shù)字的顏色來控制任務(wù)之間的關(guān)系。
于是,我們設(shè)計了「相似任務(wù)」(aligned tasks)、「無關(guān)任務(wù)」(irrelevant tasks)、「相反任務(wù)」(contradictory tasks) 的任務(wù)關(guān)系。
根據(jù)上圖所示的實驗結(jié)果,我們有以下觀察:
- 在多任務(wù)學(xué)習(xí)和機器遺忘的實驗中,最佳的任務(wù)運算系數(shù)會隨著給定的任務(wù)向量間的關(guān)系的不同而改變。
- 在分布外泛化的實驗中,目標(biāo)任務(wù)與給定任務(wù)的正反相關(guān)性可以被最佳的任務(wù)運算系數(shù)的正負(fù)性反映出來。
以上的兩點發(fā)現(xiàn)引向了一個重要的研究方向:任務(wù)關(guān)系會如何影響任務(wù)運算。
理論分析
我們在二分類問題的設(shè)定下研究該問題。我們以一層單頭的帶有 softmax attention 的 Transformer 為理論分析的基本模型,用 Ψ 來表示所有權(quán)重參數(shù)的集合,其中包括 attention 層的參數(shù) W 以及 MLP 層的參數(shù) V。仿照許多特征學(xué)習(xí)(feature learning)的理論工作,我們做如下的數(shù)據(jù)建模:定義 μ_T 為當(dāng)前任務(wù)的 discriminative pattern。數(shù)據(jù) X 中的每一個 token 都是從 μ_T、-μ_T 以及無關(guān)的 pattern 中選擇的。如果對應(yīng)于 μ_T 的 token 個數(shù)多于 -μ_T 的個數(shù),那么 X 的標(biāo)簽 y=1。如果對應(yīng)于 -μ_T 的 token 個數(shù)多于 μ_T 的個數(shù),那么 X 的標(biāo)簽 y=-1。
接下來我們給出使用兩個任務(wù)向量進(jìn)行多任務(wù)學(xué)習(xí)和機器遺忘的理論結(jié)果。
具體而言,給定預(yù)訓(xùn)練模型 以及兩個已經(jīng)被訓(xùn)練到可以取得 ? 的泛化誤差的模型所對應(yīng)的任務(wù)向量
和
,融合得到的模型被計算為
。我們定義
表示任務(wù) T_1 與 T_2 之間的相關(guān)性。α>0,=0,<0 分別表示任務(wù)之間的相似、無關(guān)、以及相反關(guān)系。β 為一個很小的數(shù)值。那么我們有以下結(jié)果:
定理 1 的結(jié)果表明:當(dāng)兩個任務(wù)是相似的關(guān)系的時候,將任務(wù)向量疊加可以得到理想的多任務(wù)學(xué)習(xí)性能,即泛化誤差在兩個任務(wù)上都達(dá)到 ?。
定理 2 的結(jié)果表明:當(dāng)兩個任務(wù)是相反關(guān)系時,用 T_1 的任務(wù)向量減去 T_2 的任務(wù)向量可以得到理想的機器遺忘性能,即 T_1 的泛化誤差達(dá)到?,而 T_2 的泛化誤差較大。
然后,我們給出利用一組任務(wù)向量 對一個從未見過的分布外的目標(biāo)任務(wù) T'進(jìn)行預(yù)測的理論結(jié)果。我們假設(shè)所有給定任務(wù) T_i 的 discriminative pattern 互相正交,目標(biāo)任務(wù) T' 的 discriminative pattern 可以被寫為各個給定任務(wù)的 discriminative pattern 的線性組合,并以 γ_i 為第 i 個任務(wù)的 discriminative pattern 的系數(shù)。假設(shè) γ_i 不全為 0。我們有定理 3 的結(jié)果:
定理 3 的結(jié)果表明:總是存在一組 λ_i,使得融合多個任務(wù)向量得到的模型可以在目標(biāo)任務(wù) T' 上取得理想的泛化性能。
我們還在理論上論證了對任務(wù)向量進(jìn)行高效應(yīng)用的方法。在我們的一層 Transformer 以及二分類問題的框架下,我們得出了推論 1:任務(wù)向量可以被低秩近似,同時只會造成很小的預(yù)測誤差。這意味著人們可以將各種低秩訓(xùn)練和推斷方法用在任務(wù)向量中,從而大大節(jié)省任務(wù)向量的計算和存儲開銷。
我們還可以得到推論 2:訓(xùn)練得到的任務(wù)向量在 MLP 層中的部分神經(jīng)元權(quán)重較大,而剩余的神經(jīng)元權(quán)重很小。對這些小的神經(jīng)元進(jìn)行剪枝只會引起很小的誤差,從而使得前面所有定理依然成立。這個推論為對于任務(wù)向量進(jìn)行權(quán)重剪枝與稀疏化提供了理論保障。
實驗驗證
我們首先用 ViT-small/16 模型對任務(wù)向量的分布外泛化能力進(jìn)行了測試。我們使用 Colored-MNIST 數(shù)據(jù)集設(shè)計訓(xùn)練任務(wù) T_1,T_2,以及目標(biāo)測試任務(wù) T',用訓(xùn)練任務(wù)的任務(wù)向量合成一個模型,即 。我們對 T'分別與 T_1,T_2 之間的相關(guān)性 γ_1,γ_2 進(jìn)行了估計。
我們下圖的結(jié)果表明:實驗中得到的能夠帶來出色的分布外泛化性能的 λ_1,λ_2 區(qū)域(圖 A 的紅色部分)與定理 3 中證明得到的(圖 B 的紅色部分)一致。
我們接下來用 Phi-3-small (7B) 模型對任務(wù)向量在機器遺忘中的表現(xiàn)進(jìn)行驗證,所使用的數(shù)據(jù)集為《哈利波特 I》(HP1),《哈利波特 II》(HP2),《傲慢與偏見》(PP)。其中,由于出自相同的作者 J.K. 羅琳,《哈利波特 I》與《II》的語義相似度較高,而《傲慢與偏見》與另外兩個數(shù)據(jù)集不太相似。
下表的結(jié)果展示了使用從《哈利波特 I》訓(xùn)練得到的低秩任務(wù)向量 構(gòu)建模型
對三個數(shù)據(jù)集進(jìn)行機器遺忘的表現(xiàn)。我們發(fā)現(xiàn)通過疊加反向的(λ<0)任務(wù)向量,新模型在相似任務(wù)上也可以取得很好的遺忘效果,而在不相似任務(wù)上的遺忘效果較差。
總結(jié)
本文定量證明了如何根據(jù)任務(wù)間關(guān)系確定任務(wù)運算系數(shù),從而實現(xiàn)理想的多任務(wù)學(xué)習(xí)、機器遺忘、以及分布外泛化的方法,解釋了使用低秩和稀疏任務(wù)向量的可靠性。本文的理論通過實驗得到了驗證。