又快又準(zhǔn),即插即用!清華8比特量化Attention,兩倍加速于FlashAttention2,各端到端任務(wù)均不掉點(diǎn)!
論文第一作者張金濤來自清華大學(xué)計(jì)算機(jī)系,論文通訊作者陳鍵飛副教授及其他合作作者均來自清華大學(xué)計(jì)算機(jī)系。
大模型中,線性層的低比特量化(例如 INT8, INT4)已經(jīng)逐步落地;對于注意力模塊,目前幾乎各個(gè)模型都還在用高精度(例如 FP16 或 FP32)的注意力運(yùn)算進(jìn)行訓(xùn)練和推理。然而,隨著大型模型需要處理的序列長度不斷增加,Attention(注意力運(yùn)算)的時(shí)間開銷逐漸成為網(wǎng)絡(luò)優(yōu)化的主要瓶頸。
為了提高注意力運(yùn)算的效率,清華大學(xué)陳鍵飛團(tuán)隊(duì)提出了 8Bit 的 Attention(SageAttention)。實(shí)現(xiàn)了 2 倍以及 2.7 倍相比于 FlashAttention2 和 xformers 的即插即用的推理加速,且在視頻、圖像、文本生成等大模型上均沒有端到端的精度損失。
- 論文標(biāo)題:SageAttention: Accurate 8-Bit Attention for Plug-and-play Inference Acceleration
- 論文鏈接:https://arxiv.org/abs/2410.02367
- 開源代碼:https://github.com/thu-ml/SageAttention
即插即用舉例
SageAttention 可以一行代碼輕松替換掉 torch 中當(dāng)前最優(yōu)的 Attention 接口(scaled_dot_product_attention),實(shí)現(xiàn)即插即用的推理加速。
具體來說,SageAttention 的使用非常方便,使用 pip install sageattention 后,
只需要在模型的推理腳本前加入以下三行代碼即可:
效果上,以開源視頻生成模型 CogvideoX 為例,使用 SageAttention 可以端到端加速 35%,且生成的視頻無損:
全精度 Attention
SageAttention
接下來,將從背景與挑戰(zhàn),技術(shù)方案,以及實(shí)驗(yàn)效果介紹 SageAttention。
背景
隨著大模型需要處理的序列長度越來越長(比如 Llama3.1 支持 128K 的序列長度),Attention 的速度優(yōu)化變得越來越重要。下圖展示了一個(gè)標(biāo)準(zhǔn)的 Transformer 模型中各運(yùn)算隨著序列長度變化的時(shí)間占比:
挑戰(zhàn)
為了方便指代注意力元算中包含的矩陣,我們先回顧一下注意力的計(jì)算公式:
將神經(jīng)網(wǎng)絡(luò)中各運(yùn)算的數(shù)值類型從高比特量化至低比特是一種有效提升計(jì)算和訪存效率的方法。然而,研究團(tuán)隊(duì)發(fā)現(xiàn)直接將注意力運(yùn)算中的 Q, K, P, V 從 FP16 量化為 INT8 或者 FP8 后將會導(dǎo)致在幾乎所有模型和任務(wù)上都會得到極差的結(jié)果,例如,在 Unidiffuser 文生圖模型中,會得到一張完全模糊的圖像;在 Llama2-7B 進(jìn)行四選一選擇題任務(wù)上得到 25.5% 的準(zhǔn)確率。
經(jīng)過仔細(xì)分析后,研究團(tuán)隊(duì)發(fā)現(xiàn)主要是兩個(gè)原因?qū)е铝肆炕⒁饬Φ牟粶?zhǔn)確:
- 大多視頻、圖像生成模型中,矩陣 K 表現(xiàn)出了極強(qiáng)的通道維度的異常值分布,直接使用 INT8 或者 FP8 數(shù)據(jù)類型對其進(jìn)行量化會導(dǎo)致巨大的誤差。
- 在所有模型中,對矩陣 P, V 進(jìn)行量化不能保證一個(gè)模型中所有層的精度。下表展示了對 P, V 量化后,Llama2-7B 和 Unidiffuser 模型所有層中,最差情況的層對應(yīng)的量化注意力的準(zhǔn)確度,(該準(zhǔn)確度為量化注意力相比全精度注意力的誤差),可以發(fā)現(xiàn)不管對 P, V 矩陣進(jìn)行何種 8Bit (INT8,E4M3,E5M2)量化,總有些層的準(zhǔn)確率非常差,導(dǎo)致了端到端效果的下降。
技術(shù)方案
為了解決上述的兩個(gè)關(guān)鍵問題,研究團(tuán)隊(duì)提出了對應(yīng)的解決辦法。
- 對 K 進(jìn)行平滑處理。SageAttention 采用了一個(gè)簡單但非常實(shí)用的方法來消除矩陣 K 的異常值:K = K – mean (K) 其中 mean (K) 是沿著通道維度求平均值。這個(gè)簡單的做法不僅不會影響注意力計(jì)算的正確性 Softmax (QK^T) = Softmax (Q (K-mean (K))^T) ;且對整個(gè) Attention 速度的影響只有 0.2%;同時(shí)還保證了量化后的注意力運(yùn)算的精度:
- 對 Q, K 進(jìn)行分塊 INT8 量化。對于矩陣 Q, K,SageAttention 采用了以 FlashAttention 的分塊大小為粒度的 INT8 量化。這是因?yàn)椋?. 對 Q, K 矩陣進(jìn)行 INT8 量化相比于進(jìn)行 FP8 量化,注意力的精度更高。2. 在一些常用卡上,比如 RTX4090,INT8 矩陣乘法(INT32 為累加器)的速度是 FP8(FP32 為累加器)的兩倍。
- 對 P, V 采用 FP16 數(shù)據(jù)類型的矩陣乘法累加器。對于矩陣 P, V,SageAttention 采用了保留 P, V 為 FP16 的類型,但進(jìn)行矩陣乘法時(shí)采用 FP16 數(shù)據(jù)類型的累加器。這是因?yàn)椋?. PV 矩陣乘法的數(shù)值范圍始終在 FP16 的表示范圍內(nèi),且經(jīng)過大量實(shí)驗(yàn)驗(yàn)證,F(xiàn)P16 作為累加器的數(shù)據(jù)類型不會帶來任何精度損失(見下表)。2. 在一些常用卡上,比如 RTX4090,以 FP16 為累加器數(shù)據(jù)類型的矩陣乘法的速度是 FP32 作為累加器的兩倍。
SageAttention 的流程圖及算法如下所示:
實(shí)驗(yàn)效果
SageAttention 實(shí)現(xiàn)了底層的 GPU Kernel,在算子速度以及各個(gè)模型的端到端精度上都有十分不錯(cuò)的表現(xiàn)。
具體來說,算子速度相比于 FlashAttention2 和 xformers 有 2.1 以及 2.7 倍的加速。以下 4 張圖展示了在 RTX4090 上,不同的序列長度下 SageAttention 的各種 Kernel 與其他方法的速度比較。
以下 4 張圖展示了在 RTX3090 上,不同的序列長度下 SageAttention 的各種 Kernel 與其他方法的速度比較。
下表展示了在 RTX4090 上,各模型中的注意力模塊中 SageAttention 相比于使用模型原始的注意力的加速比。
真實(shí)任務(wù)的精度上,下表展示了 SageAttention 在視頻、圖像、文本生成等大模型上均沒有端到端的精度損失: