單卡A100實(shí)現(xiàn)百萬(wàn)token推理,速度快10倍,這是微軟官方的大模型推理加速
大型語(yǔ)言模型 (LLM) 已進(jìn)入長(zhǎng)上下文處理時(shí)代,其支持的上下文窗口從先前的 128K 猛增到 10M token 級(jí)別。
然而,由于注意力機(jī)制的二次復(fù)雜度,模型處理輸入提示(即預(yù)填充階段)并開(kāi)始產(chǎn)生第一個(gè) token 可能需要幾分鐘時(shí)間。導(dǎo)致首個(gè) token 生成的時(shí)間過(guò)長(zhǎng),從而嚴(yán)重影響了用戶(hù)體驗(yàn),這也極大地限制了長(zhǎng)上下文 LLM 的廣泛應(yīng)用。
舉例來(lái)說(shuō)(如圖 2a 所示),在單臺(tái)裝有 A100 的機(jī)器上為 LLaMA-3-8B 提供服務(wù)時(shí),如果提示有 30 萬(wàn)個(gè) token,模型需要 6 分鐘才能完成預(yù)填充( pre-filling)階段,如果提示增加到 100 萬(wàn)個(gè) token,這個(gè)數(shù)字將增加到 30 分鐘。
自注意力計(jì)算的開(kāi)銷(xiāo)占到了總預(yù)填充延遲的 90% 以上,這使其成為 LLM 處理長(zhǎng)上下文時(shí)的主要瓶頸?,F(xiàn)有的加速預(yù)填充方法在應(yīng)用于長(zhǎng)上下文 LLM 時(shí)通常無(wú)法保持可接受的準(zhǔn)確性或效率。
為了解決上述問(wèn)題,來(lái)自微軟、薩里大學(xué)的研究者提出了一種旨在加速長(zhǎng)序列處理預(yù)填充的稀疏計(jì)算方法:MInference( Milliontokens Inference )。
- 論文地址:https://arxiv.org/pdf/2407.02490
- 論文主頁(yè):https://hqjiang.com/minference.html
- 論文標(biāo)題:MInference 1.0: Accelerating Pre-filling for Long-Context LLMs via Dynamic Sparse Attention
MInference 可以直接應(yīng)用于現(xiàn)有 LLM,無(wú)需對(duì)預(yù)訓(xùn)練設(shè)置進(jìn)行修改或額外的微調(diào)。
通過(guò)對(duì)各種下游任務(wù)(包括 InfiniteBench、RULER、PG-19 和 Needle In A Haystack)以及模型(包括 LLaMA-3-1M、Yi-200K、GLM-4-1M、Phi-3-128K 和 Qwen2-128K)進(jìn)行評(píng)估,實(shí)驗(yàn)證明 MInference 可有效將 A100 上的預(yù)填充推理延遲降低多達(dá) 10 倍,同時(shí)保持準(zhǔn)確性。
使用 MInference 1.0 ,長(zhǎng)上下文 LLM(如 LLaMA-3-8B-1M、GLM-4-1M)在單個(gè) A100 上的推理速度實(shí)現(xiàn)了 10 倍提升,并且準(zhǔn)確度更高。
方法介紹
作者提出了 MInference,這個(gè)名字反映了他們希望在一臺(tái) A100 機(jī)器上實(shí)現(xiàn)百萬(wàn)(million)token 推理的雄心。
MInference 是一種無(wú)需訓(xùn)練的高效方法,用于基于動(dòng)態(tài)稀疏注意力的長(zhǎng)上下文 LLM 的預(yù)填充階段。
研究者認(rèn)為注意力,特別是在長(zhǎng)上下文中,是稀疏和動(dòng)態(tài)的,即在不同的輸入中,稀疏模式有很大的不同。這種動(dòng)態(tài)稀疏性呈現(xiàn)出三種適用于所有輸入的獨(dú)特空間聚合模式:A 形(A-shape)、垂直 - 斜線(Vertical-Slash)和塊狀 - 稀疏(Block-Sparse)。
MInference 首先使用內(nèi)核感知稀疏模式搜索算法為每個(gè)頭部離線確定最佳動(dòng)態(tài)稀疏模式,如算法 1 所示。在推理過(guò)程中,它會(huì)根據(jù)頭部的模式動(dòng)態(tài)逼近動(dòng)態(tài)稀疏指數(shù),如算法 2、3 所示。最后,作者使用優(yōu)化后的 GPU 內(nèi)核執(zhí)行高效的動(dòng)態(tài)稀疏注意力計(jì)算,大大減少了長(zhǎng)上下文 LLM 的預(yù)填充階段延遲。
例如,對(duì)于「垂直 - 斜線」模式,作者首先利用最后一個(gè) Q 和 K 之間的注意力計(jì)算來(lái)估計(jì)垂直線和斜線的最佳指數(shù)。然后,他們利用動(dòng)態(tài)稀疏編譯器 PIT 和 Triton 構(gòu)建垂直 - 斜線 FlashAttention 內(nèi)核,加速注意力計(jì)算。對(duì)于 A 形、垂直 - 斜線和塊狀 - 稀疏模式,作者首先在注意力計(jì)算中使用 Q 和 K 的均值池。利用均值池和 MatMul 的交換屬性,可以估算出塊狀 - 稀疏指數(shù)。然后,他們使用 Triton 構(gòu)建塊稀疏 FlashAttention 內(nèi)核,加速注意力計(jì)算。有關(guān)內(nèi)核的詳細(xì)實(shí)現(xiàn),請(qǐng)參閱附錄 C.4 和代碼。
在長(zhǎng)上下文基準(zhǔn)中的評(píng)估結(jié)果
作者在一系列場(chǎng)景中測(cè)試了 MInference,包括 QA、編碼、基于檢索的任務(wù)、multi-hop QA、總結(jié)和數(shù)學(xué)任務(wù)。RULER 基準(zhǔn)包括幾個(gè)復(fù)雜的 multi-hop 或 multi-needle 任務(wù),有效地反映了 LLM 的實(shí)際上下文窗口大小。如表 1 所示,MInference 有效地保留了 LLM 的實(shí)際上下文窗口處理能力,甚至將實(shí)際上下文窗口大小略微擴(kuò)展到 32K。
作者還使用平均 token 長(zhǎng)度為 214K 的 InfiniteBench 在更廣泛的任務(wù)中測(cè)試了 MInference,如表 2 所示。與 SoTA 基線相比,MInference 在所有任務(wù)中都始終保持了良好的性能。值得注意的是,在更具挑戰(zhàn)性的檢索任務(wù)(如 KV 檢索任務(wù))中,所有基線都無(wú)法做出準(zhǔn)確預(yù)測(cè),準(zhǔn)確率低于 1.2%。但是,MInference 成功地保留了處理動(dòng)態(tài) KV 對(duì)檢索的能力。
為了進(jìn)一步評(píng)估不同上下文長(zhǎng)度和關(guān)鍵信息在提示中不同位置時(shí)的性能,作者使用「大海撈針」任務(wù)測(cè)試了各種模型和方法。如圖 1 所示,MInference 在不同的模型、上下文窗口和提示信息位置下都表現(xiàn)良好,與原始模型相比,其性能保持不變甚至略有提高。在 LLaMA-3-8B 和 GLM-4-9B-1M 的情況下,MInference 在高達(dá) 1M 的上下文窗口中實(shí)現(xiàn)了完全綠色的性能。相比之下,即使在 70K 上下文窗口中,StreamingLLM 和 InfLLM 在提示的中間段性能也會(huì)下降到 20% 以下。
作者還使用 PG-19 在語(yǔ)言模型任務(wù)中測(cè)試了 MInference,其中包括多達(dá) 100k 的 token。如圖 2 所示,MInference 有效地保持了 LLaMA-3-8B 和 Yi-9B-200K 的困惑度,而所有基線都出現(xiàn)了不同程度的困惑度下降。此外,與標(biāo)準(zhǔn)的 StreamingLLM 相比,使用膨脹和步長(zhǎng)配置的 StreamingLLM 更好地保持了困惑度性能。
延遲和內(nèi)核中的稀疏模式
圖 3 展示了本文提出的三種注意力模式以及 FlashAttention 的微基準(zhǔn)測(cè)試結(jié)果??梢钥闯?,Vertical-Slash 是三種模式中最慢的,但在 1M 上下文窗口下,相比 FlashAttention 仍然實(shí)現(xiàn)了 13 倍的加速。
圖 4 展示了 Vertical-Slash 頭部?jī)?nèi)核中的稀疏索引。垂直線通過(guò) PIT FlashAttention 使用 1x64 塊計(jì)算,而斜線通過(guò)塊級(jí) FlashAttention 使用 64x64 塊計(jì)算。