Meta 新作:FlashAttention 的數(shù)值偏差有多大?
一、背景
最近 Meta 的研究員開發(fā)了一個新的框架來了解 LLM 訓練中數(shù)值偏差的影響,并基于該框架評估了 LLM 中廣泛采用的 FlashAttention 的數(shù)值偏差。
對應的論文為:[2405.02803] Is Flash Attention Stable?
PS:其實論文很簡單,結(jié)論也很簡單:使用 FlashAttention 相比 Baseline Attention 確實會帶來數(shù)值偏差。但帶來的數(shù)值偏差比從 FP32 到 FP16 的數(shù)值偏差小得多,甚至小于不同初始化方法帶來的偏差。吐槽一下,論文中的圖都比較模糊。
二、摘要
LLM 預訓練的代價很高,也更加的復雜。很多 LLM 在預訓練中都遇到了訓練過程不穩(wěn)定的情況,通常表示為損失的毛刺(Spike)。數(shù)值偏差(Numeric Deviation)被認為是導致這種訓練不穩(wěn)定的潛在原因,但由于訓練的成本很高,量化這一點非常有挑戰(zhàn)性。
本文中,作者開發(fā)了一種系統(tǒng)性的方法來理解數(shù)值偏差的影響,并使用廣泛采用的 FlashAttention 來驗證了該框架。作者發(fā)現(xiàn),與 Baseline Attention 相比,在單個前向傳播中,BF16 下的 FlashAttention 會有超過一個數(shù)量級的數(shù)值偏差。然而,使用基于 Wasserstein 距離的數(shù)據(jù)驅(qū)動分析來提供數(shù)值偏差對訓練過程中模型權(quán)重影響的上限,發(fā)現(xiàn) FlashAttention 中的數(shù)值偏差比低精度訓練的影響小 2-5 倍。
三、引言
3.1 數(shù)值精度
如下圖為常見的浮點數(shù)值精度,其中 sign 表示符號位,exponent 表示指數(shù)位,fraction 表示尾數(shù)位。相比 float32,float16 的指數(shù)位和尾數(shù)位都更小,而 bfloat16 的指數(shù)位和 float32 相同,只是尾數(shù)位更少。因此,通常 float32 轉(zhuǎn) float16 時通常會帶來較大的精度損失,而 float32 轉(zhuǎn) bfloat16 通常只需要做小數(shù)位的截斷,損失相對較小。現(xiàn)在的 LLM 預訓練中通常都會使用 bfloat16。
- Float32:指數(shù)位 8 位,尾數(shù)位 23 位,數(shù)據(jù)范圍為[1.18e-38, 3.40e+38]
- float16:指數(shù)位 5 位,尾數(shù)位 10 位,數(shù)據(jù)范圍為[6.10e-05, 6.55e+04]
- bfloat16:指數(shù)位 8 位,尾數(shù)位 7 位,數(shù)據(jù)范圍為[1.18e-38, 3.39e+38]
3.2 數(shù)值誤差
在浮點數(shù)的計算中會存在兩種常見的誤差:
- 溢出誤差(Overflow Error):浮點都有一個有限的表示范圍,當計算結(jié)果超出這個表示范圍時就會產(chǎn)生溢出錯誤,往往表現(xiàn)為無窮大。比如,令 float a = FLT_MAX * 2,此時 a 的值為正無窮大。
- 舍入誤差(Rounding Error):浮點數(shù)有固定的有效位數(shù),當一個數(shù)值不能被精確表示時,就會被舍入到最接近的可表示的浮點數(shù)。這種輸入在數(shù)值計算中是不可避免的,因為大多數(shù)實數(shù)在計算機中無法被精確表示。比如在 C 中打印 0.1f,printf("a = %.20f\n", 0.1f),其輸出結(jié)果為 0.10000000149011611938,是一個近似值。
除此之外,有時也會提到下溢誤差(Underflow Error):當一個非常小的非零結(jié)果小于浮點數(shù)表示范圍下限時發(fā)生,通常導致結(jié)果被舍入為零。
由于 float16 和 bfloat16 的不同指數(shù)位和尾數(shù)位,也就導致它們出現(xiàn)誤差的場景不太一樣。
- float16:指數(shù)位較少,尾數(shù)位較多,表示范圍有限,但表示精度更高,因此更容易發(fā)生溢出誤差。
- bfloat16:指數(shù)位較多,尾數(shù)位較少,表示范圍更大,但表示精度有限,因此更容易發(fā)生舍入誤差。下溢誤差也更多一些。
3.3 訓練損失毛刺
在 Meta OPT、BigScience Bloom、Google PaLM、TII Falcon 以及智源 GLM 訓練中都出現(xiàn)了訓練損失出現(xiàn)毛刺的情況,也有一些有效的手段可以緩解,但依舊不知道其根因。比如 Google PaLM 中驗證了其并非是單個樣本導致的。
如下圖所示,是 [2211.05100] BLOOM: A 176B-Parameter Open-Access Multilingual Language Model 中遇到的毛刺現(xiàn)象:
3.4 評估指標
Wasserstein 距離,也稱為 Earth Mover’s Distance (EMD),是一種衡量兩個概率分布之間差異的方法。這種距離的直觀含義是,將一個概率分布轉(zhuǎn)變成另一個概率分布所需要的“工作量”或“成本”,其中“工作量”可以理解為將一堆形狀不同的沙子(一個概率分布)鏟動并重塑為另一堆沙子(另一個概率分布)所需要的努力。
Wasserstein 距離基于最優(yōu)運輸理論。給定兩個概率分布 P 和 ??,以及一個成本函數(shù) ??(??,??),Wasserstein 距離定義為將分布 P 轉(zhuǎn)變?yōu)?Q 所需的最小成本。數(shù)學上,它表示為:
這里的 π 是 P 和 ?? 之間的所有可能的聯(lián)合分布的集合,而 Π(P,Q) 表示所有這些聯(lián)合分布中,邊際分布分別是 P 和 Q 的集合。
相比其他距離度量(如歐氏距離或 KL 散度),Wasserstein 距離的一個主要優(yōu)勢在于其能夠更加有效地處理概率分布之間的微小變化,特別是當這些分布不重疊或僅部分重疊時。這使得 Wasserstein 距離在數(shù)據(jù)稀疏或異構(gòu)的情況下特別有用。
四、方法&實驗
4.1 方法
作者開發(fā)了一個 microbenchmark 來隔離和研究 FlashAttention 引起的數(shù)值偏差。其設(shè)計如下圖 Fig 2 所示,在原始的 FlashAttention 中只支持 FP16 和 BF16 格式,因此作者重新實現(xiàn)了 FlashAttention,以便分析不同的數(shù)值精度的影響。作者進一步修改模型,可以在每次調(diào)用 Attention 時計算 Baseline Attention 和 FlashAttention 的注意力矩陣輸出,從而可以使用最大差異(max difference)以及 Wasserstein 距離來度量差異。作者也進行了一系列訓練來度量整個訓練過程中模型權(quán)重的差異。
4.2 數(shù)據(jù)類型的影響
如下圖 Fig.3 所示,作者對比了不同數(shù)據(jù)類型下 Baseline Attention 和 FlashAttention 的數(shù)值偏差,可以看出,數(shù)值精度越高,偏差越?。?/p>
為了進一步分析這種數(shù)值偏差,作者探索了序列長度對數(shù)值偏差的影響,其中會保持 FlashAttention 的 tile 大小和 SRAM 大小相同。如下圖所示,隨著序列長度的增加,數(shù)值偏差也會適當增加。其中左圖(a)表示最大誤差,右圖(b)表示誤差的均值。由于序列變長,也就需要更多的 tile,相應也有更多的 resaling,這也就可能產(chǎn)生更多的誤差:
4.3 算法配置的影響
如下圖 Fig 6 所示,作者進一步探索了 FlashAttention 中不同配置的影響:
- (a)和(c)針對不同的 Block/tile Area 大小的影響,使用比較大的 Block 后 Baseline Attention 和 FlashAttention 的差異很小,主要是因為 rescaling 計算更少一些。
- (b)使用 Square Block 對 Baseline Attention 和 FlashAttention 的影響不大。?
4.4 模型權(quán)重的變化
作者進一步驗證了訓練中模型權(quán)重的變化(對比 Baseline Attention 和 FlashAttention),如下圖 Fig 7 所述,不管是最大誤差還是 Wasserstein 距離都會隨著訓練的迭代而逐漸變大,并且趨勢類似:
如下圖 Fig.8 所示,作者進一步驗證了整個訓練中其他變量帶來的模型權(quán)重的偏差。可以看出,雖然 Baseline Attention 和 FlashAttention 會導致權(quán)重產(chǎn)生誤差,但是其甚至比不同初始化方法帶來的誤差還小,更是遠小于 FP16 vs BF16 和 FP16 vs FP32 帶來的誤差:
五、參考鏈接
本文轉(zhuǎn)載自 ??AI閑談??,作者: AI閑談
