清華朱軍團(tuán)隊(duì)新作:使用4位整數(shù)訓(xùn)練Transformer,比FP16快2.2倍,提速35.1%,加速AGI到來(lái)!
將激活、權(quán)重和梯度量化為4位,有望加速神經(jīng)網(wǎng)絡(luò)訓(xùn)練。
然而,現(xiàn)有的4位訓(xùn)練方法需要自定義數(shù)字格式,而現(xiàn)代硬件不支持這種格式。
最近,清華朱軍團(tuán)隊(duì)提出了一種使用INT4算法實(shí)現(xiàn)所有矩陣乘法的Transformer訓(xùn)練方法。
使用超低INT4精度進(jìn)行訓(xùn)練,是非常具有挑戰(zhàn)性的。為了實(shí)現(xiàn)這一目標(biāo),研究者仔細(xì)分析了Transformer中激活和梯度的具體結(jié)構(gòu),為它們提出專用的量化器。
對(duì)于前向傳播,研究者確定了異常值的挑戰(zhàn),并提出了Hadamard量化器來(lái)抑制異常值。
對(duì)于后向傳播,他們通過(guò)提出位分割,來(lái)利用梯度的結(jié)構(gòu)稀疏性,并利用分?jǐn)?shù)采樣技術(shù)來(lái)準(zhǔn)確量化梯度。
這種新的算法,在自然語(yǔ)言理解、機(jī)器翻譯和圖像分類等廣泛任務(wù)上,都實(shí)現(xiàn)了具有競(jìng)爭(zhēng)力的準(zhǔn)確性。
原型線性算子運(yùn)算速度比FP16同類算子快2.2倍,訓(xùn)練速度提高了35.1%。
圖片
論文地址:https://arxiv.org/abs/2306.11987
代碼地址:https://github.com/xijiu9/Train_Transformers_with_INT4
全新的INT 4訓(xùn)練算法
訓(xùn)練神經(jīng)網(wǎng)絡(luò)對(duì)計(jì)算的要求很高。使用低精度算術(shù)進(jìn)行訓(xùn)練(完全量化訓(xùn)練/FQT)有望提高計(jì)算和內(nèi)存效率。
FQT方法在原來(lái)的全精度計(jì)算圖中添加了一些量化器和反量化器,并用消耗更小的低精度浮點(diǎn)運(yùn)算,代替了消耗更高的浮點(diǎn)運(yùn)算。
FQT的研究旨在降低訓(xùn)練數(shù)值精度,而不犧牲太多的收斂速度或精度。
所需的數(shù)值精度已從FP16降低到FP8、INT32+INT8和INT8+INT5。
FP8訓(xùn)練是在帶有Transformer引擎的Nvidia H100 GPU中實(shí)現(xiàn)的,加速了大規(guī)模Transformer的訓(xùn)練。最近的訓(xùn)練數(shù)值精度,已經(jīng)降到了4位。
然而,這些4位訓(xùn)練方法不能直接用于加速,因?yàn)樗鼈冃枰远x數(shù)字格式,而現(xiàn)代硬件不支持這些格式。
首先,前向傳播中的不可微量化器,會(huì)使損失情況變得崎嶇不平,基于梯度的優(yōu)化器很容易陷入局部最優(yōu)。
其次,梯度僅僅以低精度近似計(jì)算。這種不精確的梯度會(huì)減慢訓(xùn)練過(guò)程,甚至導(dǎo)致訓(xùn)練不穩(wěn)定或發(fā)散。
而在這項(xiàng)工作中,研究者為Transformer提出了一種新穎的INT4訓(xùn)練算法。
圖片
訓(xùn)練Transformer的所有高消耗的線性運(yùn)算,都可以寫在矩陣乘法(MM)的形式中。
這種MM形式,可以讓我們?cè)O(shè)計(jì)更靈活的量化器,通過(guò)利用Transformer中激活、權(quán)重和梯度的特定結(jié)構(gòu),就可以更好地近似于FP32矩陣乘法。
隨機(jī)數(shù)值線性代數(shù) (RandNLA) 領(lǐng)域的進(jìn)步,被這種量化器充分利用。
對(duì)于前向傳播,研究者發(fā)現(xiàn),激活中的異常值是精度下降的主要原因。
為了抑制異常值,他們提出了Hadamard量化器,它會(huì)對(duì)激活矩陣的變換版本進(jìn)行量化。這種變換是塊對(duì)角Hadamard矩陣,它將離群值中攜帶的信息傳播到矩陣的鄰近條目,從而縮小了離群值的數(shù)值范圍。
對(duì)于后向傳播,他們利用了激活梯度的結(jié)構(gòu)稀疏性。研究者發(fā)現(xiàn),一些token的梯度非常大。同時(shí),其余大多數(shù)token的梯度非常均勻,甚至比較大梯度的量化殘差更均勻。
圖片
因此,與其計(jì)算所有梯度,不如節(jié)省計(jì)算較大梯度殘差的計(jì)算資源。
為了利用這種稀疏性,研究者提出了位分割,將每個(gè)token的梯度分割為高4位和低4位。
然后,通過(guò)杠桿分?jǐn)?shù)采樣(leverage score sampling)來(lái)選擇信息最豐富的梯度,這是RandNLA的一種重要采樣技術(shù)。
圖片
結(jié)合前向和后向傳播的量化技術(shù),研究者提出了一種使用INT4MM進(jìn)行Transformer中所有線性運(yùn)算的算法, 并且評(píng)估了在各種任務(wù)上訓(xùn)練Transformer的算法,包括自然語(yǔ)言理解、問(wèn)答、機(jī)器翻譯和圖像分類。
與現(xiàn)有的4位訓(xùn)練算法相比,他們的算法實(shí)現(xiàn)了有競(jìng)爭(zhēng)力的或更高的精度。
此外,這種算法與GPU等當(dāng)代硬件兼容,因?yàn)樗恍枰狥P4或?qū)?shù)格式等自定義的數(shù)字格式。
這種原型量化+INT4 MM算子實(shí)現(xiàn),速度比FP16MM基線快2.2倍,并且將訓(xùn)練速度提高了35.1%。
相關(guān)工作
完全量化訓(xùn)練
完全量化訓(xùn)練 (FQT) 方法通過(guò)將激活、權(quán)重和梯度量化為低精度來(lái)加速訓(xùn)練,因此訓(xùn)練期間的線性和非線性算子可以用低精度算術(shù)來(lái)實(shí)現(xiàn)。
FQT的研究設(shè)計(jì)了新穎的數(shù)值格式和量化算法,可以更好地逼近全精度張量。
目前的研究前沿是4位FQT。由于梯度的數(shù)值范圍很大以及從頭開始訓(xùn)練量化網(wǎng)絡(luò)的優(yōu)化問(wèn)題,F(xiàn)QT具有挑戰(zhàn)性。
由于這些挑戰(zhàn),現(xiàn)有的4位FQT 算法在某些任務(wù)上的精度仍然下降了1-2.5%,并且無(wú)法支持當(dāng)代硬件。
圖片
其他有效的訓(xùn)練方法
混合專家在不增加訓(xùn)練預(yù)算的情況下提高了模型容量。
結(jié)構(gòu)性dropout利用計(jì)算有效的方法來(lái)正則化模型。高效的注意力降低了計(jì)算注意力的二次時(shí)間復(fù)雜度。
分布式訓(xùn)練系統(tǒng)通過(guò)利用更多的計(jì)算資源,減少了訓(xùn)練時(shí)間。
研究者降低數(shù)值精度的工作與這些方向具有正交性。
圖片
前向傳播
神經(jīng)網(wǎng)絡(luò)訓(xùn)練是一個(gè)迭代優(yōu)化過(guò)程,通過(guò)前向和后向傳播計(jì)算隨機(jī)梯度。
研究團(tuán)隊(duì)使用4位整數(shù)(INT4)算法加速前向和后向傳播。
正向傳播能以線性和非線性(GeLU, normalization, softmax等)算子的組合來(lái)實(shí)現(xiàn)。
在我們的訓(xùn)練過(guò)程中,我們用INT4算術(shù)加速所有線性運(yùn)算符,并將所有計(jì)算量較小的非線性運(yùn)算符保留在16位浮 點(diǎn)(FP16)格式中。
Transformer中的所有線性運(yùn)算都可以寫成矩陣乘法(MM)的形式。
為了便于表述,本文考慮以下簡(jiǎn)單矩陣乘法的加速:
圖片
這種MM的最主要用例是全連接層。
考慮一個(gè)輸入形狀為(批量大小S,序列長(zhǎng)度T,維度D)的Transformer。
全連接層可以表述成上邊的公式,其中X是N = STtoken的激活,W是權(quán)重矩陣。
對(duì)于注意力層,可能需要批量矩陣乘法(BMMS)。
我們提出的技術(shù)可以應(yīng)用于BMMS。
學(xué)習(xí)步長(zhǎng)量化(Learned Step Quantization)
為了加速訓(xùn)練,必須使用整數(shù)運(yùn)算來(lái)計(jì)算前向傳播。
研究人員為此目的,利用學(xué)習(xí)步長(zhǎng)量化器(LSQ)。
LSQ是靜態(tài)量化,他的量化尺度不依賴于輸入的方法,因此比動(dòng)態(tài)方法消耗更小,量化方法,需要在每次迭代時(shí)動(dòng)態(tài)計(jì)算量化尺度。
激活異常值
簡(jiǎn)單地將LSQ應(yīng)用到具有4位激活/權(quán)重的FQT會(huì)導(dǎo)致精度下降,因?yàn)闀?huì)激活異常值。
圖片
如上圖所示,激活有一些離群值條目,它們是其規(guī)模比其他條目大得多。
不幸的是,Transformers傾向于將信息存儲(chǔ)在這些異常值中,而且這樣的截?cái)鄷?huì)嚴(yán)重?fù)p害準(zhǔn)確性。
當(dāng)訓(xùn)練任務(wù)是在一些新的下游任務(wù)上微調(diào)預(yù)訓(xùn)練模型時(shí),異常值問(wèn)題尤為明顯。
因?yàn)轭A(yù)訓(xùn)練模型比隨機(jī)初始化包含更多的異常值 。
Hadamard量化
我們提出了Hadamard量化(HQ)來(lái)解決異常值問(wèn)題。
其主要思想是將另一個(gè)具有較少異常值的線性空間中的矩陣進(jìn)行量化。
激活矩陣中的異常值形成了一個(gè)特征結(jié)構(gòu)(feature-wise structure)。
他們通常集中在幾個(gè)維度上,也就是說(shuō)X中只有幾列顯著大于其他列。
哈達(dá)瑪變換(Hardamand transform)是一個(gè)線性變換,它可以將異常值分?jǐn)偟狡渌麠l目中。
后向傳播
現(xiàn)在我們考慮使用INT4操作來(lái)加速線性層的后向傳播。
我們將在本節(jié)中討論激活梯度/權(quán)重梯度的計(jì)算。
梯度的結(jié)構(gòu)稀疏性
我們注意到,在訓(xùn)練過(guò)程中梯度矩陣往往非常稀疏。
而且稀疏性具有這樣的結(jié)構(gòu):
的幾行(比如tokens)具有較大的條目,而大多數(shù)其他行卻接近全零向量。
圖片
這種結(jié)構(gòu)稀疏性源于現(xiàn)代神經(jīng)網(wǎng)絡(luò)的嚴(yán)重過(guò)度參數(shù)化。
幾乎在整個(gè)訓(xùn)練過(guò)程中,網(wǎng)絡(luò)都以超參數(shù)化方案運(yùn)行,除了一些困難的例子之外,它可以很好地適應(yīng)大多數(shù)訓(xùn)練數(shù)據(jù)。
因此,對(duì)于擬合良好的數(shù)據(jù)點(diǎn),(激活)梯度將接近于零。
研究人員發(fā)現(xiàn)對(duì)于預(yù)訓(xùn)練任務(wù),例如,經(jīng)過(guò)幾個(gè)訓(xùn)練周期后,結(jié)構(gòu)稀疏性很快就會(huì)出現(xiàn)。
對(duì)于微調(diào)任務(wù),梯度整個(gè)訓(xùn)練過(guò)程中始終是稀疏的。
位分割(Bit Splitting)和杠桿分?jǐn)?shù)采樣(Leverage Score Sampling)
如何設(shè)計(jì)梯度量化器,以利用結(jié)構(gòu)稀疏性在反向傳播期間準(zhǔn)確計(jì)算MM呢?
高級(jí)的思路是:梯度的許多行都是如此小,對(duì)參數(shù)梯度影響很小,但浪費(fèi)了大量的計(jì)算量。
另一方面,大行無(wú)法用INT4精確表示。
我們放棄掉一些小行并使用節(jié)省下來(lái)的計(jì)算能力來(lái)更準(zhǔn)確地表示大行。
實(shí)驗(yàn)
研究人員在包括語(yǔ)言模型在內(nèi)的各種任務(wù)上評(píng)估我們的INT4訓(xùn)練算法微調(diào)、機(jī)器翻譯和圖像分類。
研究人員用CUDA和cutlass執(zhí)行了他們提出的HQ-MM和LSS-MM算法。
研究人員用INT4實(shí)現(xiàn)替換所有浮點(diǎn)線性運(yùn)算符,但沒(méi)有簡(jiǎn)單地使用LSQ來(lái)嵌入層,并保持最后一個(gè)分類器層的精度。
最后研究人員對(duì)所有評(píng)估的模型采用了默認(rèn)架構(gòu)、優(yōu)化器、調(diào)度器和超參數(shù)。
收斂模型精度
研究人員在下表中比較了收斂模型在各種任務(wù)上的準(zhǔn)確性。
圖片
作為對(duì)照的方法包括全精度訓(xùn)練(FP)、INT8訓(xùn)練(INT8)、FP4訓(xùn)練(「超低」),使用LSQ進(jìn)行激活和權(quán)重(LSQ+LUQ)的4 位對(duì)數(shù)量化,以及我們這種利用HQ進(jìn)行前向傳播,利用LSS進(jìn)行反向傳播(HQ+LSS)的算法。
「超低」沒(méi)有公開的實(shí)現(xiàn),因此我們僅列出了它在機(jī)器上的原始論文中的性能翻譯任務(wù)。
除了大型機(jī)器翻譯任務(wù)和大型視覺(jué)Transformer任務(wù)之外,我們將每次運(yùn)行重復(fù)三次,并將標(biāo)準(zhǔn)差報(bào)告為表中的下標(biāo)。
研究人員沒(méi)有進(jìn)行任何類型的知識(shí)蒸餾或數(shù)據(jù)增強(qiáng)。
消融實(shí)驗(yàn)
研究人員進(jìn)行的消融實(shí)驗(yàn)?zāi)康氖钦故厩跋蚝秃笙蚍椒ǖ挠行浴?/span>
研究不同量化器的前向傳播的有效性,我們將后向傳播留在FP16中。
結(jié)果如下圖所示。
圖片
計(jì)算和內(nèi)存效率
最后,研究人員通過(guò)評(píng)估他們的原型實(shí)現(xiàn),展示了他們的方法加速神經(jīng)網(wǎng)絡(luò)訓(xùn)練的潛力。
而且他們的實(shí)施還沒(méi)有完全優(yōu)化。
研究人員也沒(méi)有將線性算子與非線性和歸一化進(jìn)行融合。
因此,結(jié)果不能完全反映INT4訓(xùn)練算法的潛力。
完全優(yōu)化的實(shí)施需要大量工程,超出了我們論文的討論范圍。
結(jié)論
研究人員提出了一種對(duì)硬件很友好的Transformer INT4的訓(xùn)練方法。
通過(guò)分析Transformer中MM的屬性,研究人員提出了HQ和LSS方法來(lái)量化激活和梯度,同時(shí)保持準(zhǔn)確性。
在幾個(gè)重要任務(wù)上,我們的方法與現(xiàn)有的INT4方法表現(xiàn)相當(dāng),甚至更好。
研究人員的這些工作可能會(huì)擴(kuò)展到除了Transformers之外的其他MM架構(gòu)中,例如 MLP-Mixer、圖神經(jīng)網(wǎng)絡(luò)和循環(huán)神經(jīng)網(wǎng)絡(luò)網(wǎng)絡(luò)。
這是他們未來(lái)的研究方向。
更廣泛的影響:研究人員的算法可以提高效率并減少訓(xùn)練神經(jīng)網(wǎng)絡(luò)的能源消耗,這有助于減少深度學(xué)習(xí)造成的碳排放。
但是,高效的訓(xùn)練算法還可能促進(jìn)那些,對(duì)于人來(lái)安全存在隱患的大語(yǔ)言模型和惡意人工智能應(yīng)用程序的開發(fā)。
比如,會(huì)被用于虛假內(nèi)容生成的相關(guān)模型和應(yīng)用。
限制:這項(xiàng)工作的主要限制是它只能加速具有較大規(guī)模的矩陣乘法(線性層)的大模型,但不能加速卷積層。
而且,所提出的方法還不能很好地適用于OPT-175B等超大模型。
據(jù)我們所知,即使是INT8訓(xùn)練對(duì)于這些超大型模型來(lái)說(shuō)仍然是尚待解決的問(wèn)題。