FP8 訓(xùn)練新范式:減少 40% 顯存占用,訓(xùn)練速度提高 1.4 倍
近期DeepSeek V3 引爆國(guó)內(nèi)外的社交媒體,他們?cè)谟?xùn)練中成功應(yīng)用了 FP8 精度,顯著降低了 GPU 內(nèi)存使用和計(jì)算開(kāi)銷(xiāo)。這表明,F(xiàn)P8 量化技術(shù)在優(yōu)化大型模型訓(xùn)練方面正發(fā)揮著越來(lái)越重要的作用。
近期,來(lái)自伯克利,英偉達(dá),MIT 和清華的研究者們提出了顯存高效的 FP8 訓(xùn)練方法:COAT(Compressing Optimizer states and Activation for Memory-Efficient FP8 Training),致力于通過(guò) FP8 量化來(lái)壓縮優(yōu)化器狀態(tài)和激活值,從而提高內(nèi)存利用率和訓(xùn)練速度。COAT 實(shí)現(xiàn)了端到端內(nèi)存占用減少 1.54 倍,端到端訓(xùn)練速度提高 1.43 倍,同時(shí)保持模型精度。它還可以使訓(xùn)練批次大小加倍,從而更好地利用 GPU 資源。通過(guò)利用 FP8 精度,COAT 使大型模型的高效全參數(shù)訓(xùn)練在更少的 GPU 上成為可能,并有助于在分布式訓(xùn)練環(huán)境中加倍批次大小,為大規(guī)模模型訓(xùn)練的擴(kuò)展提供了實(shí)用的解決方案。最重要的是,他們的訓(xùn)練代碼完全開(kāi)源。
論文第一作者席浩誠(chéng)本科畢業(yè)于清華大學(xué)姚班,目前在伯克利攻讀博士學(xué)位,他在英偉達(dá)實(shí)習(xí)期間完成了這篇工作。論文共同通訊作者為 MIT 韓松副教授和清華大學(xué)陳鍵飛副教授。
- 論文標(biāo)題:COAT: Compressing Optimizer States and Activation for memory efficient FP8 Training
- 論文鏈接:https://arxiv.org/abs/2410.19313
- 開(kāi)源代碼:https://github.com/NVlabs/COAT
一、FP8 優(yōu)化器狀態(tài)
1. FP8 量化優(yōu)化器狀態(tài)的難點(diǎn)
論文作者發(fā)現(xiàn),當(dāng)前的量化方法無(wú)法充分利用 FP8 的表示范圍,因此在使用每組量化(per-group quantization)對(duì)優(yōu)化器狀態(tài)進(jìn)行量化時(shí)會(huì)導(dǎo)致較大的量化誤差。對(duì)于 FP8 的 E4M3 格式,我們希望量化組 X 的動(dòng)態(tài)范圍覆蓋 E4M3 的最小可表示值(0.00195)和最大可表示值(448)之間的整個(gè)跨度,以充分利用其表示能力。然而,E4M3 的動(dòng)態(tài)范圍通常未被充分利用:E4M3 的動(dòng)態(tài)范圍約為 200000,但一階動(dòng)量的每個(gè)量化組的最大值最小值之比通常為 1000,二階動(dòng)量的該比值則通常為 10,遠(yuǎn)小于 E4M3 的動(dòng)態(tài)范圍。這使得用 FP8 來(lái)量化優(yōu)化器狀態(tài)的誤差非常大。
2. 解決方案:動(dòng)態(tài)范圍擴(kuò)展
論文作者發(fā)現(xiàn),在量化之前引入一個(gè)擴(kuò)展函數(shù) f (?),能夠擴(kuò)大量化組的動(dòng)態(tài)范圍,并使其與 E4M3 對(duì)齊。使用的擴(kuò)展函數(shù)為:
其中,k 是即時(shí)計(jì)算的參數(shù),每個(gè)量化組共享一個(gè) k。當(dāng) k > 1 時(shí),動(dòng)態(tài)范圍將被擴(kuò)大,并更接近 E4M3 的動(dòng)態(tài)范圍。在每一步訓(xùn)練中,都可以即時(shí)的計(jì)算出最優(yōu)的 k,從而可以充分利用 E4M3 的表示范圍,而原始的量化方法只能利用其中的一小部分。
動(dòng)態(tài)范圍擴(kuò)展方法可以大大減少量化誤差,并充分利用 E4M3 的動(dòng)態(tài)范圍。除此之外,還發(fā)現(xiàn),E4M3 比 E5M2 更適合一階動(dòng)量。而對(duì)于二階動(dòng)量,雖然在原始設(shè)置中 E4M3 優(yōu)于 E5M2,但在應(yīng)用我們的擴(kuò)展函數(shù)后,它們的量化誤差幾乎相同。因此,建議在量化優(yōu)化器狀態(tài)時(shí)使用 E4M3 + E4M3 量化策略或 E4M3 + E5M2 量化策略。
二、FP8 激活
1. 動(dòng)機(jī):非線性層占用大量?jī)?nèi)存
在語(yǔ)言模型的前向傳播中,必須保留激活值以用于反向傳播計(jì)算梯度。在 Llama 模型系列中,非線性層通常占內(nèi)存占用的約 50%。相比之下,線性層的貢獻(xiàn)不到 25%。因此,優(yōu)化線性和非線性層以減少激活內(nèi)存占用至關(guān)重要。
2. 解決方案:混合粒度 FP8 精度流
FP8 精度流要求所有線性和非線性層的輸入和輸出采用 FP8 格式。通過(guò)直接以 FP8 格式保存輸入張量用于反向傳播,這消除了額外的量化操作需求,從而減少了相關(guān)開(kāi)銷(xiāo)。FP8 精度流自然地將非線性和線性層的內(nèi)存占用減少了 50%,因?yàn)樗鼈冎恍枰4?FP8 激活值,而不是 BF16。為了進(jìn)一步提高該方法的準(zhǔn)確性,作者提出在不同層中變化量化粒度,以混合粒度的方式平衡精度和效率。
三、實(shí)驗(yàn)結(jié)果
COAT 在多個(gè)任務(wù)中展示了其在內(nèi)存占用和訓(xùn)練速度方面的優(yōu)勢(shì),同時(shí)保持了模型性能。
1. 訓(xùn)練加速 1.43 倍,顯存降低 1.54 倍
在使用 4 張 H100 訓(xùn)練 Llama-2-13B 模型時(shí),COAT 將每個(gè) GPU 的內(nèi)存占用從 BF16 的 76.1GB 減少到 49.1GB,實(shí)現(xiàn)了 1.54 倍的內(nèi)存縮減。同時(shí),COAT 將訓(xùn)練速度從 BF16 的每秒 2345 個(gè) token 提升至每秒 5295 個(gè) token,達(dá)到 1.43 倍的加速。在幾乎所有的訓(xùn)練場(chǎng)景下,COAT 都能夠使 Batch Size 翻倍,或是讓訓(xùn)練所需的卡數(shù)減小。
2. 訓(xùn)練完全不掉點(diǎn),F(xiàn)P8 訓(xùn)練表現(xiàn)和 BF16 吻合
COAT 在各種應(yīng)用場(chǎng)景下,均展現(xiàn)出了出色的精度,完全不會(huì)導(dǎo)致模型性能下降。例如,在大語(yǔ)言模型預(yù)訓(xùn)練任務(wù)中,COAT 可以保持近乎無(wú)損的模型性能,訓(xùn)練中的 loss 曲線也和 BF16 完全吻合。
COAT 在視覺(jué)語(yǔ)言模型微調(diào)中同樣實(shí)現(xiàn)了和 BF16 訓(xùn)練完全一致的表現(xiàn)。無(wú)論是 loss 曲線,還是下游任務(wù)上的表現(xiàn),COAT 均和 BF16 基準(zhǔn)相持平。
在一些實(shí)際的下游任務(wù)例子中,經(jīng)過(guò) COAT 訓(xùn)練過(guò)的模型也有著相當(dāng)優(yōu)秀的生成和總結(jié)能力。
四、總結(jié)
COAT 的核心價(jià)值在于使用 FP8 進(jìn)行訓(xùn)練的同時(shí)做到了顯存優(yōu)化。動(dòng)態(tài)范圍擴(kuò)展減少量化誤差,混合粒度量化優(yōu)化激活存儲(chǔ),兩者協(xié)同作用使得端到端內(nèi)存占用降低 1.54 倍。這種優(yōu)化不僅適用于單機(jī)訓(xùn)練,更在分布式訓(xùn)練中發(fā)揮關(guān)鍵作用 —— 通過(guò)批量大小翻倍,可在相同硬件條件下處理更多數(shù)據(jù),顯著提升訓(xùn)練效率。而對(duì)于顯存資源緊張的研究者,COAT 也提供了全參數(shù)訓(xùn)練的可行路徑,降低了大模型訓(xùn)練的門(mén)檻。