大模型也能切片,微軟SliceGPT讓LLAMA-2計(jì)算效率大增
大型語言模型(LLM)通常擁有數(shù)十億的參數(shù),用了數(shù)萬億 token 的數(shù)據(jù)進(jìn)行訓(xùn)練,這樣的模型訓(xùn)練、部署成本都非常高。因此,人們經(jīng)常用各種模型壓縮技術(shù)來減少它們的計(jì)算需求。
一般來講,這些模型壓縮技術(shù)可以分為四類:蒸餾、張量分解(包括低秩因式分解)、剪枝和量化。其中,剪枝方法已經(jīng)存在了一段時(shí)間,但許多方法需要在剪枝后進(jìn)行恢復(fù)微調(diào)(RFT)以保持性能,這使得整個(gè)過程成本高昂且難以擴(kuò)展。
為了解決這一問題,來自蘇黎世聯(lián)邦理工學(xué)院、微軟的研究者提出了一個(gè)名為 SliceGPT 的方法。SliceGPT 的核心思想是刪除權(quán)重矩陣中的行和列來降低網(wǎng)絡(luò)的嵌入維數(shù),同時(shí)保持模型性能。
研究人員表示,有了 SliceGPT,他們只需幾個(gè)小時(shí)就能使用單個(gè) GPU 壓縮大型模型,即使沒有 RFT,也能在生成和下游任務(wù)中保持有競爭力的性能。目前,該論文已經(jīng)被 ICLR 2024 接收。
- 論文標(biāo)題:SLICEGPT: COMPRESS LARGE LANGUAGE MODELS BY DELETING ROWS AND COLUMNS
- 論文鏈接:https://arxiv.org/pdf/2401.15024.pdf
剪枝方法的工作原理是將 LLM 中權(quán)重矩陣的某些元素設(shè)置為零,并(選擇性地)更新矩陣的周圍元素以進(jìn)行補(bǔ)償。其結(jié)果是形成了一種稀疏模式,這意味著在神經(jīng)網(wǎng)絡(luò)前向傳遞所需的矩陣乘法中,可以跳過一些浮點(diǎn)運(yùn)算。
運(yùn)算速度的相對(duì)提升取決于稀疏程度和稀疏模式:結(jié)構(gòu)更合理的稀疏模式會(huì)帶來更多的計(jì)算增益。與其他剪枝方法不同,SliceGPT 會(huì)剪掉(切掉?。?quán)重矩陣的整行或整列。在切之前,他們會(huì)對(duì)網(wǎng)絡(luò)進(jìn)行一次轉(zhuǎn)換,使預(yù)測結(jié)果保持不變,但允許剪切過程帶來輕微的影響。
結(jié)果是權(quán)重矩陣變小了,神經(jīng)網(wǎng)絡(luò)塊之間傳遞的信號(hào)也變小了:他們降低了神經(jīng)網(wǎng)絡(luò)的嵌入維度。
下圖 1 將 SliceGPT 方法與現(xiàn)有的稀疏性方法進(jìn)行了比較。
通過大量實(shí)驗(yàn),作者發(fā)現(xiàn) SliceGPT 可以為 LLAMA-2 70B、OPT 66B 和 Phi-2 模型去除多達(dá) 25% 的模型參數(shù)(包括嵌入),同時(shí)分別保持密集模型 99%、99% 和 90% 的零樣本任務(wù)性能。
經(jīng)過 SliceGPT 處理的模型可以在更少的 GPU 上運(yùn)行,而且無需任何額外的代碼優(yōu)化即可更快地運(yùn)行:在 24GB 的消費(fèi)級(jí) GPU 上,作者將 LLAMA-2 70B 的推理總計(jì)算量減少到了密集模型的 64%;在 40GB 的 A100 GPU 上,他們將其減少到了 66%。
此外,他們還提出了一種新的概念,即 Transformer 網(wǎng)絡(luò)中的計(jì)算不變性(computational invariance),它使 SliceGPT 成為可能。
SliceGPT 詳解
SliceGPT 方法依賴于 Transformer 架構(gòu)中固有的計(jì)算不變性。這意味著,你可以對(duì)一個(gè)組件的輸出應(yīng)用一個(gè)正交變換,只要在下一個(gè)組件中撤銷即可。作者觀察到,在網(wǎng)絡(luò)區(qū)塊之間執(zhí)行的 RMSNorm 運(yùn)算不會(huì)影響變換:這些運(yùn)算是可交換的。
在論文中,作者首先介紹了在 RMSNorm 連接的 Transformer 網(wǎng)絡(luò)中如何實(shí)現(xiàn)不變性,然后說明如何將使用 LayerNorm 連接訓(xùn)練的網(wǎng)絡(luò)轉(zhuǎn)換為 RMSNorm。接下來,他們介紹了使用主成分分析法(PCA)計(jì)算各層變換的方法,從而將區(qū)塊間的信號(hào)投射到其主成分上。最后,他們介紹了刪除次要主成分如何對(duì)應(yīng)于切掉網(wǎng)絡(luò)的行或列。
Transformer 網(wǎng)絡(luò)的計(jì)算不變性
用 Q 表示正交矩陣:
- 注意,向量 x 乘以 Q 不會(huì)改變向量的 norm,因?yàn)樵谶@項(xiàng)工作中,Q 的維度總是與 transformer D 的嵌入維度相匹配。
假設(shè) X_? 是 transformer 一個(gè)區(qū)塊的輸出,經(jīng)過 RMSNorm 處理后,以 RMSNorm (X_?) 的形式輸入到下一個(gè)區(qū)塊。如果在 RMSNorm 之前插入具有正交矩陣 Q 的線性層,并在 RMSNorm 之后插入 Q^?,那么網(wǎng)絡(luò)將保持不變,因?yàn)樾盘?hào)矩陣的每一行都要乘以 Q、歸一化并乘以 Q^?。此處有:
現(xiàn)在,由于網(wǎng)絡(luò)中的每個(gè)注意力或 FFN 塊都對(duì)輸入和輸出進(jìn)行了線性運(yùn)算,可以將額外的運(yùn)算 Q 吸收到模塊的線性層中。由于網(wǎng)絡(luò)包含殘差連接,還必須將 Q 應(yīng)用于所有之前的層(一直到嵌入)和所有后續(xù)層(一直到 LM Head)的輸出。
不變函數(shù)是指輸入變換不會(huì)導(dǎo)致輸出改變的函數(shù)。在本文的例子中,可以對(duì) transformer 的權(quán)重應(yīng)用任何正交變換 Q 而不改變結(jié)果,因此計(jì)算可以在任何變換狀態(tài)下進(jìn)行。作者將此稱為計(jì)算不變性,并在下面的定理中加以定義。
定理 1:設(shè) 和
為 RMSNorm 連接的 transformer 網(wǎng)絡(luò)第 ? 塊線性層的權(quán)重矩陣,
、
為相應(yīng)的偏置(如果有),W_embd 和 W_head 為嵌入矩陣和頭矩陣。設(shè) Q 是維數(shù)為 D 的正交矩陣,那么下面的網(wǎng)絡(luò)就等同于原來的 transformer 網(wǎng)絡(luò):
復(fù)制輸入偏置和頭偏置:
可以通過算法 1 來證明,轉(zhuǎn)換后的網(wǎng)絡(luò)計(jì)算出的結(jié)果與原始網(wǎng)絡(luò)相同。
LayerNorm Transformer 可以轉(zhuǎn)換為 RMSNorm
Transformer 網(wǎng)絡(luò)的計(jì)算不變性僅適用于 RMSNorm 連接的網(wǎng)絡(luò)。在處理使用 LayerNorm 的網(wǎng)絡(luò)之前,作者先將 LayerNorm 的線性塊吸收到相鄰塊中,從而將網(wǎng)絡(luò)轉(zhuǎn)換為 RMSNorm。
圖 3 顯示了 Transformer 網(wǎng)絡(luò)(見圖 2)的這種轉(zhuǎn)換。在每個(gè)區(qū)塊中,作者將輸出矩陣 W_out 與均值減法矩陣 M 相乘,后者考慮了后續(xù) LayerNorm 中的均值減法。輸入矩陣 W_in 被前面 LayerNorm 塊的比例預(yù)乘。嵌入矩陣 W_embd 必須進(jìn)行均值減法,而 W_head 必須按照最后一個(gè) LayerNorm 的比例重新縮放。這只是運(yùn)算順序的簡單改變,不會(huì)影響網(wǎng)絡(luò)輸出。
每個(gè)塊的轉(zhuǎn)換
現(xiàn)在 transformer 中的每個(gè) LayerNorm 都已轉(zhuǎn)換為 RMSNorm,可以選擇任意 Q 來修改模型。作者最初的計(jì)劃是從模型中收集信號(hào),利用這些信號(hào)構(gòu)建一個(gè)正交矩陣,然后刪除部分網(wǎng)絡(luò)。他們很快發(fā)現(xiàn),網(wǎng)絡(luò)中不同區(qū)塊的信號(hào)并沒有對(duì)齊,因此他們需要在每個(gè)區(qū)塊應(yīng)用不同的正交矩陣,即 Q_?。
如果每個(gè)區(qū)塊使用的正交矩陣不同,則模型不會(huì)改變,證明方法與定理 1 相同,但算法 1 第 5 行除外。在這里可以看到,殘差連接和塊的輸出必須具有相同的旋轉(zhuǎn)。為了解決這個(gè)問題,作者通過對(duì)殘差進(jìn)行線性變換 來修改殘差連接。
圖 4 顯示了如何通過對(duì)殘差連接進(jìn)行額外的線性運(yùn)算,對(duì)不同的區(qū)塊進(jìn)行不同的旋轉(zhuǎn)。與權(quán)重矩陣的修改不同,這些附加運(yùn)算無法預(yù)先計(jì)算,并且會(huì)給模型增加少量(D × D)開銷。盡管如此,還是需要通過這些操作來對(duì)模型進(jìn)行切除操作,而且可以看到整體速度確實(shí)加快了。
為了計(jì)算矩陣 Q_?,作者使用了 PCA。他們從訓(xùn)練集中選擇一個(gè)校準(zhǔn)數(shù)據(jù)集,在模型中運(yùn)行(在將 LayerNorm 運(yùn)算轉(zhuǎn)換為 RMSNorm 之后),并提取該層的正交矩陣。更確切地說,如果 他們使用轉(zhuǎn)換后網(wǎng)絡(luò)的輸出來計(jì)算下一層的正交矩陣。更確切地說,如果
是校準(zhǔn)數(shù)據(jù)集中第 i 個(gè)序列的第 ? 個(gè) RMSNorm 模塊的輸出,計(jì)算:
并將 Q_?設(shè)為 C_? 的特征向量,按特征值遞減排序。
切除
主成分分析的目標(biāo)通常是獲取數(shù)據(jù)矩陣 X 并計(jì)算低維表示 Z 和近似重構(gòu):
其中 Q 是 的特征向量,D 是一個(gè) D × D 小刪除矩陣(包含 D × D 同位矩陣的 D 小列),用于刪除矩陣左邊的一些列。從 QD 是最小化
的線性映射的意義上來說,重建是 L_2 最佳(L_2 optimal)的。
當(dāng)對(duì)區(qū)塊間的信號(hào)矩陣 X 應(yīng)用 PCA 時(shí),作者從未將 N × D 信號(hào)矩陣具體化,而是將刪除矩陣 D 應(yīng)用于構(gòu)建該矩陣前后的運(yùn)算。在上述運(yùn)算中,該矩陣已乘以 Q。作者刪除了 W_in 的行以及 W_out 和 W_embd 的列。他們還刪除了插入到殘差連接中的矩陣 的行和列(見圖 4)。
實(shí)驗(yàn)結(jié)果
生成任務(wù)
作者對(duì)經(jīng)過 SliceGPT 和 SparseGPT 剪裁后大小不同的 OPT 和 LLAMA-2 模型系列在 WikiText-2 數(shù)據(jù)集中進(jìn)行了性能評(píng)估。表 1 展示了模型經(jīng)過不同級(jí)別的剪裁后保留的復(fù)雜度。相比 LLAMA-2 模型,SliceGPT 在應(yīng)用于 OPT 模型時(shí)表現(xiàn)出了更優(yōu)越的性能,這與作者根據(jù)模型頻譜的分析得出的推測相符。
SliceGPT 的性能將隨著模型規(guī)模的增大而提升。在對(duì)所有 LLAMA-2 系列模型剪裁 25% 情況下,SparseGPT 2:4 模式的表現(xiàn)都遜于 SliceGPT。對(duì)于 OPT,可以發(fā)現(xiàn)在除 2.7B 模型之外的所有模型中,30% 切除比例的模型的稀疏性都優(yōu)于 2:4 的稀疏性。
零樣本任務(wù)
作者采用了 PIQA、WinoGrande、HellaSwag、ARC-e 和 ARCc 五個(gè)任務(wù)來評(píng)估 SliceGPT 在零樣本任務(wù)上的表現(xiàn),他們?cè)谠u(píng)估中使用了 LM Evaluation Harness 作為默認(rèn)參數(shù)。
圖 5 展示了經(jīng)過剪裁的模型在以上任務(wù)中取得的平均分?jǐn)?shù)。圖中上行顯示的是 SliceGPT 在 WikiText-2 中的平均準(zhǔn)確率,下行顯示的是 SliceGPT 在 Alpaca 的平均準(zhǔn)確率。從結(jié)果中可以觀察到與生成任務(wù)中類似的結(jié)論:OPT 模型比 LLAMA-2 模型更適應(yīng)壓縮,越大的模型經(jīng)過剪裁后精度的下降越不明顯。
作者在 Phi-2 這樣的小模型中測試了 SliceGPT 的效果。經(jīng)過剪裁的 Phi-2 模型與經(jīng)過剪裁的 LLAMA-2 7B 模型表現(xiàn)相當(dāng)。最大型的 OPT 和 LLAMA-2 模型可以被有效壓縮,當(dāng)從 66B 的 OPT 模型中刪除 30% 時(shí),SliceGPT 可以做到僅損失了幾個(gè)百分點(diǎn)。
作者還進(jìn)行了恢復(fù)微調(diào)(RFT)實(shí)驗(yàn)。使用 LoRA 對(duì)剪裁過的 LLAMA-2 和 Phi-2 模型進(jìn)行了少量 RFT。
實(shí)驗(yàn)結(jié)果如圖 6 所示??梢园l(fā)現(xiàn),RFT 的結(jié)果在 WikiText-2 和 Alpaca 數(shù)據(jù)集存在顯著差異,模型在 Alpaca 數(shù)據(jù)集中展現(xiàn)了更好的性能。作者認(rèn)為出現(xiàn)差異的原因在于 Alpaca 數(shù)據(jù)集中的任務(wù)和基準(zhǔn)任務(wù)更接近。
對(duì)于規(guī)模最大的 LLAMA-2 70B 模型,剪裁 30% 再進(jìn)行 RFT 后,最終在 Alpaca 數(shù)據(jù)集中的平均準(zhǔn)確率為 74.3%,原稠密模型的準(zhǔn)確率為 76.6%。經(jīng)過剪裁的模型 LLAMA-2 70B 保留了約 51.6B 個(gè)參數(shù),其吞吐量得到了顯著提高。
作者還發(fā)現(xiàn) Phi-2 無法在 WikiText-2 數(shù)據(jù)集中,從被剪裁過的模型中恢復(fù)原有準(zhǔn)確率,但在 Alpaca 數(shù)據(jù)集中能恢復(fù)幾個(gè)百分點(diǎn)的準(zhǔn)確率。被剪裁過 25% 并經(jīng)過 RFT 的 Phi-2 在 Alpaca 數(shù)據(jù)集中,平均準(zhǔn)確率為 65.2%,原稠密模型的準(zhǔn)確率為 72.2%。剪裁過的模型保留了 2.2B 個(gè)參數(shù),保留了 2.8B 模型準(zhǔn)確率的 90.3%。這表明即使是小型語言模型也可以有效剪枝。
基準(zhǔn)吞吐量
和傳統(tǒng)剪枝方法不同,SliceGPT 在矩陣 X 中引入了(結(jié)構(gòu)化)稀疏性:整列 X 被切掉,降低了嵌入維度。這種方法既增強(qiáng)了 SliceGPT 壓縮模型的計(jì)算復(fù)雜性(浮點(diǎn)運(yùn)算次數(shù)),又提高了數(shù)據(jù)傳輸效率。
在 80GB 的 H100 GPU 上,將序列長度設(shè)置為 128,并將序列長度批量翻倍找到最大吞吐量,直到 GPU 內(nèi)存耗盡或吞吐量下降。作者比較了剪裁過 25% 和 50% 的模型的吞吐量與原稠密模型 80GB 的 H100 GPU 上的吞吐量。剪裁過 25% 的模型最多實(shí)現(xiàn)了 1.55 倍的吞吐量提升。
在剪裁掉 50% 的情況下,最大的模型在使用一個(gè) GPU 時(shí),吞吐量實(shí)現(xiàn)了 3.13 倍和 1.87 倍的大幅增加。這表明在 GPU 數(shù)量固定的情況下,被剪裁過的模型的吞吐量將分別達(dá)到原稠密模型的 6.26 倍和 3.75 倍。
經(jīng)過 50% 的剪裁后,雖然 SliceGPT 在 WikiText2 中的保留的復(fù)雜度比 SparseGPT 2:4 差,但吞吐量卻遠(yuǎn)超 SparseGPT 的方法。對(duì)于大小為 13B 的模型,在內(nèi)存較少的消費(fèi)級(jí) GPU 上,小模型的吞吐量可能也會(huì)有所提高。
推理時(shí)間
作者還研究了使用 SliceGPT 壓縮的模型從端到端的運(yùn)行時(shí)間。表 2 比較了在 Quadro RTX6000 和 A100 GPU 上,OPT 66B 和 LLAMA-2 70B 模型生成單個(gè) token 所需的時(shí)間??梢园l(fā)現(xiàn),在 RTX6000 GPU 上,對(duì)模型剪裁過 25% 后,推理速度提高了 16-17%;在 A100 GPU 上,速度提高了 11-13%。相比原稠密模型,對(duì)于 LLAMA-2 70B,使用 RTX6000 GPU 所需的計(jì)算量減少了 64%。作者將這種提升歸功于 SliceGPT 采用了用較小的權(quán)重矩陣替換原權(quán)重矩陣,并使用了 dense kernels ,這是其他剪枝方案無法實(shí)現(xiàn)的。
作者表示,在撰寫本文時(shí),他們的基線 SparseGPT 2:4 無法實(shí)現(xiàn)端到端的性能提升。相反,他們通過比較 transformer 層中每個(gè)運(yùn)算的相對(duì)時(shí)間,將 SliceGPT 與 SparseGPT 2:4 進(jìn)行比較。他們發(fā)現(xiàn),對(duì)于大型模型,SliceGPT (25%) 與 SparseGPT (2:4) 在速度提升和困惑度方面具有競爭力。
計(jì)算成本
所有 LLAMA-2、OPT 和 Phi-2 模型都可以在單個(gè) GPU 上花費(fèi) 1 到 3 小時(shí)的時(shí)間進(jìn)行切分。如表 3 所示,通過恢復(fù)微調(diào),可以在 1 到 5 個(gè)小時(shí)內(nèi)壓縮所有 LM。
了解更多內(nèi)容,請(qǐng)參考原論文。