「乘法變加法」!MIT清華校友全新方法優(yōu)化Transformer:Addition is All You Need
LLM能耗的瘋狂增長(zhǎng),甚至已經(jīng)引起了聯(lián)合國(guó)的注意,成為了不容小覷的能源消耗者。
據(jù)統(tǒng)計(jì),2023年初ChatGPT服務(wù)的平均用電量為每天564兆瓦時(shí),相當(dāng)于18000個(gè)美國(guó)家庭每天的總用電量。
谷歌的情況更加嚴(yán)峻。最壞的情況下,谷歌AI服務(wù)消耗的電力可能和一整個(gè)愛(ài)爾蘭相當(dāng),約為每年29.3 TWh。
要在提升推理速度的同時(shí)降低大模型的能耗,減少神經(jīng)網(wǎng)絡(luò)所需的計(jì)算量才是關(guān)鍵。
而LLM等大規(guī)模神經(jīng)網(wǎng)絡(luò),大部分計(jì)算量正是消耗在浮點(diǎn)級(jí)精度的矩陣乘法上。
從線性注意力機(jī)制到量化,大多數(shù)Transformer的優(yōu)化都離不開對(duì)于乘法效率的大幅提高。要么減少運(yùn)算操作次數(shù),要么減少操作數(shù)的位數(shù)。
但如果從乘法運(yùn)算這個(gè)更加底層的邏輯出發(fā),兩位華人研究者提出,可以用一個(gè)整數(shù)加法器以高精度近似進(jìn)行浮點(diǎn)數(shù)乘法運(yùn)算,即L-Mul乘法算法。
論文地址:https://arxiv.org/abs/2410.00907
相比量化過(guò)程中的FP8乘法,L-Mul能達(dá)到更高的精度,而且運(yùn)算量顯著減少。
實(shí)驗(yàn)結(jié)果顯示,在張量處理硬件中應(yīng)用L-Mul操作能將逐元素浮點(diǎn)張量乘法的能量成本降低95%,點(diǎn)積的能量成本降低80%。
此外,L-Mul可以直接集成到各個(gè)級(jí)別的現(xiàn)有模型中,無(wú)需額外訓(xùn)練,甚至能無(wú)損替換注意力機(jī)制中所有的矩陣、元素級(jí)別的浮點(diǎn)數(shù)乘法。
整體而言,L-Mul方法專注于提高對(duì)張量進(jìn)行算術(shù)運(yùn)算的效率——這與當(dāng)前在I/O和控制優(yōu)化方面的研究是相互獨(dú)立但又相輔相成的。
由此作者認(rèn)為,真正高能效、高計(jì)算效率的人工智能計(jì)算將從I/O、控制流,和算術(shù)運(yùn)算的全面優(yōu)化整合中產(chǎn)生。
論文簡(jiǎn)介
大多數(shù)機(jī)器學(xué)習(xí)模型,包括神經(jīng)網(wǎng)絡(luò),都使用浮點(diǎn)張量來(lái)表示它們的輸入、輸出和可訓(xùn)練參數(shù)。
其中,典型的選擇是32位和16位浮點(diǎn)張量,即fp32和fp16。
在現(xiàn)代計(jì)算硬件中,浮點(diǎn)數(shù)之間的乘法比加法運(yùn)算消耗更多的能量,浮點(diǎn)數(shù)運(yùn)算也顯然比整數(shù)更加昂貴。
用n代表數(shù)字位數(shù),那么整數(shù)加法的計(jì)算復(fù)雜度僅有O(n);而對(duì)于指數(shù)部分有e位、尾數(shù)部分有m位的浮點(diǎn)數(shù),乘法運(yùn)算則需要O(e)復(fù)雜度的加法加上O(m^2)復(fù)雜度的乘法。
如表1所示,元素級(jí)別的運(yùn)算上,fp32乘法和int32加法已經(jīng)差距懸殊,能量高出37倍;如果是張量級(jí)別的運(yùn)算,那更是相差甚遠(yuǎn)。
比如下面兩種常用的運(yùn)算:逐元素乘法Y_1和點(diǎn)積Y_2。
計(jì)算Y_1時(shí),如果A和X都是fp32張量,相比int32矩陣的加法所消耗的能量也會(huì)高出37倍。
同樣,計(jì)算Y_2時(shí)涉及m×n×k次的浮點(diǎn)乘法和加法,兩個(gè)數(shù)字的每次乘加運(yùn)算都會(huì)消耗0.9+3.7=4.6(pJ)能量。
如果替換為int32,那么每次運(yùn)算的能量成本就變?yōu)?.1+0.9=1.0 pJ,僅為原始成本的21.7%。
類似地,如果原始精度為fp16,替換為int16后也能達(dá)到1?(0.05+0.4)/(1.1+0.4)=70%的效率提升。
線性復(fù)雜度乘法(L-MUL)
那么,對(duì)于n位的浮點(diǎn)數(shù),到底要如何用整數(shù)加法近似計(jì)算浮點(diǎn)數(shù)乘法,實(shí)現(xiàn)O(n)復(fù)雜度?
考慮兩個(gè)浮點(diǎn)數(shù)x和y,它們的指數(shù)和小數(shù)部分的位數(shù)分別為x_e、y_e和x_m、y_m。
傳統(tǒng)的浮點(diǎn)乘法可以表示為:
再加上一個(gè)異或操作(⊕)來(lái)決定結(jié)果的符號(hào)為正或?yàn)樨?fù)。
其中,尾數(shù)部分的乘法操作是提升效率的瓶頸,復(fù)雜度為O(m^2)。
L-Mul所做的,就是移除這個(gè)操作,引入了一種新的乘法算法,以O(shè)(m)的計(jì)算復(fù)雜度處理尾數(shù):
對(duì)比上面的公式可以發(fā)現(xiàn),我們僅僅是將x_m · y_m替換為2^{-l?(m)},其中l(wèi)(m)是一個(gè)簡(jiǎn)單的分段函數(shù)。
雖然等式(1)包含4個(gè)加法操作,但浮點(diǎn)數(shù)的位格式設(shè)計(jì)能幫助我們用一個(gè)加法器實(shí)現(xiàn)L-Mul算法。
浮點(diǎn)格式隱式處理1+x_m,所以不必計(jì)算(1+...)的值;整數(shù)加法操作還會(huì)自動(dòng)將尾數(shù)進(jìn)位發(fā)送到指數(shù),這與傳統(tǒng)浮點(diǎn)乘法器中的舍入過(guò)程不同。
在傳統(tǒng)方法中,小數(shù)部分需要手動(dòng)舍入為1.x,并且向指數(shù)部分添加進(jìn)位需要作為獨(dú)立步驟進(jìn)行;而根據(jù)L-Mul中的分段函數(shù)l(m),如果尾數(shù)和大于2,進(jìn)位會(huì)自動(dòng)添加到指數(shù)。
因此,通過(guò)跳過(guò)尾數(shù)乘法和舍入操作,L-Mul算法比傳統(tǒng)浮點(diǎn)乘法更高效。
算法的具體實(shí)現(xiàn)過(guò)程如圖2所示,最佳實(shí)現(xiàn)是在硬件級(jí)別,因此作者添加了在英偉達(dá)GPU上模擬該過(guò)程的內(nèi)聯(lián)PTX匯編代碼。
常規(guī)浮點(diǎn)乘法和L-Mul算法的復(fù)雜度比較;在匯編代碼中,$1和$2是存儲(chǔ)輸入的fp32寄存器,$0是用于輸出的fp32寄存器。s1、s2、r0、r1、r2是存儲(chǔ)中間結(jié)果的無(wú)符號(hào)int32寄存器
L-Mul結(jié)果的構(gòu)造可以用以下等式表示,其中所有位級(jí)計(jì)算都作為無(wú)符號(hào)整數(shù)之間的操作執(zhí)行:
在此基礎(chǔ)上,作者進(jìn)一步用L-Mul實(shí)現(xiàn)了注意力機(jī)制。
在Transformer模型中,注意力機(jī)制由于其處理輸入上下文C的O(|C|^2)復(fù)雜度而具有高計(jì)算成本。
但如果使用L-Mul,無(wú)需額外訓(xùn)練,就可以用最小的性能損失替代復(fù)雜的張量乘法,實(shí)現(xiàn)更高效的注意力機(jī)制,如下所示:
其中L-matmul(Q, K^T)表示矩陣乘法操作,其中所有常規(guī)浮點(diǎn)乘法都被替換為整數(shù)加法,用L-Mul實(shí)現(xiàn),顯著降低了計(jì)算資源消耗。
精度和成本分析
精度分析的目標(biāo)是確定L-Mul近似計(jì)算的精度,相當(dāng)于將浮點(diǎn)數(shù)的小數(shù)部分舍入到多少位,并和具有2位或3位尾數(shù)的fp8(e5m2或e4m3)進(jìn)行比較。
考慮正浮點(diǎn)數(shù)x、y,并明確舍入后要保留的k位,可以寫成以下格式:
其中x_k、y_k是x_m、y_m的前k位,x_r、y_r是k位舍入后將被忽略的剩余位的值。x′、y′是保留尾數(shù)前k位并進(jìn)行舍入后的數(shù)值。
考慮x和y在全精度下有m位尾數(shù)。例如,F(xiàn)P16有10位尾數(shù),BF16包含7位。
乘法運(yùn)算Mul(x, y) = x · y的誤差及其期望值可以表示為:
與k位尾數(shù)的浮點(diǎn)乘法相比,k位尾數(shù)L-Mul的誤差為:
利用上述方程,可以計(jì)算k位L-Mul和浮點(diǎn)乘法之間精度差的期望值,具體來(lái)說(shuō):
當(dāng)x_m、y_m呈均勻分布時(shí),可以計(jì)算以下期望:
通過(guò)估計(jì)f1?(m,k)和f2?(k)并進(jìn)一步推斷E?[e^k_{l?m?u?}k] 和 E?[e^k_{m?u?l}]可以得知, 如果是在操作數(shù)均勻分布的情況下,L-Mul比f(wàn)p8_e5m2更精確;然而,預(yù)訓(xùn)練LLM的權(quán)重分布通常是存在偏差的。
這種近似計(jì)算究竟能否適用于當(dāng)前的LLM,還需要實(shí)驗(yàn)結(jié)果來(lái)證明。
基于五個(gè)流行大語(yǔ)言模型的組合權(quán)重分布,實(shí)驗(yàn)結(jié)果發(fā)現(xiàn),在實(shí)踐中,L-Mul可以在使用5位尾數(shù)的情況下實(shí)現(xiàn)超越fp8_e4m3的更高準(zhǔn)確度。
此外,結(jié)合門運(yùn)算的復(fù)雜度估算可以進(jìn)一步證實(shí),L-Mul比f(wàn)p8乘法更加高效且準(zhǔn)確。這一結(jié)果突顯了L-Mul在低精度計(jì)算中的潛在優(yōu)勢(shì)。
關(guān)于精度和成本分析的更詳細(xì)理論推導(dǎo)可見(jiàn)于論文2.3節(jié)以及附錄A。
LLM實(shí)驗(yàn)結(jié)果
要證明L-Mul的實(shí)際應(yīng)用價(jià)值,就需要在LLM的實(shí)際任務(wù)上運(yùn)行。
精度分析
論文選擇了各種基于Transformer的語(yǔ)言模型,包括Llama 3.1、Mistral、Gemma 2等,并在各種語(yǔ)言和視覺(jué)任務(wù)基準(zhǔn)上評(píng)估了L-Mul算法的數(shù)值精度。
對(duì)比全精度模型權(quán)重的運(yùn)行結(jié)果,可以證明,對(duì)基于Transformer的LLM而言,在注意力機(jī)制中用L-Mul替換標(biāo)準(zhǔn)乘法運(yùn)算可以達(dá)到幾乎無(wú)損的近似效果,可以在微調(diào)或免訓(xùn)練設(shè)置下替換Transformer層中的不同模塊。
圖3展示了選擇不同k值和l(k)值的均方誤差(mean square errors)結(jié)果,實(shí)驗(yàn)包含Llama 3.1和Gemma 2的兩個(gè)小模型,在GSM8k數(shù)據(jù)集上運(yùn)行。
在兩個(gè)模型中,使用3位尾數(shù)的L-Mul比f(wàn)p8_e5m2更精確,而使用4位尾數(shù)的L-Mul可以達(dá)到或近似于fp8_e4m3的誤差水平。
紅色表示平均誤差低于fp8_e4m3,下劃線表示誤差介于e4m3和e5m2之間
以上兩個(gè)模型的平均誤差如圖4所示。
前面的理論推導(dǎo)顯示,L-Mul在使用的計(jì)算資源少于fp8_e5m2時(shí),期望誤差可以低于fp8_e4m3,此處的實(shí)驗(yàn)結(jié)果正式了前面理論估計(jì)的正確性。
實(shí)驗(yàn)表明,在各種規(guī)模的LLM中,使用6位尾數(shù)FP操作數(shù)的L-Mul算法近似達(dá)到最低平均誤差,顯著優(yōu)于e5m2、e4m3兩種fp8格式。
此外,3位和4位尾數(shù)的L-Mul分別達(dá)到或超過(guò)了fp8_e5m2和fp8_e4m3的精度。
L-Mul與不同格式fp8浮點(diǎn)是進(jìn)行乘法運(yùn)算的誤差水平比較
基準(zhǔn)測(cè)試
本節(jié)的實(shí)驗(yàn)旨在證明,L-Mul可以在不損失性能的情況下替代注意力機(jī)制中的張量乘法,而使用fp8乘法則會(huì)降低推理精度。
這就意味著,L-Mul可以在降低注意力計(jì)算能耗80%的同時(shí)達(dá)到相同的推理性能。
對(duì)于文本任務(wù),表2展示了Llama和Mistral模型在各種自然語(yǔ)言基準(zhǔn)測(cè)試上的評(píng)估結(jié)果,包括MMLU、BBH、ARC-C等。
結(jié)果表明,L-Mul不僅顯著減少了計(jì)算資源,而且在絕大多數(shù)測(cè)試中(12/14)的得分高于fp8_e4m3。
與bf16推理相比,性能差距被降低到最低水平。在兩個(gè)模型中,bf16和L-Mul之間在常識(shí)、結(jié)構(gòu)化推理和語(yǔ)言理解方面的平均性能差異僅為0.07%。
值得注意的是,對(duì)于Mistral和Gemma2兩個(gè)模型,基于L-Mul的注意力機(jī)制與bf16基準(zhǔn)相比略微提高了平均性能,分別達(dá)到52.92%和47.01%。
Llama3.1使用L-Mul時(shí),準(zhǔn)確率略低于bf16,但仍高于fp8_e4m3和fp8_e5m2。
相反,將注意力計(jì)算中的張量四舍五入到fp8_e5m2會(huì)導(dǎo)致顯著的性能下降,盡管e5m2比L-Mul更復(fù)雜。
3個(gè)語(yǔ)言模型在GSM8k數(shù)據(jù)集上使用少樣本提示的運(yùn)行結(jié)果,包括L-Mul方法和3種精度bf16、fp8_e4m3、fp8_e5m2的對(duì)比
視覺(jué)-語(yǔ)言任務(wù)主要用Llava模型進(jìn)行了測(cè)試,結(jié)果如表4所示。
除了在TextVQA基準(zhǔn)上的準(zhǔn)確率差距略大,達(dá)到了0.5%,在POPE、VQAv2、Llava-Bench、VizWiz等其他基準(zhǔn)上,L-Mul達(dá)到了和bf16相似甚至更好的性能。
此外,誤差估計(jì)和消融實(shí)驗(yàn)(表5)可以進(jìn)一步表明,在無(wú)需額外訓(xùn)練的設(shè)置下,4位尾數(shù)的L-Mul可以達(dá)到與fp8_e4m3相當(dāng)?shù)臏?zhǔn)確性,而3位尾數(shù)的L-Mul優(yōu)于fp8_e5m2乘法。
微調(diào)
以上的實(shí)驗(yàn)結(jié)果,是直接將預(yù)訓(xùn)練LLM從標(biāo)準(zhǔn)注意力適配到新的基于L-Mul的注意力機(jī)制運(yùn)行的,沒(méi)有進(jìn)行額外訓(xùn)練。
進(jìn)一步的研究還表明,微調(diào)可以彌補(bǔ)L-Mul和標(biāo)準(zhǔn)乘法之間的性能差距。
本節(jié)的實(shí)驗(yàn)中,不僅在Gemma2的注意力機(jī)制層中實(shí)現(xiàn)L-Mul,而且對(duì)于模型中所有乘法運(yùn)算——包括線性變換中的矩陣乘法、元素級(jí)乘法以及注意力機(jī)制層內(nèi)的乘法,都使用L-Mul和fp8_e4m3進(jìn)行近似,之后在GSM8k數(shù)據(jù)集上對(duì)更新后的模型進(jìn)行微調(diào)。
將注意力機(jī)制、線性變換和逐元素乘積中的所有乘法運(yùn)算替換為3位尾數(shù)L-Mul的模型進(jìn)行微調(diào),其性能可與使用fp8_e4m3累積精度的標(biāo)準(zhǔn)模型微調(diào)相媲美。
值得注意的是,本實(shí)驗(yàn)中的L-Mul操作使用3位尾數(shù)(k=3),累加精度為fp8_e4m3,以探索極其高效的設(shè)置。
結(jié)果可以看出,在fp8精度下,微調(diào)后的fp8_e4m3 L-Mul模型達(dá)到了與標(biāo)準(zhǔn)微調(diào)fp8_e4m3模型相當(dāng)?shù)男阅堋?/span>
這表明,L-Mul可以在不影響微調(diào)模型性能的情況下提高訓(xùn)練效率。此外,也揭示了訓(xùn)練L-Mul原生LLM的潛質(zhì),用于更加精確、節(jié)能的模型托管。
微調(diào)后fp8和L-Mul模型在零樣本設(shè)置下的評(píng)估
作者介紹
Hongyin Luo
Hongyin Luo是MIT計(jì)算機(jī)科學(xué)與人工智能實(shí)驗(yàn)室(CSAIL)的研究科學(xué)家,在Jim Glass博士領(lǐng)導(dǎo)的口語(yǔ)語(yǔ)言系統(tǒng)(SLS)小組工作。
他于2016年在清華大學(xué)獲得學(xué)士學(xué)位,導(dǎo)師是NLP領(lǐng)域的大牛級(jí)人物:劉知遠(yuǎn)和孫茂松。
隨后于2022年在MIT EECS獲得博士學(xué)位,專注自然語(yǔ)言處理中的自訓(xùn)練研究。
他的研究重點(diǎn)是提高語(yǔ)言模型的效率、透明性和推理能力。最新研究結(jié)合了自然語(yǔ)言與不同的形式推理引擎,包括蘊(yùn)涵模型(entailment model)和程序解釋器。
他構(gòu)建了小型語(yǔ)言模型,以1/500的計(jì)算量表現(xiàn)優(yōu)于GPT3-175B,開發(fā)了處理搜索引擎噪聲的自我去噪語(yǔ)言模型,以及無(wú)需任務(wù)特定示例即可實(shí)現(xiàn)準(zhǔn)確推理的自然語(yǔ)言嵌入程序。