新注意力讓大模型上下文內(nèi)存占用砍半!精度不減還能加速2倍
大模型同樣的上下文窗口,只需一半內(nèi)存就能實(shí)現(xiàn),而且精度無(wú)損?
前蘋(píng)果ASIC架構(gòu)師Nils Graef,和一名UC伯克利在讀本科生一起提出了新的注意力機(jī)制Slim Attention。
它以標(biāo)準(zhǔn)多頭注意力(MHA)為基準(zhǔn),對(duì)其中的value緩存處理過(guò)程進(jìn)行了調(diào)整,實(shí)現(xiàn)了更少的內(nèi)存占用。
具體來(lái)說(shuō),Slim Attention既可以讓KV緩存大小減半,也可以在KV緩存大小不變的情況下讓上下文翻倍,都不會(huì)帶來(lái)精度損失。
此外,在內(nèi)存帶寬受限的場(chǎng)景下,它還可以將模型的推理過(guò)程加速1.5-2倍。
網(wǎng)友評(píng)價(jià),Slim Attention雖然簡(jiǎn)單,但卻是一個(gè)很酷的想法。
還有AI創(chuàng)業(yè)者評(píng)論說(shuō),這是一項(xiàng)重大突破,可能重塑對(duì)模型訓(xùn)練和部署的看法。
K-Cache is All You Need
在標(biāo)準(zhǔn)的MHA機(jī)制當(dāng)中,對(duì)于輸入X會(huì)通過(guò)線性變換,經(jīng)由三個(gè)投影矩陣W_Q、W_K、W_V得到Q(query)、K(key)和V(value)三個(gè)矩陣。
在推理階段,每個(gè)輸入token計(jì)算得到的K和V向量都需要緩存起來(lái),形成KV cache供后續(xù)token計(jì)算時(shí)使用。
Slim Attention的核心思路是,利用MHA中W_K和W_V通常都是方陣的性質(zhì),只存儲(chǔ)K而不直接存儲(chǔ)V,然后實(shí)時(shí)利用K計(jì)算出V。
△原始MHA(左)與改進(jìn)版(右)對(duì)比
在訓(xùn)練階段,Slim Attention與標(biāo)準(zhǔn)MHA一樣,會(huì)對(duì)輸入X計(jì)算Q、K、V三個(gè)矩陣,注意力計(jì)算和梯度回傳也與標(biāo)準(zhǔn)MHA完全一致。
在W_K可逆的前提下,Slim Attention引入一個(gè)新的參數(shù)矩陣W_KV:
W_KV = W_K^(-1)·W_V
據(jù)此,可以得到:
V = X·W_V = X·W_K·W_K^(-1)·W_V = K·W_KV
推理過(guò)程則主要分為兩個(gè)階段——提示階段(并行計(jì)算)和生成階段(自回歸)。
提示階段與標(biāo)準(zhǔn)MHA一樣,將輸入的所有token并行計(jì)算Q、K矩陣,但不同的是,這里不直接計(jì)算V,而是將中間結(jié)果K緩存供后續(xù)使用。
生成階段每個(gè)時(shí)間步生成一個(gè)新token,首先計(jì)算該時(shí)間步的Q向量q,然后基于q和之前時(shí)間步緩存的K矩陣,計(jì)算注意力得(即softmax的輸入)。
在softmax之前,Slim Attention通過(guò)公式V = K · W_KV實(shí)時(shí)計(jì)算V矩陣。具體有兩種方式:
- 直接計(jì)算V,然后將softmax結(jié)果與V相乘(矩陣乘法)得到注意力輸出;
- 先將softmax結(jié)果與K相乘,然后再與W_KV相乘,當(dāng)序列較長(zhǎng)時(shí)這種方式更高效。
剩余流程(殘差連接、前饋層等)與標(biāo)準(zhǔn)MHA一致,最后將當(dāng)前步的k向量添加到K緩存中,供下一時(shí)間步使用。
總之,Slim Attention是標(biāo)準(zhǔn)MHA的精確數(shù)學(xué)重寫(xiě),因此與近似方法不同,可確保準(zhǔn)確率不會(huì)下降。
以此為前提,Slim Attention實(shí)現(xiàn)了KV緩存減半或上下文翻倍的效果。
前蘋(píng)果架構(gòu)師與UC伯克利本科生成果
Slim Attention的作者是AI初創(chuàng)公司OpenMachine的創(chuàng)始人兼CEO Nils Graef,以及UC伯克利在讀本科生Andrew Wasielewski。
Nils的主業(yè)是機(jī)器學(xué)習(xí)加速器的架構(gòu)和設(shè)計(jì),曾發(fā)表兩篇IEEE期刊論文和30多項(xiàng)專(zhuān)利,引用次數(shù)超過(guò)900次。
創(chuàng)立OpenMachine前,Nils在知名推理加速平臺(tái)Groq(注意不是馬斯克的Grok)擔(dān)任芯片架構(gòu)師。
更早的時(shí)候,他先后擔(dān)任過(guò)谷歌ML加速器架構(gòu)&設(shè)計(jì)工程師和蘋(píng)果ASIC架構(gòu)師。
Andrew Wasielewski是UC伯克利在讀本科生,專(zhuān)業(yè)是物理和EECs(電氣工程與計(jì)算機(jī)科學(xué)),預(yù)計(jì)將于明年畢業(yè)。
根據(jù)論文署名信息顯示,Slim Attention的工作是Andrew在OpenMachine完成的。
去年7月,Nils和Andrew還與其他人合作,發(fā)表了一篇名為Flash normalization的論文,提出了一種更快的RNS歸一化方式。
此外在Slim Attention的致謝中還提到,艾倫實(shí)驗(yàn)室的Dirk Groeneveld,以及SGLang三作謝志強(qiáng),對(duì)其工作提供了有益討論;Transformer作者之一、Character.AI創(chuàng)始人Noam Shazeer給出了積極反饋。
論文地址:https://arxiv.org/abs/2503.05840
參考鏈接:https://x.com/rohanpaul_ai/status/1901092052282339474