DeepSeek中的多頭潛在注意力(MLA)淺嘗
MLA是MHA的變體,因此先來看看MHA。
MHA(多頭注意力)
MHA通過將輸入向量分割成多個(gè)并行的注意力“頭”,每個(gè)頭獨(dú)立地計(jì)算注意力權(quán)重并產(chǎn)生輸出,然后將這些輸出通過拼接和線性變換進(jìn)行合并以生成最終的注意力表示。
Transformer 編碼器塊內(nèi)的縮放點(diǎn)積注意力機(jī)制和多頭注意力機(jī)制
MHA計(jì)算過程
MHA 能夠理解輸入不同部分之間的關(guān)系。然而,這種復(fù)雜性是有代價(jià)的——對(duì)內(nèi)存帶寬的需求很大,尤其是在解碼器推理期間。主要問題的關(guān)鍵在于內(nèi)存開銷。在自回歸模型中,每個(gè)解碼步驟都需要加載解碼器權(quán)重以及所有注意鍵和值。這個(gè)過程不僅計(jì)算量大,而且內(nèi)存帶寬也大。隨著模型規(guī)模的擴(kuò)大,這種開銷也會(huì)增加,使得擴(kuò)展變得越來越艱巨。
MLA(多頭潛在注意力)
概念:
- 多頭注意力機(jī)制:Transformer 的核心模塊,能夠通過多個(gè)注意力頭并行捕捉輸入序列中的多樣化特征。
- 潛在表示學(xué)習(xí):通過將高維輸入映射到低維潛在空間,可以提取更抽象的語(yǔ)義特征,同時(shí)有效減少計(jì)算復(fù)雜度。
問題:
MLA 的提出:MLA 將多頭注意力機(jī)制 與 潛在表示學(xué)習(xí) 相結(jié)合,解決MHA在高計(jì)算成本和KV緩存方面的局限性。
MLA的具體做法(創(chuàng)新點(diǎn)): 采用低秩聯(lián)合壓縮鍵值技術(shù),優(yōu)化了鍵值(KV)矩陣,顯著減少了內(nèi)存消耗并提高了推理效率。
如上圖,在MHA、GQA中大量存在于keys values中的KV緩存——帶陰影表示,到了MLA中時(shí),只有一小部分的被壓縮Compressed的Latent KV了。
并且,在推理階段,MHA需要緩存獨(dú)立的鍵(Key)和值(Value)矩陣,這會(huì)增加內(nèi)存和計(jì)算開銷。而MLA通過低秩矩陣分解技術(shù),顯著減小了存儲(chǔ)的KV(Key-Value)的維度,從而降低了內(nèi)存占用。
MLA的核心步驟: