LLM 剪枝+蒸餾:NVIDIA 的最佳實踐 精華
一、背景
模型剪枝、蒸餾是傳統(tǒng) AI 模型壓縮常用的方案,尤其是模型要用于端側(cè)部署的場景,相比模型量化,模型剪枝和蒸餾的代價比較高,而且往往在指標上也沒有特別明顯的優(yōu)勢,因此真正落地的場景要少得多。當然,也傳言有些模型會蒸餾 OpenAI 的 ChatGPT,不過主要是用其生成高質(zhì)量數(shù)據(jù)。本文中,我們介紹 NVIDIA 最近發(fā)布的 LLM 剪枝和蒸餾的最佳實踐。
對應(yīng)的論文為:[2408.11796] LLM Pruning and Distillation in Practice: The Minitron Approach
我們之前也介紹過 NVIDIA 的兩篇相關(guān)文章,可以參考:
- NVIDIA LLM 持續(xù)預(yù)訓(xùn)練的最佳實踐
- NVIDIA 開源 LLM:迷你杯 Nemotron-4 15B & 超大杯 Nemotron-340B
二、摘要
本文中,作者探索了兩種不同的 LLM 剪枝策略:(1)深度剪枝和(2)聯(lián)合 Hidden、Attention、MLP 的寬度剪枝,并使用 LM 評估框架 Harness 進行常見基準評估。然后將模型與 NeMo Aligner 對齊,并在 Instruct-tuned 版本中進行測試。使用提出的剪枝和蒸餾將 LLaMA 3.1 8B 和 Mistral NeMo 12B 模型分別壓縮為 4B 和 8B 參數(shù)的模型。此外,作者發(fā)現(xiàn),在無法訪問原始數(shù)據(jù)的情況下,在蒸餾數(shù)據(jù)集上適當微調(diào) Teacher 模型也有幫助。
PS:基于當前論文的工作還有許多可以嘗試的方案,比如:
- 論文中寬度剪枝后每個 Transformer Block 層的參數(shù)還保持一致,而根據(jù)之前的研究(本文中也提到了),模型的最開始 2 層和最后 2 層可能更加的重要,是否可以通過某種方式實現(xiàn)只剪枝中間層,而保留最前 2 層和最后 2 層,效果是否會更好?
- 如果使用 LLaMA 3.1 70B 作為 Teacher 模型,是否能將 4B 參數(shù)量規(guī)模的模型精度與 8B 模型對齊?
三、方案
3.1 概覽
如下圖 Figure 1 所示為本文方法的概覽,其包含 3 個階段:
- Teacher 校正:在目標數(shù)據(jù)集(127B)上對預(yù)訓(xùn)練模型進行微調(diào),生成校正的 Teacher 模型,以便用于蒸餾。
- 剪枝:應(yīng)用剪枝技術(shù)壓縮模型,生成 Student 模型。(PS:需要說明的是,非結(jié)構(gòu)化剪枝往往導(dǎo)致模型無法充分發(fā)揮 GPU 算力,所以在使用 GPU 推理的場景中相對較少,這里作者主要是使用的結(jié)構(gòu)化剪枝,非常適合 GPU 運算)
- 蒸餾:使用 Teacher 模型蒸餾 Student 模型,以恢復(fù)剪枝損失的模型準確性。?
3.2 剪枝
權(quán)重剪枝是一種強大且眾所周知的模型壓縮技術(shù)。本文中,作者重點介紹結(jié)構(gòu)化剪枝(并不是為了適配 GPU 稀疏算力的 2:4 稀疏結(jié)構(gòu)化剪枝,其可以參考附錄),也就是從模型中刪除 Block 或 Channel(PS:不是將其置為 0),包括 Neuron、Attention Head、Convolutional Filter 和深度剪枝。如下圖 Figure 2 所示,對于 LLM 而言,首先計算每個層、神經(jīng)元、Head 和 Embedding 維度的重要性;然后對這些重要性分數(shù)進行排序;最后進行剪枝操作,并多次迭代。
3.2.1 重要性預(yù)估
作者采用純粹基于激活的重要性評估策略,該策略使用小型校準數(shù)據(jù)集,通過前向推理來計算所有軸(深度、神經(jīng)元、Head,嵌入通道)的靈敏度信息。此外,作者將深度修剪作為一種特殊情況,不會與其他壓縮維度結(jié)合使用。
具體來說,作者使用一個 1024 個 Sample 的小型校準數(shù)據(jù)集,通過分別檢查 MHA、MLP 和 LayerNorm 層產(chǎn)生的激活來計算每個 Head、神經(jīng)元和 Channel 的重要性。
對于深度剪枝,作者使用 3 個指標評估 Layer 的重要性:
- LM 驗證損失。
- Block 重要性(BI)。
- 下游任務(wù)的準確性。
對于基于 Loss 的排序,只需刪除單個或連續(xù)的 Block,并計算其對 LM Loss 的影響,這可以作為層的“重要性”或“敏感度”。BI 使用 Layer 或 Layer Blocks 的輸入和輸出之間的余弦距離來計算。作者注意到 BI 和 LM 損失指標高度相關(guān),但并沒有在下游任務(wù)上生成最準確的剪枝模型。因此,作者使用 Winogrande 基準來評估 Layer 的重要性。
3.2.2 模型修剪
對于給定的模型,獲得每個軸的重要性排名之后,可以直接對相應(yīng)的權(quán)重矩陣進行修剪。對于神經(jīng)元和 Head 修剪,分別修剪 MLP 和 MHA 層權(quán)重;對于 Embedding Channel,修剪 MLP、MHA 和 LayerNorm 中權(quán)重矩陣的 Embedding 維度。
3.2.3 蒸餾訓(xùn)練
對修剪后的模型進行 ReTraining 以恢復(fù)準確性。本文中,作者探索了兩種 ReTraining 策略:
- 利用 Ground truth 標簽的常規(guī)訓(xùn)練。
- 使用未修剪模型(Teacher)進行監(jiān)督知識蒸餾。
蒸餾的過程如下圖 Figure 3 所示,作者只在最后的 Logits 上添加 KL 散度損失。
3.3 訓(xùn)練詳情
3.3.1 預(yù)訓(xùn)練
使用預(yù)訓(xùn)練的 LLaMA 3.1 8B([2407.21783] The Llama 3 Herd of Models) 和 Mistral Nemo 12B 模型(Mistral NeMo | Mistral AI | Frontier AI in your hands)。
3.3.2 數(shù)據(jù)集
所有實驗使用 Nemotron-4 ([2402.16819] Nemotron-4 15B Technical Report 和 [2407.07263] Reuse, Don't Retrain: A Recipe for Continued Pretraining of Language Models)的 Continuous Training(CT) 數(shù)據(jù)集。
3.3.3 剪枝
作者采用的簡化剪枝方案來自 Minitron 論文([2407.14679] Compact Language Models via Pruning and Knowledge Distillation)中的最佳實踐。具體來說:
- 寬度剪枝:
分別使用 l2-norm 和 mean 作為跨 Batch 和 Sequence 維度的聚合函數(shù)。
執(zhí)行單次修剪,避免迭代方案。
- 深度剪枝:
- 遵循[2403.17887] The Unreasonable Ineffectiveness of the Deeper Layers 中的觀察結(jié)果,刪除一個連續(xù) subgroup,該 subgroup 會使 Winogrande 的準確性下降最小。
- 沒有采用 NAS 搜索的結(jié)構(gòu)。
修剪后的 3 個模型參數(shù)如下圖所示,這里也可以看出,寬度剪枝后各層的超參數(shù)一致:
3.3.4 蒸餾
Teacher 校正:直接使用 Mistral Nemo 12B 模型在作者自己的數(shù)據(jù)集上表現(xiàn)不佳,這是由于 Teacher 模型訓(xùn)練的原始數(shù)據(jù)集與蒸餾數(shù)據(jù)集的分布不一致。為了解決這個問題,作者首先使用數(shù)據(jù)集中 >= 127B 的 Token 微調(diào) Teacher 模型。如下圖 Figure 4 所示,使用經(jīng)校正的 Teacher 模型蒸餾,Student 模型在驗證集上的 Loss 明顯低于使用原始 Teacher 模型。因此,作者將這種方案應(yīng)用到了 Mistral-Nemo 和 LLaMA 3.1 Teacher 模型。當然,微調(diào) Teacher 模型也會導(dǎo)致其在一些指標上有所提升,而在一些指標上有所下降:
Retrining:根據(jù) Minitron 中的方案,作者選擇僅 Logit 蒸餾(PS:之前很多工作也會蒸餾 Feature Map),最大限度減少 Teacher 和 Student 的 KL 散度損失,并完全忽略 LM 交叉熵損失。蒸餾的超參數(shù)如下圖 Table 4 所示,在 32 個 DGX H100 節(jié)點上訓(xùn)練。
指令微調(diào):為了評估蒸餾模型的指令跟隨能力,作者使用 Nemo-Aligner 和用于 Nemotron-4 340B 的指令微調(diào)數(shù)據(jù)集對 LLaMA 3.1 Minitron 4B 模型進行 SFT,結(jié)果如下圖 Table 2 所示:
四、分析和評估
4.1 分析
4.1.1 寬度和深度剪枝
如下圖 Figure 5 所示為根據(jù)寬度和深度剪枝的 LLaMA-3.1-Minitron-4B 的訓(xùn)練曲線,可以看出,兩者具有相同的參數(shù)量,但是寬度剪枝對應(yīng)的初始損失更小,并且始終優(yōu)于深度剪枝。
4.1.2 剪枝和蒸餾
如下圖 Figure 6 展示了剪枝和蒸餾方法的正交優(yōu)勢。作者比較了下述 4 種方案,可以看出,與隨機初始化相比,剪枝的起點明顯更好,而基于蒸餾的訓(xùn)練優(yōu)于傳統(tǒng)的訓(xùn)練方法,同時需要訓(xùn)練的 Token 明顯減少:
- Random Init + Distillation:隨機權(quán)重初始化和蒸餾。
- Random Pruning + Distillation:隨機剪枝和蒸餾。其中的組件被隨機修剪而不是依賴重要性分數(shù)。
- Pruning + LM Loss:使用本文的修剪方案,但使用基于交叉熵的 LM Loss 訓(xùn)練。
- Pruning + Distillation:本文的剪枝和蒸餾方案。LM 驗證損失最低。?
4.1.3 Teacher 校正
如下圖 Figure 7 所示,作者對比了兩種 Teacher 校正方法,結(jié)果表明,Teacher 校正并不影響剪枝的最優(yōu)性,用校正后的 Teacher 至關(guān)重要。Teacher 校正也可以與蒸餾同時進行,以彌合差距:
- Prune corrected teacher + distill corrected teacher:剪枝和蒸餾校正的 Teacher 模型。
- Prune original teacher + distill continuously corrected teacher:剪枝原始的 Teacher 模型,并使用不斷校正的 Teacher 模型來蒸餾。?
4.1.4 深度剪枝度量
在檢查 LM 驗證損失如何隨著連續(xù) Layer Block 的刪除而增加時,如下圖 Figure 8 所示,作者觀察到開始和結(jié)尾的層是最重要的。刪除非連續(xù)層可能導(dǎo)致更好的 LM 驗證損失。
但是在評估下游任務(wù)時,上述結(jié)論不一定成立。如下圖 Figure 9 所示,根據(jù)每層重要性刪除 16 層,Winogrande 精度為 0.5,而連續(xù)刪除 16-31 層的精度為0.595。在基于蒸餾的 Retraining 中,差距仍然存在,作者選擇了后一種方法。
4.2 評估
4.2.1 Base 模型評估
如下圖 Table 1 所示為 Base 模型的評估,與類似大小的模型相比,MN-Minitron8B 在各方面都表現(xiàn)出卓越的準確性,并且訓(xùn)練 Token 數(shù)小 40x(380B vs 15T)。同樣,與 Teacher LLaMA 3.1 8B 模型相比,LLaMA-3.1 4B 模型表現(xiàn)良好,并且使用的訓(xùn)練 Token 減少 150x(94B vs 15T)。剪枝后的 LLaMA-3.1 4B 也優(yōu)于之前的 Minitron 4B。此外,從中也可以看出,基于寬度剪枝的變體優(yōu)于基于深度剪枝的變體。這些結(jié)果充分表明了方案的有效性。
PS:不過從 8B -> 4B 的損失依然比較大,甚至和直接進行 AWQ(W4A16)的量化損失差不多。當然 8B AWQ 的推理效率可能不如 4B,然而一些 W8A8 的方案也能獲得相當?shù)木?,詳情可以參見后文“附錄”的量化部分。此外,量化的成本可能遠低于剪枝+蒸餾。如下圖所示為 neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w4a16 · Hugging Face 中使用 W4A16 的 Instruct 模型精度:
4.2.2 Instruct 模型評估
如下圖 Table 2 所示為 LLaMA-3.1 Minitron 4B 指令微調(diào)的性能??梢钥闯?,其寬度剪枝變體在所有指標上優(yōu)于原始的 Minitron 4B:
五、附錄
5.1 模型壓縮
模型壓縮有 4 種常見的方案:量化,剪枝,蒸餾,低秩分解。在 LLM 場景中,模型量化的方案非常多,比如 llm.int8()、AWQ、GPTQ、SmoothQuant 等等,其實現(xiàn)簡單,代價小,是最常見的方案。而其他幾種方案應(yīng)用相對比較少,它們的區(qū)別如下圖 Figure 2 所示,圖片來自 [2308.07633] A Survey on Model Compression for Large Language Models:
5.2 量化
量化是模型壓縮中最常用的技術(shù),不同的量化方案會針對不同的精度、Tensor 類型等,比如有常見的 KV Cache Only 量化,Weight Only 量化,以及幾種方案的結(jié)合,具體如下圖所示:
不同的量化方案在不同模型上的量化損失也會有所不同,但是大體上來說,壓縮后的 Bit 數(shù)越低損失越大。如下圖 Table 1 所示為 [2404.14047] An Empirical Study of LLaMA3 Quantization: From LLMs to MLLMs 中對 LLaMA3-8B 模型的量化評估(都使用 INT 類型,未使用 FP8/FP6/FP4),可以看出:
- W8A16 的量化損失都比較小,幾乎無損
- W4A16 和 W8A8 的損失相比 W8A16 會大一些,大部分情況也基本可以接受,但也和量化方法有關(guān),比如,GPTQ 和 QuIP 的 W4A16 相比 AWQ 的損失會更大一些。
- 更低 Bit 的量化損失會比較大,比如 W3A16 和 W2A16 的精度會明顯下降。?
NVIDIA 的 GPU 從 Hopper 架構(gòu)開始可以支持 FP8 計算,使用 FP8 后其精度相比 SmoothQuant 的 INT8 以及其他的 W4A16 損失更小,更具有使用價值(數(shù)據(jù)來自 https://github.com/NVIDIA/TensorRT-LLM/blob/v0.9.0/docs/source/blogs/quantization-in-TRT-LLM.md):
那么這些量化方案的速度怎么樣呢,如下圖所示,在 [2404.14294] A Survey on Efficient Inference for Large Language Models 中作者評估了 TensorRT-LLM 和 LMDeploy 推理框架在不同場景的 W4A16 推理性能,使用的 GPU 為 NVIDIA A100,圖中的數(shù)據(jù)分別為 Prefill/Decoding/End2End 的加速比,可以看出,基本都可以實現(xiàn) 2 倍左右加速,當序列比較長或者 Batch size 比較大時會略低一些,當然兩個框架也各有千秋:
- TensorRT-LLM 在 Batch size 比較小時優(yōu)勢更明顯。
- LMDeploy 在 Batch size 比較大時優(yōu)勢更明顯。?
如下圖所示為使用 FP8 相比 FP16 可以加速 1.4-1.5 倍,這是比較早的數(shù)據(jù),進一步優(yōu)化后可以到 1.6-1.7 倍:
如下表所示為 TensorRT-LLM 中作者對不同量化方案的總結(jié),可以看出,如果 GPU 支持 FP8,則使用 FP8 是最理想的選擇,如果允許少量精度損失,則可以進一步使用 INT4-FP8 AWQ 的方案:
5.3 剪枝
NVIDIA 在 Ampere 架構(gòu)的 Tensor Core 中引入了稀疏矩陣乘法支持,理論最多可以提升 2 倍性能,實際上可能只有 1.5 倍,而且對稀疏化的方式有要求,如下圖所示,每 4 個 Weight 中需要有 2 個是 0 值,也就是可以剪掉的值:
基本上稀疏化都會帶來一定的精度損失,如下圖 Table 2 所示,論文 [2310.15929] E-Sparse: Boosting the Large Language Model Inference through Entropy-based N:M Sparsity 中作者評估了 2:4 稀疏化對模型精度的影響,可以看出(PS:論文圖片中的部分數(shù)字有誤,比如 13B 模型上 Magnitude 平均精度達到了 57.71,實際平均只有 49.36):
- 每種稀疏化方法都有一定的精度損失,即使最優(yōu)的方案損失也比較大。
- 不同規(guī)模的模型上精度損失差異比較大,小模型損失比較大。
- 13B 模型稀疏化后的精度還不如原始的 7B 模型,那么 13B 模型的稀疏化基本沒有太大的意義。?
在這樣的損失下能帶來什么收益呢?其收益主要體現(xiàn)在速度的提升和顯存的節(jié)約,如下圖 Table 5 所示,其矩陣乘法可以加速 20%-30%,而端到端延時只能加速 16% 左右:
當然,顯存的節(jié)約還是比較客觀的,基本可以節(jié)約 43% 左右的顯存空間,也許可以通過增加 Batch Size 來增加吞吐:
六、參考鏈接
- https://arxiv.org/abs/2408.11796
- https://arxiv.org/abs/2407.21783
- https://mistral.ai/news/mistral-nemo/
- https://arxiv.org/abs/2402.16819
- https://arxiv.org/abs/2407.07263
- https://www.arxiv.org/abs/2407.14679
- https://arxiv.org/abs/2403.17887
- https://arxiv.org/abs/2308.07633
- https://arxiv.org/abs/2404.14047
- https://arxiv.org/abs/2404.14294
- https://arxiv.org/abs/2310.15929
本文轉(zhuǎn)載自 ??AI閑談??,作者: AI閑談
