英偉達(dá)提出首個Mamba-Transformer視覺骨干網(wǎng)絡(luò)!打破精度/吞吐瓶頸 | CVPR 2025
正如標(biāo)題所言「Attention is all you need」,Transformer已成為不同領(lǐng)域的「霸主」,包括計算機(jī)視覺、自然語言處理、語音處理和機(jī)器人技術(shù)。
第一個挑戰(zhàn)Transformer架構(gòu)的是Mamba,一種新的狀態(tài)空間模型(SSM),它具有線性時間復(fù)雜度,并在多個語言建模任務(wù)中超越或與Transformer媲美。
但在不同的視覺任務(wù)上,Vision Transformer (ViT) 和卷積神經(jīng)網(wǎng)絡(luò) (CNN) 架構(gòu)的骨干網(wǎng)絡(luò),仍然優(yōu)于基于Mamba的視覺模型。
而這一次,英偉達(dá)高級工程師Ali Hatamizade,宣布被頂會CVPR-2025接受的MambaVision,在視覺任務(wù)上超越以往的模型,而設(shè)計的關(guān)鍵在于將Mamba和Transformer混合。
正如圖1所示,在ImageNet-1K基準(zhǔn)上,MambaVision的Top-1準(zhǔn)確率和圖像處理能力達(dá)到了新的Pareto最優(yōu)點(diǎn),超越了Mamba、CNN和ViT基于的模型,有時差距非常顯著。
在下游任務(wù)如目標(biāo)檢測、實(shí)例分割以及語義分割中,采用MambaVision作為骨干網(wǎng)絡(luò)的模型在MS COCO和ADE20數(shù)據(jù)集上分別超越了同等規(guī)模的對比模型。
MambaVision是首個針對計算機(jī)視覺應(yīng)用,結(jié)合Mamba和Transformer的混合架構(gòu)的嘗試。主要貢獻(xiàn)總結(jié)如下:
1 引入了重新設(shè)計的適用于視覺任務(wù)的Mamba模塊,提升了相較于原始Mamba架構(gòu)的準(zhǔn)確性和圖像處理能力。
2 系統(tǒng)性地研究了Mamba和Transformer模塊的融合模式,并展示了在最終階段加入自注意力模塊,顯著提高了模型捕捉全局上下文和長距離空間依賴的能力。
論文鏈接:https://arxiv.org/abs/2407.08083
在這項(xiàng)工作中,作者系統(tǒng)地重新設(shè)計了Mamba模塊,使其更加適合視覺任務(wù)。
新方法是一種混合架構(gòu),結(jié)合了新提出的公式(即MambaVision Mixer和MLP)以及Transformer模塊。
具體來說,研究了不同的集成模式,比如以等參數(shù)方式將Transformer模塊添加到早期、中間和最終層,或者每隔l層添加一次。
分析表明,在最終階段利用多個自注意力模塊,可以顯著增強(qiáng)捕捉全局上下文和長程空間依賴的能力。
使用混合架構(gòu)相較于純Mamba或ViT模型,圖像處理能力也得到了顯著提升。
網(wǎng)絡(luò)架構(gòu)
宏觀架構(gòu)
如圖2所示,MambaVision 采用了分層架構(gòu),由4個不同的階段組成。
前兩個階段使用基于CNN的層,負(fù)責(zé)在較高輸入分辨率下進(jìn)行快速特征提取,而第3和第4階段則包括了新提出的 MambaVision和Transformer模塊。
具體來說,給定一個大小為H×W×3的圖像,輸入首先被轉(zhuǎn)換為大小為H/4×W/4×C的重疊patch,并通過兩層連續(xù) 3×3的CNN 層(步幅為2)構(gòu)成的主干投影到C維嵌入空間中。
在各個階段之間的下采樣模塊由一個批歸一化的3×3的CNN 層(步幅為2)組成,將圖像分辨率減半。
此外,第1和第2階段中的CNN模塊,采用了通用的殘差模塊結(jié)構(gòu),具體如下:
其中:Conv3×3 表示3×3卷積操作;BN表示批歸一化(Batch Normalization);GELU 是激活函數(shù),表示 Gaussian Error Linear Unit;z^ 是經(jīng)過卷積、批歸一化和激活函數(shù)處理后的中間結(jié)果;最后,z是通過卷積和批歸一化后的結(jié)果與原始輸入相加,形成殘差連接。
這種結(jié)構(gòu)有助于緩解深層網(wǎng)絡(luò)訓(xùn)練中的梯度消失問題,并提高模型的訓(xùn)練效率。
Mamba架構(gòu)
Mamba是結(jié)構(gòu)化狀態(tài)空間序列模型的擴(kuò)展,能夠通過可學(xué)習(xí)的隱狀態(tài) h(t),將一維連續(xù)輸入x(t)轉(zhuǎn)換為y(t)。該過程的公式如下:
其中,矩陣A,B,C是模型的參數(shù)。
離散化:為了提高計算效率,以上公式中的連續(xù)參數(shù)A,B和C需要轉(zhuǎn)化為離散參數(shù)。具體而言,假設(shè)時間步長為Δ,可以應(yīng)用零階保持規(guī)則來獲取離散參數(shù):
這種離散化方法能夠提升計算效率,便于在實(shí)際應(yīng)用中實(shí)現(xiàn) Mamba 模型。
使用離散參數(shù)代入到原方程:
此外,對于一個大小為T的輸入序列,可以用帶有卷積核K的全局卷積,進(jìn)一步簡化上式中的輸出,具體如下
選擇性:Mamba 進(jìn)一步擴(kuò)展了S4公式,引入了一種選擇機(jī)制,使得模型能夠進(jìn)行依賴于輸入的序列處理。這種機(jī)制使得模型的參數(shù)B 、C和Δ可以根據(jù)輸入動態(tài)調(diào)整,從而濾除無關(guān)信息。
設(shè)輸入X是TxC矩陣,其中 T 為序列長度,C為嵌入維度,第3和第4階段的第n層輸出可以按如下方式計算:
其中,NormNorm和MixerMixer分別表示層歸一化和 token 混合模塊的選擇。
層架構(gòu)
在不失一般性的情況下,層歸一化(Layer Normalization)被用于 NormNorm。給定N層,前 N/2層使用 MambaVision混合模塊,而剩余的N/2層使用自注意力機(jī)制。
MambaVision 混合模塊:重新設(shè)計了原始的Mamba混合模塊,使其更適合視覺任務(wù)。
如圖3所示,首先將因果卷積(causal convolution)替換為常規(guī)卷積,因?yàn)橐蚬矸e將信息限制在一個方向上,這對視覺任務(wù)來說不僅沒必要,而且局限性還很大。
此外,添加了一個不包含SSM(狀態(tài)空間模型)的對稱分支,該分支由額外的卷積和SiLU激活函數(shù)組成,以補(bǔ)償由于SSM的順序約束而可能丟失的內(nèi)容。
然后,將兩個分支的輸出拼接起來,并通過最終的線性層進(jìn)行投影。這種組合確保了最終的特征表示,同時包含順序信息和空間信息,從而充分利用了兩個分支的優(yōu)勢。
注意到,每個分支的輸出被投影到一個大小為C/2的嵌入空間(即原始嵌入維度的一半),以保持與原始模塊設(shè)計相似的參數(shù)量。
給定輸入Xin,MambaVision混合模塊的輸出Xout計算如下:
其中,Linear(Cin,Cout)(?)表示一個線性層,輸入和輸出的嵌入維度分別為Cin和Cout;Scan是選擇性掃描操作(selective scan);σ是激活函數(shù),這里使用的是Sigmoid線性單元(SiLU;Conv和Concat 分別表示1D卷積和拼接操作。
實(shí)驗(yàn)結(jié)果
表1展示了ImageNet-1K分類結(jié)果。具體來說,與不同類別的模型進(jìn)行了比較,包括基于卷積的模型、基于 Transformer的模型、卷積-Transformer混合模型以及基于Mamba的模型,并證明新模型在ImageNet Top-1準(zhǔn)確率和圖像處理能力方面大幅超越了之前的工作。
例如,與流行的模型如ConvNeXt和Swin Transformers相比,MambaVision-B(84.2%)優(yōu)于 ConvNeXt-B(83.8%)和 SwinB(83.5%),同時在圖像處理能力上也有顯著優(yōu)勢。
在與基于 Mamba 的模型比較時也觀察到了類似的趨勢。具體來說,盡管MambaVision-B(84.2%的圖像處理能力顯著更高,但仍優(yōu)于 VMamba-B(83.9%)。
與同等規(guī)模的模型相比,MambaVision 型變體的FLOPs遠(yuǎn)低于它們。例如,MambaVision-B 的GFLOPs比 MaxViT-B 少了56%。
表2展示在MS COCO數(shù)據(jù)集上的目標(biāo)檢測和實(shí)例分割結(jié)果。
具體來說,訓(xùn)練了不同檢測尺寸的模型,以進(jìn)一步驗(yàn)證 MambaVision 不同場景下的有效性。
通過簡單的Mask-RCNN檢測頭,預(yù)訓(xùn)練的MambaVision-T骨干網(wǎng)絡(luò),超過了 ConvNeXt-T和 Swin-T模型。
使用Cascade Mask-RCNN網(wǎng)絡(luò)時,MambaVision-T、MambaVision-S和MambaVision-B都超過了競爭對手。
表3展示了在ADE20K數(shù)據(jù)集上的語義分割基準(zhǔn)測試。
對于這些實(shí)驗(yàn),使用了 UPerNet,以便與其他模型進(jìn)行比較。
觀察到,MambaVision 模型在不同變體下超越了同等規(guī)模的競爭模型。
例如,MambaVision-T、MambaVision-S 和 MambaVision-B分別在mIoU上超越了Swin-T、Swin-S和Swin-B,提升幅度為+0.6、+0.6和+1.0。
盡管沒有對下游任務(wù)進(jìn)行大量的超參數(shù)調(diào)優(yōu)優(yōu)化,這些結(jié)果仍然證明了MambaVision作為一種有前景的視覺任務(wù)骨干網(wǎng)絡(luò)的可行性,特別是在高分辨率設(shè)置下。
消融實(shí)驗(yàn)和更多細(xì)節(jié)請參考原文。