原作者帶隊再次改造xLSTM,7B模型速度最快超Mamba 50%,權(quán)重代碼全開源
近年來,大型語言模型(LLM)通過大量計算資源在推理階段取得了解決復(fù)雜問題的突破。推理速度已成為 LLM 架構(gòu)的關(guān)鍵屬性,市場對高效快速的 LLM 需求不斷增長。
其中,采用 Transformer 架構(gòu)的模型雖然占據(jù)了主流,但在輸入序列長度增加時,計算量會呈二次方增長。因此,自上個世紀(jì) 90 年代興起的 LSTM 卷土重來,它的提出者和奠基者 Sepp Hochreiter 在去年 5 月推出了 xLSTM,將 LSTM 擴展到數(shù)十億參數(shù),成為 Transformer 的有力替代品,提供了與序列長度線性相關(guān)的計算擴展和穩(wěn)定的內(nèi)存占用。
然而,xLSTM 在擴展至更大參數(shù)規(guī)模時存在限制,推理速度和效率具體如何也沒做系統(tǒng)測評。
近日,Sepp Hochreiter 等來自 NXAI、JKU 的研究者再次對 xLSTM 進行了優(yōu)化,現(xiàn)在可以擴展到 70 億參數(shù)了。
具體來講,xLSTM 7B 模型基于 DCLM 數(shù)據(jù)集,使用 128 塊 H100 GPU,在 8192 上下文長度下訓(xùn)練了 2.3 萬億 token。研究者對原始 xLSTM 架構(gòu)進行了改進,確保訓(xùn)練效率和穩(wěn)定性,同時保持任務(wù)性能。新架構(gòu)依靠 mLSTM 單元和并行訓(xùn)練模式,實現(xiàn)高性能的同時最大化速度。
- 論文標(biāo)題:xLSTM 7B: A Recurrent LLM for Fast and Efficient Inference
- 論文地址:https://arxiv.org/pdf/2503.13427
- 代碼地址:https://github.com/NX-AI/xlstm
- Hugging Face 地址:https://huggingface.co/NX-AI/xLSTM-7b
通過修改模塊架構(gòu),研究者優(yōu)化了吞吐量,在低維空間運行 mLSTM 并添加前饋 MLP 層,同時去除了不必要的組件以提高 GPU 利用率。優(yōu)化后的架構(gòu)在保持相似性能的同時,將 token 吞吐量提高了 2 到 4 倍。研究者還優(yōu)化了訓(xùn)練穩(wěn)定性,特別是 mLSTM 單元的門控機制,有效解決了梯度問題。
在各類任務(wù)評估中,xLSTM 7B 與同規(guī)模 Transformer 和 Mamba 模型表現(xiàn)相當(dāng)。通過架構(gòu)優(yōu)化,該模型在推理效率測試中實現(xiàn)了最高的預(yù)填充和生成吞吐量,同時保持最低的 GPU 內(nèi)存占用。
論文作者之一 Günter Klambauer 表示,xLSTM 7B 成為了最快、最高效的 7B 語言模型!
優(yōu)化的 xLSTM 7B 架構(gòu)
xLSTM 7B 架構(gòu)的核心是 mLSTM 單元,它的循環(huán)和并行模式可以實現(xiàn)高效的訓(xùn)練和推理。為了充分發(fā)揮該單元的潛力,研究者重新審視了相鄰塊結(jié)構(gòu)的設(shè)計。
與 Mamba 等其他線性 RNN 類似,以前的 xLSTM 架構(gòu)將與通道卷積相結(jié)合的 mLSTM 單元置于線性上投影和下投影之間,這被稱為預(yù)上投影(pre up-projection )塊。這些塊將序列混合和通道混合結(jié)合在一個塊中,因此均勻堆疊,而無需交錯位于前饋 MLP 層。盡管預(yù)上投影塊架構(gòu)已展示出了對 1.4B 參數(shù) xLSTM 的競爭性語言建模性能,但由于以下幾方面的原因,它在計算效率方面付出了很大代價:
- 在預(yù)上投影塊中,mLSTM 在比模型嵌入維數(shù)高得多的維數(shù)上運行,這導(dǎo)致 mLSTM 操作的計算成本和 GPU 內(nèi)存使用量大幅增加。
- 省略位置前饋 MLP 層會導(dǎo)致模型中高效線性層 FLOP 的比例下降。
- 以前的 xLSTM 架構(gòu)使用幾個額外的組件,例如可學(xué)習(xí)的殘差連接、通道卷積以及用于計算查詢、鍵和值的?。▔K對角化)投影層。如果沒有自定義內(nèi)核融合,這些小操作會導(dǎo)致 GPU 上出現(xiàn)多個短內(nèi)核調(diào)用,無法有效利用張量核心,從而大幅降低 GPU 利用率。
- 以前,輸入和遺忘門預(yù)激活是通過連接的查詢、鍵和值投影計算出來的。而在大規(guī)模張量并行訓(xùn)練設(shè)置中,這需要每個 mLSTM 塊進行額外的全歸約操作,從而增加總體通信成本。
因此,為了將 xLSTM 擴展到更大的模型大小,研究者通過解決以上四個限制來優(yōu)化 mLSTM 塊以實現(xiàn)最大效率。
對于優(yōu)化 mLSTM 塊,研究者首先在模型的嵌入維數(shù)而不是更高維數(shù)的空間中操作 mLSTM 單元,并在每個 mLSTM 層之后放置位置前饋 MLP 層。此修改增加了高度優(yōu)化的線性層(即矩陣乘法)FLOP 的比例,并降低了 mLSTM 操作的計算成本。顯著減少的 GPU 內(nèi)存使用量使得在訓(xùn)練期間可以使用更大的批大小,從而提高了訓(xùn)練效率。
此外,研究者放棄了通道卷積和可學(xué)習(xí)的殘差連接等操作,并用密集線性層替換塊查詢、鍵和值投影。這再次增加了線性層 FLOP,并確保有效使用 mLSTM 層內(nèi)的張量核。最后,確保每個 head 的門預(yù)激活都是獨立計算的。
這些優(yōu)化產(chǎn)生了下圖 1 和下圖 8 中改進后的 mLSTM 塊和 xLSTM 架構(gòu),其中在 xLSTM 7B 架構(gòu)中堆疊了 32 個 mLSTM 塊。
下表 4 為 xLSTM 7B 的超參數(shù),包括模型參數(shù)(近 70 億)、詞表大小(50257)、塊數(shù)量(32)、模型維數(shù)(4096)以及 head 數(shù)(8)。
研究者觀察到,本文優(yōu)化在 1.4B 參數(shù)的模型訓(xùn)練中實現(xiàn)了 3.5 倍的加速,但在驗證困惑度方面略有損失,可以通過增加幾個訓(xùn)練步驟來緩解,詳見下表 2。
優(yōu)化穩(wěn)定性
研究者發(fā)現(xiàn),先前在 7B 參數(shù)規(guī)模下的 xLSTM 架構(gòu)在訓(xùn)練初期階段常出現(xiàn)不穩(wěn)定現(xiàn)象。具體而言,他們觀察到在較高學(xué)習(xí)率條件下訓(xùn)練會導(dǎo)致梯度幅度和損失值劇烈波動。本文通過以下方法解決了這些穩(wěn)定性問題:
- 使用 RMSNorm 替代 LayerNorm;
- 對輸入門和遺忘門實施軟上限限制;
- 對輸入門偏置進行負(fù)初始化。
1. 使用 RMSNorm 的預(yù)歸一化(Pre-Norm with RMSNorm)
下圖 9 中的實驗證實,預(yù)歸一化技術(shù)同樣適用于 xLSTM 架構(gòu)的預(yù)歸一化層。因此,研究者在 xLSTM 架構(gòu)中將 LayerNorm 替換為 RMSNorm(全稱為 Root Mean Square Normalization)。
2. 門控軟上限限制(Gate Soft-Capping)
為了降低潛在的大幅異常特征和相關(guān)損失峰值,研究者對輸入門和遺忘門的預(yù)激活值應(yīng)用了軟上限限制,使其值被限制在特定上限值 a 的 - a 與 a 之間。本文采用 a=15 對門控進行限制,所使用的函數(shù)為
3. 負(fù)輸入門偏置初始化(Negative Input Gate Bias Initialization)
研究者發(fā)現(xiàn),在訓(xùn)練初期,xLSTM 模型會出現(xiàn)較大的梯度范數(shù)峰值,這對模型的最終性能產(chǎn)生不利影響(詳見下圖 11)。將輸入門初始化為較大的負(fù)值(如 - 10)能有效緩解這些梯度范數(shù)峰值,從而提升模型性能。
綜上所述,這些優(yōu)化措施使 xLSTM 7B 的預(yù)訓(xùn)練過程變得極為穩(wěn)定,如下圖 2 所示。
語言建模性能評估
Huggingface 排行榜
研究者首先在 7B 參數(shù)規(guī)模上,將 xLSTM 7B 與最先進的 Transformer 和循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)大語言模型進行了基準(zhǔn)測試。
結(jié)果總結(jié)在下表 1 中,顯示 xLSTM 7B 在 7B 規(guī)模模型中排名居中,其中一些表現(xiàn)更好的模型受益于更大規(guī)模的訓(xùn)練數(shù)據(jù)集。研究者認(rèn)為,如果使用更大且更精心策劃的訓(xùn)練數(shù)據(jù)集,尤其是在早期訓(xùn)練階段更加注重數(shù)學(xué)和代碼數(shù)據(jù),xLSTM 7B 可能會達到最強 7B 模型的性能水平。
長文本評估與微調(diào)
研究者將 xLSTM 與幾種基線模型進行了比較:作為 Transformer 基線的 Llama 2 7B(未進行長文本微調(diào))和 Llama 3.1 8B(已進行長達 131K 詞元的長文本微調(diào)),作為狀態(tài)空間模型(State Space Model,SSM)基線的 CodestralMamba 和 FalconMamba,以及作為額外循環(huán)神經(jīng)網(wǎng)絡(luò)(Recurrent Neural Network,RNN)基線的 RWKV-5/6。
下表 3 展示了 RULER 評估結(jié)果。對于 xLSTM 7B,預(yù)訓(xùn)練中的長文本降溫(cooling)階段極大地提升了其長文本處理能力,使其性能與狀態(tài)空間模型相當(dāng),并且優(yōu)于 RWKV-5/6。
值得注意的是,長文本 xLSTM 7B 在 131K 上下文長度時實現(xiàn)了 20% 的平均準(zhǔn)確率,盡管在降溫階段訓(xùn)練時僅使用了最多 32K 的上下文長度。這一點尤為顯著,因為與具有不斷增長的 KV 緩存(Key-Value cache)的 Transformer 不同,xLSTM 7B 必須在有限容量的固定大小內(nèi)存中存儲整個序列的信息(見表 3)。
速度基準(zhǔn)測試
本研究主要關(guān)注本地單用戶推理場景,這在模型部署到邊緣設(shè)備時較為常見。除非另有說明,研究在單個英偉達 H100 GPU 上對批大小為 1 的 xLSTM 7B 模型進行生成式推理基準(zhǔn)測試,并將其與 Llama 2 和 Llama 3 模型進行了比較。
生成吞吐量
如下圖 4 所示,由于注意力機制隨輸入上下文長度呈二次方增長,Transformer 模型在較長預(yù)填充長度下的文本生成速度顯著降低。
研究表明,xLSTM 7B 的文本生成速度比 Mamba 快約 50%,這主要得益于其優(yōu)化的塊設(shè)計。即使在預(yù)填充長度為 0 的情況下,xLSTM 7B 也比采用類似塊設(shè)計的基于 Llama 的 Transformer 模型更快。
生成效率與內(nèi)存消耗分析
研究者測量了不同生成長度下的 token 生成時間和 GPU 內(nèi)存使用情況(不包括預(yù)填充)。圖 5(左)展示了循環(huán)模型在計算時間上呈線性增長,與 Transformer 呈二次方增長的對比;圖 5(右)則顯示了循環(huán)模型內(nèi)存占用保持恒定,而 Transformer 的 KV 緩存隨生成長度線性增長的對比。
得益于優(yōu)化的模塊設(shè)計,mLSTM 在低維空間中運行,使得 xLSTM 7B 模型與 Mamba 模型相比具有顯著更低的內(nèi)存占用(如下圖 5 右側(cè)所示)和更短的生成時間(如圖 5 左側(cè)所示)。
TTFT(Time To First Token)
在語言模型作為用戶界面(可能在邊緣設(shè)備上)的應(yīng)用場景中,較短的響應(yīng)時間至關(guān)重要。下圖 6 展示了不同模型在處理各種長度的預(yù)填充(prefill)內(nèi)容后,生成 1 個或 100 個 token 所需的響應(yīng)時間或延遲。在所有預(yù)填充長度條件下,xLSTM 7B 模型均表現(xiàn)出最快的響應(yīng)速度。
更多實驗結(jié)果請參閱原論文。