FlashAttention3:“苗條”的就是比較好! 原創(chuàng)
?1.加速Hopper GPU
注意力機制是Transformer架構(gòu)的核心能力,也是大型語言模型和長上下文應(yīng)用的瓶頸。FlashAttention(和 FlashAttention-2)開創(chuàng)了一種通過最小化內(nèi)存讀/寫來加速 GPU 注意力的方法,現(xiàn)在大多數(shù)庫都使用它來加速 Transformer 訓(xùn)練和推理。
這導(dǎo)致了過去兩年上下文LLM長度的大幅增加,從2-4K(GPT-3,OPT)增加到128K(GPT-4),甚至1M(Llama 3)。然而,盡管取得了成功,但 FlashAttention 尚未利用現(xiàn)代硬件中的新功能,F(xiàn)lashAttention-2在H100 GPU上僅實現(xiàn)了35%的理論最大FLOP 利用率。
在這篇博文中介紹了三種主要技術(shù)來加快對Hopper GPU的關(guān)注:利用 Tensor Core和TMA的異步性來
1) 通過扭曲專業(yè)化重疊整體計算和數(shù)據(jù)移動
2)交錯塊matmul和softmax操作
3)利用硬件支持實現(xiàn) FP8 低精度的非連貫處理
2.快速了解
FlashAttention-3比使用FP16的FlashAttention-2快1.5-2.0倍,高達(dá)740 TFLOPS,即H100理論最大FLOPS利用率為 75%。使用FP8時,F(xiàn)lashAttention-3達(dá)到接近 1.2 PFLOPS,誤差比基線FP8注意小2.6倍。
- 更高效的GPU利用率:新技術(shù)可利用高達(dá)75%的H100 GPU最大功能,而之前僅為35%。這導(dǎo)致在訓(xùn)練和運行大型語言模型方面,比以前的版本快得多(1.5-2 倍LLMs)。
- 以較低的精度獲得更好的性能:FlashAttention-3可以處理精度較低的數(shù)字FP8,同時保持精度。這樣可以實現(xiàn)更快的處理速度,并可能降低內(nèi)存使用率,從而為運行大規(guī)模AI操作的客戶節(jié)省成本并提高效率。
- 能夠在以下位置LLMs使用更長的上下文:通過加速注意力機制,F(xiàn)lashAttention-3使AI模型能夠更有效地處理更長的文本片段。這可以使應(yīng)用程序能夠在不減慢速度的情況下理解和生成更長、更復(fù)雜的內(nèi)容。
3.GEMM和SOFTMAX
注意力有兩個主要操作GEMMs(GEMMs是指廣義矩陣乘法General Matrix Multiply),例如注意力機制中Q和K之間以及注意力矩陣P和V之間的矩陣乘法。
GPU上面現(xiàn)代加速器上,非matmul操作比matmul操作慢得多。例如softmax中的指數(shù)運算等特殊函數(shù)的吞吐量遠(yuǎn)遠(yuǎn)低于浮點乘加。這些特殊運算(函數(shù))SF一般是由多功能(計算)單元負(fù)責(zé),多功能(計算)單元是獨立于浮點乘-加(例如y=wx+b)或矩陣乘加之外。
例如,H100 GPU SXM5具有989TFLOPS的FP16矩陣乘法,但對于特殊的函數(shù)SF,只有 3.9TFLOPS的吞吐,吞吐量低 256 倍。
CUDA 編程指南規(guī)定,特殊函數(shù)的吞吐量為每個時鐘周期每個流式多處理器 (SM) 16次操作。將16乘以132SM和1830 Mhz(用于計算 FP16 matmul 的989TFLOPS 的時鐘速度)得到 3.9TFLOPS!
假如注意力機制的head維度為128,matmul FLOPS比指數(shù)運算多512倍,這意味著與matmul運算相比,花費在指數(shù)運算的時間需要比矩陣運算多50%的時間。Matmul在FP8的精度下速度比FP16還要快多兩倍,這樣一來就被指數(shù)運算嚴(yán)重的拖后腿!能有魔法棒實現(xiàn)兩者并行么?
上面文縐縐的話翻譯成白話就是:GEMM比Softmax快,如何讓兩者并駕齊驅(qū)?
Warp是SM中的基本概念,可以先回去溫習(xí)下GPU的組成。Warp其實已經(jīng)做了一些調(diào)度的事宜,某些Warp被阻塞,其他翹曲可以運行。?
例如存在 2個warpgroup(標(biāo)記為 1 和 2),每個warpgroup是4個warp 的組),這時候通過使用同步屏障 (bar.sync),以便warpgroup 1首先執(zhí)行它的GEMM。例如,一次迭代的GEMM1和下一次迭代的 GEMM0。然后warpgroup 2執(zhí)行它的GEMM,而warpgroup 1執(zhí)行它的softmax, 等等。這個類似乒乓球的調(diào)度方式,確保了兩者并駕齊驅(qū)。上圖相同顏色的為相同的迭代。
這種方式在實踐中,調(diào)度并不是真的這么妥帖,但是這樣的調(diào)度可以將 FP16 注意力前向傳遞從大約 570 TFLOPS提高到620 TFLOPS(頭部head 128維,序列長度8K)。
即使在一個Warpgroup中,可以在這個群組運行GEMM的時候運行softmax的某些部分。如下圖所示:
<非工科讀者跳過!>具體的原理在于在注意力算法中,內(nèi)部循環(huán)(主循環(huán))內(nèi)的操作具有順序依賴性,這些依賴性會阻礙單次迭代中的并行化。例如,(本地)softmax 18-19行依賴于第一個 GEMM 的輸出,而第二個 GEMM 將其結(jié)果作為操作數(shù)。實際上,算法 1 的第 17- 21行中的等待語句序列化了softmax 和GEMM的執(zhí)行。但是可以通過寄存器中的額外緩沖區(qū)在迭代之間流水線來打破這些依賴關(guān)系。遵循這一思路,F(xiàn)L3提出了以下兩階段GEMM-softmax流水線算法:
<繼續(xù)>這種流水線將吞吐量從大約620 TFLOPS提高到大約640-660 TFLOPS,用于FP16注意力向前轉(zhuǎn)移,但代價是更高的寄存器壓力,因為需要更多的寄存器來容納GEMM的累加器和softmax的輸入/輸出。
擴展上述 2 階段算法,F(xiàn)L3繼續(xù)提出了一個3階段變體,該變體將進一步重疊第二個WGMMA與softmax。雖然這種方法提供了更高的 Tensor Core 利用率的潛力,但它需要更多的寄存器。
4.FP8的量化支持?
FP8和FP32在寄存器中的存儲布局的不一致給FL3的算法帶來了挑戰(zhàn)。
對于 FP8 FlashAttention-3, ??在將分片加載到SMEM后進行內(nèi)核內(nèi)轉(zhuǎn)置。對于內(nèi)核內(nèi)轉(zhuǎn)置,我們利用了LDSM ( ldmatrix ) 和STSM ( stmatrix )指令,它們涉及一系列線程共同加載 SMEM到RMEM,并以 128 字節(jié)的粒度存儲 RMEM 到 SMEM。
LDSM/STSM指令都是高效的,允許在warpgroup中執(zhí)行,并且能夠在執(zhí)行內(nèi)存復(fù)制時轉(zhuǎn)置布局。在第一次迭代之后,可以在前一個??切片和當(dāng)前 ??切片的WGMMA運算中,加入下一個??切片的轉(zhuǎn)置。
使用 FP8 (e4m3) 格式,僅使用3位來存儲尾數(shù),使用4位來存儲指數(shù)。這導(dǎo)致比FP16/BF16更高的數(shù)值誤差。此外,大型模型通常具有異常值,它的量級比大多數(shù)其他值大得多,這使得量化變得困難。為了減少 FP8中注意力機制的誤差,F(xiàn)L3采用了兩種技術(shù):
- 塊量化:為每個塊保留一個標(biāo)量,以便對于每個Q,K,V 將其張量拆分為大小????×?? ????×?? 塊,然后獨立量化。這種量化可以與注意力之前的操作融合,而不會額外減慢速度。由于FlashAttention-3算法都是基于快進行計算,因此可以縮放每個S塊進行量化,而無需計算成本。
- 利用QuIP的非相干處理,將Q和K與隨機正交矩陣相乘,以“分散”異常值并減少量化誤差。<不明白可以跳過,后面專欄介紹這種算法>。
在實驗中,Q、K、V是由標(biāo)準(zhǔn)正態(tài)分布生成的,但0.1%的條目具有較大的量級(模擬異常值),我們發(fā)現(xiàn)非相干處理可以將量化誤差減少 2.6倍。下表為數(shù)值誤差比較。
5.性能對比
下面展示了FlashAttention-3的一些結(jié)果,并將其與FlashAttention-2以及 Triton和cuDNN中的實現(xiàn)進行了比較(兩者都已經(jīng)使用了Hopper GPU 的新硬件功能)。對于FP16,F(xiàn)lashAttention-2的加速約為1.6倍至 2.0倍。
本文轉(zhuǎn)載自 ??魯班模錘??,作者: 龐德公
