自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

Flash Attention穩(wěn)定嗎?Meta、哈佛發(fā)現(xiàn)其模型權(quán)重偏差呈現(xiàn)數(shù)量級波動(dòng)

發(fā)布于 2024-5-13 09:38
瀏覽
0收藏

眾所周知,大語言模型的訓(xùn)練常常需要數(shù)月的時(shí)間,使用數(shù)百乃至上千個(gè) GPU。以 LLaMA2 70B 模型為例,其訓(xùn)練總共需要 1,720,320 GPU hours。由于這些工作負(fù)載的規(guī)模和復(fù)雜性,導(dǎo)致訓(xùn)練大模型存在著獨(dú)特的系統(tǒng)性挑戰(zhàn)。


最近,許多機(jī)構(gòu)在訓(xùn)練 SOTA 生成式 AI 模型時(shí)報(bào)告了訓(xùn)練過程中的不穩(wěn)定情況,它們通常以損失尖峰的形式出現(xiàn),比如谷歌的 PaLM 模型訓(xùn)練過程中出現(xiàn)了多達(dá) 20 次的損失尖峰。


數(shù)值偏差是造成這種訓(xùn)練不穩(wěn)定性的潛在原因,由于大語言模型訓(xùn)練運(yùn)行成本極高,如何量化數(shù)值偏差儼然成為關(guān)鍵問題。


在最新的一項(xiàng)工作中,來自 Meta、哈佛大學(xué)的研究者開發(fā)了一個(gè)原則性定量方法來理解訓(xùn)練優(yōu)化中的數(shù)值偏差,以此評估不同的最新優(yōu)化技術(shù),并確定它們在用于訓(xùn)練大模型時(shí)是否可能引入意外的不穩(wěn)定性。


Flash Attention穩(wěn)定嗎?Meta、哈佛發(fā)現(xiàn)其模型權(quán)重偏差呈現(xiàn)數(shù)量級波動(dòng)-AI.x社區(qū)


  • 論文標(biāo)題:Is Flash Attention Stable?
  • 論文鏈接:https://arxiv.org/pdf/2405.02803


結(jié)果發(fā)現(xiàn),在一次單獨(dú)的前向傳遞過程中,F(xiàn)lash Attention 的數(shù)值偏差比 BF16 的 Baseline Attention 大一個(gè)數(shù)量級。


具體而言,該方法包括兩個(gè)階段,包括:


(1)開發(fā)一個(gè)微基準(zhǔn)來擾動(dòng)給定優(yōu)化中的數(shù)值精度;

(2)通過基于 Wasserstein 距離的數(shù)據(jù)驅(qū)動(dòng)分析評估數(shù)值偏差如何轉(zhuǎn)化為模型權(quán)重的變化。


研究者分析了 SOTA 優(yōu)化技術(shù) Flash Attention ,并量化了可能引入的數(shù)值偏差。Flash Attention 是一種廣泛用于加速注意力機(jī)制的技術(shù),通常被認(rèn)為是 Transformer 模型中的系統(tǒng)瓶頸。Flash Attention 在提高速度和減少內(nèi)存訪問量的同時(shí),也依賴于算法優(yōu)化,而算法優(yōu)化有可能導(dǎo)致數(shù)值偏差的增加。


研究者假設(shè)添加重新縮放因子(rescaling factors )可能會引入無意的近似,導(dǎo)致數(shù)值折衷,這可能會在后續(xù)影響訓(xùn)練穩(wěn)定性。


他們在多模態(tài)文本到圖像工作負(fù)載的背景下分析了 Flash Attention,以確定 Flash Attention 與其基線之間數(shù)值偏差的潛在重要性。最終,他們引入了一個(gè)框架來量化訓(xùn)練優(yōu)化的數(shù)值偏差及其下游影響。


研究者在數(shù)值偏差量化上主要作出了以下兩點(diǎn)貢獻(xiàn):


(1)設(shè)計(jì)了一個(gè)微基準(zhǔn)來分離數(shù)值精度對數(shù)值偏差的影響。


研究者所設(shè)計(jì)的微基準(zhǔn)作為一種技術(shù),用于衡量和量化傳統(tǒng)黑盒優(yōu)化(如 Flash Attention)所導(dǎo)致的數(shù)值偏差。通過擾動(dòng)通常在提供的內(nèi)核中不可用的方面,他們開創(chuàng)性地發(fā)現(xiàn)在低數(shù)值精度(BF16)下,與 Baseline Attention 相比,F(xiàn)lash Attention 的數(shù)值偏差大約高出一個(gè)數(shù)量級。


