不分割成token,直接從字節(jié)中高效學習,Mamba原來還能這樣用
在定義語言模型時,通常會使用一種基本分詞方法,把句子分為詞(word)、子詞(subword)或字符(character)。其中,子詞分詞法一直是最受歡迎的選擇,因為它在訓練效率和處理詞匯表外單詞的能力之間實現(xiàn)了自然的折中。然而,一些研究指出了子詞分詞法的問題,如對錯別字、拼寫和大小寫變化以及形態(tài)變化缺乏穩(wěn)健性。
因此,有些研究人員另辟蹊徑,采用了一種使用字節(jié)序列的方法,即從原始數(shù)據(jù)到預測的端到端映射,中間不進行任何分詞。與子詞模型相比,基于字節(jié)級的語言模型能夠更容易地在不同的書寫形式和形態(tài)變化之間進行泛化。當然,將文本建模為字節(jié)意味著生成的序列要比對應的子詞長得多。如此一來,效率的提升就要依靠架構(gòu)的改進來實現(xiàn)了。
自回歸 Transformer 在語言建模中占主導地位,但效率問題尤為突出:計算成本隨序列長度呈二次方增長,因此對長(字節(jié))序列的擴展能力很差。研究人員壓縮了 Transformer 的內(nèi)部表示,以便處理長序列,例如開發(fā)了長度感知建模方法,在這種方法中,token 組在中間層內(nèi)合并。最近,Yu 等人 [2023] 提出了 MegaByte Transformer,它使用固定大小的字節(jié)片段作為子詞的模擬壓縮形式。因此,MegaByte 可以降低計算成本。不過,這可能還不是最好的方法。
在一份新論文中,來自康奈爾大學的研究者介紹了一種高效、簡單的字節(jié)級語言模型 MambaByte。該模型對最近推出的 Mamba 架構(gòu)進行了直接改造。Mamba 建立在狀態(tài)空間模型(SSM)開創(chuàng)的方法基礎(chǔ)上,引入了對文本等離散數(shù)據(jù)更有效的選擇機制,并提供了高效的 GPU 實現(xiàn)。作者的簡單觀察結(jié)果是,使用 Mamba(不做修改)可以緩解語言建模中的主要計算瓶頸,從而消除 patching 并有效利用可用的計算資源。
- 論文標題:MambaByte: Token-free Selective State Space Model
- 論文鏈接:https://arxiv.org/pdf/2401.13660.pdf
他們在實驗中將 MambaByte 與 Transformers、SSM 和 MegaByte(patching)架構(gòu)進行了比較,這些架構(gòu)都是在固定參數(shù)和固定計算設(shè)置下,并在多個長篇文本數(shù)據(jù)集上進行比較的。圖 1 總結(jié)了他們的主要發(fā)現(xiàn)。
與字節(jié)級 Transformers 相比,MambaByte 能更快地實現(xiàn)更好的性能,計算效率也明顯更高。作者還考慮了無 token 語言模型與現(xiàn)有最先進的子詞模型相比的可行性。在這方面,他們發(fā)現(xiàn) MambaByte 與各種子詞基線模型相比具有競爭力,但它能處理更長的序列。研究結(jié)果表明,MambaByte 是現(xiàn)有依賴分詞器( tokenizer)的模型的有力替代品,有望用來促進端到端學習。
背景:選擇性狀態(tài)空間序列模型
SSM 通過一階微分方程對隱藏狀態(tài)的跨時間演變進行建模。線性時不變(time-invariant) SSM 在幾種模態(tài)的深度學習中顯示出了良好的效果。然而,Mamba 作者 Gu 和 Dao 最近認為,這些方法的恒定動態(tài)缺乏隱藏狀態(tài)中依賴輸入的上下文選擇,而這可能是語言建模等任務所必需的。為此,他們提出了 Mamba,該方法將給定輸入 x (t) ∈ R、隱藏狀態(tài) h (t) ∈ R^n 和輸出 y (t) ∈ R 在時間 t 的時變連續(xù)狀態(tài)動態(tài)定義為:
其參數(shù)為對角時不變系統(tǒng)矩陣 A∈R^(n×n),以及隨時間變化的輸入和輸出矩陣 B (t)∈R^(n×1) 和 C (t)∈R^(1×n)。
要對字節(jié)等離散時間序列建模,必須通過離散化來逼近 (1) 中的連續(xù)時間動態(tài)。這就產(chǎn)生了離散時間隱態(tài) recurrence,每個時間步都有新矩陣 A、B 和 C,即
請注意,(2) 類似于循環(huán)神經(jīng)網(wǎng)絡的線性版本,可以在語言模型生成過程中以這種循環(huán)形式應用。離散化要求每個輸入位置都有一個時間步,即 ?[k],對應于 的 x [k] = x (t_k)。然后就可以根據(jù) ?[k] 計算出離散時間矩陣 A、B 和 C。圖 2 展示了 Mamba 如何為離散序列建模。
在 Mamba 中,SSM 項是輸入選擇性的,即 B、C 和 ? 被定義為輸入 x [k]∈R^d 的函數(shù):
其中 W_B ∈ R^(n×d)(C 的定義類似),W_? ∈ R^(d×r) 和 W_R ∈ R^(r×d)(對于某個 r ?d)是可學習的權(quán)重,而 softplus 則確保正向性。請注意,對于每個輸入維度 d,SSM 參數(shù) A、B 和 C 都是相同的,但時間步數(shù) ? 是不同的;這導致每個時間步數(shù) k 的隱藏狀態(tài)大小為 n × d。
Mamba 將這個 SSM 層嵌入到一個完整的神經(jīng)網(wǎng)絡語言模型中。具體來說,該模型采用了一系列門控層,其靈感來源于之前的門控 SSM。圖 3 顯示了將 SSM 層與門控神經(jīng)網(wǎng)絡相結(jié)合的 Mamba 架構(gòu)。
線性 recurrence 的并行掃描。在訓練時,作者可以訪問整個序列 x,從而更高效地計算線性 recurrence。Smith et al. [2023] 的研究證明,使用工作效率高的并行掃描可以高效計算線性 SSM 中的順序 recurrence。對于 Mamba,作者首先將 recurrence 映射到 L 個元組序列,其中 e_k =,然后定義一個關(guān)聯(lián)算子
使得
。最后,他們應用并行掃描計算序列
。一般來說,這需要
時間,使用 L/2 個處理器,其中
是矩陣乘法的成本。注意,A 是一個對角矩陣,線性 recurrence 可在
時間和 O (nL) 空間內(nèi)并行計算。使用對角矩陣進行并行掃描的運行效率也很高,只需 O (nL) FLOPs。
實驗結(jié)果
表 2 顯示了每個數(shù)據(jù)集的每字節(jié)比特數(shù)(BPB)。在本實驗中,MegaByte758M+262M 和 MambaByte 模型使用相同的每字節(jié) FLOP 數(shù)(見表 1)。作者發(fā)現(xiàn),在所有數(shù)據(jù)集上,MambaByte 的性能始終優(yōu)于 MegaByte。此外,作者注意到,由于資金限制,他們無法對 MambaByte 進行完整的 80B 字節(jié)訓練,但 MambaByte 在計算量和訓練數(shù)據(jù)減少 63% 的情況下仍優(yōu)于 MegaByte。此外,MambaByte-353M 還優(yōu)于字節(jié)級 Transformer 和 PerceiverAR。
在如此少的訓練步驟中,MambaByte 為什么比一個大得多的模型表現(xiàn)得更好?圖 1 通過觀察參數(shù)數(shù)量相同的模型進一步探討了這種關(guān)系。圖中顯示,對于參數(shù)大小相同的 MegaByte 模型,輸入 patching 較少的模型表現(xiàn)更好,但在計算歸一化后,它們的表現(xiàn)類似。事實上,全長的 Transformer 雖然在絕對意義上速度較慢,但在計算歸一化后,其性能也與 MegaByte 相似。相比之下,改用 Mamba 架構(gòu)可以顯著提高計算使用率和模型性能。
根據(jù)這些發(fā)現(xiàn),表 3 比較了這些模型在 PG19 數(shù)據(jù)集上的較大版本。在這個實驗中,作者將 MambaByte-972M 與 MegaByte-1.3B+350M 和其他字節(jié)級模型以及幾個 SOTA 子詞模型進行了比較。他們發(fā)現(xiàn),MambaByte-972M 即使只訓練了 150B 字節(jié),其性能也優(yōu)于所有字節(jié)級模型,并與子詞模型相比具有競爭力。
文本生成。Transformer 模型中的自回歸推理需要緩存整個上下文,這會大大影響生成速度。MambaByte 不存在這一瓶頸,因為它每層只保留一個隨時間變化的隱藏狀態(tài),因此每生成一步的時間是恒定的。表 4 比較了 MambaByte-972M 和 MambaByte-1.6B 與 MegaByte-1.3B+350M 在 A100 80GB PCIe GPU 上的文本生成速度。雖然 MegaByte 通過 patching 大大降低了生成成本,但他們觀察到 MambaByte 由于使用了循環(huán)生成,在參數(shù)相似設(shè)置下速度達到了前者的 2.6 倍。