2.5%KV緩存保持大模型90%性能,大模型金字塔式信息匯聚模式探秘
用KV緩存加速大模型的顯存瓶頸,終于迎來(lái)突破。
北大、威斯康辛-麥迪遜、微軟等聯(lián)合團(tuán)隊(duì)提出了全新的緩存分配方案,只用2.5%的KV cache,就能保持大模型90%的性能。
這下再也不用擔(dān)心KV占用的顯存容量過(guò)高,導(dǎo)致顯卡不夠用了。
圖片
該方法名為PyramidKV,顧名思義,在KV緩存壓縮的過(guò)程中融入了金字塔型的信息匯聚方式。
在內(nèi)存受限的情況下,PyramidKV表現(xiàn)非常出色,既保留了長(zhǎng)上下文理解能力,又顯著減少了內(nèi)存使用。
目前,PyramidKV相關(guān)代碼已經(jīng)在GitHub開源。
引入金字塔信息匯聚方式
隨著模型尺寸的增大,推理需要的時(shí)間越來(lái)越多。KV cache作為推理加速的關(guān)鍵技術(shù),通過(guò)緩存之前的解碼步驟中計(jì)算出的Transformer的K和V矩陣減少后續(xù)解碼時(shí)間。
但是,隨著序列長(zhǎng)度增大,需要緩存的KV cache會(huì)快速增長(zhǎng),占用大量顯存。針對(duì)這一問(wèn)題,之前的工作設(shè)計(jì)策略是對(duì)KV cache進(jìn)行壓縮。
實(shí)際上,長(zhǎng)文本的推理加速和顯存節(jié)省作為一個(gè)重要的話題,這涉及到廣泛的大模型下游應(yīng)用,比如檢索增強(qiáng)生成(Retrieval-Augmented Generation)、上下文學(xué)習(xí)(In-Context Learning)受到廣泛關(guān)注。
KV cache及KV cache的壓縮能否有效幫助長(zhǎng)文本實(shí)現(xiàn)推理加速成為廣受關(guān)注的研究方向。
采用均一壓縮策略,是最佳方案嗎?
傳統(tǒng)壓縮方法的一個(gè)共同特點(diǎn)是,均對(duì)每個(gè)Transformer層使用同樣的KV cache壓縮設(shè)置,使用同樣的方法壓縮到同樣的長(zhǎng)度。
圖片
但PyramidKV團(tuán)隊(duì)發(fā)現(xiàn),對(duì)KV cache進(jìn)行極致壓縮情況下上述方法的表現(xiàn),發(fā)現(xiàn)當(dāng)超長(zhǎng)文本壓縮到極致小的KV大小時(shí)(從32k 長(zhǎng)度壓縮到64,即保留0.2%的KV cache長(zhǎng)度)時(shí),會(huì)面臨嚴(yán)重的性能減弱。
于是作者提出了疑問(wèn):對(duì)每個(gè)Transformer層將KV cache壓縮到同樣的大小是否為最優(yōu)方案?
為了回答上述問(wèn)題,研究團(tuán)隊(duì)對(duì)大模型進(jìn)行檢索增強(qiáng)生成的機(jī)制進(jìn)行深入分析。
作者研究了Llama模型進(jìn)行多文檔問(wèn)答的逐層注意力圖,發(fā)現(xiàn)了注意力層中的金字塔形信息匯聚模式(Pyramidal Information Funneling)的存在:
- 在模型的低層(例如第0層)中,注意力得分呈現(xiàn)近似均勻分布,這表明模型在較低層時(shí)從所有可用內(nèi)容中全局聚合信息,而不會(huì)優(yōu)先關(guān)注特定的段落。
- 當(dāng)編碼信息進(jìn)行到中間層(6-18)時(shí),逐漸轉(zhuǎn)變?yōu)榫劢乖诙温鋬?nèi)部的注意力模式 (Localized Attention)。在這個(gè)階段,注意力主要集中在同一文檔內(nèi)的Token上,表明模型在單個(gè)段落內(nèi)進(jìn)行了段落內(nèi)部的信息聚合。
- 這種趨勢(shì)在上層(24-30)繼續(xù)并加強(qiáng),本文觀察到了“Attention Sink”和“Massive Activation”現(xiàn)象。
在這些層中,注意力機(jī)制極大地集中在少數(shù)幾個(gè)關(guān)鍵Token上,因此只需要保留這些關(guān)鍵Token就能讓輸出保持一致并且減少顯存占用。
圖片
這種注意力分配模式,即極高的注意力得分,表明模型已將信息聚合到這些關(guān)鍵標(biāo)記中。
這種注意力現(xiàn)象顯示了大模型對(duì)大量復(fù)雜的信息的進(jìn)行編碼的機(jī)制,最終得到生成準(zhǔn)確答案所需的最關(guān)鍵信息。
根據(jù)以上的發(fā)現(xiàn),作者認(rèn)為之前的工作對(duì)所有Transformer層統(tǒng)一處理是低效的,因此不同Transformer層的注意力稀疏程度并不相同。在低層能觀察到特別稠密的注意力,而在較高層則可以觀察到非常稀疏的注意力。
因此,在不同層之間使用固定的 KV 緩存大小可能會(huì)導(dǎo)致性能不佳。這些方法可能在較高層的稀疏注意力中保留許多不重要的 tokens,而忽略了較低層密集注意力中的許多重要的 tokens。
每層注意力特點(diǎn)不同,分層施策才是正解
于是,作者選擇了通過(guò)基于注意力模式動(dòng)態(tài)分配緩存預(yù)算來(lái)提高壓縮效率。
具體而言,PyramidKV在信息更加分散的較低層分配更多的KV cache緩存,而在信息集中于少數(shù)關(guān)鍵tokens的較高層減少KV cache緩存。
一旦為每一層確定了KV緩存預(yù)算,PyramidKV在每一個(gè)Transformer層中選擇根據(jù)注意力選擇要緩存的KV。
最后的部分Token的KV緩存,即Instruction Token,會(huì)在所有Transformer層中保留。
根據(jù)UIUC、普林斯頓等提出的SnapKV方法,剩余的KV的選擇由從這些Instruction Token中獲得的對(duì)其他的Token注意力分?jǐn)?shù)來(lái)指導(dǎo)——
接收到更高注意力分?jǐn)?shù)的Token被認(rèn)為與生成過(guò)程更相關(guān),因此其KV狀態(tài)優(yōu)先保存在GPU緩存中。
圖片
2.5%的KV cache,保持90%模型性能
為了評(píng)估PyramidKV的表現(xiàn),作者使用最新的開源大模型Llama-3-8B-Instruct和Mistral-7B-Instruct,來(lái)對(duì)PyramidKV和其他方法進(jìn)行對(duì)比。
測(cè)試示例以生成格式進(jìn)行評(píng)估,所有任務(wù)的答案均通過(guò)貪婪解碼生成,并使用 LongBench來(lái)評(píng)估PyramidKV在處理長(zhǎng)上下文輸入任務(wù)中的表現(xiàn)。
LongBench是一個(gè)精心設(shè)計(jì)的基準(zhǔn)測(cè)試套件,用于測(cè)試語(yǔ)言模型處理長(zhǎng)文檔和復(fù)雜信息序列的能力。
該基準(zhǔn)測(cè)試旨在對(duì)長(zhǎng)上下文輸入進(jìn)行多任務(wù)評(píng)估,包括17個(gè)數(shù)據(jù)集,涵蓋單文檔問(wèn)答、多文檔問(wèn)答、摘要生成、少樣本學(xué)習(xí)、合成數(shù)據(jù)和代碼生成等任務(wù)。
數(shù)據(jù)集的平均輸入長(zhǎng)度從1235個(gè)到18409個(gè)tokens不等,需要大量的內(nèi)存來(lái)管理KV緩存。
對(duì)于所有這些任務(wù),作者都遵循 LongBench推薦的標(biāo)準(zhǔn)指標(biāo)。
結(jié)果,在64、96、128、256和512個(gè)KV cache緩存大小的設(shè)定下,PyramidKV在LongBench中均取得了優(yōu)于baseline的效果。
圖片
在此基礎(chǔ)上,作者還研究了兩種不同的操作場(chǎng)景——節(jié)省內(nèi)存場(chǎng)景(Memory-Efficient Scenario)和保持性能場(chǎng)景(Performance-Preserving Scenario),分別用于在內(nèi)存和模型性能之間進(jìn)行權(quán)衡。
PyramidKV在Longbench的多個(gè)任務(wù)和平均得分上均取得了優(yōu)于baseline的效果。
值得注意的是,PyramidKV在size為128的設(shè)定下,在TREC任務(wù)(上下文學(xué)習(xí)問(wèn)答挑戰(zhàn))中表現(xiàn)出顯著優(yōu)越的性能,相較于baseline,提高了20.的ACC結(jié)果。
圖片
總體而言,PyramidKV僅用12%的KV緩存就能保持完整的性能,并且在各種KV緩存大小的設(shè)定下和不同主干模型中始終優(yōu)于其他方法,特別是在僅保留約128(0.7%)KV cache緩存的節(jié)省內(nèi)存場(chǎng)景中,其性能優(yōu)勢(shì)尤為明顯。
在具體任務(wù)的檢查中,PyramidKV在TREC任務(wù)(上下文學(xué)習(xí)問(wèn)答挑戰(zhàn))中表現(xiàn)出顯著優(yōu)越的性能,僅僅使用64的KV cache緩存大小(原始輸入是5k長(zhǎng)度)就能達(dá)到90%的性能。
這表明模型有效地聚合了樣本中的任務(wù)信息,突出了在上下文學(xué)習(xí)任務(wù)上進(jìn)一步研究的潛力。
下面的表則展示了PyramidKV使KV緩存的占用減少的情況。作者評(píng)估了Llama-3-8B-Instruct的內(nèi)存消耗。
具體來(lái)說(shuō),作者發(fā)現(xiàn)在固定批量大小為1、輸入長(zhǎng)度為8192、模型權(quán)重為fp16格式的情況下,PyramidKV在不同緩存大小下顯著減少了KV緩存的內(nèi)存,還一定程度上保留了任務(wù)性能。
圖片
為了進(jìn)一步理解PyramidKV在LongBench上的性能,作者還進(jìn)行了“大海撈針”實(shí)驗(yàn),將PyramidKV與SnapKV進(jìn)行比較,并且對(duì)比128大小的KV緩存和完整的KV緩存。
在輸入序列長(zhǎng)度在2000到4000之間的中等上下文情況下,SnapKV在“大海撈針”測(cè)試中產(chǎn)生了越來(lái)越多的錯(cuò)誤案例。
在輸入序列長(zhǎng)度超過(guò)6000的長(zhǎng)上下文情況下,SnapKV顯著降低了LLMs在評(píng)估中的性能。
相比之下,PyramidKV在大多數(shù)情況下減輕了這種弱化效應(yīng)。下圖展示了定量結(jié)果。分?jǐn)?shù)越高、顏色越淺,表示著檢索能力越強(qiáng)。
在該任務(wù)的平均得分中,完整KV得分為65.0,PyramidKV得分為62.6,而SnapKV得分為57.3。
圖片
此外,作者的實(shí)驗(yàn)表明,PyramidKV在上下文學(xué)習(xí)(In-Context Learning)的少樣本學(xué)習(xí)任務(wù)中顯著優(yōu)于其他方法。
這表明KV cache緩存壓縮在上下文學(xué)習(xí)中的應(yīng)用前景廣闊,這種方法有可能在受限的內(nèi)存條件下實(shí)現(xiàn)更多樣本的引入。
論文地址:https://arxiv.org/abs/2406.02069項(xiàng)目主頁(yè):
https://zefan-cai.github.io/PyramidKV.github.io/