(2)基于 Wasserstein Distance 度量進(jìn)行了數(shù)據(jù)驅(qū)動(dòng)的分析。


通過該分析,研究者將觀察到的數(shù)值偏差置于上下文,并為其對下游模型屬性的影響形成一個(gè)上限(upper bound)。在研究者的案例研究中,他們能夠限制觀察到的數(shù)值偏差的影響,并發(fā)現(xiàn):「Flash Attention 引入的模型權(quán)重偏差大約為低精度訓(xùn)練的 1/2 至 1/5 倍?!?/p>


這項(xiàng)研究強(qiáng)調(diào)了開發(fā)一種原則性方法的重要性:「不僅要量化,而且要將訓(xùn)練優(yōu)化對數(shù)值偏差的影響置于上下文中。」通過構(gòu)建代理(proxies)來將數(shù)值偏差置于上下文中,旨在推斷通常難以衡量的下游模型效果(即訓(xùn)練不穩(wěn)定性)的可能性。


實(shí)驗(yàn)方法


研究者首先開發(fā)了一個(gè)微基準(zhǔn)來分離并研究 Flash Attention 引起的數(shù)值偏差。如圖 2 所示,他們通過對 Flash Attention 進(jìn)行數(shù)值上的重新實(shí)現(xiàn),以分析不同的數(shù)值精度,并在算法的每個(gè)步驟應(yīng)用潛在的優(yōu)化措施。


Flash Attention穩(wěn)定嗎?Meta、哈佛發(fā)現(xiàn)其模型權(quán)重偏差呈現(xiàn)數(shù)量級波動(dòng)-AI.x社區(qū)

圖 2: 微基準(zhǔn)設(shè)計(jì)摘要。


這是必要的,因?yàn)?Flash Attention 內(nèi)核目前僅支持 FP16 和 BF16 數(shù)值格式。該內(nèi)核還是 CUDA 代碼的包裝 API 調(diào)用,這使得擾動(dòng)算法以檢查數(shù)值偏差的影響變得具有挑戰(zhàn)性。


相比之下,他們的微基準(zhǔn)設(shè)計(jì)允許在算法內(nèi)部進(jìn)行精度輸入和修改。研究者將微基準(zhǔn)與原始的 Flash Attention kernel 進(jìn)行了驗(yàn)證。


他們進(jìn)一步設(shè)計(jì)了一種技術(shù),以比較模型執(zhí)行過程中每個(gè)步驟的 Attention 矩陣的輸出。并修改了模型代碼,每次調(diào)用注意力時(shí)都計(jì)算 Baseline Attention 和 Flash Attention,這允許對相同的輸入矩陣進(jìn)行精確的輸出矩陣比較。


為了將其置于上下文中,研究者還通過相同和獨(dú)立的訓(xùn)練運(yùn)行,使用 Max difference 和 Wasserstein Distance 度量來量化模型權(quán)重在整個(gè)訓(xùn)練過程中的差異。


對于訓(xùn)練實(shí)驗(yàn),研究者則使用一種將文本輸入轉(zhuǎn)換為圖像的生成式 AI workload(即文本到圖像模型)。他們使用 Shutterstock 數(shù)據(jù)集重新訓(xùn)練模型,并在一組英偉達(dá) 80GB A100 GPU 集群上運(yùn)行此實(shí)驗(yàn)。


通過微基準(zhǔn)量化數(shù)值偏差


研究者首先分析了 Flash Attention 在前向傳遞過程中的影響。他們利用微基準(zhǔn)測試,在隨機(jī)初始化查詢、鍵、值向量相同的情況下,檢驗(yàn)不同數(shù)值精度對 Attention 計(jì)算的輸出矩陣的影響。


正如圖 3 所示,當(dāng)研究者使用從 BF16 到 FP64 變化的不同數(shù)值格式時(shí),F(xiàn)lash Attention 和 Baseline Attention 之間的數(shù)值偏差隨著尾數(shù)位數(shù)的增加而減小。這表明數(shù)值差異是由于較少的尾數(shù)位數(shù)所固有的近似造成的。


Flash Attention穩(wěn)定嗎?Meta、哈佛發(fā)現(xiàn)其模型權(quán)重偏差呈現(xiàn)數(shù)量級波動(dòng)-AI.x社區(qū)

