Mamba-2新架構出世一統(tǒng)江湖!普林斯頓CMU華人再出神作,性能狂飆8倍
年前,Mamba被頂會ICLR拒稿的消息曾引起軒然大波。
甚至有研究人員表示:如果這種工作都被拒了,那我們這些「小丑」要怎么辦?
這次,新一代的Mamba-2卷土重來、再戰(zhàn)頂會,順利拿下了ICML 2024!
仍是前作的兩位大佬(換了個順序),仍是熟悉的配方:
論文地址:https://arxiv.org/pdf/2405.21060
開源代碼和模型權重:https://github.com/state-spaces/mamba
不同的是,作者在更高的視角上,統(tǒng)一了狀態(tài)空間模型(SSM)和注意力機制(Attention),也就是文章標題所說的「Transformers are SSMs」。
——這下咱們都是一家人了,不用動不動就「打生打死」了。
性能方面,Mamba-2采用了新的算法(SSD),比前代提速2-8倍,對比FlashAttention-2也不遑多讓,在序列長度為2K時持平,之后便一路遙遙領先。
在Pile上使用300B token訓練出的Mamba-2-2.7B,性能優(yōu)于在同一數(shù)據(jù)集上訓練的Mamba-2.8B、Pythia-2.8B,甚至是更大的Pythia-6.9B。
從理論上整合了SSM和Transformer,同等性能下,模型更小,消耗更低,速度更快。
更重要的是,能夠利用GPU的硬件資源(矩陣乘法單元),以及針對Transformer的一系列優(yōu)化。
——Mamba-2大有一統(tǒng)江湖之勢。
1代Mamba,爆發(fā)式占領AI社區(qū)
事實上,關于1代Mamba的各種研究一直在爆發(fā)性地增長,arxiv已經被各種Mamba所占領,谷歌學術的引用量也達到了350多。
后續(xù)工作如雨后春筍一般冒出,包括視覺、基因組學、圖表等的直接應用,以及回憶能力、上下文學習能力、形式語言表達能力等方面的研究。
作者興奮地表示:「我們多年來一直在追求的高效序列模型研究路線,真正引起了機器學習社區(qū)的共鳴。」
唯一遺憾的是,Mamba遭到ICLR拒稿,所以關于Mamba到底有沒有前途這個事也就被打上了問號。
現(xiàn)在,問題解決了,不但論文被接收了,而且還證明了Transformer和Mamba其實是一家人——
「你說我不行?那Transformer到底行不行?」
值得注意的是,之前很火的Vision Mamba以及另一篇關于Mamba的研究也殺入了ICML 2024。
對于改進Mamba的初衷,作者表示,當前AI社區(qū)的大家都在努力解決Transformer的問題,盡管SSM的特性和效果都相當好,但卻跟社區(qū)的努力方向不一致。
這次的Mamba-2可以把針對Transformer的優(yōu)化都用上,不浪費大家的努力。
新架構一統(tǒng)江湖
在介紹新架構之前,小編先幫大家簡單理一下背景。
狀態(tài)空間模型SSM之所以如此令人著迷,是因為它們顯得如此之「基礎」。
比如,它們與序列模型的許多主要范式,都有著豐富的聯(lián)系。
它們似乎抓住了連續(xù)、卷積和循環(huán)序列模型的本質,把所有這些元素都包含在了一個簡單優(yōu)雅的模型里。
不過,另一個主要的序列模型范式——注意力機制的變體,卻更加無所不在。
然而SSM卻總感覺和Attention是脫節(jié)的。
在這里,研究者們發(fā)出了「靈魂拷問」——SSM和注意力之間的概念聯(lián)系是什么?有無可能將二者結合起來?
那就要從公式說起了。
狀態(tài)空間模型SSM可以這么定義:
這是個微分方程,利用導數(shù)定義進行代換:
可以得到SSM的解:
這個東西就跟RNN一毛一樣了:
所以可以認為SSM等價于RNN。
如果將RNN的遞歸結構展開,那么它又可以等價于卷積:
此時,便可以利用卷積的特性進行并行訓練,而進行推理時又可以享受RNN帶來的O(1)復雜度。
當然,好事不能讓你全占了,這種結構仍然逃不過固有的梯度爆炸(或消失),以及難以勝任選擇性復制和上下文學習等任務。
為此,Mamba在SSM的基礎上加入了能夠隨輸入變化的參數(shù)。
不過這樣做的代價是失去了固定kernel帶來的并行性,所以作者另辟蹊徑,使用前綴和的方式來加速RNN的訓練。
不過,從計算角度來看,Mamba在硬件效率上仍然遠不如注意力機制。
原因在于,目前常用的GPU、TPU等加速器,是為矩陣乘法進行過專門優(yōu)化的。
1代Mamba吃不到硬件矩陣運算單元的紅利,盡管推理時有速度優(yōu)勢,但訓練時問題就大了。
所以作者就想,我能不能把Mamba的計算重構成矩陣乘法呢?
于是,新一代的Mamba誕生了。
結構化狀態(tài)空間對偶性:SSD
Mamba-2的核心,是結構化狀態(tài)空間對偶性(State Space Duality,SSD)的概念:
1. SSD模型指的是一個特定的獨立層,比如注意力層或狀態(tài)空間模型(SSM),可以被整合到深度神經網絡中;
2. SSD框架是一個用于推理該模型(以及更多理論連接)的通用框架;
3. SSD算法是一種比以前的SSM更高效地計算SSD層的算法。
SSD框架(紅色,藍色):狀態(tài)空間模型(即半分離矩陣)和結構化掩碼注意力涵蓋了大量高效的序列模型。它們的交集就是SSD模型(紫色)
原始的Mamba(或更準確地說,其核心「S6」層)實際上是一個具有對角結構的選擇性狀態(tài)空間模型(SSM)。
Mamba-2的SSD層只做了一個小改動:它進一步限制了對角矩陣??,使其成為標量乘以單位矩陣的結構。換句話說,??的對角元素必須都是相同的值。
在這種情況下,??可以表示為形狀(??),并且還可以將????識別為一個標量(有時會表示為????)。
所謂「對偶性」是指,方程(1)(標量-恒等結構????的情況)和(3)中定義的兩個模型實際上是完全相同的模型。
因此,我們可以將其視為一個特定函數(shù):
SSD vs. SSM
與之前的狀態(tài)空間模型(SSM)相比,SSD在遞歸矩陣??上增加了更多結構:
1. Mamba-1(S6)在矩陣??上使用對角結構,而Mamba-2(SSD)在矩陣??上使用標量乘以單位矩陣的結構;
2. Mamba-1的頭維度是??=1(即所有通道完全由獨立的SSM控制),而Mamba-2使用的頭維度是??>1(默認情況下類似于??=64)。
特別是,這可以通過兩種方式視為權重共享:
1. 通過將矩陣??的對角結構限制為標量乘以單位矩陣,遞歸動態(tài)在狀態(tài)空間的所有??元素之間共享;
2. 這些動態(tài)也在給定頭的所有??通道之間共享。
換句話說,一個單一的SSM頭的總狀態(tài)大小為??×??,在Mamba-1中由獨立的標量遞歸控制,而在Mamba-2中由單一的共享遞歸控制。
而這些變化,主要就是為了提高效率——讓模型能夠以「雙重注意形式」查看,從而允許使用矩陣乘法。
因此,與Mamba-1相比,Mamba-2支持更大的狀態(tài)維度(從N=16提升到了N=64、N=256甚至更高),同時在訓練期間速度更快。
SSD vs. Attention
與標準(自)注意力機制相比,SSD只有兩點不同:
1. 取消了softmax歸一化;
2. 以乘法方式應用單獨的元素級掩碼矩陣。
第一個不同之處在于,它將模型的有效狀態(tài)大小從線性減少到常數(shù),并將效率從二次方提升到了線性。
第二個不同之處是SSD與標準線性注意力的區(qū)別。一種理解掩碼的方法是將其視為依賴于輸入的相對位置編碼,由于掩碼??的存在,標準的注意力得分????????會被一個權重????:??×=?????????+1所衰減,這可以理解為基于位置??和??之間距離的「折現(xiàn)細數(shù)」(discount factor)。
在注意力形式中,這種依賴輸入的位置掩碼可以解釋為Mamba「選擇性」的關鍵因素!
SSD算法
由于Mamba-1的算法和實現(xiàn)沒有使用張量核心,因此只能進行小規(guī)模的狀態(tài)擴展(通常為??=16)。
相比之下,矩陣乘法的FLOPs要比非矩陣乘法快得多(最多快16倍):
- A100 GPU有312 TFLOPS的BF16矩陣乘法性能,但只有19 TFLOPS的FP32算術性能;
- H100有989 TFLOPS的BF16矩陣乘法性能,但只有67 TFLOPS的FP32算術性能。
這次,Mamba-2的一個主要目標,便是利用張量核心來加速SSM。
由于SSD連接了SSM和結構化矩陣,計算SSM或線性注意力的高效算法,可以直接對應于「token混合」或「序列混合」矩陣??的不同分解。
如今,這個算法不僅速度更快,而且比原始的Mamba選擇性掃描更容易實現(xiàn),僅需大約25行代碼!
def segsum(x):"""Naive segment sum calculation. exp(segsum(A)) produces a 1-SS matrix,
which is equivalent to a scalar SSM."""
T = x.size(-1)
x_cumsum = torch.cumsum(x, dim=-1)
x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagnotallow=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)return x_segsum
def ssd(X, A, B, C, block_len=64, initial_states=None):"""
Arguments:
X: (batch, length, n_heads, d_head)
A: (batch, length, n_heads)
B: (batch, length, n_heads, d_state)
C: (batch, length, n_heads, d_state)
Return:
Y: (batch, length, n_heads, d_head)
"""assert X.dtype == A.dtype == B.dtype == C.dtype
assert X.shape[1] % block_len == 0# Rearrange into blocks/chunks
X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]
A = rearrange(A, "b c l h -> b h c l")
A_cumsum = torch.cumsum(A, dim=-1)# 1. Compute the output for each intra-chunk (diagonal blocks)
L = torch.exp(segsum(A))
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)# 2. Compute the state for each intra-chunk# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries# (middle term of factorization of off-diag blocks; A terms)if initial_states is None:
initial_states = torch.zeros_like(states[:, :1])
states = torch.cat([initial_states, states], dim=1)
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]# 4. Compute state -> output conversion per chunk# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p")return Y, final_state
Mamba-2架構
Mamba-2架構的核心貢獻是提出了新的SSD層,及其理論,與此同時,研究者也對Mamba的神經網絡架構做了一些小改變。
Mamba-2塊通過刪除連續(xù)線性映射來簡化Mamba塊:SSM參數(shù)??, ??, ??是在塊的開頭生成的,而不是作為SSM輸入??的函數(shù)。如NormFormer中一樣,添加了一個額外的歸一化層,以提高穩(wěn)定性。B和C映射只有一個在??頭之間共享的頭,類似于多值注意力(MVA)
主要的變化是,??輸入并行生成(??, ??, ??)SSM參數(shù),而非按順序生成。
之所以這樣的做,與注意力有一定關系。但實際上,它更加簡潔,易于使用張量并行等擴展技術。
此外,模型架構還有一些其他不同之處。不過,研究作者想要強調的是,這些架構更改并不是模型的真正要點。
論文中,研究人員主要討論了兩種設計選擇,最終形成Mamba-2架構。
首先是,塊設計。
1. 并行參數(shù)映射
在Mamba-2中,SSD層被視為從??, ??, ??, → ?? 的映射。與標準注意力架構的類比,其中??, ??, ??對應于并行創(chuàng)建的Q, K, V投影。
2. 額外歸一化
在初步實驗中,研究者發(fā)現(xiàn)在較大的模型中容易出現(xiàn)不穩(wěn)定性。
他們通過在最終輸出映射前的數(shù)據(jù)塊中添加一個額外的歸一化層(如LayerNorm、GroupNorm或RMSNorm)來緩解這一問題。
這種歸一化的用法與NormFormer架構有最直接的關系,后者也在MLP和MHA塊的末尾添加了歸一化層。
此外,研究者還發(fā)現(xiàn),這種變化與最近從線性注意力角度衍生而來,并與Mamba-2相關的模型類似。
最初的線性注意力公式通過一個分母項,進行了歸一化,它模仿了標準注意力中softmax函數(shù)的歸一化。
在TransNormerLLM和RetNet研究中,卻發(fā)現(xiàn)了這種歸一化不穩(wěn)定性。因此在線性注意力層之后增加了一個額外的LayerNorm或GroupNorm。
研究者提出的「額外歸一化層」與此前研究策略有不同,這是在「乘法門」分支之后加入的歸一化層。
其次是,序列轉換多頭模式。
此前,曾提到了SSM被定義為序列轉換,其中:
- ??, ??, ??參數(shù)的狀態(tài)維度是N;
- 它們定義了序列轉換,比如可以表示為矩陣
;
- 這一轉換只針對輸入序列進行操作,并獨立于P軸。
總而言之,大家可以將其視為,定義序列轉換的一個頭。
多頭序列變換由H個獨立的頭組成,模型總維度為D=d_model。這些參數(shù)在多頭之間是共享的,從而形成一個頭模式(head pattern)。
狀態(tài)大小N和頭維度P,分別類似于注意力的??頭維度和??頭維度。
正如現(xiàn)代Transformer架構,在Mamba-2中,研究牛人員通常將這些維度,選擇為64或128左右的常數(shù)。
當模型維度D增加時,便會增加頭數(shù)量,同時保持頭維度N和P不變。
為了說明如何做到這一點,研究人員可以將多頭注意力的思想進行移植和擴展,從而為SSM或任何一般序列變換定義類似的模式。
1. 多頭SSM(MHS) / 多頭注意力(MHA)模式
經典多頭注意力(MHA)模式假定頭維度P,能夠被模型維度D整除。頭的數(shù)量H被定義為H=D/P。
然后,通過為每個參數(shù)創(chuàng)建H個獨立的副本,就構建出了H個獨立序列變換的「頭」副本。
值得注意的是,雖然MHA模式最初只是針對注意力序列變換而描述的,但其可以應用到任何符合定義的序列變換上。
比如,上圖中,多頭SSD層將接受形狀如公式 (17) 所示的輸入,其中SSD算法被復制到了H=n_heads維度上。
2. 多合約SSM(MCS) / 多查詢注意(MQA)模式
多查詢注意力,是對注意力的一種巧妙優(yōu)化,可以顯著提高自回歸推理的速度,它依賴于緩存??張量。
這一技術只需避免給??和??額外的頭維度,或者換句話說,將(??, ??)的單個頭廣播到??的所有查詢頭。
利用狀態(tài)空間對偶性(SSD),便可以多頭注意力(MQA)定義為與方程 (18) 等效的的狀態(tài)空間模型(SSM)。
在這里,??和??(分別對應注意力機制中的V和K)在H個頭之間共享。
由于控制SSM狀態(tài)收縮的??參數(shù),在每個頭中都有獨立的副本,因此研究人員也將其稱之為多合約SSM(MCS)頭模式。
類似地,研究人員還可以定義一種多鍵注意力(MKA),或多擴展SSM(MES)頭模式。其中??(控制SSM擴展)在每個頭中是獨立的,而??和??則在所有的頭中共享。
3. 多輸入SSM(MIS)/多值注意力(MVA)模式
雖然MQA因其KV緩存而受到關注,但它并不是SSM的自然選擇。
相反,在Mamba中,??被視為SSM的主要輸入,因此??和??是輸入通道共享的參數(shù)。
研究人員在公式(20)中定義了新的多輸入SSM(MIS)模式的多值注意(MVA),它同樣可以應用于任何序列變換,如SSD。
有了以上詞匯,現(xiàn)在便可以更準確地描述最初的Mamba架構。
Mamba架構的選擇性SSM(S6)層具有的特征是:
- 頭維度??=1:每個通道都有獨立的SSM動態(tài)??;
- 多輸入SSM(MIS)或多值注意力(MVA)頭結構:矩陣??、??(對應于注意力對偶性中的K、Q應)在輸入??(對應于注意力中的V)的所有通道中共享。
當然,作者表示,也可以在應用SSD時,去掉這些頭模式的辯題。
有趣的是,盡管在參數(shù)數(shù)量和總狀態(tài)維度上都有所控制,但在下游性能上卻存在明顯差異。他們根據(jù)經驗發(fā)現(xiàn)Mamba最初使用的MVA模式性能最佳。
第三是,分組頭模式。
多查詢注意力的理念可以擴展到分組查詢注意力:與使用1個K和V頭不同,它可以創(chuàng)建G個獨立的K和V頭,其中1<G,而且G可以整除H。
這樣做有兩個動機:一是彌合多查詢注意力和多頭注意力性能差距,二是通過將G設置為分片數(shù)(shards)的倍數(shù),以實現(xiàn)更高效的張量并行。
最后,研究人員還提到了線性注意力的其他SSD擴展項。
比如,核注意力近似于Softmax注意力,指數(shù)核特征圖。
語言建模
論文中,雖沒有像Mamba-1那樣廣泛地測試Mamba-2,但作者認為新架構總體上可與第一代性能相當,或者更好。
另外,研究稱,全語言模型結果使用與Mamba相同的協(xié)議,并且在Chinchilla Law上的擴展性略好于Mamba。
Pile數(shù)據(jù)集上的充分訓練的模型,以及標準的零樣本下游任務評估中,也看到了類似的趨勢。
即使在性能相當?shù)那闆r下,Mamba-2的訓練速度也比初代Mamba快得多!
合成語言建模:MQAR
更有趣的是,研究者針對Mamba-2再次嘗試了一項合成任務。
初代Mamba論文中,曾研究了「合成復制」和「誘導頭」等合成任務后,后續(xù)研究中開始研究更難的聯(lián)想回憶任務。
目前,由Zoology和Based團隊引入的多查詢聯(lián)想回憶(MQAR),已經成為行業(yè)里的事實標準。
這次,研究人員測試了一個更難的版本,結果發(fā)現(xiàn),Mamba-2的性能顯著優(yōu)于Mamba-1。
其中一個原因是,新架構的「狀態(tài)」要大得多——最多是Mamba-1的16倍。這也是Mamba-2的設計初衷之一
另外,即便是在控制狀態(tài)大小的情況下,Mamba-2在這一特定任務上的表現(xiàn)也明顯優(yōu)于Mamba-1。
系統(tǒng)和擴展優(yōu)化
好在,Transformer誕生后,整個研究界和大公司已經對它進行了長達7年的系統(tǒng)優(yōu)化。
SSD框架在SSM和注意力之間建立聯(lián)系后,也可以讓我們?yōu)镸amba-2等模型實現(xiàn)很多類似的優(yōu)化。
為此,研究者的重點,就是用于大規(guī)模訓練的張量并行和序列并行,以及用于高效微調和推理的變長序列。
張量并行
使用張量并行(TP)進行Mamba-1的大規(guī)模訓練時,一個難點在于,它每層需要進行2次全歸約(all-reduce),而Transformer中的注意力或MLP層,每層只需1次全歸約。
這是因為,一些SSM參數(shù)是內部激活的函數(shù),而不是層輸入的函數(shù)。
在Mamba-2中,采用了「并行投影」結構,所有SSM參數(shù)都是層輸入的函數(shù),因此,就可以輕松地將TP應用于輸入投影。
將輸入投影和輸出投影矩陣,根據(jù)TP的程度分成2、4、8個分片。
使用分組歸一化,分組數(shù)量可被TP程度整除,這樣每個GPU都能單獨進行歸一化。就是這些改變,使得每層只需1次全歸約,而不是2次。
序列并行
在訓練非常長的序列時,可能需要沿序列長度進行拆分,并將不同部分分配給不同的設備。
有兩種主要的序列并行(SP)形式:對于殘差和歸一化操作,這種形式將TP中的全歸約替換為規(guī)約-散布、殘差+歸一化,然后是all-gather。
由于Mamba-2使用與Transformer相同的殘差和歸一化結構,這種SP形式可以直接應用,無需修改。
對于注意力或SSM操作,也稱為上下文并行(CP)。
對于注意力機制,可以使用Ring注意力沿序列維度進行拆分。
對于Mamba-2,SSD框架又再次幫了大忙:使用相同的塊分解,就可以讓每個GPU計算其本地輸出和最終狀態(tài),然后在更新每個GPU的最終輸出之前,在GPU之間傳遞狀態(tài)。
可變長度
在微調和推理過程中,同一批次中經常會出現(xiàn)不同長度的序列。
對于Transformer,通常會采用填充方式使所有序列長度相同(雖然會浪費計算資源),或者專門為可變長度序列實現(xiàn)注意力機制,并進行負載平衡。
而對于SSM,就可以將整個批次視為一個長「序列」,并通過將每個序列末尾token的狀態(tài)轉移????設置為0,避免在批次中的不同序列之間傳遞狀態(tài)。
結果
結果顯示,更快的SSD算法,直接能讓我們將狀態(tài)維度增加到64或128!而在Mamba-1中,維度僅為16。
盡管從技術角度看,對于相同的??,Mamba-2比Mamba-1受到的限制會更多,然而更大的狀態(tài)維度,帶來的結果通常就是模型質量的提升。
更受限制,但更大的狀態(tài)維度通常會提升模型質量。
比如我們開頭所見的,在Pile上訓練3000億tokens,Mamba-2的表現(xiàn)就明顯優(yōu)于Mamba-1和Pythia。
而混合模型的表現(xiàn),也很令人滿意。
從最近的Jamba和Zamba的工作中,研究者發(fā)現(xiàn),將Mamba層與注意力層結合,可以超過純Transformer或Mamba模型的性能。
在2.7B參數(shù)和3000億tokens規(guī)模上驗證一個僅包含6個注意力塊(和58個SSD塊)的混合模型后可以發(fā)現(xiàn),其表現(xiàn)優(yōu)于64個SSD塊以及標準的Transformer++基線模型(32個門控MLP和32個注意力塊)。
混合Mamba/注意力模型的下游評估
而且,對于相同的狀態(tài)維度,SSD算法比Mamba-1的選擇性掃描算法快得多,并且在計算上更能擴展到更大的狀態(tài)維度。
其中的關鍵就在于,要充分利用張量核心的強大計算能力!
序列長度2K的效率基準
未來方向
如今,線性注意力和SSM連接起來后,前途一片大好,更快的算法、更好的系統(tǒng)優(yōu)化,就在眼前了。
作者提出,接下來AI社區(qū)需要探索的,有以下三個方向——
理解:含有少量(4-6)注意力層的混合模型表現(xiàn)非常出色,甚至超過了純Mamba(-2)或Transformer++。
這些注意力層的作用是什么?它們能被其他機制替代嗎?
訓練優(yōu)化:盡管SSD可能比注意力機制更快,但由于Transformer中的MLP層非常適合硬件,整體上Mamba-2在短序列長度(例如2K)上,可能仍然比Transformer慢。
未來,是不是可以讓SSD利用H100的新特性,讓SSM在2-4K序列長度的大規(guī)模預訓練中,比Transformer還快?
推理優(yōu)化:有許多針對Transformers的優(yōu)化方法,特別是處理KV緩存(量化、推測性解碼)。
如果,模型狀態(tài)(如SSM狀態(tài))不再隨著上下文長度擴展,KV緩存不再是瓶頸,那時的推理環(huán)境,會如何變化?