1-bit大模型還能再突破!新一代BitNet架構(gòu)啟用4位激活值
量化到1 bit的LLM還能再突破?
這次,他們對(duì)激活值下手了!
近日,BitNet系列的原班人馬推出了新一代架構(gòu):BitNet a4.8,為1 bit大模型啟用了4位激活值:
圖片
論文地址:https://arxiv.org/pdf/2411.04965
眾所周知,激活值量化通常是比較難辦的。
本次的BitNet a4.8采用混合量化和稀疏化策略,來減輕異常通道引入的量化誤差。
簡單來說就是,對(duì)注意力層和FFN層的輸入采用4位量化,同時(shí)用8位整數(shù)稀疏化中間狀態(tài)。
大量實(shí)驗(yàn)表明,BitNet a4.8在相同的訓(xùn)練成本下,實(shí)現(xiàn)了與前代BitNet b1.58相當(dāng)?shù)男阅埽瑫r(shí)因?yàn)榭梢猿缘?位(INT4/FP4)內(nèi)核的計(jì)算紅利,實(shí)現(xiàn)了更快的推理速度。
BitNet a4.8僅激活55%的參數(shù),并支持3 bit KV cache,進(jìn)一步提升了大規(guī)模LLM部署和推理的效率。
BitNet a4.8
圖片
模型架構(gòu)
模型的整體架構(gòu)如圖1所示,BitNet a4.8采用了與BitNet b1.58相同的布局。
作者使用BitLinear替換注意力(MHA)和前饋網(wǎng)絡(luò)(FFN)中的線性投影,以從頭開始學(xué)習(xí)1.58 bit權(quán)重。對(duì)于激活值,采用混合量化和稀疏化策略來減輕異常值維度引入的誤差。
圖片
圖2說明了模型大小為7B的BitNet b1.58中,每個(gè)模塊輸入的分布。
注意力層和FFN層的輸入通常類似高斯分布,而在FFN下采樣之前的激活值和注意力中的輸出投影中,發(fā)現(xiàn)了很多異常值通道和大量接近零的條目(全精度LLM也有類似觀察結(jié)果)。
圖片
如圖3所示,直接將低位量化應(yīng)用于這些中間狀態(tài)會(huì)引入很大的量化誤差。
因此,作者使用Q-Sparse的稀疏化方法,將這些中間狀態(tài)保持在8位(同時(shí)消除了計(jì)算瓶頸)。
對(duì)于自注意層的輸出投影,使用sparsify-then-quantize函數(shù):
兩個(gè)Q分別表示權(quán)重W和激活X的量化函數(shù),M是掩碼,根據(jù)激活X的絕對(duì)值取topK,⊙是元素乘法。
具體來說,權(quán)重量化和激活值量化函數(shù)可以表述為:
對(duì)于FFN,這里采用squared ReLU和門控線性單元(GLU)來進(jìn)一步提高激活的稀疏性:
根據(jù)初步實(shí)驗(yàn)的結(jié)果,使用squared ReLU時(shí),下采樣輸入的稀疏性超過了80%,且對(duì)性能的影響最小。
此外,作者還觀察到gate + squared ReLU的輸出也表現(xiàn)出高激活稀疏性(7B模型為67.5%)。通過首先計(jì)算gate projection,然后僅在非零通道上執(zhí)行up projection,可以進(jìn)一步減少推理的計(jì)算量。
相比之下,attention和FFN的輸入中包含的異常值特征要少得多,可以使用absmean函數(shù)將激活值量化為4位整數(shù):
模型訓(xùn)練
初始化
BitNet a4.8使用BitNet b1.58的權(quán)重開始訓(xùn)練,分為W1.58A8與W1.58A4兩階段。
第一階段使用8位激活和GLU + squared ReLU訓(xùn)練模型;第二階段采用上面介紹過的混合量化和稀疏化。
圖片
BitNet a4.8只需少量訓(xùn)練,即可快速適應(yīng)4bit位寬和稀疏激活,同時(shí)性能損失可以忽略不計(jì)。
梯度近似
作者使用直通估計(jì)器(STE)對(duì)BitNet a4.8進(jìn)行梯度逼近,使用混合精度訓(xùn)練來更新參數(shù)。
圖片
這里直接繞過了不可微函數(shù),包括反向傳播過程中的量化函數(shù)和topK稀疏函數(shù)。對(duì)于混合精度訓(xùn)練,保持全精度latent weight來累積參數(shù)更新。
模型量化
浮點(diǎn)量化提供了比基于整數(shù)的量化更寬的動(dòng)態(tài)范圍,這對(duì)于處理激活值的長尾分布至關(guān)重要。
研究人員將FFN下采樣層的輸入保留為8位整數(shù),其他激活值使用MinMax量化器量化為FP4:
公式中E和M分別表示指數(shù)和尾數(shù)部分的位寬。這里采用E2M1格式,因?yàn)樗膭?dòng)態(tài)范圍更大。
實(shí)驗(yàn)
本文將BitNet a4.8、BitNet b1.58,以及各種參數(shù)量大小的FP16精度LLaMA進(jìn)行了比較。
其中的1.58 bit模型,遵循BitNet b1.58的訓(xùn)練方案,采用了兩階段權(quán)重衰減和學(xué)習(xí)率調(diào)度。
圖片
所有模型都使用RedPajama數(shù)據(jù)集中的100B token進(jìn)行訓(xùn)練,以確保公平比較。
對(duì)于BitNet a4.8,作者首先使用95B token來訓(xùn)練8位激活值的模型。然后重用優(yōu)化器狀態(tài),并使用5B token進(jìn)行混合量化和稀疏化的訓(xùn)練。實(shí)驗(yàn)將topK設(shè)置為50%(attention的輸出投影位置)。
作者使用lm-evaluation-harness工具包,評(píng)估模型在一系列語言任務(wù)上的zero-shot準(zhǔn)確性,包括ARC-Easy(ARCe)、ARCChallenge(ARCc)、Hellaswag(HS)、Winogrande(WGe)和PIQA(PQ)。另外還測試了在C4數(shù)據(jù)集(測試集)上的困惑度。
主要結(jié)果
圖片
表1總結(jié)了BitNet a4.8、BitNet b1.58和FP16 LLaMA的詳細(xì)測試結(jié)果。
全精度(FP16)LLaMA和BitNet b1.58之間的性能差距,隨著模型大小的增長而縮小。對(duì)于7B模型,BitNet b1.58在語言模型困惑度和任務(wù)的平均準(zhǔn)確性方面與LLaMA相當(dāng)。
此外,相比于BitNet b1.58,BitNet a4.8的平均精度幾乎沒有損失。
圖片
表2展示了各種大小的BitNet a4.8、BitNet b1.58 和 FP16 LLaMA中每個(gè)模塊的詳細(xì)稀疏性(使用C4驗(yàn)證集上的非嵌入?yún)?shù)計(jì)算)。
值得注意的是,BitNet a4.8的稀疏性明顯高于BitNet b1.58和LLaMA。
比如在7B模型中,BitNet a4.8的整體稀疏性達(dá)到了44.5%,只有3.4B的活躍參數(shù)。down projection層的輸入顯示出特別高的稀疏性,且中間狀態(tài)分布以零為中心。
此外,gate projection的輸出非常稀疏,導(dǎo)致了up projection的高稀疏性(因?yàn)橹恍枰趶腉ate中選擇非零通道來執(zhí)行投影)。
具體來說,對(duì)于7B BitNet a4.8,Gate和up projection的稀疏率分別為67.5%和12.0%。
圖片
表3顯示了BitNet a4.8在3B和7B模型大小下,low-bit attention的詳細(xì)情況。模型使用4位KV或QKV頭,精度損失可忽略不計(jì),同時(shí)KV cache可以量化為3位整數(shù)。
low-bit attention對(duì)于高效的長序列建模至關(guān)重要,它減少了KV cache的內(nèi)存占用和IO,并加速了注意力計(jì)算。
在本文的實(shí)驗(yàn)中,作者采用RoPE后量化。使用absmax函數(shù)將QKV頭直接量化為無符號(hào)整數(shù),無需任何校準(zhǔn)數(shù)據(jù)集。
對(duì)于3 bit KV量化,研究人員將bos token的頭保留為4 bit,因?yàn)樗嗟漠惓V堤卣鳌?/p>
消融實(shí)驗(yàn)
圖片
圖4顯示了700M BitNet a4.8的訓(xùn)練損耗曲線,比較了使用完整的INT4/FP4量化,以及本文的混合量化和稀疏化。
完整的INT4量化會(huì)導(dǎo)致發(fā)散,而混合架構(gòu)在訓(xùn)練困惑度方面明顯優(yōu)于完整的FP4架構(gòu)。
使用RedPajama數(shù)據(jù)集中25B token,來進(jìn)行模型的第一階段訓(xùn)練,采用absmean和MinMax量化器分別進(jìn)行完整的INT4和FP4量化。
對(duì)于完整的INT4量化,由于其輸入具有更大的異常值,這里設(shè)置β = 2*mean(|X|)。
圖片
接下來為1.3B BitNet a4.8的down projection層輸入,設(shè)置不同的量化或激活函數(shù)。
所有模型都使用RedPajama數(shù)據(jù)集中的50B token進(jìn)行第一階段訓(xùn)練。為了確保公平比較,其他激活值都保留在8位。
圖5顯示了這些模型的訓(xùn)練損失曲線。Squared ReLU的訓(xùn)練困惑度比Swish略好,同時(shí)實(shí)現(xiàn)了更高的稀疏性。
此外,對(duì)down projection的輸入應(yīng)用FP4量化會(huì)導(dǎo)致性能顯著下降,而將INT4激活與STE一起使用會(huì)導(dǎo)致發(fā)散。
參考資料: