新架構(gòu)RNN反超Transformer:每個(gè)隱藏狀態(tài)都是一個(gè)模型,一作:從根本上改變語(yǔ)言模型
新架構(gòu),再次向Transformer發(fā)起挑戰(zhàn)!
核心思想:將RNN中的隱藏狀態(tài)換成可學(xué)習(xí)的模型。
甚至在測(cè)試時(shí)都可以學(xué)習(xí),所以該方法稱為TTT(Test-Time Training)。
共同一作UC伯克利的Karen Dalal表示:我相信這將從根本上改變語(yǔ)言模型。
圖片
一個(gè)TTT層擁有比RNN表達(dá)能力更強(qiáng)的隱藏狀態(tài),可以直接取代Transformer中昂貴的自注意力層。
在實(shí)驗(yàn)中,隱藏狀態(tài)是線性模型的TTT-Linear表現(xiàn)超過了Transformer和Mamba,用更少的算力達(dá)到更低的困惑度(左),也能更好利用長(zhǎng)上下文(右)。
圖片
此外,隱藏狀態(tài)是MLP模型的TTT-MLP在32k長(zhǎng)上下文時(shí)表現(xiàn)還要更好。
圖片
Karen Dalel還指出,理論上可學(xué)習(xí)的隱藏狀態(tài)可以是任意模型,對(duì)于更長(zhǎng)上下文來(lái)說,可以是CNN、甚至可以是完整的Transformer來(lái)套娃。
目前剛剛出爐的TTT論文已經(jīng)在學(xué)術(shù)界引起關(guān)注和討論,斯坦福博士生Andrew Gao認(rèn)為,這篇論文或許能成為下一篇Attention is all you need。
圖片
另外有人表示,眾多新架構(gòu)能否真正擊敗Transformer,還要看能不能擴(kuò)展到更大規(guī)模。
Karen Dalel透露,馬上就會(huì)推出7B模型。
圖片
用機(jī)器學(xué)習(xí)模型來(lái)壓縮上下文
傳統(tǒng)RNN,隱藏狀態(tài)固定大小表達(dá)能力受限,也不好并行訓(xùn)練。
Transformer強(qiáng)大,但自注意力機(jī)制隨上下文長(zhǎng)度呈平方復(fù)雜度,非常昂貴。
最近一系列基于RNN的架構(gòu)創(chuàng)新中:
RWKV,用線性注意力結(jié)合RNN和Transformer的優(yōu)點(diǎn),在訓(xùn)練時(shí)可以并行計(jì)算。
Mamba,賦予模型選擇性記住或遺忘信息的能力來(lái)壓縮上下文,同時(shí)設(shè)計(jì)了面向硬件的高效并行算法。
它們的表現(xiàn)在短上下文時(shí)追上甚至超越了Transformer,但在32k超長(zhǎng)上下文以上,Trasformer依舊稱霸。
圖片
TTT團(tuán)隊(duì)的想法來(lái)自于:與其讓隱藏狀態(tài)被動(dòng)地儲(chǔ)存信息,不如讓它主動(dòng)學(xué)習(xí)。
就像Transformer模型作為一個(gè)整體在壓縮互聯(lián)網(wǎng)數(shù)據(jù)到參數(shù)中一樣,可學(xué)習(xí)的隱藏狀態(tài)模型也在少量參數(shù)上不斷縮上下文信息。
這種“隱藏狀態(tài)模型”隨著時(shí)間的推移仍然具有固定的大小(固定的模型參數(shù)),但表達(dá)能力更強(qiáng)了。
論文的聯(lián)合指導(dǎo)UCSD助理教授王小龍認(rèn)為:
Transformer顯式地儲(chǔ)存所有輸入token,如果你認(rèn)為個(gè)神經(jīng)網(wǎng)絡(luò)是壓縮信息的好方法,那么壓縮這些token也將是有意義的。
圖片
如此一來(lái),整個(gè)框架的時(shí)間復(fù)雜度還是線性的,
圖片
至此,序列建模被拆解為兩個(gè)嵌套的學(xué)習(xí)循環(huán),外循環(huán)負(fù)責(zé)整體的語(yǔ)言建模,內(nèi)循環(huán)通過自監(jiān)督學(xué)習(xí)壓縮上下文信息。
外循環(huán)的參數(shù)變成了內(nèi)循環(huán)的超參數(shù),也就是元學(xué)習(xí)的一個(gè)變種了。
標(biāo)準(zhǔn)的元學(xué)習(xí)是訓(xùn)練一個(gè)適應(yīng)不同任務(wù)的模型,而TTT是讓模型去適應(yīng)每一個(gè)測(cè)試樣本。單個(gè)樣本雖然信息量小,但用來(lái)訓(xùn)練隱藏狀態(tài)模型也綽綽有余。
圖片
特別的,在內(nèi)循環(huán)是一個(gè)線性模型時(shí),相當(dāng)于線性注意力。當(dāng)內(nèi)循環(huán)是一個(gè)Nadaraya-Watson estimator時(shí),TTT等價(jià)于自注意力。
圖片
在測(cè)試時(shí)學(xué)習(xí)
在TTT層里,使用自監(jiān)督學(xué)習(xí)方法將上下文壓縮到隱藏狀態(tài)。
上下文就是未標(biāo)記的數(shù)據(jù)集,隱藏狀態(tài)不再是一個(gè)固定的向量,可以是線性模型、小型神經(jīng)網(wǎng)絡(luò)或任何機(jī)器學(xué)習(xí)模型,更新規(guī)則采用了在自監(jiān)督損失上的一步梯度下降。
這樣一來(lái),隱藏狀態(tài)模型可以記住產(chǎn)生大梯度的輸入,并且可以獲得比選擇性遺忘機(jī)制更強(qiáng)的擬合和泛化能力,并且在測(cè)試時(shí)仍然為每個(gè)輸入序列訓(xùn)練不同的參數(shù)。
圖片
到目前為止,樸素的TTT層已經(jīng)有效了,但還無(wú)法并行化。
團(tuán)隊(duì)提出的解決方案為mini-batch梯度下降,把一個(gè)batch內(nèi)的梯度計(jì)算并行化。
再通過Dual form方法,只在mini-batch結(jié)束時(shí)計(jì)算權(quán)重以及輸出token,避免冗余計(jì)算。在JAX版實(shí)現(xiàn)中快了5倍以上。
圖片
TTT能否成為“Transformer殺手”?
理論上都走的通了,那么TTT在實(shí)驗(yàn)中表現(xiàn)到底如何?
最簡(jiǎn)單干凈的測(cè)試方法,應(yīng)該是直接替換掉Transformer中的自注意力層。
但是在研究過程中,團(tuán)隊(duì)發(fā)現(xiàn)Mamba等現(xiàn)代RNN的骨干中在RNN層之前還包含時(shí)間卷積,對(duì)TTT也有幫助。
所以實(shí)驗(yàn)中TTT-Linear和TTT-MLP主要應(yīng)用到Mamba骨干上,其他訓(xùn)練細(xì)節(jié)也嚴(yán)格遵照Mamba論文中的設(shè)置。
最終在Pile數(shù)據(jù)集短上下文測(cè)試中:
- 2k上下文時(shí),TTT-Linear、Mamba和Transform具有相當(dāng)?shù)男阅?,TTT-MLP的表現(xiàn)略差。
- 8k上下文時(shí),TTT-Linear和TTT-MLP都優(yōu)于Mamba和Transformer,應(yīng)用在Transformer骨干的TTT-MLP(T)在1.3B參數(shù)左右也略好與Mamba。
總的來(lái)說,隨著上下文長(zhǎng)度的增長(zhǎng),TTT層相對(duì)于Mamba的優(yōu)勢(shì)也會(huì)擴(kuò)大。
另外團(tuán)隊(duì)猜測(cè),線性模型比MLP表達(dá)能力差,因此從Mamba骨干的卷積中受益更多。
圖片
長(zhǎng)上下文實(shí)驗(yàn)使用Pile的子集Books3:
- 32k上下文,TTT-Linear和TTT-MLP的表現(xiàn)都優(yōu)于曼巴,類似于Pile 8k的觀察。即使是帶有Transformer骨干的TTT-MLP(T)表現(xiàn)也略好于曼巴。
- 1.3B參數(shù)尺度上,TTT-MLP(T)僅比TTT-MLP(M)稍差,Transformer骨干可能更適合論文評(píng)估范圍之外的更大模型和更長(zhǎng)的上下文。
圖片
在A100上測(cè)試速度,TTT-Linear在預(yù)填充階段比Mamba稍快,解碼階段幾乎與Mamba速度相同。TTT-MLP相比Transformer整體上也有線性復(fù)雜度的優(yōu)勢(shì)。
圖片
共同一作Karan Dala表示:我一直被問到的一個(gè)問題是,我們是否相信TTT就是“Transformer殺手”,我仍然認(rèn)為我們需要繼續(xù)努力。
隱藏狀態(tài)可以是任意模型,但目前的研究只涉及了線性模型和小型MLP,更復(fù)雜的還有待研究。
隱藏狀態(tài)模型的學(xué)習(xí)可以用Adam代替普通的梯度下降等等。
還可用于視頻建模
三位共同一作中:
Yu Sun博士畢業(yè)于UC Berkeley,目前是斯坦福大學(xué)博士后。
圖片
Xinhao Li是電子科技大學(xué)校友,碩士畢業(yè)于UCSD。
圖片
Karan Dalel本科畢業(yè)于UC Berkley,正在機(jī)器人初創(chuàng)公司1X實(shí)習(xí)。
圖片
最后,聯(lián)合指導(dǎo)UCSD助理教授王小龍還透露,TTT方法除了語(yǔ)言模型,還適用于視頻。
TTT就是“Transformer殺手”,我仍然認(rèn)為我們需要繼續(xù)努力。
將來(lái)在對(duì)長(zhǎng)視頻進(jìn)行建模時(shí),我們可以密集地采樣幀而不是采樣1 FPS,這些密集幀對(duì)Transformer來(lái)說是一種負(fù)擔(dān),但對(duì)TTT層來(lái)說是一種福音。
圖片
論文地址:https://arxiv.org/abs/2407.04620
參考鏈接:[1]https://x.com/karansdalal/status/1810338845659131940[2]https://x.com/xiaolonw/status/1810387662060269668