DenseMamba:大模型的DenseNet時刻,Mamba和RetNet精度顯著提升
隨著 ChatGPT 的突破性進展,大型語言模型(LLMs)迎來了一個嶄新的里程碑。這些模型在語言理解、對話交互和邏輯推理方面展現(xiàn)了卓越的性能。過去一年,人們目睹了 LLaMA、ChatGLM 等模型的誕生,它們基于 Transformer 架構,采用多頭自注意力(MHSA)機制來捕捉詞匯間的復雜關系,盡管 MHSA 模塊在模型中扮演著核心角色,但其在推理過程中對計算和內(nèi)存資源的需求卻極為龐大。具體來說,對于長度為 N 的輸入句子,自注意力的計算復雜度高達 O (N^2),而內(nèi)存占用則達到了 O (N^2D),其中 D 是模型的維度。
為了應對這一挑戰(zhàn),最新的研究致力于簡化 Transformer 架構,以降低其在計算和空間上的復雜度。研究者們探索了多種創(chuàng)新方法,包括卷積語言模型、循環(huán)單元、長上下文模型,以及狀態(tài)空間模型(SSMs)。這些新興技術為構建高效能的 LLMs 提供了強有力的替代方案。SSMs 通過引入高效的隱藏狀態(tài)機制,有效處理長距離依賴問題,同時保持了訓練的并行性和推理的高效率。隱藏狀態(tài)能夠在時間維度上傳遞信息,減少了在每一步中訪問歷史詞匯的計算負擔。通過狀態(tài)轉(zhuǎn)移參數(shù) A,隱藏狀態(tài)能夠?qū)⑶耙粫r間步的信息傳遞至當前時間步,實現(xiàn)對下一個詞匯的自回歸預測。
盡管隱藏狀態(tài)在 SSMs 中起著至關重要的作用,但其在以往的研究中并未得到充分研究。不同層的權重和隱藏特征包含了從細粒度到粗粒度的多層次信息。然而,在早期的 SSMs 版本中,隱藏狀態(tài)僅在當前層內(nèi)流動,限制了其傳遞更深層信息的能力,從而影響了模型捕獲豐富層次信息的能力。
為了解決這個挑戰(zhàn),華為諾亞方舟實驗室的科研團隊發(fā)表了新工作《DenseMamba: State Space Models with Dense Hidden Connection for Efficient Large Language Models》, 提出一個適用于各類 SSM 模型例如 Mamba 和 RetNet 的 DenseSSM 方法,該方法有選擇地將淺層隱藏狀態(tài)整合到深層,保留了對最終輸出至關重要的淺層細粒度信息,以增強深層感知原始文本信息的能力。
- 論文鏈接:https://arxiv.org/abs/2403.00818
- 項目主頁:https://github.com/WailordHe/DenseSSM
文章首先分析了狀態(tài)空間模型(SSMs)中的隱藏狀態(tài)退化問題,
上標 “l(fā)” 表示第 l 個塊。其中,Θ(·) 是從 SSM 模塊的最后一個輸出到輸入的轉(zhuǎn)換,例如卷積和前饋網(wǎng)絡(FFN)。從公式 (7) 可以看出,從第 (l-m) 層到第 l 層的隱藏信息傳遞需要經(jīng)過 m 個變換塊和 m 次 BC 矩陣乘法。這樣復雜的計算過程可能導致顯著的信息丟失,這意味著在第 l 層嘗試檢索淺層的某些信息變得非常困難和不清晰。
方法
密集(Dense)隱藏層連接
在上述分析中發(fā)現(xiàn)隨著層深度的增加,SSM 中重要隱藏狀態(tài)的衰減。因此,DenseSSM 提出了一種密集連接的隱藏狀態(tài)方法,以更好地保留來自淺層的細粒度信息,增強深層感知原始文本信息的能力。對于第 l 個塊,DenseSSM 在其前 m 個塊中密集連接隱藏狀態(tài)。
首先,收集淺層隱藏狀態(tài),并引入一個選擇性轉(zhuǎn)換模塊 φ,同時將它們投影到目標層的子空間并選擇有用的部分:
操作是融合中間隱藏向量和當前隱藏狀態(tài)的函數(shù)。具有所提出的密集隱藏層連接的 SSM 被稱為 DenseSSM, 下圖為遞歸模式的 DenseSSM 示例。
DenseSSM 也可以基于卷積模式以實現(xiàn)高效訓練。根據(jù)狀態(tài)空間模型(SSM)的公式可以得到:
這個過程可以通過對輸入序列進行卷積來實現(xiàn):
在文章所提出的 DenseSSM 中,可以獲得隱藏狀態(tài)加強的 SSM 的輸出:
DenseSSM 方法的并行實現(xiàn)示例圖:
Selective Transition Module (選擇性轉(zhuǎn)換模塊)
選擇性轉(zhuǎn)換模塊 φ(·) 的目的是將輸入投影到目標子空間,并同時選擇隱藏信息的有用部分。通過投影層和門控選擇機制實現(xiàn)了選擇性轉(zhuǎn)換模塊,如上圖所示。首先,前 m 個 SSM 塊中的隱藏狀態(tài)會被投影到相同的空間:
然后,根據(jù)輸入生成門控權重,并使用它們來選擇有用的隱藏狀態(tài):
在實踐中作者保持了簡單且高效的實現(xiàn)。投影層使用線性變換實現(xiàn),而門控模塊則使用參數(shù)高效的帶有激活函數(shù)的兩層 MLP。
Hidden Fusion Module (隱藏層融合模塊)
選擇性轉(zhuǎn)換模塊后從淺層獲得了選擇的隱藏狀態(tài),即后,DenseSSM 方法利用一個隱藏融合模塊將這些精選的淺層隱藏狀態(tài)與當前層的隱藏狀態(tài)結合起來。由于這些精選狀態(tài)已經(jīng)被投影到相同的空間,因此可以簡單地將它們累加到當前層的隱藏狀態(tài)上:
為了保持模型的高效性,其他可能的實現(xiàn)方式,例如拼接和交叉注意力機制沒有被使用。
擴展到 RetNet
RetNet 可以被視為一種狀態(tài)空間模型,它利用線性注意力來簡化自注意力的計算復雜度。與標準 Transformer 相比具有快速推理和并行化訓練兼得的優(yōu)勢。
其中,是循環(huán)狀態(tài), RetNet 的密集 KV 連接執(zhí)行方式如下。首先,淺層的 K 和 V 被連接起來:
然后,這些 K 和 V 被注入到當前層的原始鍵(或值)中:
配備了使用所提出 DenseSSM 方法的密集鍵值(KV)連接的 RetNet 被稱為 DenseRetNet,如下圖所示。
此外,DenseRetNet 也可以在并行模式下實現(xiàn),也就是說,可以在 GPU 或 NPU 上并行訓練。DenseRetNet 的并行模式公式如下:
實驗
文章進行了全面的實驗,以驗證所提出的 DenseSSM 的有效性。這些實驗在不同的架構上進行,包括 RetNet 和 Mamba。
預訓練數(shù)據(jù)
在實驗中,選擇了 The Pile 數(shù)據(jù)集的一個子集,并從頭開始訓練所有模型。為了確保訓練集包含 150 億(15B)個 tokens,對數(shù)據(jù)集進行了隨機抽樣。在所有實驗中,統(tǒng)一使用了 LLaMA 分詞器來處理這些數(shù)據(jù)。
評估數(shù)據(jù)集
在評估模型性能時,特別關注了模型在多種下游任務上的零樣本和少樣本學習能力。這些任務包括了一系列測試常識推理和問答的數(shù)據(jù)集,例如 HellaSwag、BoolQ、COPA、PIQA、Winograd、Winogrande、StoryCloze、OpenBookQA、SciQ、ARC-easy 和 ARC-challenge。此外,文章還報告了 WikiText 和 LAMBADA 的詞困惑度指標。所有評估都通過使用 LM evaluation harness 標準化的評估工具進行,以確保評估模型能力的一致性。
實驗設置
為了驗證提出的 DenseSSM 機制的有效性,選擇了 350M 和 1.3B 兩種模型規(guī)格進行實驗。所有模型都是從頭開始訓練的,并進行了一個 Epoch 的訓練,共使用了 1.5B tokens。訓練時,設置訓練的 batch size 為 0.5M,序列長度為 2048 個 token。訓練過程中使用了 AdamW 優(yōu)化器,并采用了多項式學習率衰減,warm-up 比例設置為總訓練步數(shù)的 1.5%。權重衰減設置為 0.01,梯度裁剪設置為 1。
DenseRetNet 的實驗
DenseRetNet 模型的大小和超參數(shù)設置詳細列出如下。此外,DenseRetNet 模型中還進一步集成了全局注意力單元(GAU)。GAU 將注意力機制與前饋網(wǎng)絡(FFN)塊結合為一個單元,這使得模型能夠同時進行通道混合和 token 混合。與原始的 GAU 不同,多頭機制仍然被采用以實現(xiàn)多尺度的指數(shù)衰減,這種設計旨在提高模型對不同尺度特征的捕捉能力,從而提升性能。
在通用語料庫以及包括常識推理和問答在內(nèi)的多種下游任務上,對 DenseRetNet 模型進行了評估。實驗結果的比較表格顯示,DenseRetNet 模型在 Wikitext 和 LAMBADA 語料庫上取得了更低的困惑度。此外,在零樣本和少樣本設置的下游任務中,DenseRetNet 表現(xiàn)出了顯著的優(yōu)勢。與 RetNet 相比,DenseRetNet 顯著提升了性能,并且在與基于 Transformer 的語言模型的比較中,實現(xiàn)了更優(yōu)越的性能表現(xiàn)。這些結果表明,DenseRetNet 在處理自然語言處理任務時,具有強大的能力和潛力。
DenseMamba 的實驗
下表詳細列出了 DenseMamba 模型的參數(shù)設置。由于 DenseMamba 使用的分詞器相比于 Mamba 模型中使用的 GPT-NeoX 分詞器規(guī)模較小,為了使參數(shù)數(shù)量相匹配,作者在模型中增加了兩層。除此之外,模型結構和其他訓練設置均遵循了 Mamba 論文中的描述。具體而言,對于 360M 參數(shù)的模型,學習率被設定為 3e-4;對于 1.3B 參數(shù)的模型,學習率被設定為 2e-4。在這兩種情況下,均沒有采用 dropout 技術。
下表比較了 DenseMamba 與相對應模型的性能。DenseMamba 在測試集上表現(xiàn)出卓越的困惑度和準確性,優(yōu)于 Mamba 和其他基于 Transformer 的模型。
總結
文章提出了一個新的框架 ——DenseSSM(密集狀態(tài)空間模型),旨在通過增強隱藏信息在不同層之間的流動來提升狀態(tài)空間模型(SSM)的性能。在 SSM 中,隱藏狀態(tài)是存儲關鍵信息的核心單元,更有效地利用這些狀態(tài)對于模型的基本功能至關重要。為了實現(xiàn)這一目標,作者提出了一種方法,即從淺層收集隱藏狀態(tài),并將它們有選擇性地融合到深層的隱藏狀態(tài)中,這樣可以增強 SSM 對文本低層信息的感知能力。
DenseSSM 方法的設計考慮到了保持 SSM 原有的優(yōu)點,如高效的自回歸推理能力和高效的并行訓練特性。通過將 DenseSSM 方法應用于流行的架構,例如 RetNet 和 Mamba,作者成功地創(chuàng)造了具有更強大的基礎語言處理能力的新架構。這些新架構在公共基準測試中表現(xiàn)出了更高的準確性,證明了 DenseSSM 方法的有效性。