MiniCache 和 PyramidInfer 等 6 種優(yōu)化 LLM KV Cache 的最新工作
一、背景
在 LLM 推理中,常常會(huì)采用 KV Cache 來(lái)緩存之前 Token 的中間結(jié)果,以顯著減少重復(fù)計(jì)算,從而降低自回歸生成中的延遲。然而,KV Cache 的大小與序列長(zhǎng)度成正比,在處理長(zhǎng)序列時(shí)會(huì)面臨極大的挑戰(zhàn)。尤其當(dāng)前許多模型開始支持幾百 K 甚至幾 M 的序列長(zhǎng)度,進(jìn)一步凸顯了 KV Cache 的問(wèn)題,因此很多研究工作致力于降低 KV Cache 的占用。
本文中簡(jiǎn)單介紹幾個(gè)最新的工作,包括 SnapKV、YOCO、CLA、Layer-Condensed KV Cache、MiniCache 以及 PyramidInfer,它們都試圖降低緩解 KV Cache 的壓力。關(guān)于 GQA、MQA、DeepSeek MLA 以及量化相關(guān)的工作我們已經(jīng)在之前進(jìn)行了介紹,這里不再贅述。
二、KV Cache 大小
KV Cache 的大小與模型配置(層數(shù),hidden_size,Attention head 個(gè)數(shù)等)以及序列長(zhǎng)度、Batch Size 成正比。其中單個(gè) Token 對(duì)應(yīng)的 KV Cache 大小與模型配置相關(guān),并且是固定的,這里將其稱為單位 KV Cache 計(jì)算公式為:
sum_token = (hidden_size / num_attention_heads * num_key_value_heads) * num_hidden_layers * 2(k, v)
而總的 KV Cache 大小為:
sum = sum_token * seq_len * batch_size
batch_size 和 seq_len 越大,KV Cache 越大,如下圖所示為 LLaMA2-7B 模型的 batch_size 和 seq_len 對(duì)應(yīng)的 KV Cache 大?。J(rèn) FP16 精度):
- 當(dāng) batch_size * seq_len 為32K時(shí),比如 batch_size 為 1,seq_len 為 32K,其 KV Cache 大小為16GB,甚至超過(guò)模型權(quán)重大小 14GB。
- 當(dāng) batch_size * seq_len 為128K時(shí),比如 batch_size 為 1,seq_len 為 128K,其 KV Cache 大小為 64GB,加上模型權(quán)重 14GB 甚至快要超過(guò) A100 GPU 的 80GB 顯存限制。?
三、SnapKV
[2404.14469] SnapKV: LLM Knows What You are Looking for Before Generation 的核心思路比較簡(jiǎn)單,如下圖 Figure 1 所示,在 Prefill 階段不是保留所有輸入 Token 的 KV Cache,而是采用稀疏化的方式,針對(duì)每個(gè) Attention Head 將 Prompt 分為 Prefix 和 Window 兩部分;然后,通過(guò) Window 中 Token 與 Prefix 中 Token 的 Attention Score 來(lái)選擇稀疏化的 Token;最后,將它們的 KV Cache 和 Window 中 Token 的 KV Cache 一起作為 Prompt 的 KV Cache。需要說(shuō)明的是:每個(gè) Attention Head 中從 Prefix 里挑選的 Token 可能不同。此外,Decoding 階段也不會(huì)再更新 Prompt 的 KV Cache。
SnapKV 在處理 16K Token 的輸入時(shí),可以獲得 3.6x 的加速,內(nèi)存效率提升 8.2x。同時(shí)在 16 個(gè)長(zhǎng)序列數(shù)據(jù)集上保持了與基線模型相當(dāng)?shù)木?。此外,使?Huggingface 可以在單個(gè) A100-80GB GPU 上處理 380K 上下文 Token 的任務(wù)。
四、YOCO
在 [2405.05254] You Only Cache Once: Decoder-Decoder Architectures for Language Models 中,作者只保留一層全局的 KV Cache。這種設(shè)計(jì)可以大大降低 GPU 顯存的需求,加快 Prefill 階段。如下圖所示,YOCO 模型與常規(guī) Decoder-Only LLM 的區(qū)別有幾點(diǎn):
- 前 L/2 層(Self-Decoder)使用Efficient Self-Attention,實(shí)際上就是滑動(dòng)窗口 Self-Attention或作者之前論文提出的Multi-Scale Retention。其只用保存窗口內(nèi)的 KV Cache 即可。
- 第 L/2 層的 KV Cache 作為Global KV Cache。也就是只有一層有全局 KV Cache。
- 后 L/2 層(Cross-Decoder)使用Global Cross Attention,對(duì)應(yīng)的 KV 為上一步的 Global KV Cache,也就是后續(xù)所有 L/2 層的 Cross Attention 的 KV Cache 都是相同的。?
五、CLA
[2405.12981] Reducing Transformer Key-Value Cache Size with Cross-Layer Attention 中作者同樣采用 Cross-Attention 機(jī)制來(lái)降低 KV Cache。不同的是作者并非采用固定層作為 Cross-Attention 的輸入,而是采用相鄰層,如下圖左圖所示。最簡(jiǎn)單的方式就是隔層共享,稱作 CLA2,實(shí)際也可以每 3 層共享,稱作 CLA3,如下圖右圖所示。此外,這種方法與 MQA 和 GQA 等修改 Attention Head 的方案是兼容的。CLA2 顯存減小 2x,CLA3 顯存減小 3x。
作者訓(xùn)練 1B 和 3B 參數(shù)模型模型實(shí)驗(yàn)表明,CLA 相比傳統(tǒng)的 MQA 在顯存占用、準(zhǔn)確性方面可以實(shí)現(xiàn)帕累托改進(jìn),從而實(shí)現(xiàn)更長(zhǎng)的序列長(zhǎng)度和更大的 Batch Size。(PS:但并不意味著可以優(yōu)于現(xiàn)在廣泛采用的 GQA?)
六、Layer-Condensed KV Cache
在 [2405.10637] Layer-Condensed KV Cache for Efficient Inference of Large Language Models 中,作者同樣采用了僅計(jì)算和緩存少量層 KV Cache 的方案,從而顯著節(jié)約顯存消耗并提升吞吐量。如下圖 Figure 1 所示,僅保留最后一個(gè) Transfomer Block 層的 KV Cache,當(dāng)生成后續(xù) Token 時(shí)其對(duì)應(yīng)的 KV Cache 都從最后一層取。
七、MiniCache
在 [2405.14366] MiniCache: KV Cache Compression in Depth Dimension for Large Language Models 中,作者觀察到 KV Cache 在 LLM 中的深層部分的相鄰層之間表現(xiàn)出了高度相似性,可以基于這些相似性對(duì) KV Cache 進(jìn)行壓縮。此外,作者還引入了 Token 保留策略,對(duì)高度不同的 KV Cache 不進(jìn)行合并。并且這種方法可以與其他的 KV Cache 量化方案正交使用。
作者在 LLaMA-2、LLaMA-3、Phi-3、Mistral 和 Mixtral 等模型上進(jìn)行實(shí)驗(yàn),在 ShareGPT 數(shù)據(jù)集上,采用 4 Bit MiniCache LLaMA–7B 與 FP16 全量 KV Cache 相比實(shí)現(xiàn)了 5.02x 的壓縮比,推理吞吐提高約 5 倍,顯存占用減少 41%,同時(shí)性能幾乎無(wú)損。
如下圖 Figure 3 所示為其壓縮策略和保留策略:
如下圖 Figure A 所示為其詳細(xì)的執(zhí)行流程:
- 1. 獲取 KV Cache:在 Prefill 階段,逐層生成 KV Cache。
- 2. 跨層合并:當(dāng)?shù)竭_(dá)合并開始層 S 時(shí),將當(dāng)前層 L 的 KV Cache 與前一層 L-1 的 KV Cache 進(jìn)行合并,以減少冗余。
- 3. 緩存:將合并后的 KV Cache 存儲(chǔ)起來(lái),以便將來(lái)使用。
- 4. 刪除:在 Decoding 階段,刪除不必要的或冗余的 KV Cache,以優(yōu)化內(nèi)存使用。
- 5. 加載和生成:獲取所需的 KV Cache,用于生成輸出。
- 6. 恢復(fù):對(duì)獲取的 KV Cache 應(yīng)用誤差抑制機(jī)制,包括 rescaling 和 retention recovery,以最小化合并和壓縮過(guò)程中引入的誤差。
- 7. 更新:在恢復(fù)階段后,使用最終的 KV Cache 更新共享的 KV Cache。
八、PyramidInfer
在 [2405.12532] PyramidInfer: Pyramid KV Cache Compression for High-throughput LLM Inference 中,作者發(fā)現(xiàn)影響未來(lái)生成的關(guān)鍵 KV 的數(shù)量逐層減少,并且可以通過(guò)注意力權(quán)重的一致性來(lái)提取這些關(guān)鍵 KV?;谶@些發(fā)現(xiàn),作者提出了 PyramidInfer,通過(guò)逐層保留關(guān)鍵上下文來(lái)壓縮 KV Cache。PyramidInfer 在不犧牲性能的情況下計(jì)算更少的 KV,并節(jié)約大量顯存。實(shí)驗(yàn)結(jié)果表明,與 Accelerate 相比,PyramidInfer 的吞吐提高了 2.2 倍,KV Cache 的顯存占用減少了 54% 以上。
如下圖 Figure 2 所示為 PyramidInfer 與 StreamingLLM 和 H2O 的區(qū)別,PyramidInfer 中 KV Cache 會(huì)逐層遞減,越往后越稀疏(PS:如果是這樣,那么 Layer-Condensed KV Cache 中只保留最后一層的方案是不是不太合理):
PyramidInfer 的執(zhí)行過(guò)程如下圖 Figure 6 所示:
- 在 Prefill 階段,PyramidInfer 只保留每層的關(guān)鍵上下文(Pivotal Context, PvC)來(lái)壓縮 KV Cache。
- 在 Decoding 階段,PyramidInfer 根據(jù)新的最近的 Token 來(lái)更新 PvC。?
如下圖 Table 1 所示,PyramidInfer 在使用更少 KV Cache 的情況下獲得更快的推理速度:
如下圖 Figure 11 所示,作者進(jìn)一步測(cè)試了 PyramidInfer 在更多 Batch Size 下的表現(xiàn),其在比較小 Batch Size 時(shí)幾乎沒(méi)有加速,主要是因?yàn)闇p少 KV Cache 還需要一些額外的計(jì)算;而在比較大的 Batch Size 能獲得更大的加速比。而 Full Cache 當(dāng) Batch Size 大于 32 吞吐反而降低:(PS:這個(gè)降低不太符合預(yù)期,通常來(lái)說(shuō)隨著 Batch Size 的增加,計(jì)算密度會(huì)更高,相應(yīng)的吞吐也應(yīng)該更高,而且在 32 左右還遠(yuǎn)沒(méi)有到 Compute Bound)。
九、參考鏈接
- ??https://arxiv.org/abs/2404.14469??
- ??https://arxiv.org/abs/2405.05254??
- ??https://arxiv.org/abs/2405.12981??
- ??https://arxiv.org/abs/2405.10637??
- ??https://arxiv.org/abs/2405.14366??
- ??https://arxiv.org/abs/2405.12532??
本文轉(zhuǎn)載自 ??AI閑談??,作者: AI閑談
