Mamba真比Transformer更優(yōu)嗎?Mamba原作者:兩個都要!混合架構才是最優(yōu)解
去年12月,CMU、普林斯頓的兩位華人學者Albert Gu和Tri Dao一舉推出了Mamba架構,向Transformer多年的霸主地位發(fā)起挑戰(zhàn)。
論文地址:https://arxiv.org/abs/2312.00752
完全拋棄注意力機制和MLP模塊、上下文長度線性縮放、推理速度比Transformer快5倍…這些特點讓所有人都為之一振,Jim Fan大佬也發(fā)推贊嘆「為推翻Transformer的研究感到興奮」。
論文發(fā)表后的6個月中,兩位作者發(fā)現,雖然Mamba很強大,但是大家依舊更關注各種Transformer的變體。
畢竟整個學術社區(qū)在注意力機制上深耕多年,從模型、標準庫到算子、GPU,此時完全拋棄之前的研究、轉向Mamba的SSM不太現實,也讓Mamba架構顯得非常格格不入。
于是,我們看到Mamba-2的論文在更高層面上將SSM和注意力機制統(tǒng)一了起來,同時相比Mamba-1實現了2~8倍的速度提升。
論文地址:https://arxiv.org/abs/2405.21060
就在大家都期待著「王者歸來」的Mamba-2與Transformer一決高下時,英偉達、威斯康星-麥迪遜大學、普林斯頓、CMU等多個機構的作者共同發(fā)表了一篇實證研究文章,發(fā)現基于Mamba架構的語言模型在長上下文任務上不敵Transformer。
其實不管出現哪種創(chuàng)新的方法或模型,有論文提出批評意見總是難免的。但細看這篇文章居然發(fā)現,Mamba的創(chuàng)造者Tri Dao和Albert Gu兩人竟然也在作者列表中。
論文地址:https://arxiv.org/abs/2406.07887
在此為兩位科學家實事求是的精神點贊。
此外,作者列表中還能發(fā)掘到另一個華點——Albert Gu和Tri Dao都有了新title。
Albert Gu現任Cartesia AI的聯合創(chuàng)始人兼首席科學家,他們最新的產品是實時語音交互API Cartesia Sonic。
https://cartesia.ai
Tri Dao是Together AI的創(chuàng)始科學家,該公司主要提供云服務,同時也貢獻前沿的開源研究。
https://www.together.ai
接下來我們還是詳細看看,這篇文章對Mamba和Transformer的能力具體做了哪些對比研究。
簡介
在迄今為止的研究中(包括提出Mamba架構的論文),SSM與Transformer的對比都只進行了較小規(guī)模的實驗(<3B參數,<1T token),這些結論在訓練預算更大的情況下是否成立?
這篇技術報告就是要回答這個問題。作者分別訓練出Mamba、Mamba-2、Mamba-2-Hybrid、Transformer等4種架構的8B參數模型,在35個NLP下游任務中對比性能。
訓練數據包括1.1T和3.5T兩個數據集,都是英偉達用于訓練Nemotron-4的數據集的前身,由70%英語、15%非英語和15%代碼組成
其中,Mamba-2-Hybrid是一個SSM-Transformer的混合架構模型,包含24個Mamba-2層,以及均勻分布在整個模型中的4個自注意力層和28個MLP層。
總體而言,這項對比實驗消除了比較不同LLM的常見困難,包括訓練數據、分詞器、評估管道等方面,確保評估流程的標準和可重復性。
為了方便復現和進一步研究,用于訓練Mamba、Mamba-2和Mamba-2-Hybrid的代碼已經開源,而且研究團隊還在HuggingFace上發(fā)布了Mamba-2 8B和Mamba-2-Hybrid 8B的模型權重(作為英偉達Megatron-LM框架和代碼庫的一部分)。
https://huggingface.co/nvidia
實驗結果表明,雖然Mamba和Mamba-2更擅長建模語言,但在上下文學習方面,以及從上下文中回憶信息時,性能落后于Transformer模型。
尤其是在MMLU基準上,即使提高了訓練數據的token數量,基于Mamba的模型依舊和Transformer有不小的差距。
Mamba vs. Transformer
用于評估的35個下游任務大致包含3個類別:
- 標準短上下文任務(12個):HellaSwag、ARC-Easy、ARC-Challenge、MMLU、OpenBookQA、TruthfulQA等
- 自然長上下文任務(9個):LongBench中的6個任務和LM Evaluation Harness框架中的3個任務
- 綜合長上下文任務(14個):RULER框架中的13個開源測試(包括「大海撈針」的8個變體)以及今年剛提出的「電話簿」(Phonebook)任務,旨在衡量模型在長輸入文本中檢索、跟蹤、聚合信息的能力。
表2展示了經過1.1T數據訓練后,純SSM架構的Mamba和Mamba-2與Transformer模型的部分評估結果。
在常見任務上,Mamba和Mamba-2的性能都可以匹配甚至超過Transformer模型,但MMLU基準是一個例外。進行零樣本或少樣本學習時,Mamba-2相比Transformer分別有10分和17分的差距。
因為在1.1T數據集上Mamba模型的訓練速度就已經比Mamba-2慢了將近3×(模型的狀態(tài)維度較大),出于效率方面的考量,在3.5T數據集上只訓練了Mamba-2模型和Transormer模型,部分結果如表3所示。
從表3可知,更多的訓練數據有助于Mamba-2在MMLU任務上得到改進,5-shot分數的差距縮小到僅1.37分,其他任務上依舊全面領先Transformer。
Mamba折戟MMLU與電話簿任務
由于MMLU在一眾下游任務的結果中顯得如此反常,論文對此進行了更細致的拆解和討論。
如上圖所示,MMLU的任務類似于考試中的選擇題,但在cloze格式中也可以不提供備選答案,以填空題的方式提供給模型。
表4中提供了MMLU按照格式細分后,3個模型各自的分數(用1.1T token訓練)。在標準模式和選擇題模式中,Mamba架構不敵Transformer,但在填空題模式中居然實現了分數反超。
結合表3中的結果,我們有理由推斷,純SSM模型和Transformer模型包含的知識內容應該是同等級別的,但前者需要更多的訓練才能理解MMLU的前兩種格式。
作者推斷,這種差距可能源于Transformer強大的上下文學習能力,可以看到該模型從0-shot到5-shot的準確度提升非常明顯。
此外,SSM模型可能無法直接將答案所需的知識路由到輸出的單個答案token中(即ABCD選項的其中一個),而這正是自注意力層擅長的任務。
此外,Mamba系列模型在「電話簿」上的表現也并不理想,該任務旨在衡量模型通過少數示例進行上下文學習,以及從上下文中復制信息的能力。
下圖展現了任務的兩種變體,標準版是先提供整個電話簿,再給出目標查詢;反轉版則是先查詢,再給電話簿。
圖3a、c分別展示了3個模型在這兩個任務變體上的準確率。
Transformer在電話簿長度不超過預訓練的上下文長度(4096)時,準確率接近100%,相比之下,Mamba和Mamba-2在輸入序列達到500 token時就出現了顯著的性能滑坡。
如果仔細觀察Mamba系列的輸出答案(圖2b),可以發(fā)現SSM架構的模型并非完全無法記憶上下文信息,而是保留了一些模糊記憶,給出的電話號碼通常有幾位是正確的。
綜合以上結果,我們可以將MMLU和「電話簿」任務確立為純SSM架構模型的挑戰(zhàn)性任務,并且推測出可能原因:這兩個任務需要上下文學習、token間信息路由以及從上下文復制的能力,它們可能是Mamba系列模型的能力軟肋。
SSM-Transformer混合架構
由于在MMLU和「電話簿」任務上看到了SSM架構的能力缺陷,作者想到——讓SSM和Transformer強強聯合,能夠起到取長補短的效果?
于是他們將自注意力和MLP層添加到Mamba架構中,想看看模型能否克服上述問題。
論文首先報告了一系列消融實驗的結果,通過對比在下游任務上的表現,探索出了能達到最佳性能的架構設計與參數(表6)。
56層的Mamba-2-Hybrid中包含4個(7.1%)自注意力層,24 個(42.9%)Mamba-2層和28個(50%)MLP 層,其中Mamba-2層使用與Mamba-2模型相同的參數。
自注意力、MLP層的數量以及MLP層擴展因子這些參數的選擇并非隨機,而是根據驗證集上損失值結果(圖4)進行的最優(yōu)化設計。
消融實驗的結果還顯示,混合模型中不添加旋轉位置編碼(RoPE)能達到更好的下游任務性能(表5),而且Mamba層、自注意力層、MLP層的順序也會影響模型能力。
首先,Mamba層必須出現在架構的開頭,以確保模型自然地學習到位置信息。相比使用重復塊模式,將自注意力和MLP均勻分散在整個模型是更好的配置。
而且通過計算驗證集上的模型困惑度(perplexity)可以得知,相比多頭注意力(MHA),使用組查詢注意力層(GQA)能減少推理計算量和內存量,但幾乎不會造成模型質量的下降。
效率方面,Mamba-2-Hybrid實現了29.9%的FLOP利用率(MFU),與Transfomer的30.7%基本相當。此外,前者有推理速度方面的巨大優(yōu)勢。
在長上下文情境中,受益于多個SSM層的存在,Mamba-2-Hybrid的token生成速度比Transformer加速了將近8×(圖5)。
評估
測評發(fā)現,這種混合架構果然有了「取長補短」的效果,混合架構在5-shot MMLU測評中同時超過了單純的Transformer和SSM架構,取得得了最高準確度(圖6)。
從表7中的多個基準總體來看,Mamba-2-Hybrid在效率更高的同時,性能也超過了Transformer模型。
相比Mamba-2,混合架構的長上下文能力也得到了顯著提高(表10),在RULER基準上的綜合任務、「大海撈針」任務的平均成績也都超過了Transformer。
在Mamba系列表現較差的「電話簿」任務上,Mamba-2-Hybrid可以在預訓練上下文長度 (4K) 內以近乎完美的精度完成電話簿任務,還可以稍微超出該長度進行泛化,在最多5.5k token的電話簿上實現100%準確率。
甚至,Mamba-2-Hybrid的潛力還不止于此,當預訓練長度擴展到128k并在4個自注意力層中使用全局注意力時,「電話簿」任務的100%準確率也延伸到了將近150k token。
結論
論文開頭的評估結果表明,在更大訓練預算的情況下,純SSM模型依舊能在下游任務上超過Transformer,但上下文學習和信息檢索能力有所局限。
基于此,作者提出的混合架構模型Mamba-2-Hybrid能夠在提高效率的同時繼續(xù)表現出比Transformer更強大的性能,并彌補了純SSM架構的相關缺陷。
這項研究所展示的全面結果告訴我們,Mamba和Transformer這兩種架構各有長短,也許并不需要其中一個取代另一個,將二者結合起來是一條值得探索的、有巨大潛力的路徑。