圖 3:數(shù)值格式對于 Flash Attention 的數(shù)值偏差所產(chǎn)生的效果。


之后,研究者為進(jìn)行標(biāo)準(zhǔn)比較,在 FP64 數(shù)值格式下的 Baseline Attention 設(shè)置了「黃金值」,然后將不同數(shù)值格式下的 Attention 輸出與該值進(jìn)行了比較(如圖 4 所示)。


Flash Attention穩(wěn)定嗎?Meta、哈佛發(fā)現(xiàn)其模型權(quán)重偏差呈現(xiàn)數(shù)量級波動(dòng)-AI.x社區(qū)

圖 4:FP64 下 Baseline Attention「黃金值」的比較。


結(jié)果表明,F(xiàn)lash Attention 的數(shù)值偏差大約是在 BF16 下 Baseline 的 10 倍。


為了進(jìn)一步分析這種觀察到的數(shù)值偏差,研究者保持 tile 大小和 SRAM 大小不變的同時(shí),掃描了矩陣的序列長度(如圖 5 所示)。


Flash Attention穩(wěn)定嗎?Meta、哈佛發(fā)現(xiàn)其模型權(quán)重偏差呈現(xiàn)數(shù)量級波動(dòng)-AI.x社區(qū)

圖 5: 序列長度對 Flash Attention 數(shù)值偏差的影響。


如圖所示,隨著序列長度的增加,無論是通過(a)最大差異上限的測量,還是通過(b)差異的平均值和標(biāo)準(zhǔn)差的測量,F(xiàn)lash Attention 和 Baseline Attention 之間的數(shù)值偏差都在增加。


除此之外,研究者還利用微基準(zhǔn)設(shè)計(jì)進(jìn)行不同優(yōu)化的實(shí)驗(yàn),以便更好地了解數(shù)值偏差的影響(如圖 6 所示)。


圖 6a 顯示了調(diào)換 block 維數(shù)的順序如何導(dǎo)致 Flash Attention 和 Baseline Attention 之間的數(shù)值差異增大。圖 6b 中的其他擾動(dòng),比如限制 tile 大小為正方形,不會對數(shù)值偏差產(chǎn)生影響。圖 6c 表明了 block/tile 大小越大,數(shù)值偏差越小。


Flash Attention穩(wěn)定嗎?Meta、哈佛發(fā)現(xiàn)其模型權(quán)重偏差呈現(xiàn)數(shù)量級波動(dòng)-AI.x社區(qū)

圖 6: 算法的改變及其對觀察到的數(shù)值偏差的影響。


通過權(quán)重差異來了解數(shù)值偏差


雖然在前向傳遞過程中,F(xiàn)lash Attention 可能會導(dǎo)致 Attention 輸出的數(shù)值偏差,但這項(xiàng)研究的最終目標(biāo)是確定這是否會在模型訓(xùn)練過程中產(chǎn)生任何影響,以研究它是否會導(dǎo)致訓(xùn)練的不穩(wěn)定性。


因此,研究者希望量化 Flash Attention 是否在訓(xùn)練過程中改變了模型,即上文觀察到的 Attention 輸出差異是否反映在訓(xùn)練過程中更新的模型權(quán)重中。


研究者利用兩個(gè)指標(biāo)來衡量使用 Baseline Attention 訓(xùn)練的模型與使用 Flash Attention 訓(xùn)練的模型之間的模型權(quán)重差異。首先計(jì)算最大差異,即找出權(quán)重矩陣之間差異的絕對值并取最大值,從而得出偏差的上限,如下所示:


Flash Attention穩(wěn)定嗎?Meta、哈佛發(fā)現(xiàn)其模型權(quán)重偏差呈現(xiàn)數(shù)量級波動(dòng)-AI.x社區(qū)


雖然最大差值提供了數(shù)值偏差的上限,但它沒有考慮到每個(gè)矩陣的分布情況。因此,研究者通過 Wasserstein Distance 來量化權(quán)重差異,這是衡量張量之間相似性的常用度量。雖然在計(jì)算上稍顯復(fù)雜,但 Wasserstein Distance 包含了張量分布的形狀信息以衡量相似性。計(jì)算公式概述如下:


Flash Attention穩(wěn)定嗎?Meta、哈佛發(fā)現(xiàn)其模型權(quán)重偏差呈現(xiàn)數(shù)量級波動(dòng)-AI.x社區(qū)


