我們還需要Transformer中的注意力嗎?
最近幾周,AI 社區(qū)有一個熱門話題:用無注意力架構(gòu)來實現(xiàn)語言建模。簡要來說,就是機器學(xué)習(xí)社區(qū)有一個長期研究方向終于取得了實質(zhì)性的進展,催生出 Mamba 兩個強大的新模型:Mamba 和 StripedHyena。它們在很多方面都能比肩人們熟知的強大模型,如 Llama 2 和 Mistral 7B。這個研究方向就是無注意力架構(gòu),現(xiàn)在也正有越來越多的研究者和開發(fā)者開始更嚴(yán)肅地看待它。
近日,機器學(xué)習(xí)科學(xué)家 Nathan Lambert 發(fā)布了一篇題為《狀態(tài)空間 LLM:我們需要注意力嗎?》的文章,詳細介紹了 2023 年無注意力模型的發(fā)展情況。他還表示:2024 年你將會有不同的語言模型架構(gòu)可選。需要說明,這篇文章包含不少數(shù)學(xué)內(nèi)容,但深度理解它們是值得的。鑒于這篇文章較長,所以這里先列出分節(jié)目錄,以方便讀者索引:
- 引言:我們?yōu)槭裁纯赡懿⒉幌胧褂米⒁饬σ约笆裁词茄h(huán)神經(jīng)網(wǎng)絡(luò)。
- Mamba 模型:這種新的狀態(tài)空間模型能為未來多種類別的語言模型提供功能和硬件加速。
- StripedHyena 模型:這種來自 Together AI 的新 7B 模型組合了 RNN 和 Transformer 方向的最近研究成果,表現(xiàn)很出色。
- Monarch Mixers 研究:這篇新論文給出了一個范例,展示了該研究的運作方式,以及為什么沒有注意力或 MLP 也能成功。
- Zoology 研究:這是一個用于高效 LLM 架構(gòu)研究的代碼庫,另外還有基于這些研究得到的模型 Based。
此外還有更多鏈接、閱讀材料和資源。
如果你對這些內(nèi)容感興趣,可以閱讀作者 Nathan Lambert 對這一領(lǐng)域的兩位領(lǐng)先研究者的采訪,參閱機器之心報道《誰能撼動 Transformer 統(tǒng)治地位?Mamba 作者談 LLM 未來架構(gòu)》。
注意力 vs 循環(huán)和狀態(tài)空間模型(SSM)
本文的核心是理解不同的計算方式會怎樣為模型帶來不同的能力。本文關(guān)注的主題是語言,但其中的思想也適用于其它許多模態(tài)(事實上,這些新架構(gòu)最早取得成功的模態(tài)是音頻)。當(dāng)模型的內(nèi)部不同時,就會出現(xiàn)不同的歸納方式、訓(xùn)練會有新的擴展律、不同的推理時間成本、新的表達能力水平(即模型可學(xué)習(xí)的任務(wù)的復(fù)雜度)等等。架構(gòu)會改變有關(guān)模型的表達方式的一切,即便數(shù)據(jù)一樣也是如此。
一如既往,不同的架構(gòu)選擇都有各自的優(yōu)劣之處?,F(xiàn)如今最流行的 Transformer 架構(gòu)的核心組件注意力有非常出色的性能和可用性,原因有很多。本文不會把這些原因都列出來;簡單來說,注意力有利于處理語言任務(wù)時有一個自然的歸納偏差的模型、可以輕松在 GPU 和 TPU 上擴展訓(xùn)練的模型、可以高效處理大批量輸入的模型(例如存儲鍵 - 值矩陣)等等。
究其核心,注意力中有從過去每個 token 到當(dāng)前 token 的映射。正是這種密集架構(gòu),讓模型有能力表征許多不同的內(nèi)容并關(guān)注長上下文樣本。
而循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)將時間整合進模型的方式卻大不相同,這是本文要討論的主要競爭方法。這些模型會在每次遇到新的輸入數(shù)據(jù)時更新一個內(nèi)部狀態(tài)變量(以下記為 x)原理上講,這種內(nèi)部狀態(tài)可以捕獲任意系統(tǒng)的相關(guān)長期行為,而無需與數(shù)據(jù)之間有直接的計算鏈接。這能讓模型在計算長序列時的效率非常高,但直到最近,人們都還沒能證明其在性能上能媲美基于注意力的模型。下圖比較了注意力和 RNN 的計算圖譜:
在討論這些模型時,會遇到很多奇特的術(shù)語。而研究社區(qū)想做的是創(chuàng)造一種具有 RNN 那樣的時間依賴能力,同時又能維持注意力或卷積等架構(gòu)的高效訓(xùn)練能力的模型。為此,最近出現(xiàn)了許多圍繞狀態(tài)空間模型(SSM)的研究成果,其遵照的是狀態(tài)的連續(xù)時間或離散時間演變:x'(t) = Ax (t) + Bu (t), y (t) = Cx (t) + Du (t)。使用巧妙的線性代數(shù)或微分方程,根據(jù)它是連續(xù)時間或離散時間,控制這個狀態(tài)演變的矩陣可以表示成一個一維卷積。卷積架構(gòu)的效率很高,所以這是個好預(yù)兆,但除此之外,本文不會出現(xiàn)艱深的數(shù)學(xué)。
下面展示了其方程,來自 Mamba 論文(https://arxiv.org/abs/2312.00752 )。除非你想成為這方面的專家,否則你只需要知道這是在連續(xù)時間中構(gòu)建的(1a 和 1b),往往會被離散化(2a 和 2b),并會得到一個核 K(3a 和 3b)。從技術(shù)上講,這是一個一維卷積。
Mamba 的 SSM 方程
作者表示,盡管可以預(yù)期這不會在 2024 年改變一切,但卻有可能在 2-4 年內(nèi)帶來天翻地覆的改變。不同的任務(wù)將會使用不同的 LLM 架構(gòu)。作者還預(yù)計 ChatGPT 這樣的系統(tǒng)將會使用多種類型的語言模型來執(zhí)行日常任務(wù)。正如本文將描述的那樣,基于這種 RNN 結(jié)構(gòu)構(gòu)建的模型(因為一些技術(shù)原因進行了許多修改)在長上下文任務(wù)的潛在準(zhǔn)確度和推理成本方面有明顯的規(guī)模擴展優(yōu)勢。
如果你對語言建模和機器學(xué)習(xí)的艱深數(shù)學(xué)感興趣,那么十二月必定是個好月份。很多理性的人都知道注意力多半會被替代,只是疑惑會被怎樣替代以及何時會發(fā)生??紤]到對特定于注意力的基礎(chǔ)設(shè)施的投資之巨,作者預(yù)計這種方面短期內(nèi)還無法達到 GPT-N 或 Gemini 那樣的地位。如果它成功了且注意力被放棄,那谷歌就會面臨一個大麻煩,因為 TPU 不見得也能用于這些新技術(shù)(就像 TPU 已經(jīng)不能很好地處理 MoE 一樣)。盡管如此,SSM 及相關(guān)技術(shù)依然面臨諸多挑戰(zhàn),而且很多東西都還沒有得到概念驗證,例如:
- 高效利用 GPU 的能力,這是有效擴展所需的。
- 輕松微調(diào)模型并維持大多數(shù)性能的能力。
- 執(zhí)行上下文學(xué)習(xí)以及系統(tǒng) prompt 等功能的能力。
- 事實上,大型 Transformer 模型的大多數(shù)參數(shù)和計算依然還是前向網(wǎng)絡(luò)(FFN),這也是 SSM 要么使用,要么不修改的部分。
- RNN 中隱藏狀態(tài)所需能力的瓶頸。
- 整合檢索記憶等功能的能力,尤其是對于長文檔。這更側(cè)重于整合復(fù)雜的信息源,而不是已經(jīng)存在的長文本擴展。
列表不盡于此。當(dāng)你在閱讀此文章時,請記住這種新的無注意力技術(shù)很有潛力,但這并不意味著它現(xiàn)在就能與當(dāng)前最先進的技術(shù)競爭,即便它確實有幾個非常亮眼的結(jié)果。這只是意味著我們還處于很早期。
最近發(fā)布的模型
Mamba 以及 RNN 的高效計算
12 月 4 日,Albert Gu 和 Tri Dao 公布了 Mamba,一個力圖讓無注意力 LLM 的性能比肩 Transformer 的新模型,同時還能解決長上下文場景中計算受限的問題。Mamba 有三大關(guān)鍵特性:
1.(數(shù)據(jù))選擇機制:「我們設(shè)計了一種簡單的選擇機制,可以根據(jù)輸入對 SSM 的參數(shù)進行參數(shù)化?!?/span>
2. 硬件感知型算法:一個將卷積切換成對特征的掃描的開關(guān),可讓模型在已有硬件上的運行效率更高。
3. 架構(gòu):將之前 SSM 的循環(huán)與 transformer 的前向模塊風(fēng)格結(jié)合起來。
這三方面都涉及大量數(shù)學(xué)知識。歸根結(jié)底,它們是設(shè)計用于在不導(dǎo)致計算低效的前提下提升 SSM 表達能力的技術(shù)。
數(shù)據(jù)選擇機制能夠?qū)⒀h(huán)空間 B 和 C 的處理矩陣表示為輸入文本 x 的函數(shù)(這也被稱為移除矩陣的線性時不變性(LTI))。這能以通用性為代價提升表達能力,因為輸入序列會根據(jù)領(lǐng)域的不同而有非常不同的行為。這些矩陣可以學(xué)習(xí)哪些輸入 token 是最重要的,因此這個機制名為「選擇」。
Mamba 論文給出的算法
硬件感知型組件關(guān)注的重點是如何將隱藏狀態(tài) h 保存在內(nèi)存中效率最高的部分。SSM 更新中使用的核心參數(shù)(線性化的 A、B 和 C 矩陣)會被保存在一種名為 SRAM 的緩存中,這樣一來,到處搬運權(quán)重就不會造成較大的計算瓶頸。下圖展示了所使用的內(nèi)存類型:
最后,Mamba 論文還包含一個新的模型模塊,其設(shè)計靈感來自 SSM 和 Transformer 模型。本文作者并不清楚如此設(shè)計的原因,但考慮到激活和 MLP 對當(dāng)前機器學(xué)習(xí)的重要性,以及當(dāng)前最佳的 LLM 基本都涉及到它們,因此這種做法應(yīng)當(dāng)是合理的。
這一項目可被看作是基于 SSM 社區(qū)大量研究成果的集大成之作(StripedHyena 確非如此)。使用定制版的 CUDA 核,GPU 可以說是火力全開。
專門設(shè)計了架構(gòu)的 CUDA 核能將推理速度提得飛快,如下所示:
最后,與 Pythia 套件相比,考慮到模型大小,其平均評估性能較低。需要指出,Pythia 在參數(shù)方面的模型效率上已經(jīng)不是當(dāng)前最佳了,但其做對比的是 Mistral 等模型(還在 y 軸上更上面),所以這其實是褒獎。此外,該模型不一定穩(wěn)健和靈活,參見下圖,但它也值得了解。
正如之前的采訪提到的那樣,Tri Dao 表示架構(gòu)只會以更好的擬合度將擴展律曲線向上移,而數(shù)據(jù)驅(qū)動型 LLM 依然是創(chuàng)造最佳模型的最大要素。作者認為這會進一步限制模型架構(gòu),以便在可用計算和相關(guān)任務(wù)上獲得更好的性能。這是很棒的東西。Mamba 的模型和代碼地址:https://github.com/state-spaces/mamba
此外,GitHub 上還有 Mamba 的最小實現(xiàn)版本:https://github.com/johnma2006/mamba-minimal
真實世界性能:StripedHyena-7B
盡管之前的兩個項目旨在推進 LLM 的架構(gòu)發(fā)展,但 StripedHyena(SH)卻在新 LLM 架構(gòu)方面顯得最惹人眼球,但其目標(biāo)是將許多先進方法和架構(gòu)(包括注意力)組合起來推進 LLM 的性能表現(xiàn)。
12 月 8 日,Together AI 發(fā)布了第一個模型:StripedHyena-7B。StripedHyena(SH) 讓人震驚。這種新的語言模型與很多常用的語言模型表現(xiàn)相當(dāng)。根據(jù)其博客描述(https://www.together.ai/blog/stripedhyena-7b ),可知該模型采用了一種名為 grafting(嫁接)的技術(shù)。本質(zhì)上講,就像是 Together AI 取用了不同預(yù)訓(xùn)練模型的模塊,將它們連接到一起,并繼續(xù)訓(xùn)練模型使其性能穩(wěn)定。這篇博客中寫到:
「我們嫁接了 Transformer 和 Hyena 的架構(gòu)組件,并在 RedPajama 數(shù)據(jù)集的一個混合數(shù)據(jù)集上進行了訓(xùn)練,并用更長的上下文數(shù)據(jù)進行了增強。
Hyena 這個名稱來自論文《Hyena Hierarchy: Towards Larger Convolutional Language Models》。
在 OpenLLM 排行榜任務(wù)上,StripedHyena 輕松擊敗了 Llama 2 和 Yi 7B!
和許多無注意力架構(gòu)一樣,該模型的一大賣點是長上下文性能。這篇論文使用 ZeroScrolls 基準(zhǔn)表明 StripedHyena 在該任務(wù)上的平均 f1 分數(shù)比 Mistral 7b v0.1 高 3 分(盡管其并未在每個子類上都獲勝)。盡管勝過 Mistral 的 StripedHyena 的得分也不過只有 27.5,但考慮到 GPT-4 的平均分也只有 41.7,因此這個成績還算是不錯了。
除了性能之外,這項研究的還有一大部分是圍繞 LLM 的不同概念的計算效率。其博客詳細說明了不同架構(gòu)的模型的 Chinchilla 風(fēng)格的擴展律。下圖左邊比較了 Llama-2 與 StripedHyena,右邊則是根據(jù)預(yù)算的最優(yōu)注意力比例:
與 Mamba 一樣,發(fā)布的 StripedHyena 包含有關(guān)推理提升的大量細節(jié)。首先是端到端的完整速度:
其次,隨著上下文長度增長,其總的內(nèi)存使用量呈現(xiàn)次二次的擴展趨勢。
鑒于該模型更注重真實世界性能,作者將其與 Mistral 7b 進行了比較。作者表示盡管他更喜歡 Mistral 的答案,但 StripedHyena 的答案是正確的,如果它早一點發(fā)布,說不定還能當(dāng)一段時間的最佳模型。這表明這些新架構(gòu)并沒有落后太多。
不過其也有些局限。研究團隊沒有分享基礎(chǔ)模型使用的數(shù)據(jù),只是說「RedPajama 數(shù)據(jù)集的一個混合數(shù)據(jù)集,并用更長的上下文數(shù)據(jù)進行了增強」。
在他們博客的最后,他們明確表示我們應(yīng)該關(guān)注 Together 在這個方向的下一步動作:
- 有更長上下文的更大模型
- 多模態(tài)支持
- 更進一步的性能優(yōu)化
- 將 StripedHyena 整合進檢索流程,以充分利用更長的上下文。
近期研究
這一節(jié)將介紹可能會有 Mamba 和 StripedHyena 那樣影響力的新論文,但它們可能還需要進一步的研究才能達到這一步。
Monarch Mixer:沒有注意力或多層感知器的模型
- 論文題目:Monarch Mixer:A Simple Sub-Quadratic GEMM-Based Architecture
- 論文:https://arxiv.org/abs/2310.12109
- 代碼:https://github.com/HazyResearch/m2
- 博客:https://hazyresearch.stanford.edu/blog/2023-07-25-m2-bert
這篇論文研究的不僅僅是去除 Transformer 中的注意力,還去除了占用了大部分參數(shù)的 MLP。這類研究將在未來 6-12 個月內(nèi)出現(xiàn)在類 Mamba 的模型中。
GEMM 是一種廣義的矩陣乘法算法,是許多流行架構(gòu)的基礎(chǔ)運算,包括 FFN、RNN、LSTM 和 GRU。GEMM 能在 GPU 上非常高效地執(zhí)行。可以預(yù)見,由于矩陣乘法是現(xiàn)代機器學(xué)習(xí)的核心,但以新的方式設(shè)置它們并不總是能成功地擴展!
我們先過一遍其抽象描述,理解下這是什么意思:
「在序列長度和模型維度方面,機器學(xué)習(xí)模型在不斷得到擴展,以支持更長的上下文并取得更優(yōu)的性能。但是,Transformer 等現(xiàn)有架構(gòu)會隨這些軸呈二次擴展。我們就會問:是否存在能隨序列長度和模型維度呈次二次擴展的高性能架構(gòu)?」
盡管現(xiàn)在的長上下文長度已經(jīng)很長了,但其推理效率并不高。作者認為模型大小在這里不是一個核心因素,或者換一種說法:尋找一種擴展律遵循比指數(shù)冪律更小規(guī)律的架構(gòu)。下面繼續(xù):
「我們提出了 Monarch Mixer (M2),一種沿序列長度和模型維度都是次二次擴展的新架構(gòu)?!筂onarch 矩陣是一類簡單的表達結(jié)構(gòu)化矩陣,可以捕獲很多線性變換,能在 GPU 上實現(xiàn)很高的計算效率,并次二次地擴展?!?/span>
下圖展示了 Monarch 矩陣的圖片,它會首先按序列長度有效地混合輸入,然后再使用模型維度,而不是兩者同時。
「為了驗證概念,我們探索了 M2 在三個領(lǐng)域的性能:非因果 BERT 式的語言建模、ViT 式的圖像分類、因果 GPT 式的語言建模?!?/span>
本文關(guān)注的是 GPT 式的語言建模,因為這種技術(shù)勢頭正盛,但無注意力架構(gòu)很可能助益許多領(lǐng)域(就像擴散模型那樣)。對于最近才涉足機器學(xué)習(xí)領(lǐng)域的人來說,Transformer(BERT)的雙向編碼器表征是值得了解的 —— 它是第一個產(chǎn)生大量微調(diào)版本并帶來諸多好處的 Transformer 模型。繼續(xù)談 BERT 式模型和 ViT(視覺 Transformer)的性能:
「對于非因果式 BERT 式建模,M2 在下游的 GLUE 質(zhì)量上可以比肩 BERT-base 和 BERT-large,同時參數(shù)量少 27%,在 4k 序列長度上也能實現(xiàn)高 9.1 倍的吞吐量。在 ImageNet 上,M2 只用一半的參數(shù)量就讓準(zhǔn)確度超過了 ViT-b 1%?!?/span>
現(xiàn)在回到 GPT 式模型:
「因果 GPT 式模型帶來了一個技術(shù)挑戰(zhàn):通過掩碼實現(xiàn)強制因果性會帶來二次的計算瓶頸。為了減輕這個瓶頸的影響,我們基于多元多項式評估和插值對 Monarch 矩陣進行了全新的理論分析,這讓我們可以將 M2 參數(shù)化成因果模型,同時保留次二次特性。使用這種參數(shù)化,在 The PILE 上,M2 能在預(yù)訓(xùn)練困惑度指標(biāo)上比肩 360M 參數(shù)的 GPT 式 Transformer—— 這首次表明也許不使用注意力或 MLP 也能達到比肩 Transformer 的質(zhì)量。」
這一段的信息量很大。本質(zhì)上講,當(dāng)使用 transformer 執(zhí)行推理時,需要將注意力矩陣遮掩成上三角矩陣,以便每個生成的 token 僅會看過去的 token。這在該模型的解碼器部分,如果你觀察 BERT 這樣的編碼器,你會看到一個完全激活的注意力矩陣。M2 論文中的數(shù)學(xué)緩解了這個二次瓶頸(粗略地講,直接為 N 個生成的 token 關(guān)注 M 個上下文 token)。這很了不起,但其中的數(shù)學(xué)難以理解。但你可以在那個采訪中看理解的人如何解釋。
那篇論文中的模型有 360M 參數(shù),所以這在很多 GPT2 模型的范圍內(nèi),還有很長的路要走。
模型 Zoology 和 Based 模型
在這波模型發(fā)布熱潮中還有 Zoology:一個在合成任務(wù)上理解和測試語言模型架構(gòu)的軟件庫。地址:https://github.com/HazyResearch/zoology
Hazy Research 的 Christopher Re 研究組分享了兩篇相關(guān)的博客文章。
第一篇文章研究了不同的架構(gòu)如何管理語言建模中的關(guān)聯(lián)召回(associate recall)問題,即模型檢索和組合來自不同數(shù)據(jù)源或概念的信號的能力。地址:https://hazyresearch.stanford.edu/blog/2023-12-11-zoology1-analysis
簡而言之,這是基于注意力的模型表現(xiàn)很好的任務(wù)。其中給出了一個惹眼的例子:
「我們發(fā)現(xiàn)一個 7000 萬參數(shù)的注意力模型的表現(xiàn)優(yōu)于一個 14 億參數(shù)的門控卷積模型。盡管關(guān)聯(lián)召回聽上去像是只有內(nèi)行人才懂的任務(wù),但其在機器學(xué)習(xí)領(lǐng)域已經(jīng)有很長的歷史,并且之前已有研究表明解決這些任務(wù)與上下文學(xué)習(xí)等吸引人的功能有關(guān)?!?/span>
第二篇文章更是引起了人們關(guān)注,因為他們基于自己的發(fā)現(xiàn)發(fā)布了一個新模型架構(gòu):Based 地址:https://hazyresearch.stanford.edu/blog/2023-12-11-zoology2-based
這篇文章中寫到:
我們在這三個維度上展示了這些性質(zhì),結(jié)果表明 Based 能夠提供:
簡單直觀的理解:Based 的基礎(chǔ)是一種觀點,即簡單的卷積和注意力就善于建模不同類型的序列。我們沒有引入新的復(fù)雜性來克服它們各自的缺點,而是以一種直觀的方式將各自常用的版本(短一維卷積,「尖峰」線性注意力)組合起來,從而兼取了兩者之長。
高質(zhì)量建模:盡管模型很簡單,但評估中發(fā)現(xiàn),在語言建模困惑度方面, Based 在多個規(guī)模等級上都優(yōu)于完整版的 Llama-2 式的 Transformer(旋轉(zhuǎn)嵌入、SwiGLU MLP 等)和現(xiàn)代的狀態(tài)空間模型(Mamba、Hyena)。
高效的高吞吐量推理:對于使用純 PyTorch 實現(xiàn)的模型,Based 的推理吞吐量比相競爭的 Transformer(一個參數(shù)量相當(dāng)?shù)?Mistral 并使用了滑動窗口注意力和 FlashAttention 2)高 4.5 倍。為了讓 LLM 執(zhí)行批處理任務(wù),高吞吐量是非常關(guān)鍵的。
總而言之,不同的架構(gòu)都有各自不同的優(yōu)勢和短板。