清華稀疏Attention,無需訓(xùn)練加速一切模型!
在當(dāng)今各類大語言模型以及視頻模型中,長序列場景越來越普遍,而 Attention 的計算復(fù)雜度隨著序列長度呈平方增長,成為長序列任務(wù)下的主要計算瓶頸。此前,清華大學(xué)陳鍵飛團隊提出的即插即用量化的 SageAttention 系列工作已實現(xiàn) 3 倍加速于 FlashAttention,且在各類大模型上均保持了端到端的精度,已被業(yè)界和社區(qū)廣泛使用。為了進一步加速 Attention,清華大學(xué)陳鍵飛團隊進一步提出了無需訓(xùn)練可直接使用的稀疏 Attention(SpargeAttn)可用來加速任意模型。實現(xiàn)了 4-7 倍相比于 FlashAttention 的推理加速,且在語言,視頻、圖像生成等大模型上均保持了端到端的精度表現(xiàn)。
論文標(biāo)題:SpargeAttn: Accurate Sparse Attention Accelerating Any Model Inference
下圖展示了 SpargeAttn 的速度,可以發(fā)現(xiàn)在 RTX4090 上,SpargeAttn 在 60% 稀疏度的情況下可以達到 900TOPS 的速度,甚至是使用 A100 顯卡速度的 4.5 倍(A100 上 FlashAttention 只有 200TOPS)。
在 SpargeAttn 的 Github 倉庫中可以發(fā)現(xiàn),SpargeAttn 的使用方法比較簡潔,只需要進行一次簡單的超參數(shù)搜索過程,就可以永久地對任意的模型輸入進行推理加速。
接下來,將從前言,挑戰(zhàn),方法,以及實驗效果四個方面介紹 SpargeAttn。
前言
隨著大模型需要處理的序列長度越來越長,Attention 的速度優(yōu)化變得越來越重要。這是因為相比于網(wǎng)絡(luò)中其它操作的 O (N) 的時間復(fù)雜度,Attention 的時間復(fù)雜度是 O (N^2)。盡管 Attention 的計算復(fù)雜度為 O (N^2),但幸運的是 Attention 具備很好的稀疏性質(zhì),即 P 矩陣的很多值都接近 0。如何利用這種稀疏性來節(jié)省計算就成為了 attention 加速的一個重要方向。大多數(shù)現(xiàn)有的工作都集中在利用 P 矩陣在語言模型中表現(xiàn)出來的固定的稀疏形狀(如滑動窗口)來節(jié)省計算,或是需要重新訓(xùn)練模型,比如 DeepSeek 的 NSA 以及 Kimi 的 MoBA。此外,現(xiàn)有稀疏 Attention 通常需要較大的上下文窗口(如 64K~1M)才能有明顯加速。SpargeAttn 的目標(biāo)是開發(fā)一個無需訓(xùn)練、對各種模型(語言 / 視頻 / 圖像)通用、精度無損、對中等長度的上下文(如 4-32K)也有加速效果的注意力機制。
圖 1: 不同的模型表現(xiàn)出不同的稀疏形狀
實現(xiàn)通用的,無需訓(xùn)練的稀疏 Attenion 有哪些挑戰(zhàn)?
挑戰(zhàn) 1
通用性:Attention 雖然具備稀疏性質(zhì),但是其稀疏形狀在不同的模型甚至同一模型的不同層中都是不同的,體現(xiàn)出很強的動態(tài)性。如圖 1 所示,前兩種模型分別為視頻模型和圖像生成模型,這兩個模型中的 Attention 的稀疏形狀相比語言模型更加沒有規(guī)律。設(shè)計一種各種模型通用的稀疏 Attention 是困難的。
挑戰(zhàn) 2
可用性:對于各種 Attention 的輸入,很難同時實現(xiàn)準確且高效的稀疏 Attention。這是因為準確性要求了完全精確地預(yù)測 P 中的稀疏區(qū)域,高效性則要求了此預(yù)測的時間開銷極短。在一個極短的時間內(nèi)完全精準地預(yù)測 P 的稀疏形狀是困難的。
方法
為了解決上述的兩個挑戰(zhàn),研究團隊提出了對應(yīng)的解決辦法。
- 研究團隊提出了一種各模型通用的快速的對 P 矩陣稀疏部分進行預(yù)測的算法。該方法選擇性地對 Q, K 矩陣進行壓縮并預(yù)測 P 矩陣,接著使用 TopCdf 操作省略 P 中稀疏部分對應(yīng)的 QK^T 與 PV 的矩陣乘法。
- 研究團隊提出了在 GPU Warp 級別上的稀疏 Online Softmax 算法,該算法通過利用 Online Softmax 中全局最大值與局部最大值之間的差異,進一步省略了一些 PV 的矩陣乘法計算。
- 可選的,針對視頻和圖像模型,研究團隊充分利用圖像以及視頻中的 Token 局部相似性質(zhì),使用希爾伯特重排的方法對 Attention 前的 Token 進行重新排列,進一步提高稀疏度。
- 最后,研究團隊將這種稀疏方法與基于量化的 SageAttention 融合到一起,進一步加速 Attention。
圖 2: SpargeAttn 的算法流程圖
SpargeAttn 的算法流程如下所示:
實驗效果
總的來說,SpargeAttn 在視頻、圖像、文本生成等大模型均可以實現(xiàn)無需訓(xùn)練的加速效果,同時保證了各任務(wù)上的端到端的精度。
下表展示了 SpargeAttn 在各模型上的稀疏度,Attention 速度,以及各任務(wù)上的端到端精度,可以發(fā)現(xiàn) SpargeAttn 在保證了加速的同時沒有影響模型精度:(注:此論文中的所有實驗都是基于 SageAttention 實現(xiàn),目前 Github 倉庫中已有基于 SageAttention2 的實現(xiàn),進一步提供了 30% 的加速。
值得一提的是,此前的稀疏 Attention 工作很多無法實際使用的原因之一是稀疏預(yù)測部分的 Overhead 較大,而 SpargeAttn 團隊還將稀疏預(yù)測部分的代碼進行了極致優(yōu)化,將 Overhead 壓縮到了幾乎在各種長度的序列下都可以忽略的地步:
下表展示了對于各模型的端到端的加速效果,以視頻生成模型 Mochi 為例,SpargeAttn 提供了近兩倍的端到端加速效果:(注:此論文中的所有實驗都是基于 SageAttention 實現(xiàn),目前 Github 倉庫中已有基于 SageAttention2 的實現(xiàn),進一步提供了 30% 的加速)