在Transformer時代重塑RNN,RWKV將非Transformer架構擴展到數百億參數
Transformer 模型在幾乎所有自然語言處理(NLP)任務中都帶來了革命,但其在序列長度上的內存和計算復雜性呈二次方增長。相比之下,循環(huán)神經網絡(RNNs)在內存和計算需求上呈線性增長,但由于并行化和可擴展性的限制,很難達到與 Transformer 相同的性能水平。本文提出了一種新穎的模型架構,Receptance Weighted Key Value(RWKV),將 Transformer 的高效可并行訓練與 RNN 的高效推理相結合。實驗證明,RWKV 的性能與相同規(guī)模的 Transformer 相當。
深度學習技術在人工智能領域取得了重大進展,在各種科學和工業(yè)應用中發(fā)揮了關鍵作用。這些應用通常涉及復雜的序列數據處理任務,包括自然語言理解、對話式人工智能、時間序列分析等,其中用到的技術主要包括循環(huán)神經網絡(RNNs)、卷積神經網絡(CNNs)和 Transformer 等。
不過,這些方法各自存在不同的缺點,從而限制了它們在某些場景下的效率。循環(huán)神經網絡(RNNs)面臨著梯度消失的問題,使得它們難以對長序列進行訓練。此外,在訓練過程中無法在時間維度上并行化,進而限制了其可擴展性。另一方面,卷積神經網絡(CNNs)只擅長捕捉局部模式,在處理長程依賴方面還很欠缺,而這對于許多序列處理任務至關重要。
Transformer 模型由于其處理局部和長程依賴關系的能力以及可并行化訓練的特點而成為一個強大的替代方案,如 GPT-3、ChatGPT、GPT-4、LLaMA 和 Chinchilla 等都展示了這種架構的能力,推動了自然語言處理領域的前沿。盡管取得了這些重大進展,Transformer 中固有的自注意力機制帶來了獨特的挑戰(zhàn),主要是由于其二次復雜度造成的。這種復雜性使得該架構在涉及長輸入序列或資源受限情況下計算成本高昂且占用內存。這也促使了大量研究的發(fā)布,旨在改善 Transformer 的擴展性,但往往以犧牲一些特性為代價。
為了應對這些挑戰(zhàn),一個由 27 所大學、研究機構組成的開源研究團隊,聯合發(fā)表論文《 RWKV: Reinventing RNNs for the Transformer Era 》,文中介紹了一種新型模型:RWKV(Receptance Weighted Key Value),這是一種新穎的架構,有效地結合了 RNN 和 Transformer 的優(yōu)點,同時規(guī)避了兩者的缺點。RWKV 設計精良,能夠緩解 Transformer 所帶來的內存瓶頸和二次方擴展問題,實現更有效的線性擴展,同時保留了使 Transformer 在這個領域占主導的一些性質。
- 論文地址:https://arxiv.org/pdf/2305.13048.pdf
- RWKV 模型下載:https://huggingface.co/BlinkDL/rwkv-4-raven
- Demo 地址:https://www.codewithgpu.com/i/app/BlinkDL/ChatRWKV/RWKV-4-Raven-7B
本文利用線性注意力機制,允許將模型定義為 Transformer 或 RNN,從而在訓練期間并行化計算,并在推理過程中保持恒定的計算和內存復雜性,使其成為第一個可擴展到數百億參數的非 Transformer 架構。
RWKV 其中的一個特征是它能夠提供并行訓練和強大的可擴展性,類似于 Transformer。此外,該研究對 RWKV 中的注意力機制進行了重新闡述,引入了線性注意力的一個變體,避開了傳統點積(dot-product)token 交互,轉而采用更有效的通道導向注意力( channel directed attention )。這種方法與傳統的 Transformer 架構形成了鮮明的對比,其中特定的 token 交互主導了注意力。在 RWKV 中,線性注意力的實施是無需近似的,這在效率上提供了顯著的改進,并增強了可擴展性,詳見表 1。
該研究表示,開發(fā) RWKV 的主要動機是彌補神經網絡架構在計算效率和表達能力之間的差距。它為處理涉及數十億參數的大規(guī)模模型的任務提供了一個有希望且可行的解決方案,以極低的計算成本展現出強有力的競爭性。
實驗結果表明,RWKV 可以成為一個有價值的工具,用于解決各個領域擴展和部署人工智能模型的各種挑戰(zhàn),特別是那些涉及序列數據處理的領域。RWKV 為下一代更可持續(xù)、計算效率更高的序列處理任務的 AI 模型鋪平了道路。
總結而言,本文的貢獻如下:
- 引入了 RWKV 網絡架構,該架構結合了 RNN 和 Transformer 的優(yōu)點,同時減輕了它們已知的限制。
- 本文提出了一個新的注意力機制重構,進而提出線性注意力,避開了與標準 Transformer 模型相關的二次復雜性。
- 本文在基準數據集上進行了一系列全面的實驗,展示了 RWKV 在處理涉及大規(guī)模模型和長距離依賴任務上的性能、效率和可擴展性。
- 發(fā)布了預訓練模型,其大小從 1.69 億到 140 億的參數不等,這些模型是在 Pile 上訓練的。
值得注意的是,論文參與機構之一的 EleutherAI 表示:這篇論文還不是最終版本,后續(xù)會不斷完善。
RWKV 模型
RWKV 架構的名稱來源于時間混合和通道混合塊中使用的四個主要模型元素,分別如下:
- R:Receptance 向量,用于接收以往信息;
- W:權重(weight)是位置權重衰減向量,是可訓練的模型參數;
- K:鍵(Key)是類似于傳統注意力中 K 的向量;
- V:值(Value)是類似于傳統注意力中 V 的向量。
每一時間步的主要元素之間的交互是相乘增加的,具體如下圖 2 所示。
架構細節(jié)
RWKV 架構由一系列堆疊的殘差塊組成,每個殘差塊又由具有循環(huán)結構的時間混合和通道混合子塊組成。
循環(huán)被表示為當前輸入和前一個時間步的輸入之間的線性插值(研究者稱這種技術為時移混合或 token shift,如下圖 3 所示),該插值可以針對輸入嵌入的每個線性投影進行獨立調整(比如時間混合中的 R、K 和 V,通道混合中的 R 和 K),并作為公式 14 中形式化的 WKV 的時變更新。
類 Transformer 的并行化
RWKV 可以在時間并行模式下進行高效地并行化,讓人聯想到 Transformer。單個層中一個 batch 序列的時間復雜度為 O (BTd^2 ),它主要由矩陣乘法 W_□, □ ∈ {r, k, v, o}(假設 B 個序列、T 個最大 token 和 d 個通道)。同時更新注意力分數 wkv_t 需要串行掃描,并且復雜度為 O (BTd)。
類 RNN 的序列解碼
在循環(huán)網絡中,將狀態(tài) t 時的輸出用作狀態(tài) t+1 時的輸入很常見。這在語言模型的自回歸解碼推理中尤為明顯,要求每一個 token 在饋入下一步之前必須進行計算,從而使 RWKV 可以利用類 RNN 結構(即時序模式)。在這種情況下,RWKV 可以方便地循環(huán)用于推理解碼,從而利用每個輸出 token 僅依賴于最新狀態(tài)的優(yōu)勢。
然后 RWKV 充當 RNN 解碼器,在序列長度方面保持恒定速度和內存占用,從而更高效地處理更長的序列。相比之下,自注意力通常需要 KV 緩存相對于序列長度呈線性增長,這會導致效率下降,并隨序列長度增加消耗更多內存和時間。
軟件實現
RWKV 最初使用 PyTorch 深度學習庫和自定義 CUDA 內核(它用于 WKV 計算)來實現。盡管 RWKV 是一個通用循環(huán)網絡,但其當前的實現主要集中在語言建模任務(RWKV-LM)。該模型架構包含了一個嵌入層,為此研究者遵循第 4.7 節(jié)中的設置,并按照第 4.6 節(jié)中的原則依次應用幾個相同的殘差塊,具體如上圖 2 和 3 所示。
梯度穩(wěn)定性和層堆疊
RWKV 架構被設計為 Transformer 和 RNN 的融合,與傳統的 RNN 相比,Transformers 具有穩(wěn)定梯度和更深層次架構的優(yōu)勢,同時推理效率高。
RWKV 模型具有用于更新類似注意力分數的單步過程,其中包括一個依賴于時間的 softmax 操作,該操作有助于數值穩(wěn)定性并防止梯度消失(有關嚴格證明,請參見附錄 F)。直觀地說,此操作可確保梯度沿最相關的路徑傳播。Layer normalization (Ba et al., 2016) 是架構的另一個關鍵方面,它通過穩(wěn)定梯度、解決梯度消失和爆炸問題來增強深度神經網絡的訓練動態(tài)。
利用時間結構進行時序數據處理
RWKV 通過三種機制的組合來捕獲和傳播時序信息:循環(huán)、時間衰減和 token shift。
RWKV 時間混合塊中的循環(huán)是模型捕獲序列元素之間復雜關系和隨時間傳播局部信息的能力的基礎。
時間衰減機制(等式 14 中的 e^?w 和 e^u)保持了對序列元素之間位置關系的敏感性。通過逐漸減少以往信息隨時間的影響,該模型保留了時間局部性和進展感,這對于時序處理至關重要。
token shift 或 time-shift 混合或(圖 3 中的對角線箭頭),也有助于模型適應時序數據。通過在當前輸入和前一個時間步輸入之間進行線性插值,模型自然地聚合和門控輸入通道中的信息。
實驗結果
實驗的重點是回答以下問題:
- RQ1:在參數數量和訓練 token 數量相等的情況下,RWKV 與二次 transformer 架構相比具有競爭力嗎?
- RQ2:增加參數數量時,RWKV 是否仍然具有與二次 transformer 架構相競爭的能力?
- RQ3:當 RWKV 模型被訓練用于開源二次 transformer 無法高效處理的上下文長度時,增加 RWKV 的參數是否能夠獲得更好的語言建模損失?
首先是回答 RQ1 和 RQ2 問題,從圖 4 可以看出,在六個基準測試中(Winogrande、PIQA、ARC-C、ARC-E、LAMBADA 和 SciQ),RWKV 與開源二次復雜度 transformer 模型 Pythia、OPT 和 BLOOM 具有相當的競爭力。RWKV 甚至在四個任務(PIQA、OBQA、ARC-E 和 COPA)中勝過了 Pythia 和 GPT-Neo。
對于 RQ3,圖 5 顯示,增加上下文長度會導致 Pile 上的測試損失降低,這表明 RWKV 能夠有效利用較長的上下文信息。