微軟 RetrievalAttention: LLM+ANN, LLM 推理速度與精度的平衡
一、背景
本文我們繼續(xù)介紹一個(gè)針對(duì)超長(zhǎng)上下文的 LLM 推理加速工作,同樣是 Token 稀疏化的方案,來(lái)解決 LLM 在超長(zhǎng)序列場(chǎng)景計(jì)算量大、GPU 顯存消耗大的問(wèn)題,不過(guò)結(jié)合了 ANN 檢索,可以實(shí)現(xiàn)更高的精度。
對(duì)應(yīng)的論文為:[2409.10516] RetrievalAttention: Accelerating Long-Context LLM Inference via Vector Retrieval
二、摘要
本文中作者提出了 RetrievalAttention,無(wú)需訓(xùn)練就可以加速 Attention 計(jì)算。為了利用 Attention 的動(dòng)態(tài)稀疏特性,RetrievalAttention 在 CPU 內(nèi)存中使用 KV Cache 構(gòu)建近似檢索(ANN)索引,并在生成過(guò)程中通過(guò)向量檢索識(shí)別最相關(guān)的索引。由于 Query 向量和 Key 向量之間存在 Out-Of-Distribution(OOD)問(wèn)題,現(xiàn)成的 ANN 檢索仍然需要掃描 O(N) 數(shù)據(jù)(通常占所以 Key 的 30%)進(jìn)行準(zhǔn)確檢索,無(wú)法利用高稀疏性。
為了解決這個(gè)挑戰(zhàn),RetrievalAttention 采用注意力感知向量檢索算法,可以調(diào)整 Query 只訪問(wèn) 1-3% 的數(shù)據(jù),從而實(shí)現(xiàn)亞線性時(shí)間復(fù)雜度。RetrievalAttention 大幅降低了長(zhǎng)上下文 LLM 推理的成本,大幅降低 GPU 顯存需求,同時(shí)保持模型準(zhǔn)確性。特別的,RetrievalAttention 只需要 16GB 內(nèi)存就可以在具有 8B 參數(shù)的 LLM 上支持 128K Token的推理,在單個(gè) RTX4090(24GB)上可以在 0.188s 內(nèi)生成一個(gè) Token。
如下圖 Figure 1 所示為本文方法與幾種常見(jiàn)方案的對(duì)比(PS:可以看出,本文方案相比之前 Token 稀疏化方案,是在犧牲一定推理速度的情況下提升精度):
三、方法
3.1 背景
使用 ANN 來(lái)識(shí)別關(guān)鍵 Token 有個(gè)獨(dú)特的挑戰(zhàn):當(dāng)前大部分的 ANN 引擎都假設(shè) Query 向量和 Key 向量滿(mǎn)足相同的分布,以此來(lái)實(shí)現(xiàn)高召回率。作者在這篇論文中首次提出這種假設(shè)在 Attention 機(jī)制中不成立。Query 的這種 OOD 特性損壞了 ANN 的預(yù)期檢索質(zhì)量,從而導(dǎo)致不得不訪問(wèn)更多的數(shù)據(jù)來(lái)保持正確性,作者實(shí)驗(yàn)表明,為了維持可接受的準(zhǔn)確率,至少需要掃描 30% 的 Key 向量。
如下圖 Figure 2 所示:
- (a)Attention Score 具有非常高的稀疏性,64000 個(gè) Token,只有不到 500 Token 的 Score 大于 10-6。
- (b)Q 和 Q 或者 K 和 K 的相關(guān)性很高,而 Q 和 K 的相關(guān)性很差,需要掃描 30% 左右的 Token 才能保證 0.8 左右的召回率。
- (c)同樣說(shuō)明了 Q 和 K 的距離比較遠(yuǎn)。?
3.2 概覽
本文的工作主要聚焦于 Token Decoding 階段,會(huì)假設(shè) Prefill 階段已經(jīng)執(zhí)行完成,比如通過(guò) Context Caching 方案或 Prefill 和 Decoding 分離方案。
如下圖 Figure 3(a)所示為本文方案 RetrievalAttention 的概覽,其利用 CPU 側(cè)的 ANN 檢索來(lái)實(shí)現(xiàn)近似 Attention 計(jì)算,為了支持長(zhǎng)序列,也會(huì)將所有 KV Cache Offload 到 CPU 內(nèi)存以便構(gòu)建索引。如圖(b)所示是為了解決 OOD 問(wèn)題而采用的索引機(jī)制。
3.3 近似 Attention
具體來(lái)說(shuō),不使用完整的 Attention Score,而是采用最相關(guān)的 KV 向量來(lái)近似 Attention Score:
3.4 Attention 感知向量檢索
對(duì)于每對(duì) Key 和 Value,首先確定是放在 CPU Memory 還是 GPU Memory(方法見(jiàn)下一小節(jié))。然后 Offload 到 CPU 內(nèi)存的 Key 和 Value 會(huì)使用 Key 來(lái)構(gòu)建索引,并使用 Query 來(lái)檢索。
為了加速 Token 生成過(guò)程中的向量檢索速度,RetrievalAttention 利用 Prefill 階段的現(xiàn)有 Query 來(lái)指導(dǎo) Key 向量的索引構(gòu)建。如上圖 Figure 3(b)所示,RetrievalAttention 顯式的建立從 Query 向量到其最近的 Key 向量的連接(即精確的 K 個(gè)最近鄰,或 KNN)。KNN 結(jié)果可以通過(guò) GPU 高效計(jì)算,形成從 Query 向量分布到 Key 向量分布的映射。使用這種結(jié)構(gòu),Decoding 的 Query 向量查詢(xún)時(shí)可以首先查詢(xún)最近的 Query 向量,然后將其映射為 Key 向量。
因此,之前的 Query 向量充當(dāng)了解決 OOD 問(wèn)題的橋梁。然而,這種結(jié)構(gòu)在內(nèi)存開(kāi)銷(xiāo)和搜索效率方面仍然存在缺陷,因?yàn)槌?Key 向量之外,還需要存儲(chǔ)和訪問(wèn) Query 向量。為了解決這個(gè)問(wèn)題,作者利用先進(jìn)的跨模態(tài) ANN 索引 RoarGraph 中的投影技術(shù)來(lái)消除 Query 向量。具體來(lái)說(shuō),通過(guò)使用 Query 向量和 Key 向量的連接關(guān)系,將 KNN 連接投影到 Key 向量中,從而有效地簡(jiǎn)化搜索。此外,此方法也允許對(duì)未來(lái)的 Query 向量進(jìn)行高效的索引遍歷。
作者實(shí)驗(yàn)結(jié)果表明,通過(guò)這種 Query 和 Key 的連接關(guān)系進(jìn)行有效建模,向量數(shù)據(jù)庫(kù)只需掃描 1-3% 的 Key 向量即可達(dá)到高召回率,與 IVF 索引相比,索引搜索延遲大幅降低 74%。
3.5 CPU 和 GPU 協(xié)同執(zhí)行
為了利用 GPU 并行性加速注意力計(jì)算,RetrievalAttention 將注意力計(jì)算分解為兩組不相交的 KV Cache 向量:GPU 上的可預(yù)測(cè)向量和 CPU 上的動(dòng)態(tài)向量,然后將兩部分 Attention 輸出合并在一起作為完整的 Attention 輸出。
具體來(lái)說(shuō),利用 Prefill 階段觀察到的模式來(lái)預(yù)測(cè) Token 生成過(guò)程中持續(xù)激活的 KV 向量。與 StreamingLLM 類(lèi)似,作者將固定的幾個(gè)初始 Token 和最近窗口內(nèi)的 Token 作為靜態(tài) Token,持久化在 GPU 上。RetrievalAttention 也可以適配更復(fù)雜的靜態(tài)模式,以便實(shí)現(xiàn)低推理成本和高準(zhǔn)確性的平衡。為了最大限度減少通過(guò)慢速 PCIe 的數(shù)據(jù)傳輸,RetrievalAttention 在 CPU 和 GPU 上獨(dú)立計(jì)算 Attention,然后將其組合起來(lái),這個(gè)靈感來(lái)自 FastAttention([2205.14135] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness)。
四、評(píng)估
4.1 實(shí)驗(yàn)配置
機(jī)器包含三個(gè):
- RTX 4090 GPU(24G 顯存),Intel i9-10900X CPU(20 Core),128 GB 內(nèi)存。
- A100 GPU(80GB 顯存),AMD EPYC CPU(24 Core)。
- A100 GPU(80GB 顯存),AMD EPYC 7V12 CPU(48 Core),1.72TB 內(nèi)存。
模型包含三個(gè):
- LLaMA-3-8B-Instruct-262K
- Yi-6B-200K
- Yi-9B-200K
對(duì)比框架包括:
- Full Attention 的 vLLM
- StreamingLLM
- SnapKV
- InfLLM
基準(zhǔn)測(cè)試包括:
- ∞-Bench
- RULER
- Needle-in-a-haystack
4.2 長(zhǎng)文本任務(wù)精度
如下圖 Table 2 所示,本文提出的 RetrievalAttention 明顯優(yōu)于之前的方案,平均精度非常接近 Full Attention。當(dāng)然,部分模型上 Flat(暴露檢索索引數(shù)據(jù)) 會(huì)略好于 RetrievalAttention,不過(guò)差距不大。
4.3 時(shí)延評(píng)估
如下圖 Table 6 所示,作者首先驗(yàn)證了本文提出的檢索方式的有效性,可以看出,提出的 RetrievalAttention 相比 Flat 和 IVF 可以提供 4.9x 和 1.98x 的加速,證明了檢索機(jī)制的有效性:
如下圖 Table 4 所示,作者也與之前的其他稀疏化方案進(jìn)行對(duì)比,可以看出,之前的方案往往采用固定的 Token 數(shù),因此隨著序列變長(zhǎng)并沒(méi)有明顯增加時(shí)延,而本文的方法會(huì)略微增加。同時(shí),本文方法推理 Latency 相比之前方法明顯增加,大概是之前方法 Latency 的 3x-6x。然而其仍然明顯低于 Full Attention(FlexGen) 的結(jié)果。相當(dāng)于在效果和速度之間的折衷。
如下圖 Table 7 和 Table 8 為在 A100 上的結(jié)果,結(jié)論類(lèi)似,不過(guò)在 100K 和 200K 時(shí)其 Latency 會(huì)超過(guò) vLLM:
五、參考鏈接
本文轉(zhuǎn)載自 ??AI閑談??,作者: AI閑談
