4比特量化三倍加速不掉點!清華即插即用的SageAttention迎來升級
論文共同第一作者張金濤、黃浩峰分別來自清華大學計算機系和交叉信息研究院,論文通訊作者陳鍵飛副教授及其他合作作者均來自清華大學計算機系。
大模型中,線性層的低比特量化已經(jīng)逐步落地。然而,對于注意力模塊,目前幾乎各個模型都還在用高精度(例如 FP16 或 FP32)的注意力運算進行訓練和推理。并且,隨著大型模型需要處理的序列長度不斷增加,Attention(注意力運算)的時間開銷逐漸成為主要開銷。
此前,清華大學陳鍵飛團隊提出的 8-Bit 的即插即用 Attention(SageAttention),將 Attention 中的 QK^T 量化至 INT8,將 PV 保持為 FP16 精度并使用 FP16 精度的矩陣乘法累加器,同時提出 Smooth K 技術(shù)保持了量化 Attention 的精度,實現(xiàn)了 2 倍加速于 FlashAttention2,且在各類大模型上均保持了端到端的精度表現(xiàn)。
目前,SageAttention 已經(jīng)被業(yè)界及社區(qū)廣泛地使用于各種開源及商業(yè)大模型中,比如 CogvideoX、Mochi、Flux、Llama3、Qwen 等。
近日,陳鍵飛團隊進一步提出了 4-Bit 的即插即用 Attention(SageAttention2),相較于 FlashAttention2 和 xformers 分別實現(xiàn)了 3 倍以及 4.5 倍的即插即用的推理加速,且在視頻、圖像、文本生成等大模型上均保持了端到端的精度表現(xiàn)。
- 論文標題:SageAttention2: Efficient Attention with Thorough Outlier Smoothing and Per-thread INT4 Quantization
- 論文鏈接:https://arxiv.org/abs/2411.10958
- 開源代碼:https://github.com/thu-ml/SageAttention
即插即用舉例
SageAttention2 實現(xiàn)了高效的 Attention 算子,可以實現(xiàn)即插即用的推理加速。輸入任意 Q, K, V 矩陣,SageAttention2 可以快速返回 Attention Output (O)。
具體來說,SageAttention2 使用起來很方便,克隆倉庫(git clone https://github.com/thu-ml/SageAttention)并執(zhí)行 python setup.py install 后,只需一行代碼便可以得到 Attention 的輸出,可以使用該接口方便地替換任意模型中的 Attention 函數(shù):
效果上,以開源視頻生成模型 CogvideoX-1.5-5B 為例,使用 SageAttention2 可以端到端加速 1.8 倍,且生成的視頻無損:
使用全精度 Attention
使用 SageAttention2
更重要的是,SageAttention2 提供了比 SageAttention 更廣泛的硬件支持。除了在 RTX 4090 上可以 3 倍加速于 FlashAttention 外,在 L20、L40、L40S 可以實現(xiàn) 2 倍的加速,在 A100、A800、A6000 上可以實現(xiàn) 1.45-1.6 倍的加速(基于 SageAttention)。
接下來,研究團隊將從前言、挑戰(zhàn)、方法以及實驗效果四個方面介紹 SageAttention2(總體流程圖如下圖)。
前言
隨著大模型需要處理的序列長度越來越長,Attention 的速度優(yōu)化變得越來越重要。下圖展示了一個標準的 Transformer 模型中各運算的時間占比隨序列長度的變化:
為了方便指代注意力運算中的矩陣,我們先回顧一下注意力的計算公式:
盡管 SageAttention 提出將 Q,K 量化至 INT8,將 P,V 保持 FP16 精度且采用 FP16 的矩陣乘法累加器來加快 Attention 的速度。然而,這樣做的缺點是:1)INT8 的矩陣乘法只達到了一半的 INT4 矩陣乘法的速度,2)使用 FP16 的乘法累加器的 FP16 的矩陣乘法的加速只在 RTX4090 和 RTX3090 顯卡上有效。
為了克服上述缺點,SageAttention2 提出將 Q, K 量化至 INT4,并將 P, V 量化至 FP8 來加速 Attention。然而,這樣做的挑戰(zhàn)是很大的。
4-Bit 注意力量化有什么問題?
研究團隊發(fā)現(xiàn)直接將注意力運算中的 Q, K 量化為 INT4 后將會導致在幾乎所有模型和任務上都會得到極差的結(jié)果,例如,在 CogVideoX 文生視頻模型中,會得到完全模糊的視頻;Llama2-7B 進行四選一選擇題任務上得到 25% 的準確率。
經(jīng)過仔細分析后,研究團隊發(fā)現(xiàn)主要是兩個原因?qū)е铝肆炕⒁饬Φ牟粶蚀_:
(1)INT4 的數(shù)值范圍相比 INT8 非常小,導致其量化誤差在 Q,K 矩陣中出現(xiàn)一些異常值時會變得十分明顯,恰好大多模型都在 Q, K 中表現(xiàn)出來了較大的通道維度的異常值。這極大削減了 QK^?矩陣乘法的精度。
(2)研究團隊發(fā)現(xiàn) Nvidia 的顯卡上,F(xiàn)P8 的矩陣乘法指令 (mma.f32.f8.f8.f32) 的乘法累加器并不是官方宣稱的 FP32 精度,而是只有 FP22 精度,這導致了 PV 矩陣乘法出現(xiàn)較大的累加誤差。
技術(shù)方案
為了解決上述的兩個挑戰(zhàn),研究團隊提出了對應的解決辦法。
(1)保留 SageAttention 中對 K 進行平滑處理的同時,提出對 Q 進行平滑處理:Q – mean (Q)。其中 mean (Q) 是沿著通道維度的平均值向量。完成該平滑操作后需要在 Attention 計算過程中將 mean (Q) 和 K^T 的向量與矩陣乘法的結(jié)果補償?shù)?S 中。
這使得相比直接量化 Q, K 至 INT4 的準確度有質(zhì)的改變,如下表展示了對比了該方法和直接量化 Q, K 至 INT4 在 Cogvideo 和 Llama3.1 上的端到端表現(xiàn)。
矩陣 Q 平滑前后的數(shù)據(jù)分布可視化的結(jié)果如下,可以發(fā)現(xiàn)平滑后的 Q 對 INT4 數(shù)據(jù)范圍的利用度更高:
(2)對 Q, K 進行 Per-thread 量化。對于矩陣 Q, K,SageAttention2 采用了根據(jù) mma 指令對矩陣內(nèi)存排布的要求,對 Q,K 中的 Token 按照 GPU 線程進行分組,使量化粒度比 SageAttention 中的 per-block 細化 16 倍,極大提高了 4Bit 的 QK^?乘法準確度的同時不引入任何額外開銷。
具體來說,在 SageAttention 中,每個 Q 的塊將被劃分為 c_w 個段,由 GPU 流處理器(SM)中的 c_w 個 GPU warp 處理。然后,每個包含 32 個線程的 warp 會使用 NVIDIA 的 mma.m16n8k64 PTX 指令來執(zhí)行 QK^?運算。根據(jù)這一指令的布局要求,研究團隊發(fā)現(xiàn)一個 warp 內(nèi)的 Q [8×(n%8)] 可以共用一個量化縮放參數(shù),而一個 warp 內(nèi)的 K [8×(n%8)] 和 K [8×(n%8+1)] 也可以共用一個量化縮放參數(shù),其中 n 是 token 索引。
這種量化方法更為細致且不增加額外開銷。這是因為它根據(jù) MMA 指令的布局將不同的 GPU 線程分配到不同的量化 Token 組,每個線程只對應一個量化縮放參數(shù)進行反量化。而非 Per-token 量化那樣,每個線程對應多個量化縮放參數(shù)。
如下表所示,可以發(fā)現(xiàn) per-thread 量化的準確度比 SageAttention 中采用的 per-block 量化高得多,準確度和 per-token 量化幾乎沒有差別。
(3)對 FP8 的 PV 矩陣乘法采用 FP32 的寄存器將每次 FlashAttention 分塊粒度的 PV 的 FP22 的乘法結(jié)果累加起來。這種做法可以有效地避免 FP22 的乘法累加器沿著序列長度累積過多的誤差,將 FP22 累加器帶來的誤差控制在 FlashAttention 分塊的粒度中,提高了 FP8 的 PV 乘法的準確度。
(4)針對 P 和 V,研究團隊對比了多種量化的數(shù)據(jù)類型,對比發(fā)現(xiàn)使用 E4M3 數(shù)據(jù)格式的 FP8 精度最準確,基本接近了 FP16 的準確度。因此采用將 P 和 V 量化至 E4M3。
下圖展示了 SageAttention2 的算法流程:
SageAttention2 共實現(xiàn)了兩種 Kernel,區(qū)別在于對 Q, K 進行 INT4 量化還是 INT8 量化:
此外,SageAttention2 還提出一種可選的對矩陣 V 進行平滑處理的技術(shù),可以進一步提高 PV 矩陣乘法的準確度。具體來說,當某些模型中 V 矩陣具有通道維度的偏移時,可以將 V 減去其通道維度的平均值 mean (V) 來去除偏移,之后進行正常的量化 Attention 運算。只需要對最終 Attention 的 Output 加上 mean (V) 即可保持計算的正確性。
這種做法可以提升準確度的原因如下圖所示。在 FP22 的表示范圍內(nèi),數(shù)值越大,相比 FP32 的誤差越大。而 P 的范圍是 0~1 之間,那么當 V 矩陣的列有較大的數(shù)值偏移時,PV 的 FP22 累加器的精度就越差,通過平滑 V 去除偏移后,就可以加強 PV 矩陣乘法的準確度。
實驗效果
SageAttention 實現(xiàn)了底層的 GPU CUDA Kernel,在算子速度以及各個模型端到端準確度上都有十分不錯的表現(xiàn)。
具體來說,算子速度相比于 FlashAttention2 和 xformers 有大約 3 倍以及 4.5 倍的加速:
算子的準確度方面也是比對 Q, K 進行 SmoothQuant 和 Hadamard 變換要更加準確:
各模型在真實場景的端到端精度表現(xiàn)中,在視頻、圖像、文本生成等大模型上均保持了端到端的精度表現(xiàn):
下圖是在 HunyuanVideo 中的可視化實例:
下圖是在 Cogvideo 中的可視化實例:
下表展示了各個語言、視頻、圖像生成模型中 SageAttention2 的端到端精度表現(xiàn):
端到端的速度表現(xiàn)上,SageAttention2 兩個 Kernel 的實現(xiàn)均可以有效地對長序列模型進行加速,比如可以端到端 1.8 倍加速 CogVideoX1.5-5B,其他模型上也均有 1.6 到 1.8 倍的提速。