從《你所需要的就是注意力》到《你所需要的就是多頭潛在注意力》,TransMLA開啟AI技術(shù)新篇章
自2017年谷歌提出了Transformer架構(gòu),以及那篇著名的論文《Attention Is All You Need》后,注意力機(jī)制迅速成為自然語言處理領(lǐng)域的核心技術(shù)。大型語言模型(LLMs)借助Transformer的自注意力機(jī)制,實(shí)現(xiàn)了對復(fù)雜語言模式的捕捉,在機(jī)器翻譯、文本生成、對話系統(tǒng)等領(lǐng)域取得了革命性的突破。它們不僅改變了學(xué)術(shù)研究的方向,更深刻地影響了生產(chǎn)力工具的發(fā)展,提高了人們的工作效率和生活質(zhì)量。
隨著模型規(guī)模和數(shù)據(jù)量的不斷增長,LLMs面臨的挑戰(zhàn)也日益凸顯。盡管計算能力的提升使得訓(xùn)練更大規(guī)模的模型成為可能,但硬件的通信瓶頸卻成為制約模型進(jìn)一步發(fā)展的主要障礙。這種瓶頸并非源于計算資源的匱乏,而是由于自注意力機(jī)制在處理長序列時,內(nèi)存和通信開銷呈指數(shù)級增長。
在自注意力機(jī)制中,每個Token都需要與序列中的所有其他Token進(jìn)行互動,這導(dǎo)致了計算量和存儲需求的急劇上升,特別是鍵值(Key-Value,KV)緩存的大小隨著序列長度的增加而迅速增長。當(dāng)處理超長文本時,KV緩存的內(nèi)存需求可能超過現(xiàn)有硬件的承載能力,導(dǎo)致模型無法高效地進(jìn)行推理和生成。
面對這一挑戰(zhàn),研究人員嘗試了多種解決方案。其中,組查詢注意力(Group Query Attention,GQA)作為一種優(yōu)化方法,通過減少鍵和值的頭數(shù),降低了KV緩存的尺寸。然而GQA在降低內(nèi)存開銷的同時,也不可避免地限制了模型的表達(dá)能力,導(dǎo)致性能下降。如何在不增加KV緩存大小的情況下,提高模型的表達(dá)能力,成為亟待解決的問題。
基于此,來自北京大學(xué)和小米公司的研究團(tuán)隊(duì)提出了多頭潛在注意力(Multi-head Latent Attention,MLA)。他們認(rèn)為,通過引入低秩矩陣,可以在鍵值層實(shí)現(xiàn)KV緩存的壓縮,同時通過上投影矩陣,增強(qiáng)模型的表達(dá)能力,從而在不增加KV緩存開銷的情況下,提升模型性能。
他們進(jìn)一步提出了TransMLA,這是一種將現(xiàn)有的GQA模型轉(zhuǎn)換為MLA模型的后訓(xùn)練方法。通過理論證明,他們證明了在相同的KV緩存開銷下,MLA的表達(dá)能力始終優(yōu)于GQA。實(shí)驗(yàn)結(jié)果也驗(yàn)證了這一觀點(diǎn),轉(zhuǎn)換后的MLA模型在下游任務(wù)中表現(xiàn)出色,顯著優(yōu)于原始的GQA模型。他們開源了TransMLA架構(gòu):https://github.com/fxmeng/TransMLA
這項(xiàng)研究的重要性在于,它為大型語言模型的優(yōu)化提供了一條新的路徑。在硬件限制難以突破的情況下,如何通過模型結(jié)構(gòu)的調(diào)整來提升性能,就顯得尤為關(guān)鍵。TransMLA的提出,不僅為模型開發(fā)者提供了一種低成本、高效率的優(yōu)化方案,也為LLMs的未來發(fā)展指明了方向。
值得一提的是,這個研究團(tuán)隊(duì)背景深厚。主要作者Fanxu Meng, Zengwei Yao, Muhan Zhang來自北京大學(xué)人工智能研究院和通用人工智能國家重點(diǎn)實(shí)驗(yàn)室,擁有扎實(shí)的理論基礎(chǔ)和豐富的研究經(jīng)驗(yàn)。同時,團(tuán)隊(duì)成員也包括來自小米公司的工程師,具備將先進(jìn)技術(shù)應(yīng)用于實(shí)際產(chǎn)品的能力。這種產(chǎn)學(xué)研結(jié)合的團(tuán)隊(duì)構(gòu)成,使得他們的研究既有前瞻性的理論價值,又具備實(shí)際應(yīng)用的可行性。
相關(guān)工作
自注意力機(jī)制的內(nèi)存瓶頸
自注意力機(jī)制是Transformer模型的核心組件,它允許模型在序列中的任意位置建立關(guān)聯(lián)。然而,其計算和內(nèi)存需求與序列長度呈二次關(guān)系。具體來說,對于長度為 TT 的序列,自注意力機(jī)制需要計算并存儲一個 T×TT \times T 的注意力矩陣。此外,還需要存儲每個Token的鍵(Key)和值(Value)向量,這些鍵值對的數(shù)量隨著序列長度的增加而迅速增加,導(dǎo)致KV緩存的內(nèi)存占用也迅速增長。
當(dāng)處理長序列時,每個Token的KV對都需要重新計算,并存儲在緩存中,這種緩存的內(nèi)存需求會隨著序列長度呈指數(shù)級增長。這種增長超出了現(xiàn)有硬件的內(nèi)存容量,成為系統(tǒng)的瓶頸,限制了大模型在處理長文本任務(wù)中的應(yīng)用。例如,在長文檔摘要或長篇文章生成任務(wù)中,KV緩存的內(nèi)存需求可能達(dá)到甚至超過現(xiàn)有硬件的內(nèi)存上限,導(dǎo)致模型無法有效運(yùn)行。
現(xiàn)有解決方案及局限性
為了緩解自注意力機(jī)制帶來的內(nèi)存瓶頸,研究人員提出了多種方法。這些方法在降低內(nèi)存和計算需求的同時,也帶來了各自的局限性。
圖1:將GQA等效轉(zhuǎn)換為MLA的過程。(a) 在GQA中,在計算注意力之前,重復(fù)K以匹配Q頭的數(shù)量。(b) 相反,WK被重復(fù);與X相乘后,直接產(chǎn)生K′,相當(dāng)于GQA。(c) 分解W′K產(chǎn)生一個r維表示,它與GQA中的KV緩存維度相匹配,同時保持等價性。
線性注意力:例如Linear Transformer、RWKV、Mamba等方法,通過線性化注意力機(jī)制,使計算和內(nèi)存需求與序列長度線性相關(guān)。這些方法通過重新設(shè)計注意力機(jī)制,將復(fù)雜的矩陣乘法轉(zhuǎn)化為向量內(nèi)積,從而避免了大量的矩陣計算。然而,這種方法可能降低模型對長距離依賴的捕捉能力,導(dǎo)致在處理需要長距離依賴的任務(wù)時,模型性能下降。
動態(tài)Token剪枝:如LazyLLM、A2SF和SnapKV,通過動態(tài)地移除不重要的Token,以減少KV緩存的大小。這些方法選擇性地保留對后續(xù)預(yù)測最重要的Token,從而減少內(nèi)存使用。然而,動態(tài)剪枝可能錯失關(guān)鍵的上下文信息,尤其是在需要全局理解的任務(wù)中。此外,為了確定哪些Token可以被剪枝,需要額外的計算和復(fù)雜的策略設(shè)計,這增加了系統(tǒng)的復(fù)雜性。
剪枝頭部維度:方法如SliceGPT、Sheared和Simple Pruning,通過減少注意力頭的數(shù)量或降低每個頭的維度,來減少計算和內(nèi)存需求。這些方法通過剪去模型中冗余或不重要的頭部來優(yōu)化內(nèi)存使用。然而,過度的剪枝可能削弱模型的表達(dá)能力,因?yàn)樽⒁饬︻^的多樣性對于捕捉復(fù)雜的語言模式至關(guān)重要,過少的頭部可能導(dǎo)致模型無法充分表示復(fù)雜的關(guān)系,進(jìn)而影響性能。
跨層共享KV表示:如YONO、MiniCache和MLKV,通過在不同Transformer層之間共享KV表示,減少需要存儲的KV緩存的總量。這種方法通過讓每層使用相同的KV表示來節(jié)省內(nèi)存。然而,由于不同層可能需要不同的表示來捕獲不同層次的特征,共享KV表示可能會限制模型的學(xué)習(xí)能力,并引入不同層間的干擾,影響性能的穩(wěn)定性。
KV量化:如GPTQ、Kivi和KVQuant,通過對KV向量進(jìn)行量化,以較低的比特數(shù)表示,從而減少內(nèi)存占用。例如,將浮點(diǎn)數(shù)表示的KV向量量化為8位或更低的精度。這種方法可以顯著減少內(nèi)存使用和計算開銷。然而,量化可能引入誤差,導(dǎo)致表示精度下降,影響模型的準(zhǔn)確性。過度量化還可能導(dǎo)致梯度消失或爆炸,影響模型的訓(xùn)練和穩(wěn)定性。
組查詢注意力(GQA)的應(yīng)用與局限
組查詢注意力(Group Query Attention,GQA)是一種優(yōu)化策略,通過將多個查詢頭分組,使每個組共享相同的鍵和值表示,從而減少了KV緩存的大小。在這種機(jī)制下,鍵和值的頭數(shù)減少了,從而降低了內(nèi)存需求。
GQA的應(yīng)用優(yōu)勢
降低內(nèi)存開銷:通過共享鍵和值,顯著減少KV緩存的大小,適用于內(nèi)存受限的環(huán)境。
提升計算效率:減少計算量,加快模型的推理速度。
然而GQA也存在一定的局限性,在減少內(nèi)存開銷的同時,也限制了模型的表達(dá)能力。由于多個查詢共享相同的鍵和值,模型的表達(dá)多樣性受到限制,這使得模型難以捕捉復(fù)雜的模式和關(guān)系。在一些需要細(xì)粒度注意力的任務(wù)中,GQA的性能可能無法達(dá)到理想效果。
盡管GQA在降低KV緩存方面表現(xiàn)出色,但其對模型表達(dá)能力的限制使得在一些應(yīng)用場景中性能受損。為此,尋找一種既能降低KV緩存大小,又不犧牲模型表達(dá)能力的方法,成為了亟待解決的問題。這正是多頭潛在注意力(MLA)所要解決的核心挑戰(zhàn)。
TransMLA方法詳解
MLA的核心思想
多頭潛在注意力(MLA)的核心思想是通過在鍵值層使用低秩矩陣來壓縮KV緩存,同時引入上投影矩陣以增強(qiáng)模型的表達(dá)能力。MLA通過低秩矩陣降低了KV緩存的尺寸,在減少內(nèi)存占用的同時,保持了必要的注意力信息。而上投影矩陣則提供了更高的表達(dá)能力,使得模型可以在不增加KV緩存大小的情況下,提高注意力機(jī)制的靈活性和多樣性。
理論證明:MLA表達(dá)能力優(yōu)于GQA
定理1:在相同的KV緩存開銷下,MLA的表達(dá)能力始終大于GQA。
證明思路:
GQA通過復(fù)制鍵值頭,使得組內(nèi)的查詢共享相同的鍵和值,這在一定程度上限制了模型的表達(dá)多樣性。而MLA則通過低秩分解,使得鍵和值的表示更加豐富,允許更多的通道間多樣性。這種差異使得MLA在保持相同KV緩存開銷的情況下,能夠提供更強(qiáng)的表達(dá)能力。
具體證明步驟:
首先,我們展示如何在GQA中通過復(fù)制鍵來匹配查詢頭的數(shù)量。設(shè)X為輸入序列,WK為鍵的投影矩陣,則鍵矩陣KK可以表示為:
在GQA中,為了匹配查詢頭的數(shù)量,鍵矩陣需要被復(fù)制。設(shè)復(fù)制因子為s,則復(fù)制后的鍵矩陣可以表示為:
接下來,我們將復(fù)制過程移到參數(shù)側(cè),即先復(fù)制投影矩陣,然后再進(jìn)行計算。設(shè)WK′為復(fù)制后的投影矩陣,則有:
這樣做可以實(shí)現(xiàn)相同的效果,同時不增加KV緩存的大小。
為了進(jìn)一步增強(qiáng)表示能力,我們引入低秩分解。使用奇異值分解(SVD),將復(fù)制的投影矩陣分解成低秩矩陣:
其中,UK和VK是正交矩陣,SK是對角矩陣。我們可以截斷SVD,僅保留前r個奇異值,從而實(shí)現(xiàn)低秩表示:
設(shè)低秩矩陣為WKa和WKb,則有:
最終,鍵矩陣表示為:
這樣,MLA通過低秩分解,不僅壓縮了KV緩存,還增強(qiáng)了模型的表達(dá)能力。通過構(gòu)造一個WKbW^b_K中向量正交的例子,可以證明MLA在某些情況下的表達(dá)能力是GQA無法達(dá)到的。
TransMLA的模型轉(zhuǎn)換方法
TransMLA提供了一種將基于GQA的預(yù)訓(xùn)練模型轉(zhuǎn)換為MLA模型的方法。具體步驟如下。
調(diào)整權(quán)重矩陣:將原始GQA模型的鍵和值投影矩陣拆分,并引入額外的低秩矩陣。通過低秩分解,我們得到兩個新的投影矩陣WKaW^a_K和WKbW^b_K,用以替換原始的投影矩陣。
保持KV緩存大小不變:在轉(zhuǎn)換過程中,通過低秩分解和矩陣調(diào)整,確保轉(zhuǎn)換后的MLA模型的KV緩存大小與原始GQA模型相同。這意味著,我們在增強(qiáng)模型表達(dá)能力的同時,不會增加內(nèi)存開銷。
增強(qiáng)模型表達(dá)能力:轉(zhuǎn)換后的MLA模型可以在不增加內(nèi)存開銷的情況下,提升注意力機(jī)制的多樣性和靈活性。具體地,通過引入上投影矩陣WKbW^b_K,模型可以更好地捕捉復(fù)雜的模式和關(guān)系,提升整體性能。
通過上述方法,TransMLA不僅成功地將GQA模型轉(zhuǎn)換為MLA模型,還顯著提升了模型的表達(dá)能力和實(shí)際性能,為大模型的優(yōu)化提供了一種高效的解決方案。
實(shí)驗(yàn)結(jié)果與分析
實(shí)驗(yàn)設(shè)置
為了驗(yàn)證TransMLA的有效性,研究團(tuán)隊(duì)設(shè)計了一系列實(shí)驗(yàn),選擇了兩個基于GQA的預(yù)訓(xùn)練模型:Qwen2.5-7B和Qwen2.5-14B。通過將這些模型轉(zhuǎn)換為MLA模型,并調(diào)整相應(yīng)的權(quán)重矩陣維度,研究團(tuán)隊(duì)希望評估MLA在實(shí)際應(yīng)用中的表現(xiàn)。
圖2:TransMLA和基于GQA的Qwen模型之間訓(xùn)練損失和準(zhǔn)確性的比較。
在實(shí)驗(yàn)中,使用了包含數(shù)學(xué)和編程任務(wù)的指令微調(diào)數(shù)據(jù)集SmolTalk,具體包括MetaMathQA和SelfOSS-Starcoder2-Instruct數(shù)據(jù)集。實(shí)驗(yàn)設(shè)置的訓(xùn)練配置為:批次大小為16,學(xué)習(xí)率為2e-5,訓(xùn)練2個周期。為了減少對原始模型的影響,實(shí)驗(yàn)中僅對鍵值層進(jìn)行訓(xùn)練。這樣的配置旨在保持實(shí)驗(yàn)的公平性和一致性,使得結(jié)果具有較高的可信度。
微調(diào)性能比較
在訓(xùn)練過程中,TransMLA模型的損失下降速度明顯快于原始GQA模型。這表明,TransMLA在數(shù)據(jù)擬合方面具有明顯優(yōu)勢。在相同的訓(xùn)練條件下,TransMLA能夠更快速地適應(yīng)數(shù)據(jù),并在較短的時間內(nèi)達(dá)到更低的損失值。這一結(jié)果驗(yàn)證了TransMLA在訓(xùn)練效率上的改進(jìn),表明其優(yōu)化方法在實(shí)際應(yīng)用中具有較高的效能。
在具體的任務(wù)表現(xiàn)上,TransMLA模型也展現(xiàn)出了優(yōu)越的性能。在數(shù)學(xué)任務(wù)(如GSM8K)和編程任務(wù)(如HumanEval)中,TransMLA模型的準(zhǔn)確率顯著高于原始GQA模型。例如,在7B模型的實(shí)驗(yàn)中,GSM8K的準(zhǔn)確率從原始GQA模型的81%提升到了TransMLA模型的86%。這種顯著的性能提升,不僅表明TransMLA在特定任務(wù)中的適應(yīng)性強(qiáng),也展示了其在提升模型表達(dá)能力方面的潛力。
TransMLA的性能提升可以歸因于以下幾個方面。
增強(qiáng)的鍵值維度在模型表達(dá)能力上的提升起到了關(guān)鍵作用,通過調(diào)整權(quán)重矩陣,TransMLA擴(kuò)大了鍵和值的維度,使得模型能夠捕捉到更多的特征和模式,從而提升了整體性能。
正交分解在優(yōu)化模型結(jié)構(gòu)方面也發(fā)揮了重要作用,實(shí)驗(yàn)結(jié)果表明,僅增加參數(shù)量并不能帶來顯著的性能提升,正交分解方法通過優(yōu)化模型的表示,使得模型在保持較低內(nèi)存開銷的情況下,實(shí)現(xiàn)了性能的顯著提升。這一結(jié)果強(qiáng)調(diào)了結(jié)構(gòu)優(yōu)化在提升模型性能中的重要性。
綜上所述,TransMLA通過優(yōu)化KV緩存的表示和引入先進(jìn)的矩陣分解技術(shù),實(shí)現(xiàn)了在不增加內(nèi)存開銷的情況下,顯著提升模型性能的目標(biāo)。這不僅為大規(guī)模語言模型的進(jìn)一步發(fā)展提供了新的思路,也為未來的研究和應(yīng)用指明了方向。
結(jié)論與展望
TransMLA成功地展示了多頭潛在注意力(MLA)的強(qiáng)大優(yōu)勢,證明了其在表達(dá)能力上顯著優(yōu)于組查詢注意力(GQA)。通過理論分析和實(shí)驗(yàn)驗(yàn)證,研究團(tuán)隊(duì)展示了MLA在相同KV緩存開銷的情況下,能夠提供更強(qiáng)的模型性能。轉(zhuǎn)換后的MLA模型在多種下游任務(wù)中表現(xiàn)出色,為大型語言模型(LLMs)的注意力機(jī)制設(shè)計提供了全新的視角和思路。研究表明,TransMLA不僅有效地減少了KV緩存的內(nèi)存需求,還大幅提升了模型的表達(dá)能力和推理速度。
TransMLA的研究不僅在技術(shù)層面上取得了重要突破,也為整個行業(yè)帶來了深遠(yuǎn)的影響。
TransMLA顯著提升了資源效率,在不增加硬件負(fù)擔(dān)的情況下,通過優(yōu)化注意力機(jī)制和鍵值緩存,提升了模型的性能。這為企業(yè)和研究機(jī)構(gòu)在有限的計算資源下,進(jìn)一步提高模型的表現(xiàn)提供了新的解決方案。
TransMLA對生態(tài)環(huán)境也具有友好性,通過減少模型對內(nèi)存和計算資源的需求,降低了能源消耗和碳排放。這一特性使得TransMLA在綠色計算和可持續(xù)發(fā)展方面表現(xiàn)出色,符合當(dāng)今社會對環(huán)保和節(jié)能的追求。
TransMLA提供了一種低成本的模型轉(zhuǎn)換方法,使現(xiàn)有的基于GQA的模型可以輕松遷移到MLA架構(gòu)。這樣的遷移不僅能夠保留原有模型的優(yōu)點(diǎn),還能大幅提升其性能和效率。這為模型開發(fā)者和應(yīng)用者提供了極大的便利,加速了新技術(shù)的推廣和應(yīng)用。
研究團(tuán)隊(duì)計劃在以下幾個方面繼續(xù)拓展和優(yōu)化TransMLA。
團(tuán)隊(duì)將擴(kuò)展到更大規(guī)模的模型,包括LLaMA、Qwen、Mistral等大型語言模型,驗(yàn)證TransMLA在更大規(guī)模下的有效性。這將進(jìn)一步檢驗(yàn)和提升MLA在處理超長序列和大規(guī)模數(shù)據(jù)上的表現(xiàn)。
研究團(tuán)隊(duì)將使用DeepSeek R1蒸餾技術(shù),進(jìn)一步提升轉(zhuǎn)換后模型的性能。通過蒸餾技術(shù),可以優(yōu)化模型的參數(shù)和結(jié)構(gòu),使其在保持高性能的同時,具有更高的計算效率和穩(wěn)定性。
針對MLA的特性,研究團(tuán)隊(duì)計劃開發(fā)特定的推理加速策略。通過設(shè)計高效的推理算法,最大程度地發(fā)揮MLA的優(yōu)勢,保持模型的低延遲和高響應(yīng)速度。這將為實(shí)時應(yīng)用和高頻交互場景中的LLMs提供更加優(yōu)質(zhì)的性能支持。
總的來說,TransMLA的成功不僅為現(xiàn)有的大模型優(yōu)化提供了有效的解決方案,也為未來的研究和應(yīng)用開辟了新的路徑。隨著技術(shù)的不斷進(jìn)步和優(yōu)化,相信MLA將在更多領(lǐng)域發(fā)揮重要作用,推動人工智能技術(shù)的發(fā)展和創(chuàng)新。
參考資料:???https://arxiv.org/pdf/2502.07864??
本文轉(zhuǎn)載自??獨(dú)角噬元獸??,作者: FlerkenS ????
