DeepSeek的MLA,任意大模型都能輕松遷移了
復(fù)旦 NLP 實驗室博士后紀燾是這篇文章的第一作者,研究方向為大模型高效推理、多模態(tài)大模型,近期代表工作為首個NoPE外推HeadScale、注意力分塊外推LongHeads、多視覺專家大模型MouSi,發(fā)表ACL、ICLR、EMNLP等頂會頂刊論文 20 余篇。
DeepSeek-R1 作為 AI 產(chǎn)業(yè)顛覆式創(chuàng)新的代表轟動了業(yè)界,特別是其訓(xùn)練與推理成本僅為同等性能大模型的數(shù)十分之一。多頭潛在注意力網(wǎng)絡(luò)(Multi-head Latent Attention, MLA)是其經(jīng)濟推理架構(gòu)的核心之一,通過對鍵值緩存進行低秩壓縮,顯著降低推理成本 [1]。
然而,現(xiàn)有主流大模型仍然基于標(biāo)準注意力架構(gòu)及其變種(e.g., MHA, GQA, MQA),推理成本相比 MLA 呈現(xiàn)顯著劣勢。使預(yù)訓(xùn)練的任意 LLMs 快速遷移至 MLA 架構(gòu)而無需從頭預(yù)訓(xùn)練,這既有重大意義又具有挑戰(zhàn)性。
復(fù)旦 NLP 實驗室、華東師大、上海 AI Lab、??低暵?lián)合提出 MHA2MLA 框架,通過部分 RoPE 保留(Partial-RoPE)和鍵值聯(lián)合表示低秩近似(Low-rank Approximation)兩個關(guān)鍵步驟,成功將任意 MHA/GQA 架構(gòu)遷移到 MLA。
目前,MHA2MLA 已位列??alphaXiv 熱度榜??
復(fù)旦 NLP 實驗室博士后紀燾為第一作者,副研究員桂韜為通訊作者。
- 論文標(biāo)題:Towards Economical Inference: Enabling DeepSeek's Multi-Head Latent Attention in Any Transformer-based LLMs
- 論文鏈接:https://arxiv.org/abs/2502.14837
- 開源代碼:https://github.com/JT-Ushio/MHA2MLA
論文概覽
本文聚焦如何將預(yù)訓(xùn)練的基于 MHA/GQA 的大語言模型高效遷移到 DeepSeek 提出的經(jīng)濟推理架構(gòu) —— 多頭潛在注意力(MLA)。
MHA 與 MLA 在多處存在差異,使得 MHA2MLA 極具挑戰(zhàn):
- 位置編碼不同:MHA 采用全維度位置編碼(PE),MLA 僅少量維度采用 PE,剩余維度則 PE 無關(guān)
- 緩存對象不同:MHA 緩存分離的鍵向量及值向量,MLA 緩存帶 PE 的鍵向量及 PE 無關(guān)的鍵值聯(lián)合低維表示向量
- 參數(shù)矩陣不同:MHA 包含查詢、鍵、值三個線性變換矩陣,MLA 則更加復(fù)雜、多達七個目的不同的線性變換矩陣
- 運算形式不同:MHA 的運算受限于訪存瓶頸,MLA 則能通過矩陣吸收等優(yōu)化實現(xiàn)更高的訪存效率
本文提出的 MHA2MLA 為了最大化利用 MHA 預(yù)訓(xùn)練參數(shù)矩陣并對齊 MLA 的緩存對象和運算形式,首先通過部分 RoPE 保留(Partial-RoPE)分離出 PE 相關(guān)表示(少量維度,如 1/8)和 PE 無關(guān)表示(大量維度),其中 PE 相關(guān)的鍵向量對齊 MLA。其次拼接值的變換矩陣(W_v)和 PE 無關(guān)的鍵的變換矩陣(W_{k, nope}),并進行 SVD 分解得到降維變換矩陣和升維變化矩陣,中間的鍵值聯(lián)合低秩表示對齊 MLA,完成了緩存對象的對齊以及運算形式的對齊。
在 135M~7B 上的實驗表明,僅需使用預(yù)訓(xùn)練數(shù)據(jù)的 0.3% 到 0.6% 進行高效微調(diào),即可基本還原架構(gòu)遷移帶來的性能損失。并且 MHA2MLA 還能結(jié)合其他高效推理技術(shù),例如結(jié)合 4-bit KV 緩存量化,Llama2-7B 減少了 92.19% KV 緩存,而 LongBench 上的性能僅下降 0.5%。
部分 RoPE 保留(Partial-RoPE)
為了實現(xiàn)從標(biāo)準的 MHA(多頭注意力機制)到 MLA(多頭潛在注意力機制)的遷移,作者提出了部分 RoPE 微調(diào)(partial-RoPE finetuning)策略,該策略通過從大量維度中移除 RoPE(旋轉(zhuǎn)位置編碼)并將其轉(zhuǎn)換為 NoPE(無位置編碼)來解決 MLA 和 RoPE 沖突的問題。
作者主要嘗試了四種移除 RoPE 的策略:1)保留高頻位置信息 S_high,該方法最簡單直接,保留了局部語義特征相關(guān)的高頻特征 [2];2)保留低頻位置信息 S_low,與保留高頻位置信息的策略形成對比,檢驗低頻成分在語義理解任務(wù)中的潛在作用;3)均勻采樣策略 S_uniform,等間隔均勻采樣頻率保留位置頻率;4)使用查詢、鍵向量范數(shù)乘積 (2-norm) 近似注意力貢獻值 [2] 的篩選策略 S_{2-norm},針對每個注意力頭,計算所有頻率的平均 2-norm 分數(shù),隨后選擇得分較高的頻率保留位置信息。該策略能自適應(yīng)識別對模型性能關(guān)鍵的特征頻率。
Partial-RoPE 的消融實驗表明:1)保留低頻位置信息的 S_low 導(dǎo)致了最大的性能損失,保留高頻位置信息的 S_high 導(dǎo)致的性能損失明顯小于保留低頻,說明了高頻維度的重要性;2)S_uniform 和 S_{2-norm} 均展現(xiàn)出更優(yōu)的性能,分別在 135M 模型和 1.7B 模型上取得了最少的性能損失。最終作者選擇 S_{2-norm} 作為默認配置,是因為注意力貢獻分數(shù)較低的維度在結(jié)合低秩近似時損失更少。
鍵值聯(lián)合表示低秩近似
移除了大量維度的 RoPE 之后,MHA2MLA 就可以對值向量和 PE 無關(guān)的鍵向量進行低秩近似,從而大幅減少緩存空間。為最大化保留預(yù)訓(xùn)練知識,本文提出兩種基于奇異值分解 (SVD) 的投影矩陣初始化策略:1)SVD_split,分別對矩陣進行低秩分解,保持各自的表征特性;2)SVD_joint,考慮鍵值矩陣之間的關(guān)聯(lián)性,參數(shù)矩陣拼接后整體進行低秩分解。
消融實驗表明:無論是在 GQA 基座還是 MHA 基座上,SVD_joint 方法始終優(yōu)于 SVD_split 方法。
實驗結(jié)果
作者在多種規(guī)模的語言模型(SmolLM-135M/360M/1B7 和 Llama2-7B)以及不同壓縮比例的配置下評估了所提出的方法。實驗表明:1)相同微調(diào)設(shè)置下,壓縮比例越高,性能損失越大,特別是對于兩個 GQA 模型;2)相同壓縮比例下,原始模型參數(shù)越多,性能損失越小,揭示了 MHA2MLA 的潛在 scaling law。3)MHA2MLA 的微調(diào)數(shù)據(jù)量僅需預(yù)訓(xùn)練數(shù)據(jù)的 0.3%~0.6%,避免了從頭預(yù)訓(xùn)練 MLA 模型的高昂成本。
作者在 LongBench 長文本生成任務(wù)中評估了結(jié)構(gòu)遷移后的 Llama2-7B 模型,將 KV 緩存量化作為基準對比方案。實驗表明,MHA2MLA 能在 d_{kv}=16 的情況下實現(xiàn)與 2-bit 量化相同的壓縮比例(87.5%),同時僅損失一半的性能(-3.0% vs. -6.2%);進一步結(jié)合 4-bit 量化后,不僅壓縮比例超過 2-bit 量化,性能損失也都優(yōu)于所有 2-bit 的基線方法,例如 92.19% 壓縮比例僅掉 0.5%,96.87% 壓縮比例僅掉 3.2%,證明了 MHA2MLA 能顯著減少推理時的訪存瓶頸。
總結(jié)與展望
本文主要研究如何將基于 MHA 的預(yù)訓(xùn)練 LLMs(或其變體)適配為 KV 緩存高效的 MLA 架構(gòu),以顯著降低推理時的訪存瓶頸。通過精心的架構(gòu)設(shè)計,MHA2MLA 僅需 0.3% 至 0.6% 預(yù)訓(xùn)練數(shù)據(jù)。該框架展現(xiàn)了與現(xiàn)有壓縮技術(shù)的強兼容性,同時保持了常識推理和長上下文處理能力,為部署資源高效的 LLMs 提供了一條實用路徑。
作者提到該研究受限于硬件條件,當(dāng)前實驗未能覆蓋 Llama3 等需 128K 長上下文微調(diào)的模型,也未突破 7B 參數(shù)規(guī)模的驗證瓶頸。擴展至更多的基座將作為未來工作之一。作者還計劃結(jié)合參數(shù)高效微調(diào)策略,進一步降低架構(gòu)遷移過程中的參數(shù)更新規(guī)模。