近8年后,谷歌Transformer繼任者「Titans」來了,上下文記憶瓶頸被打破
終于,在 2017 年推出影響 AI 行業(yè)長達(dá) 8 年的 Transformer 架構(gòu)之后,谷歌帶來了全新的架構(gòu) Titans。這次,谷歌的重點(diǎn)是將推理領(lǐng)域非常重要的測(cè)試時(shí)(test-time)計(jì)算用在了記憶(memory)層面。
在談到推出 Titans 的初衷時(shí),論文一作 Ali Behrouz 表示,「注意力機(jī)制一直是大多數(shù) LLM 進(jìn)展的重要組成部分,不過它無法擴(kuò)展到長上下文。因此,Titans 應(yīng)運(yùn)而出,它成為了一種同時(shí)具備注意力機(jī)制和元上下文記憶的結(jié)構(gòu),可以在測(cè)試時(shí)學(xué)習(xí)記憶。該架構(gòu)可以將上下文窗口擴(kuò)展到 200 萬 tokens?!?/span>
圖源:https://x.com/behrouz_ali/status/1878859086227255347
這意味著,谷歌 Transformer 迎來了它的「繼任者」。
圖源:https://x.com/mark_k/status/1878896628654022993
多年來,研究人員一直在廣泛探究如何有效地利用循環(huán)模型和注意力機(jī)制,其中循環(huán)模型旨在將數(shù)據(jù)壓縮到固定大小的記憶(稱為隱狀態(tài))中,而注意力機(jī)制允許處理整個(gè)上下文窗口,捕捉所有 token 的直接依賴。不過,更準(zhǔn)確的依賴建模往往伴隨著二次成本,導(dǎo)致模型只能處理固定長度的上下文。
因此,谷歌提出了一種新的長期神經(jīng)記憶模塊(neural memory module),它能夠?qū)W習(xí)記憶歷史上下文,并幫助注意力機(jī)制在利用過去已久信息的同時(shí)處理當(dāng)前上下文。結(jié)果表明,這種神經(jīng)記憶具有快速并行化訓(xùn)練的優(yōu)勢(shì),同時(shí)還能保持快速推理。
從記憶的角度來看,谷歌認(rèn)為注意力機(jī)制雖然受限于上下文但可以更準(zhǔn)確地建模依賴關(guān)系,因此可以起到短期記憶的作用;而神經(jīng)記憶能夠?qū)?shù)據(jù)進(jìn)行記憶,起到了長期、更持久的記憶作用?;谶@兩個(gè)模塊,谷歌引入了一個(gè)全新的系列架構(gòu) —— Titans,通過三種變體有效地將記憶融合到該系統(tǒng)架構(gòu)中,它們分別是記憶作為上下文(Memory as a Context,MAC)、記憶作為門(Memory as a Gate,MAG)和記憶作為層(Memory as a Layer,MAL)。
在語言建模、常識(shí)推理、基因組學(xué)和時(shí)序預(yù)測(cè)任務(wù)上的實(shí)驗(yàn)結(jié)果表明,Titans 架構(gòu)比 Transformer 和近年來的現(xiàn)代線性循環(huán)模型更有效。另外,在大海撈針(needle-in-haystack)中,Titans 架構(gòu)能夠有效地?cái)U(kuò)展到超過 200 萬 tokens 的上下文窗口,并且比基準(zhǔn)模型實(shí)現(xiàn)了更高的準(zhǔn)確性。
- 論文標(biāo)題:Titans: Learning to Memorize at Test Time
- 論文地址:https://arxiv.org/pdf/2501.00663v1
另外,論文作者之一 Peilin Zhong 為谷歌 NYC 算法與優(yōu)化團(tuán)隊(duì)的研究科學(xué)家,2021 年加入谷歌。他本科畢業(yè)于清華姚班,博士畢業(yè)于哥倫比亞大學(xué)。
目前,已經(jīng)有人搞出了有關(guān) Titans 架構(gòu)的非官方實(shí)現(xiàn),感興趣的讀者可以去看一下。
GitHub 地址:https://github.com/lucidrains/titans-pytorch
學(xué)習(xí)測(cè)試時(shí)記憶
谷歌詳細(xì)介紹了長期神經(jīng)記憶模塊,它成為了一種可以在測(cè)試時(shí)學(xué)習(xí)記憶的元模型。
長期記憶
為了設(shè)計(jì)一個(gè)長期神經(jīng)記憶模塊,我們需要模型能夠?qū)⑦^去歷史的抽象編碼到其參數(shù)中。因此,一個(gè)簡單的思路是訓(xùn)練神經(jīng)網(wǎng)絡(luò)并期望它能夠記住自己的訓(xùn)練數(shù)據(jù),然而記憶幾乎一直是神經(jīng)網(wǎng)絡(luò)中令人頭疼的現(xiàn)象,它限制了模型的泛化能力,還引發(fā)隱私問題,因此導(dǎo)致測(cè)試時(shí)性能不佳。
基于此,谷歌認(rèn)為需要一個(gè)在線元模型來學(xué)習(xí)如何在測(cè)試時(shí)記憶或忘記數(shù)據(jù)。在這種設(shè)置下,模型學(xué)習(xí)一個(gè)能夠記憶的函數(shù),但不會(huì)過擬合訓(xùn)練數(shù)據(jù),從而在測(cè)試時(shí)實(shí)現(xiàn)更好的泛化性能。
學(xué)習(xí)過程和意外指標(biāo)(Learning Process and Surprise Metric)。訓(xùn)練長期記憶的關(guān)鍵思路是將訓(xùn)練視為在線學(xué)習(xí)問題,其中將過去信息 x_1, …, x_t-1 壓縮到長期神經(jīng)記憶模塊中。人類往往能夠記住背離預(yù)期(令人驚訝)的事件,受此啟發(fā),模型意外可以簡單定義為它相對(duì)于輸入的梯度。梯度越大,輸入數(shù)據(jù)與過去數(shù)據(jù)的偏差就越大。因此,使用這個(gè)意外分?jǐn)?shù),可以將記憶更新如下:
這一意外指標(biāo)可以導(dǎo)致在重大意外時(shí)刻之后出現(xiàn)重要信息缺失。從人類記憶的角度來看,即使一個(gè)事件令人難忘,但它可能不會(huì)在長時(shí)間內(nèi)持續(xù)讓我們感到驚訝。為了改進(jìn)這一現(xiàn)象,谷歌將意外指標(biāo)分解為了(1)過去意外,它衡量最近過去的意外程度;(2)瞬時(shí)意外,它衡量傳入數(shù)據(jù)的意外。
這些意外指標(biāo)基于一個(gè)損失函數(shù),它就是我們的記憶在測(cè)試時(shí)學(xué)習(xí)充當(dāng)?shù)哪繕?biāo)。也就是說,記憶模塊是一個(gè)元模型,它基于損失函數(shù)
來學(xué)習(xí)一個(gè)函數(shù)。
在本文中,谷歌則專注于聯(lián)想記憶,目的是將過去的數(shù)據(jù)存儲(chǔ)為鍵(keys)和值(values)對(duì)。類似于 Transformer,在給定 x_t 的情況下,谷歌使用兩個(gè)線性層將 x_t 投影到鍵和值中:
接下來,谷歌希望記憶模塊可以學(xué)習(xí)鍵和值之間的關(guān)聯(lián),為此將損失定義如下:
遺忘機(jī)制(Forgetting Mechanism)。在處理非常大的序列(比如百萬 tokens)時(shí),管理哪些過去信息應(yīng)該被遺忘非常重要,即使使用深度或者非常大的矩陣值記憶時(shí)也是如此。因此,谷歌使用了一種自適應(yīng)遺忘機(jī)制,允許記憶忘記不再需要的信息,從而更好地管理有限的記憶容量。也就是說,給定下一個(gè) token x_t,谷歌將更新規(guī)則做如下修改:
記憶架構(gòu)(Memory Architecture)。谷歌重點(diǎn)將具有 L_M≥1 層的簡單 MLP 作為長期記憶架構(gòu),選擇它們的原因在于希望能夠更好地激勵(lì)長期記憶設(shè)計(jì)以及將其融入架構(gòu)的方法。谷歌表示,本文的架構(gòu)開辟了一個(gè)新的研究方向,有助于設(shè)計(jì)更有效且高效記憶數(shù)據(jù)的神經(jīng)架構(gòu)。
檢索記憶(Retrieving a Memory)。在探討如何設(shè)計(jì)和訓(xùn)練一個(gè)可以在測(cè)試時(shí)學(xué)習(xí)記憶的長期記憶模塊之后,剩下的關(guān)鍵問題便是如何從記憶中檢索信息?谷歌僅僅使用了沒有更新權(quán)重的前向傳遞(即推理)來檢索與查詢相對(duì)應(yīng)的記憶。在形式上,給定一個(gè)輸入 x_t,谷歌使用線性層 W_Q 來投影輸入,即 q_t = x_tW_Q,并通過以下公式從記憶 y_t 中檢索相應(yīng)(或有用)的信息。
并行化長期記憶訓(xùn)練
理論上,長期記憶模塊的訓(xùn)練需要 FLOPS,其中 N 為序列長度。不過在實(shí)踐中,我們需要并行化訓(xùn)練過程并充分利用 TPU、GPU 等硬件加速器,同時(shí)需要張量化該過程并使用更多矩陣乘法(matmuls)。
接下來,谷歌表示,使用小批量梯度下降、數(shù)據(jù)學(xué)習(xí)率和權(quán)重衰減來計(jì)算內(nèi)循環(huán)權(quán)重的方式可以重新來表示,以便它只使用矩陣乘法和求和(sums)。這里將序列拆分為大小為 b ≥ 1 的塊,并將小批量梯度下降寫做:
此外,谷歌解釋了 M_t = W_t 為線性時(shí)的情況。對(duì)于 N_p ≥ 2 的 MLP,過程類似。使用本文的損失函數(shù)可以得到如下:
最后,谷歌擴(kuò)展這一表示,以便可以合并動(dòng)量項(xiàng)。在具有動(dòng)量的塊式梯度下降中,如果觀察動(dòng)量項(xiàng)則可以得到如下:
作為函數(shù)塊的參數(shù)(Parameters as the Function of Chunks)。谷歌沒有讓參數(shù) a_t、θ_t 和 η_t 依賴于輸入,而是讓它們成為函數(shù)塊。盡管失去了表達(dá)能力,但可以幫助更快地訓(xùn)練。在這種情況下,谷歌在每個(gè)塊中對(duì)每一個(gè) a、θ 和 η 都使用了相同的值。在實(shí)驗(yàn)中,谷歌將這些參數(shù)作為了 token 的函數(shù),并表示,這種簡化(即作為塊函數(shù))可能是未來工作感興趣的地方,以便以更高效的方式訓(xùn)練更大的模型。
下圖 1 展示了如何并行并在使用矩陣乘法時(shí)完成神經(jīng)記憶訓(xùn)練。
如何融合記憶?
接下來需要解決的一個(gè)重要問題是:如何有效且高效地將神經(jīng)記憶融合到深度學(xué)習(xí)架構(gòu)中?
從記憶的角度來看,Transformer 中的 K 和 V 矩陣對(duì)可以解釋為聯(lián)想記憶塊。由于它們對(duì)依賴關(guān)系的精確建模以及有限的上下文窗口,它們可以被用作短期記憶模塊,以處理當(dāng)前上下文窗口大小。另一方面,神經(jīng)記憶能夠不斷從數(shù)據(jù)中學(xué)習(xí)并存儲(chǔ)在其權(quán)重中,因而可以發(fā)揮長期記憶的作用。谷歌通過三個(gè)不同的 Titans 變體來回答以上問題。
記憶作為上下文(Memory as a Context,MAC)
Titans 的第一個(gè)變體 MAC 的架構(gòu)設(shè)計(jì)如下圖 2 所示,將記憶作為當(dāng)前信息的上下文。
也就是說,給定一個(gè)長序列,首先將序列分成固定大小的片段 S^(??),其中 ?? = 1,...,??/??。給定傳入片段 S^(??),谷歌將它視為當(dāng)前上下文,將其過去的片段視為歷史信息。因此,谷歌讓 M_??-1 成為片段 S^(??) 之前的長期記憶狀態(tài),使用輸入上下文作為對(duì)記憶 M^??-1 的查詢,以從長期記憶中檢索相應(yīng)的信息。谷歌如下所示檢索與 S^(??) 相對(duì)應(yīng)的過去信息:
接下來將這些歷史信息與持久記憶參數(shù)一起作為注意力模塊的輸入序列:
該架構(gòu)具有兩個(gè)關(guān)鍵優(yōu)勢(shì):一是注意力模塊同時(shí)具有歷史和當(dāng)前上下文,能夠根據(jù)當(dāng)前數(shù)據(jù)決定是否需要長期記憶信息,二是注意力模塊幫助長期記憶只存儲(chǔ)來自當(dāng)前上下文的有用信息。這意味著,并非每個(gè)片段中的所有 token 都是有用的,記憶所有 token 可能會(huì)導(dǎo)致內(nèi)存溢出。因此,注意力模塊幫助記憶了解哪些信息是有用的,從而更好地管理內(nèi)存容量。
另外,在測(cè)試時(shí),(i)持久記憶參數(shù)是固定的,它們編碼了有關(guān)任務(wù)的知識(shí),不應(yīng)改變;(ii)注意力模塊權(quán)重是上下文學(xué)習(xí)器;(iii)長期記憶模塊在測(cè)試時(shí)仍然學(xué)習(xí)(記憶)信息。也就是說,即使在測(cè)試時(shí),神經(jīng)記憶的權(quán)重也會(huì)更新,這是因?yàn)闄?quán)重對(duì)過去已久的抽象進(jìn)行了編碼。
記憶作為門(Memory as a Gate,MAG)
Titans 第二個(gè)變體 MAG 的架構(gòu)設(shè)計(jì)如下圖 4 所示:
在其中一個(gè)分支中,谷歌直接使用輸入數(shù)據(jù)來更新長期記憶;在第二個(gè)分支中,谷歌使用了滑動(dòng)窗口注意力(SWA):
該架構(gòu)的整體注意力掩碼如下圖 3b 所示,其中滑動(dòng)窗口注意力(SWA)充當(dāng)精確的短期記憶,而神經(jīng)記憶模塊充當(dāng)模型的衰減記憶。該設(shè)計(jì)也可以看作是多頭架構(gòu),其中各頭的結(jié)構(gòu)不同。
記憶作為層(Memory as a Layer,MAL)
Titans 的第三個(gè)變體 MAL 使用了深度神經(jīng)網(wǎng)絡(luò),這種架構(gòu)設(shè)計(jì)在文獻(xiàn)中更為常見,其中混合模型堆疊具有完整或滑動(dòng)窗口注意力的循環(huán)模型。
給定輸入 x,可以得到以下:
其中 SW-Attn 是滑動(dòng)窗口注意力。
無注意力記憶(Memory Without Attention)。從記憶的角度來看,谷歌期望記憶系統(tǒng)的每個(gè)組件都能獨(dú)立工作,即使其他組件受到了干擾。因此,即使沒有短期記憶(即注意力),長期記憶模塊仍然應(yīng)該是一個(gè)強(qiáng)大的模型。谷歌在實(shí)驗(yàn)中將這種變體稱為 Titans (LMM)。
架構(gòu)細(xì)節(jié)
在所有塊中,谷歌使用了殘差連接;在實(shí)現(xiàn)中,谷歌使用 SiLU (.) 激活函數(shù)作為計(jì)算查詢、鍵和值的非線性激活,并使用對(duì)查詢和鍵進(jìn)行歸一化。
卷積(Convolution)。遵循最近的現(xiàn)代線性循環(huán)模型,谷歌在每個(gè)查詢、鍵和值投影后都融合了一個(gè) 1D 深度可分離卷積層。這些 1D 卷積可以提升性能,并且計(jì)算高效。
門控(Gating)。谷歌還在最終輸出投影之前利用線性層進(jìn)行歸一化和門控。
實(shí)驗(yàn)結(jié)果
谷歌在實(shí)驗(yàn)部分關(guān)注上述三種 Titans 變體,分別是 MAC、MAG 和 MAL,以及單獨(dú)的神經(jīng)記憶模塊。對(duì)于每個(gè)模型,谷歌使用了四種尺寸的模型,參數(shù)分別是 (i) 170M、(ii) 340M、(iii) 400M 和 (iv) 760M。
語言建模
谷歌首先關(guān)注模型在語言建模和常識(shí)推理任務(wù)中的困惑度。下表 1 報(bào)告了 Titans 變體和三種不同大小(340M、400M 和 760M)基線的結(jié)果。在包括 Transformer++ 在內(nèi)的非混合模型中,神經(jīng)記憶模塊在困惑度和準(zhǔn)確度測(cè)量方面均取得了最佳性能。
谷歌還發(fā)現(xiàn),Titans 的三種變體(MAC, MAG 和 MAL)都優(yōu)于 Samba (Mamba + 注意力)和 Gated DeltaNet-H2(Gated DeltaNet + 注意力)。
大海撈針
下表 2 結(jié)果顯示,與基線相比,神經(jīng)記憶模塊均取得了最佳結(jié)果。
谷歌將這種卓越的表現(xiàn)歸因于 Titans 與現(xiàn)有序列模型的三個(gè)關(guān)鍵差異:(1)與 TTT 相比,神經(jīng)記憶能夠通過使用動(dòng)量和遺忘機(jī)制(即權(quán)重衰減)更好地處理記憶容量。因此,隨著序列長度的增加,神經(jīng)記憶的性能不會(huì)下降,呈現(xiàn)出一致的趨勢(shì);(2)與具有門控(遺忘)機(jī)制的 Mamba2 相比,Titans 具有深度非線性記憶,從而實(shí)現(xiàn)了更好的記憶管理。此外,與神經(jīng)記憶和 DeltaNet 不同,Mamba2 無法移除記憶,因此在增加序列長度時(shí),其性能會(huì)出現(xiàn)顯著下降;(3)與 DeltaNet 相比,盡管它能夠使用增量規(guī)則移除記憶,但無法擦除記憶,缺乏遺忘機(jī)制。
最終,正如預(yù)期的那樣,使用 Titans 變體時(shí)能看到相當(dāng)或更好的結(jié)果,其中最佳結(jié)果來自 MAC。
BABILong 基準(zhǔn)
在微調(diào)設(shè)置中,谷歌將小型微調(diào)版本的 Titans (MAC) 與其他模型進(jìn)行了比較。
Titans 和基線的結(jié)果如下圖 6b 所示。Titans 的表現(xiàn)優(yōu)于所有模型,甚至比 GPT4 這樣的超大型模型還要好。此外,與基于 Transformer 的 RMT 等記憶模型相比,Titans 表現(xiàn)出更好的性能,這主要?dú)w功于其強(qiáng)大的記憶。
深度記憶的影響
接下來的實(shí)驗(yàn)評(píng)估了深度記憶對(duì) wall-clock 訓(xùn)練時(shí)間和模型性能的影響。
下圖 7 中報(bào)告了 Titans(LMM)和基線的困惑度與序列長度的關(guān)系。有趣的是,隨著記憶深度的增加,該模型可以在所有序列長度上實(shí)現(xiàn)更好的困惑度。此外,當(dāng)模型的參數(shù)量較少時(shí),更深的記憶模塊對(duì)序列長度的魯棒性更強(qiáng)。隨著參數(shù)量的增加,所有模型在較長的序列上都表現(xiàn)出更好的性能。
時(shí)序預(yù)測(cè)
為了展示記憶模塊在更廣泛任務(wù)中的有效性,谷歌評(píng)估了 Titans 在時(shí)序預(yù)測(cè)任務(wù)中的表現(xiàn)。結(jié)果如下表 3 所示,谷歌的神經(jīng)記憶模塊優(yōu)于所有基線,包括基于 Mamba、線性和 Transformer 的架構(gòu)。
DNA 建模
谷歌還進(jìn)一步評(píng)估了神經(jīng)記憶模塊在 DNA 建模任務(wù)上的表現(xiàn),結(jié)果如下 4 所示,相較于當(dāng)前的 SOTA 架構(gòu),Titans(LMM)在不同的下游基因組任務(wù)中仍具有競(jìng)爭力。
效率
谷歌還對(duì) Titans 與當(dāng)前 SOTA 序列模型的效率進(jìn)行了比較,下圖 9 顯示了不同序列長度 x 批大小的模型的訓(xùn)練吞吐量。可以看到,谷歌神經(jīng)記憶模塊比 Mamba2 和 Gated DeltaNet 稍慢,不過 Titans (MAL) 比基線和神經(jīng)記憶模塊都要快。
更多技術(shù)細(xì)節(jié)和實(shí)驗(yàn)結(jié)果請(qǐng)參閱原論文。


2024-09-30 14:10:00
2022-10-28 16:24:33




