Bengio精簡(jiǎn)了傳統(tǒng)RNN,性能可與Transformer媲美
深度學(xué)習(xí)三巨頭之一的Yoshua Bengio,剛剛發(fā)布了一篇有趣的新論文——
RNN就是所需的全部嗎?
Were RNNs All We Needed?
不僅論文的名字有意思,其結(jié)論更是精彩。
研究表明,精簡(jiǎn)十幾年前的RNN們,性能上可以與最近序列模型(如Transformer等)相媲美!
具體而言,Bengio等人重新審視了諸如LSTM(1997)和GRU(2014)這些傳統(tǒng)的RNN,認(rèn)為這些模型的缺點(diǎn)是由于需要時(shí)間反向傳播 (BPTT) 而導(dǎo)致速度較慢。
所以他們直接大刀闊斧地移除了LSTM和GRU中的隱藏狀態(tài)依賴,讓它們不再需要BPTT,從而可以高效地并行訓(xùn)練。
而精簡(jiǎn)改良版的RNN們,名字分別叫做minLSTM和minGRU。
它們和傳統(tǒng)RNN相比,不僅訓(xùn)練時(shí)所需的參數(shù)量大幅減少,并且完全可并行化。
嗯,是頗有一種大道至簡(jiǎn)的感覺了。
那么Bengio等人具體又是如何實(shí)現(xiàn)的?我們繼續(xù)往下看。
精簡(jiǎn)版RNN
Transformer和它的變體們可以說(shuō)是近幾年大熱的架構(gòu),但與此同時(shí)缺點(diǎn)也是較為明顯,那便是在處理長(zhǎng)序列時(shí)的計(jì)算復(fù)雜度問題。
具體來(lái)說(shuō),Transformer模型在序列長(zhǎng)度上的計(jì)算復(fù)雜度是二次方的,這使得它在處理長(zhǎng)序列時(shí)資源的消耗就比較高。
因此就需要能夠在訓(xùn)練時(shí)有效地處理長(zhǎng)序列,同時(shí)在推理時(shí)保持高效性能的替代方案——簡(jiǎn)化版的RNN。
這個(gè)過(guò)程的關(guān)鍵便是隱藏狀態(tài)依賴,讓它們不再需要BPTT,讓效率直接飆升。
minGRU
首先我們來(lái)看下Bengio團(tuán)隊(duì)對(duì)GRU的處理,即minGRU,總共分為2步。
第一步,去除之前隱藏狀態(tài)的依賴。
在傳統(tǒng)的GRU模型中,更新門zt和候選隱藏狀態(tài)h~t的計(jì)算依賴于前一時(shí)刻的隱藏狀態(tài) ht-1。這導(dǎo)致模型在訓(xùn)練時(shí)無(wú)法實(shí)現(xiàn)并行處理,因?yàn)槊總€(gè)時(shí)間步的計(jì)算都依賴于前一個(gè)時(shí)間步的結(jié)果。
為了解決這個(gè)問題,minGRU對(duì)GRU進(jìn)行了修改,使更新門和候選隱藏狀態(tài)的計(jì)算僅依賴于當(dāng)前時(shí)刻的輸入xt,而不依賴于ht-1:
通過(guò)這種方式,minGRU的每一時(shí)刻的計(jì)算可以獨(dú)立于其他時(shí)刻并行執(zhí)行。
第二步,去除候選狀態(tài)的范圍限制。
在第一步中,候選隱藏狀態(tài)h~t仍然使用雙曲正切函數(shù)(tanh)來(lái)限制其值的范圍在 [?1,1][?1,1] 之間。雖然這有助于模型的穩(wěn)定性,但它并不是并行化所必需的。
minGRU進(jìn)一步簡(jiǎn)化模型,去除了對(duì)h~t的范圍限制,將其替換為一個(gè)無(wú)需激活函數(shù)的線性變換:
這樣,候選隱藏狀態(tài)的計(jì)算變得更加簡(jiǎn)單,并且沒有任何范圍限制。
在這種結(jié)構(gòu)下,minGRU不僅減少了模型參數(shù),而且可以利用并行掃描算法在訓(xùn)練時(shí)實(shí)現(xiàn)并行化,從而顯著提高了處理長(zhǎng)序列的速度。
此外,minGRU的輸出尺度在時(shí)間上是獨(dú)立的,這有助于優(yōu)化過(guò)程中的數(shù)值穩(wěn)定性。整體變化如下:
minLSTM
接下來(lái),我們?cè)賮?lái)看下Bengio團(tuán)隊(duì)對(duì)LSTM的處理,即minLSTM,共分為三步。
第一步,去除之前隱藏狀態(tài)的依賴。
在傳統(tǒng)的LSTM模型中,遺忘門ft、輸入門it和候選細(xì)胞狀態(tài)c~t的計(jì)算依賴于前一時(shí)刻的隱藏狀態(tài)ht-1。
這導(dǎo)致模型在訓(xùn)練時(shí)無(wú)法實(shí)現(xiàn)并行處理,因?yàn)槊總€(gè)時(shí)間步的計(jì)算都依賴于前一個(gè)時(shí)間步的結(jié)果。
為了解決這個(gè)問題,minLSTM對(duì)LSTM進(jìn)行了修改,使遺忘門、輸入門和候選細(xì)胞狀態(tài)的計(jì)算僅依賴于當(dāng)前時(shí)刻的輸入xt,而不依賴于ht-1:
通過(guò)這種方式,minLSTM的每一時(shí)刻的計(jì)算可以獨(dú)立于其他時(shí)刻并行執(zhí)行。
第二步,去除候選狀態(tài)的范圍限制。
在第一步中,候選細(xì)胞狀態(tài)c~t仍然使用雙曲正切函數(shù)(tanh)來(lái)限制其值的范圍在 [?1,1][?1,1] 之間。雖然這有助于模型的穩(wěn)定性,但它并不是并行化所必需的。
minLSTM進(jìn)一步簡(jiǎn)化模型,去除了對(duì)c~t的范圍限制,將其替換為一個(gè)無(wú)需激活函數(shù)的線性變換:
這樣,候選細(xì)胞狀態(tài)的計(jì)算變得更加簡(jiǎn)單,并且沒有任何范圍限制。
第三步,確保輸出在時(shí)間上是獨(dú)立的。
在許多序列建模設(shè)置中(例如文本生成),優(yōu)化目標(biāo)/輸出在時(shí)間上是獨(dú)立的。
為了確保LSTM的輸出在時(shí)間上是獨(dú)立的,minLSTM對(duì)遺忘門和輸入門進(jìn)行了歸一化,確保它們的和為1,并且細(xì)胞狀態(tài)的尺度在時(shí)間上是獨(dú)立的:
通過(guò)這種方式,minLSTM確保了其輸出在時(shí)間上是獨(dú)立的,這有助于優(yōu)化過(guò)程中的數(shù)值穩(wěn)定性。
minLSTM的最終形式為:
Were RNNs All We Needed?
在精簡(jiǎn)了RNN們之后,Bengio團(tuán)隊(duì)也展示了實(shí)驗(yàn)結(jié)果。
例如下圖顯示了minGRU、minLSTM和Mamba模型在訓(xùn)練效率方面的比較,具體包括訓(xùn)練運(yùn)行時(shí)間、加速比和內(nèi)存占用。
這些指標(biāo)是在T4 GPU上,以64的批次大小進(jìn)行測(cè)量的:
以及在下圖中,還展示了在Shakespeare語(yǔ)言建模任務(wù)中,不同模型的學(xué)習(xí)曲線。
這個(gè)任務(wù)使用字符級(jí)生成對(duì)抗訓(xùn)練,目的是評(píng)估模型在文本生成任務(wù)中的表現(xiàn),簡(jiǎn)化RNN模型在處理語(yǔ)言建模任務(wù)時(shí)具有較好的有效性和高效率(特別是在需要快速訓(xùn)練和部署的應(yīng)用場(chǎng)景中):
總而言之,Bengio團(tuán)隊(duì)認(rèn)為,經(jīng)過(guò)簡(jiǎn)化的RNN可能仍然是處理長(zhǎng)序列任務(wù)的理想選擇,尤其是在資源有限的場(chǎng)景下,因此也提出了問題 “Were RNNs All We Needed?”
華人一作
在這項(xiàng)研究中,作者除了Bengio之外,還有一點(diǎn)值得關(guān)注,那便是一作是一位華人,Leo Feng。
從公開的個(gè)人網(wǎng)站來(lái)看,Leo Feng師從Bengio,目前是蒙特利爾大學(xué)的博士生,目前正在Borealis AI進(jìn)行研究實(shí)習(xí)。
Leo Feng的研究范圍包括元學(xué)習(xí)和高效模型的設(shè)計(jì),其本科畢業(yè)于牛津大學(xué)。
那么你覺得精簡(jiǎn)版RNN這項(xiàng)研究如何?歡迎在評(píng)論區(qū)留言討論。
論文地址:https://arxiv.org/abs/2410.01201