數(shù)值越低,表明矩陣之間的相似度越高。


利用這兩個(gè)指標(biāo),研究者隨后量化了在整個(gè)訓(xùn)練過程中與 Baseline Attention 相比,F(xiàn)lash Attention 的模型權(quán)重是如何變化的:


Flash Attention穩(wěn)定嗎?Meta、哈佛發(fā)現(xiàn)其模型權(quán)重偏差呈現(xiàn)數(shù)量級波動(dòng)-AI.x社區(qū)


根據(jù) Wasserstein Distance 和 Max Difference 這兩個(gè)指標(biāo),在整個(gè)訓(xùn)練過程中,F(xiàn)lash Attention 的加入確實(shí)改變了模型權(quán)重,而且隨著訓(xùn)練的繼續(xù),這種差異只會越來越大,這表明了使用 Flash Attention 訓(xùn)練的模型與使用 Baseline Attention 訓(xùn)練的相同模型收斂到了不同的模型。


然而,訓(xùn)練是一個(gè)隨機(jī)過程,某些模型結(jié)構(gòu)的改變可能會在下游效應(yīng)和準(zhǔn)確性方面產(chǎn)生相似的結(jié)果。即使使用 Flash Attention 和 Baseline Attention 訓(xùn)練的模型權(quán)重不同,這也是值得關(guān)注的。


完全訓(xùn)練模型并評估準(zhǔn)確性是一項(xiàng)成本昂貴且資源密集的任務(wù),特別是對于訓(xùn)練需要數(shù)月的大模型來說。


研究者通過配置一個(gè) proxy 來探尋:


(a) 這些權(quán)重變化的意義有多大?

(b) 能否將其與其他廣泛采用的訓(xùn)練優(yōu)化中的標(biāo)準(zhǔn)權(quán)重變化聯(lián)系起來?


為了實(shí)現(xiàn)這一目標(biāo),研究者設(shè)計(jì)了一系列實(shí)驗(yàn)來比較在不同場景下,訓(xùn)練過程中的權(quán)重差異是如何變化的。


除了對比使用 Flash Attention 和 Baseline Attention 的訓(xùn)練過程外,他們還量化了在訓(xùn)練開始時(shí)權(quán)重被初始化為不同隨機(jī)值的相同訓(xùn)練過程中的權(quán)重差異。這提供了一個(gè)界限,因?yàn)殡S機(jī)權(quán)重初始化是一種常用的技術(shù),并且通常會產(chǎn)生等效的結(jié)果。


此外,研究者還測量了使用不同精度訓(xùn)練的模型權(quán)重的變化。數(shù)值精度(即 FP16 與 FP32)有可能導(dǎo)致下游變化,這作為確定了 Flash Attention 權(quán)重重要性的一個(gè)上限。


如圖 8 所示,可以發(fā)現(xiàn),使用 Flash Attention 的模型權(quán)重偏差變化率與不同模型初始化的權(quán)重偏差變化率相當(dāng)或更?。ㄗ⒁饧t色和藍(lán)色曲線的斜率)。


此外,使用 FP16 與 FP32 時(shí)的權(quán)重變化率比不同模型初始化時(shí)的權(quán)重變化率更高,變化也更大。


這些結(jié)果提供了一個(gè) proxy,并表明:「雖然 Flash Attention 會出現(xiàn)數(shù)值偏差,但它會被隨機(jī)模型初始化和低精度訓(xùn)練所限制。而且所引入的模型權(quán)重偏差大約是低精度訓(xùn)練時(shí)的 1/2 至 1/5 倍?!?/p>


Flash Attention穩(wěn)定嗎?Meta、哈佛發(fā)現(xiàn)其模型權(quán)重偏差呈現(xiàn)數(shù)量級波動(dòng)-AI.x社區(qū)

圖 8: 使用 Wasserstein Distance metric 測量的訓(xùn)練過程中的相對權(quán)重差異。


更多研究細(xì)節(jié),可參考原論文。


本文轉(zhuǎn)自 機(jī)器之心 ,作者:機(jī)器之心


原文鏈接:??https://mp.weixin.qq.com/s/sG3JaZR1isZApWP6ZkYe6Q??

標(biāo)簽
收藏
回復(fù)
舉報(bào)
回復(fù)
相關(guān)推薦