多虧Transformer,Mamba更強了!僅用1%計算量達(dá)新SOTA
Attention is all you need.
至少在矩陣這兒是。
Mamba架構(gòu)最新進(jìn)展:僅需1%計算量,新模型性能達(dá)SOTA。
能做到這一點,還多虧了Transformer。
圖片
通過將Transformer模型中的知識有效遷移到Mamba等替代架構(gòu)中,模型能在保持較低計算成本的同時,性能更好。
這就是由Mamba主創(chuàng)之一Albert Gu領(lǐng)銜的最新成果。
值得一提的是,這種方法還適用于Mamba以外的非Transformer架構(gòu)。
從Transformer到SSMs
Transformer由于依賴二次自注意力機制,所需計算量很大。
二次自注意力機制能讓模型在處理序列數(shù)據(jù)時有效捕捉序列內(nèi)部的長距離依賴關(guān)系,但是由于二次時間復(fù)雜度(如果輸入規(guī)模翻倍,模型計算所需時間增加4倍),導(dǎo)致處理長序列的計算成本很高。
為了解決這個問題,學(xué)界提出了很多新架構(gòu),比如Mamba、RWKV等,它們的微調(diào)和推理成本更低。
考慮到Transformer模型預(yù)訓(xùn)練已經(jīng)投入了大量計算資源,研究人員想到,為什么不能在此基礎(chǔ)上進(jìn)行提升?
所以在本項研究中,他們提出了一種蒸餾方法MOHAWK,利用Transformer預(yù)訓(xùn)練模型來訓(xùn)練SSMs模型。
其核心在于注意力機制、線性注意力、Mamba的結(jié)構(gòu)化掩碼注意力SMA等,都是跨輸入長度維度的序列轉(zhuǎn)換。因此它們都有各自的矩陣混合器,比如softmax。
圖片
通過將注意力和SSMs視為通過應(yīng)用不同類別的矩陣來混合不同token嵌入的序列變換,序列模型架構(gòu)可以分解為獨立序列混合和通道混合塊。
比如Transformer由注意力(序列混合器)和MLP(通道混合器)塊組成,使用這種分解可以蒸餾模型的每個元素。
具體蒸餾分為三個階段:
第一階段:矩陣對齊(Matrix Orientation)。對齊序列變換矩陣本身。
第二階段:隱藏狀態(tài)對齊(Hidden-State Alignment)。對齊網(wǎng)絡(luò)每個單獨層的隱藏狀態(tài)表示,且不犧牲預(yù)先學(xué)習(xí)的表示。
第三階段:權(quán)重轉(zhuǎn)移和知識蒸餾(Weight-Transfer and Knowledge Distillation)。通過一個端到端訓(xùn)練階段,將權(quán)重轉(zhuǎn)移,最終使用只有一小部分訓(xùn)練數(shù)據(jù)來蒸餾網(wǎng)絡(luò)的最終輸出。
利用這個方法來實際修改一個模型,比如Phi-Mamba。
圖片
它結(jié)合了Mamba-2和Phi-1.5。
通過MOHAWK方法,該模型從預(yù)訓(xùn)練的Transformer模型中學(xué)習(xí),同時作為狀態(tài)空間模型,它在處理長序列上比傳統(tǒng)Transformer架構(gòu)更高效。
該模型僅使用3B token進(jìn)行蒸餾,數(shù)據(jù)量為從頭訓(xùn)練模型的1%,但是性能達(dá)到開源非Transformer架構(gòu)中的SOTA。
圖片
實驗發(fā)現(xiàn),隱藏狀態(tài)對齊更好,可以提高后續(xù)階段的性能。
圖片
研究團(tuán)隊也發(fā)布了混合Phi-Mamba-1.5B,通過5B token蒸餾,模型與類似混合模型表現(xiàn)相當(dāng),但是注意力層只用了4層。
圖片
值得一提的是,這種蒸餾方法不止適用于Mamba。
圖片
該研究由CUM助理教授、Cartesia AI聯(lián)合創(chuàng)始人及首席科學(xué)家Albert Gu領(lǐng)銜。
去年,他和FlashAttention作者Tri Dao一起提出了Mamba,成為第一個真正實現(xiàn)匹配Transformer性能的線性時間序列模型。
論文地址:https://arxiv.org/abs/2408.10189