Intel Smooth-SwiGLU:FP8 LLM 訓(xùn)練,34% 加速
一、背景
本文中我們繼續(xù)介紹一個(gè) Intel 最新的關(guān)于 FP8 訓(xùn)練相關(guān)的工作,其在一定程度上分析并解決了 FP8 訓(xùn)練中的不收斂問(wèn)題,進(jìn)一步推進(jìn)了 FP8 訓(xùn)練落地(尤其是在 H100/H800 GPU 上)的可行性。
對(duì)應(yīng)的論文:[2409.12517] Scaling FP8 training to trillion-token LLMs [1]
二、摘要
本文中,作者首次在 2T Token 的數(shù)據(jù)集上使用 FP8 精度訓(xùn)練了 LLM,比以前的限制增加了 20 倍。通過(guò)這些擴(kuò)展訓(xùn)練實(shí)驗(yàn),作者發(fā)現(xiàn)了 FP8 訓(xùn)練中的關(guān)鍵不確定性,這些不確定性在早期持續(xù)時(shí)間較短的訓(xùn)練中是無(wú)法觀察到的。作者進(jìn)一步追溯到 SwiGLU 激活函數(shù)的異常值放大問(wèn)題。有趣的是,從分析和經(jīng)驗(yàn)上表明,這些放大只發(fā)生在較長(zhǎng)的訓(xùn)練中,并將其與 SwiGLU 的權(quán)重對(duì)齊過(guò)程聯(lián)系起來(lái)。
為了解決這個(gè)新發(fā)現(xiàn)的問(wèn)題,作者引入了 Smooth-SwiGLU,這是一種新穎的修改,可確保穩(wěn)定的 FP8 訓(xùn)練而不改變函數(shù)的行為。作者還首次演示了兩個(gè) Adam 優(yōu)化器參數(shù)(一階矩和二階矩)的 FP8 量化。
結(jié)合這些創(chuàng)新,作者在 256 個(gè) Intel Gaudi2 加速器上使用 FP8 精度成功訓(xùn)練了一個(gè) 7B 參數(shù)量的模型,實(shí)現(xiàn)了與 BF16 基線相當(dāng)?shù)慕Y(jié)果,同時(shí)提供了高達(dá) 34% 的吞吐提升。
三、引言
3.1 浮點(diǎn)數(shù)值表示
我們之前的文章中提到過(guò),雖然都是 FP8 精度,但是不同硬件上存在不同的表示方式,主要包括 E5M2 和 E4M3,其中 E 表示指數(shù)位(決定了動(dòng)態(tài)范圍),M 表示尾數(shù)位(決定了表示精度)。此外,雖然都是 E5M2 或者 E4M3,不同的硬件可能采用不同的格式。比如 NVIDIA GPU 上的 E5M2 符合 IEEE 754 Style,而 E4M3 卻不符合 IEEE 754 Style,可以稱(chēng)為 ARM-Intel-Nvidia Style。此外,AMD-Graphcore-Qualcomm 的表示也有所不同。如下圖所示,IEEE 754 Style 的 E4M3 范圍為 [-240, 240],而 ARM-Intel-Nvidia Style 的 E4M3 范圍是 [-448, 448]:
PS:由于 Intel 和 NVIDIA 采用一致的表示方式,這也就意味著 Intel 的結(jié)論能夠很容易擴(kuò)展到 NVIDIA 的 GPU。
3.2 SwiGLU
在 [2002.05202] GLU Variants Improve Transformer [2] 中作者提出了在 Transformer 模型中使用各種 GLU 變體激活,也提到其在許多下游理解任務(wù)上獲得了更好的結(jié)果。然而,作者也提到并沒(méi)有解釋為什么這些修改會(huì)有效,將其歸功于上帝的恩賜。在后續(xù) Google 的 PaLM,Meta 的 LLaMA 系列等模型中廣泛采用了 SwiGLU 激活。
如下圖所示為 FFN 中應(yīng)用 SwiGLU 的公式:
其中 Swish 為激活函數(shù),可以表示為:
其 Pytorch 的實(shí)現(xiàn)也很簡(jiǎn)單,如下所示,這里用 silu 代替了 SwiGLU,對(duì)應(yīng) β 為 1:
因?yàn)檫@里有三個(gè)參數(shù):w1,w2,w3,為了保證總參數(shù)量和正常 FFN 一致,所以 LLaMA 中這里的 Hidden Dim 不是 4d,而是 4d*2/3=8d/3。
3.3 GLU 類(lèi)激活的離群點(diǎn)
如下圖 Figure 1 所示,在 [2405.14428] Mitigating Quantization Errors Due to Activation Spikes in GLU-Based LLMs [3] 中作者發(fā)現(xiàn),各種 GLU 變體的激活函數(shù)容易在特定層(比如基于 SwiGLU 激活的 FFN 的最后一個(gè) Liner 層的輸入)出現(xiàn)激活的 Spike。此外,作者發(fā)現(xiàn)這些激活的 Spike 與中間層隱藏狀態(tài)(Hidden Stage,每個(gè) Transfomer Block 的輸出)之間也存在高度相關(guān)性。并且 FFN 可能會(huì)通過(guò)殘差連接中的加法運(yùn)算放大 Hidden Stage。一旦 Hidden Stage 被放大,它就會(huì)在各層中持續(xù)存在,直到之后的層中再次遇到激活 Spike。
3.4 延遲縮放
在 FP8 的 Per Tensor Scaling 技術(shù)中,有兩種常見(jiàn)的方式:Just-in-time Scaling 和 Delayed Scaling(可以參考 NVIDIA Transformer Engine 中的實(shí)現(xiàn) Using FP8 with Transformer Engine [4])。
- Just-in-time Scaling(實(shí)時(shí)縮放):直接計(jì)算 Tensor 絕對(duì)值的最大值(amax),然后得到 Scaling 值,再對(duì) Tensor 進(jìn)行 Scaling。此種方法更加精確,但是,額外引入的開(kāi)銷(xiāo)會(huì)大幅降低 FP8 帶來(lái)的收益。
- Delayed Scaling(延遲縮放):核心思路是使用額外的 Tensor 來(lái)存儲(chǔ)之前的 amax 歷史,然后根據(jù)歷史最大值估計(jì)當(dāng)前的最大值。
如下圖為 NVIDIA Transformer Engine 中的 Delayed Scaling 實(shí)現(xiàn)方案,amax history 最多可以存儲(chǔ) 1024 個(gè) history。在進(jìn)行當(dāng)前 Tensor 的 Scaling 操作時(shí),使用當(dāng)前 Tensor 之前的 amax history 來(lái)預(yù)測(cè)當(dāng)前的 amax,然后再進(jìn)行 Scaling 操作;Scaling 操作的同時(shí)會(huì)計(jì)算當(dāng)前的 amax,并更新 amax history。
四、方案
4.1 洞察
上面提到的 GLU 類(lèi)激活引出的離群點(diǎn)(Outlier)問(wèn)題會(huì)為 FP8 訓(xùn)練帶來(lái)很多挑戰(zhàn)。本文中,作者揭示,在大規(guī)模數(shù)據(jù)集上訓(xùn)練 LLM 的后期階段,這些異常值變得尤為顯著。
如下圖 Figure 1 所示,(a)為訓(xùn)練的起始階段,沒(méi)有 Outlier;(b)為訓(xùn)練 200B Token 之后,出現(xiàn)偶發(fā)性的 Outlier。這些 Outlier 僅在訓(xùn)練中處理了很長(zhǎng)一段時(shí)間才出現(xiàn),此現(xiàn)象對(duì)維持訓(xùn)練中的數(shù)值穩(wěn)定性帶來(lái)極大挑戰(zhàn),進(jìn)一步增加了 FP8 訓(xùn)練穩(wěn)定性的難度,尤其是像 Megatron-LM 中廣泛采用的延遲縮放方案中(如上述的介紹,這些方案假設(shè)了迭代期間的一致性)。
如前所述,SwiGLU 激活函數(shù)可能導(dǎo)致 FFN 組件的最后一個(gè) Linear 層的輸入出現(xiàn) Outlier。在使用 FP8 訓(xùn)練,并采用延遲縮放技術(shù)時(shí),SwiGLU 引發(fā)的 Outlier 會(huì)打破延遲縮放的統(tǒng)計(jì)一致性假設(shè),導(dǎo)致訓(xùn)練過(guò)程的不穩(wěn)定性。如下圖 Figure 3 所示,作者展示了在 FFN 的最后一個(gè) Linear (也就是 SwiGLU 的輸出)禁用量化后的訓(xùn)練收斂性,LLaMA2 FP8 的訓(xùn)練能夠成功地在大規(guī)模數(shù)據(jù)集上收斂,從而解決先前觀察到的發(fā)散問(wèn)題,這也驗(yàn)證了 SwiGLU 對(duì) FP8 訓(xùn)練穩(wěn)定性的影響。
4.2 SwiGLU 相關(guān)問(wèn)題證明
如下圖所示為 SwiGLU 的定義,其中,SwiGLU 由輸入 x 與權(quán)重 w1、w2 進(jìn)行乘積,分別得到 xTw1 和 xTw2。然后使用 Swish 激活函數(shù)對(duì) xTw2 進(jìn)行變換,將其與 xTw1 相乘。
其他標(biāo)準(zhǔn)激活函數(shù)(如 ReLU、GeLU 和 Swish)在輸入幅度較大時(shí)最多是線性的。這意味著當(dāng)輸入 u 逐漸趨于正無(wú)窮或負(fù)無(wú)窮時(shí),這些激活函數(shù)的比值(即 ∣f(u)/u∣)會(huì)趨于小于等于 1 的某個(gè)值。然而,SwiGLU 是一個(gè)二次函數(shù),可以達(dá)到更大的值,尤其在權(quán)重 w1、w2 相互“對(duì)齊”時(shí)(例如 w1=w2,并且 ||w1||=1)。
當(dāng) w1=w2 時(shí),上式可以表示為:
假設(shè) xTw=c,則當(dāng) c 較大時(shí),σ(c) 趨近于 1,此時(shí)下述結(jié)果約等于 c2,也就是上述所說(shuō)的二次放大特性。
作者也進(jìn)一步通過(guò)理論分析證明了上述 w1、w2 相互“對(duì)齊”現(xiàn)象,這里不再贅述,具體可以查看論文。當(dāng)然,作者也進(jìn)一步通過(guò)實(shí)驗(yàn)驗(yàn)證了相關(guān)問(wèn)題。如下圖 Figure 2 所示:
- (a):LLaMA2-7B 模型在 BF16 和 FP8 精度下的訓(xùn)練損失,其中 FP8 的訓(xùn)練在達(dá)到 200B Token 后開(kāi)始出現(xiàn)明顯的發(fā)散現(xiàn)象。
- (b):展示了某一個(gè)特定 Channel 在訓(xùn)練過(guò)程中 w1 和 w2 范數(shù)的動(dòng)態(tài)變化及其相關(guān)性。
- (c):某一個(gè) Outlier 通道在訓(xùn)練初期(8B Token)與后期(330B Token)w1 和 w2 元素散點(diǎn)圖。可以看出,后期相關(guān)性顯著提高,趨近于 w1=w2。
- 某一個(gè) Outlier 通道在訓(xùn)練初期(8B Token)與后期(330B Token)w1 的直方圖分布。?
除了上圖 (c) 那樣的正相關(guān)性外,作者也觀察到了明顯的負(fù)相關(guān)性,也就是趨近于 w1=-w2,如下圖所示:
4.3 Smooth-SwiGLU
為了在解決 Outlier 問(wèn)題的同時(shí)保持完整的 FP8 加速,作者提出了 Smooth-SwiGLU 方法。如下圖 Figure 4 所示,其展示了 Smooth-SwiGLU 的核心理念:對(duì) SwiGLU 函數(shù)的線性分支施加一個(gè)縮放因子,并在最后一個(gè) Linear 后將其重新縮放。此方法防止了輸入到最后一個(gè) Linear 的量化過(guò)程中出現(xiàn) Outlier,同時(shí)保留了 SwiGLU 激活函數(shù)的整體功能,使能夠在整個(gè)網(wǎng)絡(luò)中充分利用 FP8 精度。
為降低計(jì)算開(kāi)銷(xiāo),作者采用了一種高效的并行方法來(lái)計(jì)算縮放因子 si,也就是 Per Channel 量化:
1. 將 Tensor 分割成若干塊,每塊對(duì)應(yīng)一個(gè)通道。
2. 對(duì)每個(gè)塊(通道),并行計(jì)算其最大值。
3. 利用這些每個(gè)通道的最大值,確定各通道的獨(dú)立縮放因子 si。
此方法實(shí)現(xiàn)了高效的逐通道縮放,因?yàn)槊總€(gè)通道的縮放因子是獨(dú)立并行計(jì)算的。相較于 Linear 層中的矩陣乘法,這種方法的計(jì)算成本適中,尤其是在并行化處理下,即便在非優(yōu)化實(shí)現(xiàn)中也是如此。在推理階段,這些縮放因子可以合并到包含 SwiGLU 層及其后 Linear 的 FFN 的第一和第三Linear 的權(quán)重中。
五、實(shí)驗(yàn)&結(jié)果
5.1 FP8 優(yōu)化器
Adam 優(yōu)化器及其變體在深度學(xué)習(xí)中得到廣泛應(yīng)用,Adam 優(yōu)化器的一個(gè)關(guān)鍵特征是其存儲(chǔ)兩個(gè)矩,傳統(tǒng)上采用高精度(FP32),這顯著增加了內(nèi)存開(kāi)銷(xiāo),尤其對(duì)于大規(guī)模模型而言。盡管先前研究中已證明將其一階矩降至 FP8 精度的可行性,但仍保留二階矩為 FP16。本文中作者則更進(jìn)一步,成功將二階矩也量化至 FP8,顯著提升了大型語(yǔ)言模型優(yōu)化器的效率。
如下圖 Figure 5 所示,作者探索發(fā)現(xiàn)一階矩 Mom1 采用 E4M3,二階矩 Mom2 采用 E5M2 可以很好的維持訓(xùn)練精度:
5.2 訓(xùn)練效果實(shí)驗(yàn)
如下圖 Figure 6 所示,作者在 256 個(gè) Gaudi2 上訓(xùn)練 LLaMA2 7B 模型,共訓(xùn)練了 330B 左右 Token。本文的 Smooth-SwiGLU + FP8 優(yōu)化器可以和 BF16 訓(xùn)練維持相當(dāng)?shù)?Loss,而傳統(tǒng)的 FP8 訓(xùn)練在 200B Token 時(shí)開(kāi)始發(fā)散。
如下圖 Table 2 所示,作者進(jìn)一步對(duì)比了下游任務(wù)的 Zero Shot 精度以及困惑度,可以看出,本文 FP8 訓(xùn)練出的模型可以很好的保持精度:
5.3 訓(xùn)練速度
如下圖 Table 3 所示,可以看出,本文方案在保持收斂性的同時(shí)獲得了更高的加速比,相比 BF16 訓(xùn)練可以加速 33.52%(不及 FP8 主要是因?yàn)橐肓艘恍╊~外的開(kāi)銷(xiāo))。
雖然額外引入了一些計(jì)算開(kāi)銷(xiāo),但顯存并沒(méi)有明顯增加:
六、參考鏈接
