圖靈獎得主Yoshua Bengio新作:Were RNNs All We Needed?
自從 Transformer 模型問世以來,試圖挑戰(zhàn)其在自然語言處理地位的挑戰(zhàn)者層出不窮。
這次登場的選手,不僅要挑戰(zhàn) Transformer 的地位,還致敬了經(jīng)典論文的名字。
再看這篇論文的作者列表,圖靈獎得主、深度學(xué)習(xí)三巨頭之一的 Yoshua Bengio 赫然在列。
- 論文標(biāo)題:Were RNNs All We Needed?
- 論文地址:https://arxiv.org/pdf/2410.01201v1
最近,大家重新對用循環(huán)序列模型來解決 Transformer 長上下文的問題產(chǎn)生了興趣,出現(xiàn)了一大批有關(guān)成果,其中 Mamba 的成功引爆了 AI 圈,更是點燃了大家的研究熱情。
Bengio 和他的研究團(tuán)隊發(fā)現(xiàn),這些新的序列模型有很多共同點,于是他們重新審視了 LSTM 和 GRU 這兩種經(jīng)典 RNN 模型。
結(jié)果發(fā)現(xiàn),精簡掉其中的隱藏狀態(tài)依賴之后,不再需要基于時間反向傳播的 LSTM 和 GRU 的表現(xiàn)就能和 Transformer 打個平手。
LSTM 和 GRU 僅能順序處理信息,并且在訓(xùn)練時依賴反向傳播,這使得它們在處理大量數(shù)據(jù)時速度緩慢,最終被淘汰。
基于以上發(fā)現(xiàn),他們進(jìn)一步簡化了 LSTM 和 GRU,去掉了它們對輸出范圍的限制,并確保它們的輸出在時間上是獨立的,進(jìn)而得到了 minLSTM 和 minGRU。
相比傳統(tǒng) RNN,它們不僅訓(xùn)練時所需的參數(shù)顯著減少,還可以并行訓(xùn)練,比如上下文長度為 512 時,速度能提升 175 倍。
這其實也是 Bengio 長期關(guān)注 RNN 的系列研究成果。在今年五月,Bengio 及其研究團(tuán)隊和加拿大皇家銀行 AI 研究所 Borealis AI 合作發(fā)布了一篇名為《Attention as an RNN》的論文。
正如論文名字所示,他們將注意力機(jī)制重新詮釋為一種 RNN,引入了一種基于并行前綴掃描(prefix scan)算法的新的注意力公式,該公式能夠高效地計算注意力的多對多(many-to-many)RNN 輸出。基于新公式的模塊 Aaren,不僅可以像 Transformer 一樣并行訓(xùn)練,還可以像 RNN 一樣高效更新。
簡化 LSTM 和 GRU
在這一部分,研究者通過簡化和移除各種門中的若干隱藏狀態(tài)依賴關(guān)系,證明 GRU 和 LSTM 可通過并行掃描進(jìn)行訓(xùn)練。
在此基礎(chǔ)上,研究者進(jìn)一步簡化了這些 RNN,消除了它們對輸出范圍的限制(即 tanh),并確保輸出在規(guī)模上與時間無關(guān)。
綜合上述步驟,研究者提出了 GRUs 和 LSTMs 的最小版本(minGRUs 和 minLSTMs),它們可通過并行掃描進(jìn)行訓(xùn)練,且性能可與 Transformers 和最近提出的序列方法相媲美。
minGRU
研究者結(jié)合了兩個簡化步驟,得到了一個極簡版的 GRU(minGRU)。
由此產(chǎn)生的模型比原始 GRU 效率大大提高,只需要 個參數(shù),而不是 GRU 的
個參數(shù)(其中 d_x 和 d_h 分別對應(yīng)于 x_t 和 h_t 的大?。?。在訓(xùn)練方面,minGRU 可以使用并行掃描算法進(jìn)行并行訓(xùn)練,從而大大加快訓(xùn)練速度。
在實驗部分,研究者展示了在 T4 GPU 上,當(dāng)序列長度為 512 時,訓(xùn)練步驟的速度提高了 175 倍。參數(shù)效率的提高也非常顯著。通常,在 RNN 中會進(jìn)行狀態(tài)擴(kuò)展(即 ,其中 α ≥ 1),使模型更容易從輸入中學(xué)習(xí)特征。
minLSTM
研究者結(jié)合了三個簡化步驟,得到 LSTM 的最小版本(minLSTM):
與 LSTM 的 相比,最小版本(minLSTM)的效率明顯更高,只需要
個參數(shù)。此外,minLSTM 可以使用并行掃描算法進(jìn)行并行訓(xùn)練,大大加快了訓(xùn)練速度。例如,在 T4 GPU 上,對于長度為 512 的序列,minLSTM 比 LSTM 加快了 235 倍。在參數(shù)效率方面,當(dāng) α = 1、2、3 或 4(其中
)時,與 LSTM 相比,minLSTM 僅使用了 38%、25%、19% 或 15% 的參數(shù)。
Were RNNs All We Needed?
在本節(jié)中,研究者將對最小版本(minLSTMs 和 minGRUs)與傳統(tǒng)版本(LSTMs 和 GRUs)以及現(xiàn)代序列模型進(jìn)行了比較。
Minimal LSTMs 和 GRU 非常高效
在測試時,循環(huán)序列模型會按順序推出,從而使其推理更為高效。相反,傳統(tǒng) RNN 的瓶頸在于其訓(xùn)練,需要線性訓(xùn)練時間(通過時間反向傳播),這導(dǎo)致其最終被淘汰。人們對循環(huán)序列模型重新產(chǎn)生興趣,是因為許多新的架構(gòu)可以高效地進(jìn)行并行訓(xùn)練。
研究者對比了訓(xùn)練傳統(tǒng) RNN(LSTM 和 GRU)、它們的最小版本(minLSTM 和 minGRU)以及一種最新的序列模型所需的資源,還特別將重點放在與最近大受歡迎的 Mamba 的比較上。實驗考慮了 64 的批大小,并改變了序列長度。研究者測量了通過模型執(zhí)行前向傳遞、計算損失和通過后向傳遞計算梯度的總運行時間和內(nèi)存復(fù)雜度。
運行時間。在運行時間方面(見圖 1(左)),簡化版 LSTM 和 GRU(minLSTM 和 minGRU)Mamba 的運行時間相近。對 100 次運行進(jìn)行平均,序列長度為 512 的 minLSTM、minGRU 和 Mamba 的運行時間分別為 2.97、2.72 和 2.71 毫秒。
對于長度為 4096 的序列,運行時間分別為 3.41、3.25 和 3.15 毫秒。相比之下,傳統(tǒng)的 RNN 對應(yīng)程序(LSTM 和 GRU)所需的運行時間與序列長度成線性關(guān)系。對于 512 的序列長度,在 T4 GPU 上,minGRUs 和 minLSTMs 每個訓(xùn)練步驟的速度分別比 GRUs 和 LSTMs 快 175 倍和 235 倍(見圖 1(中))。隨著序列長度的增加,minGRUs 和 minLSTMs 的改進(jìn)更為顯著,在序列長度為 4096 時,minGRUs 和 minLSTMs 的速度分別提高了 1324 倍和 1361 倍。因此,在 minGRU 需要一天才能完成固定數(shù)量的 epoch 訓(xùn)練的情況下,其傳統(tǒng)對應(yīng)的 GRU 可能需要 3 年多的時間。
內(nèi)存。通過利用并行掃描算法高效地并行計算輸出,minGRU、minLSTM 和 Mamba 創(chuàng)建了一個更大的計算圖,因此與傳統(tǒng)的 RNN 相比需要更多內(nèi)存(見圖 1(右))。與傳統(tǒng)的 RNN 相比,最小變體(minGRU 和 minLSTM)多用了 88% 的內(nèi)存。與 minGRU 相比,Mamba 多用了 56% 的內(nèi)存。但實際上,運行時間是訓(xùn)練 RNN 的瓶頸。
刪除 的效果。最初的 LSTM 和 GRU 使用輸入 x_t 和之前的隱藏狀態(tài)
計算各種門電路。這些模型利用其與時間依賴的門來學(xué)習(xí)復(fù)雜函數(shù)。然而,minLSTM 和 minGRU 的訓(xùn)練效率是通過放棄門對之前隱藏狀態(tài)
的依賴性來實現(xiàn)的。因此,minLSTM 和 minGRU 的門僅與輸入 x_t 依賴,從而產(chǎn)生了更簡單的循環(huán)模塊。因此,由單層 minLSTM 或 minGRU 組成的模型的柵極是與時間無關(guān)的,因為其條件是與時間無關(guān)的輸入
。
然而,在深度學(xué)習(xí)中,模型是通過堆疊模塊構(gòu)建的。雖然第一層的輸入 與時間無關(guān),但其輸出
與時間有關(guān),并被用作第二層的輸入,即
。因此,從第二層開始,minLSTM 和 minGRU 的門也將隨時間變化,從而建立更復(fù)雜的函數(shù)模型。表 1 比較了不同層數(shù)的模型在 Mamba 論文中的選擇性復(fù)制任務(wù)上的表現(xiàn)??梢粤⒓纯闯鰰r間依賴性的影響:將層數(shù)增加到 2 層或更多,模型的性能就會大幅提高。
訓(xùn)練穩(wěn)定性。層數(shù)的另一個影響是穩(wěn)定性增強(qiáng),隨著層數(shù)的增加,準(zhǔn)確率的差異減?。ㄒ姳?1)。此外,雖然 minLSTM 和 minGRU 都能解決選擇性復(fù)制任務(wù),但可以看到 minGRU 是一種經(jīng)驗上比 minLSTM 更穩(wěn)定的方法,它能以更高的一致性和更低的方差解決該任務(wù)。在訓(xùn)練過程中,這兩組參數(shù)的調(diào)整方向不同,使得比率更難控制和優(yōu)化。相比之下,minGRU 的信息丟棄和添加由單組參數(shù)(更新門)控制,因此更容易優(yōu)化。
Minimal LSTMs 和 GRUs 表現(xiàn)良好
上述內(nèi)容展示了簡化傳統(tǒng) RNN 所帶來的顯著效率提升。這部分將探討最小版本的 LSTM 和 GRU 與幾種流行的序列模型相比的經(jīng)驗性能。
選擇性復(fù)制。此處考慮 Mamba 論文中的長序列選擇性復(fù)制任務(wù)。與最初的復(fù)制任務(wù)不同,選擇性復(fù)制任務(wù)的輸入元素相對于輸出元素是隨機(jī)間隔的,這增加了任務(wù)的難度。為了解決這個任務(wù),模型需要進(jìn)行內(nèi)容感知推理,記憶依賴的 token 并過濾掉不依賴的 token。
表 2 將簡化版的 LSTM 和 GRU(minLSTM 和 minGRU)與可以并行訓(xùn)練的著名循環(huán)序列模型進(jìn)行了比較:S4、H3、Hyena 和 Mamba (S6)。這些基線的結(jié)果引自 Mamba 論文。在所有這些基線中,只有 Mamba 論文中的 S6 能夠解決這一任務(wù)。minGRU 和 minLSTM 也能解決選擇性復(fù)制任務(wù),其性能與 S6 相當(dāng),并優(yōu)于所有其他基線。LSTM 和 GRU 利用內(nèi)容感知門控機(jī)制,使得這些最小版本足以解決許多熱門序列模型無法解決的這一任務(wù)。
強(qiáng)化學(xué)習(xí)。接下來,研究者討論了 D4RL 基準(zhǔn)中的 MuJoCo 運動任務(wù)。具體來說考慮了三種環(huán)境:HalfCheetah、Hopper 和 Walker。對于每種環(huán)境,模型都在三種不同數(shù)據(jù)質(zhì)量的數(shù)據(jù)集上進(jìn)行訓(xùn)練:中等數(shù)據(jù)集(M)、中等游戲數(shù)據(jù)集(M-R)和中等專家數(shù)據(jù)集(M-E)。
表 3 將 minLSTM 和 minGRU 與各種 Decision Transformer 變體進(jìn)行了比較,包括原始 Decision Transformer (DT)、Decision S4 (DS4)、Decision Mamba 和(Decision)Aaren。minLSTM 和 minGRU 的性能優(yōu)于 Decision S4,與 Decision Transformer、Aaren 和 Mamba 相比也不遑多讓。與其他循環(huán)方法不同,Decision S4 是一種循環(huán)轉(zhuǎn)換不感知輸入的模型,這影響了其性能。從 3 × 3 = 9 個數(shù)據(jù)集的平均得分來看,minLSTM 和 minGRU 優(yōu)于所有基線方法,只有 Decision Mamba 的差距很小。
語言建模。研究者使用 nanoGPT 框架對莎士比亞作品進(jìn)行字符級 GPT 訓(xùn)練。圖 2 用交叉熵?fù)p失繪制了學(xué)習(xí)曲線,將所提出的最小 LSTM 和 GRU(minLSTM 和 minGRU)與 Mamba 和 Transformers 進(jìn)行了比較。結(jié)果發(fā)現(xiàn),minGRU、minLSTM、Mamba 和 Transformers 的測試損失相當(dāng),分別為 1.548、1.555、1.575 和 1.547。Mamba 的表現(xiàn)略遜于其他模型,但訓(xùn)練速度更快,尤其是在早期階段,在 400 步時達(dá)到最佳表現(xiàn),而 minGRU 和 minLSTM 則分別持續(xù)訓(xùn)練到 575 步和 625 步。相比之下,Transformers 的訓(xùn)練速度明顯較慢,需要比 minGRU 多 2000 步(~ 2.5 倍)的訓(xùn)練步驟才能達(dá)到與 minGRU 相當(dāng)?shù)男阅?,這使得它的訓(xùn)練速度明顯更慢,資源消耗也更大(與 minGRU、minLSTM 和 Mamba 的線性復(fù)雜度相比,Transformers 的復(fù)雜度為二次方)。
更多研究細(xì)節(jié),可參考原論文。