Bengio等人新作:注意力可被視為RNN,新模型媲美Transformer,但超級省內(nèi)存
序列建模的進(jìn)展具有極大的影響力,因為它們在廣泛的應(yīng)用中發(fā)揮著重要作用,包括強(qiáng)化學(xué)習(xí)(例如,機(jī)器人和自動駕駛)、時間序列分類(例如,金融欺詐檢測和醫(yī)學(xué)診斷)等。
在過去的幾年里,Transformer 的出現(xiàn)標(biāo)志著序列建模中的一個重大突破,這主要得益于 Transformer 提供了一種能夠利用 GPU 并行處理的高性能架構(gòu)。
然而,Transformer 在推理時計算開銷很大,主要在于內(nèi)存和計算需求呈二次擴(kuò)展,從而限制了其在低資源環(huán)境中的應(yīng)用(例如,移動和嵌入式設(shè)備)。盡管可以采用 KV 緩存等技術(shù)提高推理效率,但 Transformer 對于低資源領(lǐng)域來說仍然非常昂貴,原因在于:(1)隨 token 數(shù)量線性增加的內(nèi)存,以及(2)緩存所有先前的 token 到模型中。在具有長上下文(即大量 token)的環(huán)境中,這一問題對 Transformer 推理的影響更大。
為了解決這個問題,加拿大皇家銀行 AI 研究所 Borealis AI、蒙特利爾大學(xué)的研究者在論文《Attention as an RNN 》中給出了解決方案。值得一提的是,我們發(fā)現(xiàn)圖靈獎得主 Yoshua Bengio 出現(xiàn)在作者一欄里。
- 論文地址:https://arxiv.org/pdf/2405.13956
- 論文標(biāo)題:Attention as an RNN?
具體而言,研究者首先檢查了 Transformer 中的注意力機(jī)制,這是導(dǎo)致 Transformer 計算復(fù)雜度呈二次增長的組件。該研究表明注意力機(jī)制可以被視為一種特殊的循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN),具有高效計算的多對一(many-to-one)RNN 輸出的能力。利用注意力的 RNN 公式,該研究展示了流行的基于注意力的模型(例如 Transformer 和 Perceiver)可以被視為 RNN 變體。
然而,與 LSTM、GRU 等傳統(tǒng) RNN 不同,Transformer 和 Perceiver 等流行的注意力模型雖然可以被視為 RNN 變體。但遺憾的是,它們無法高效地使用新 token 進(jìn)行更新。
為了解決這個問題,該研究引入了一種基于并行前綴掃描(prefix scan)算法的新的注意力公式,該公式能夠高效地計算注意力的多對多(many-to-many)RNN 輸出,從而實現(xiàn)高效的更新。
在此新注意力公式的基礎(chǔ)上,該研究提出了 Aaren([A] ttention [a] s a [re] current neural [n] etwork),這是一種計算效率很高的模塊,不僅可以像 Transformer 一樣并行訓(xùn)練,還可以像 RNN 一樣高效更新。
實驗結(jié)果表明,Aaren 在 38 個數(shù)據(jù)集上的表現(xiàn)與 Transformer 相當(dāng),這些數(shù)據(jù)集涵蓋了四種常見的序列數(shù)據(jù)設(shè)置:強(qiáng)化學(xué)習(xí)、事件預(yù)測、時間序列分類和時間序列預(yù)測任務(wù),同時在時間和內(nèi)存方面更加高效。
方法介紹
為了解決上述問題,作者提出了一種基于注意力的高效模塊,它能夠利用 GPU 并行性,同時又能高效更新。
首先,作者在第 3.1 節(jié)中表明,注意力可被視為一種 RNN,具有高效計算多對一 RNN(圖 1a)輸出的特殊能力。利用注意力的 RNN 形式,作者進(jìn)一步說明,基于注意力的流行模型,如 Transformer(圖 1b)和 Perceiver(圖 1c),可以被視為 RNN。然而,與傳統(tǒng)的 RNN 不同的是,這些模型無法根據(jù)新 token 有效地更新自身,從而限制了它們在數(shù)據(jù)以流的形式到達(dá)的序列問題中的潛力。
為了解決這個問題,作者在第 3.2 節(jié)中介紹了一種基于并行前綴掃描算法的多對多 RNN 計算注意力的高效方法。在此基礎(chǔ)上,作者在第 3.3 節(jié)中介紹了 Aaren—— 一個計算效率高的模塊,它不僅可以并行訓(xùn)練(就像 Transformer),還可以在推理時用新 token 高效更新,推理只需要恒定的內(nèi)存(就像傳統(tǒng) RNN)。
將注意力視為一個多對一 RNN
查詢向量 q 的注意力可被視為一個函數(shù),它通過 N 個上下文 token x_1:N 的鍵和值
將其映射到單一輸出 o_N = Attention (q, k_1:N , v_1:N ) 。給定 s_i = dot (q,k_i),輸出 o_N 可表述為:
,分母為
。將注意力視為 RNN,可以在 k = 1,...,...... 時,以滾動求和的方式迭代計算
和
。然而,在實踐中,這種實現(xiàn)方式并不穩(wěn)定,會因有限的精度表示和可能非常小或非常大的指數(shù)(即 exp (s))而遇到數(shù)值問題。為了緩解這一問題,作者用累積最大值項
來重寫遞推公式,計算
和
。值得注意的是,最終結(jié)果是相同的
,m_k 的循環(huán)計算如下:
計算注意力的方法:通過將注意力視為一個 RNN,可以看到計算注意力的不同方法:在 O (1) 內(nèi)存中逐個 token 循環(huán)計算(即順序計算);或以傳統(tǒng)方式計算(即并行計算),需要線性 O (N) 內(nèi)存。由于注意力可以被看作是一個 RNN,因此計算注意力的傳統(tǒng)方法也可以被看作是計算注意力多對一 RNN 輸出的高效方法,即 RNN 的輸出以多個上下文 token 為輸入,但在 RNN 結(jié)束時只輸出一個 token(見圖 1a)。最后,也可以將注意力計算為一個逐塊處理 token 的 RNN,而不是完全按順序或完全并行計算,這需要 O (b) 內(nèi)存,其中 b 是塊的大小。
將現(xiàn)有的注意力模型視為 RNN。通過將注意力視為 RNN,現(xiàn)有的基于注意力的模型也可以被視為 RNN 的變體。例如,Transformer 的自注意力是 RNN(圖 1b),上下文 token 是其初始隱藏狀態(tài)。Perceiver 的交叉注意力是 RNN(圖 1c),其初始隱藏狀態(tài)是與上下文相關(guān)的潛變量。通過利用其注意力機(jī)制的 RNN 形式,這些現(xiàn)有模型可以高效地計算其輸出存儲。
然而,當(dāng)將現(xiàn)有的基于注意力的模型(如 Transformers)視為 RNN 時,這些模型又缺乏傳統(tǒng) RNN(如 LSTM 和 GRU)中常見的重要屬性。
值得注意的是,LSTM 和 GRU 能夠僅在 O (1) 常量內(nèi)存和計算中使用新 token 有效地更新自身,相比之下, Transformer 的 RNN 視圖(見圖 1b)會通過將一個新的 token 作為初始狀態(tài)添加一個新的 RNN 來處理新 token。這個新的 RNN 處理所有先前的 token,需要 O (N) 的線性計算量。
在 Perceiver 中,由于其架構(gòu)的原因,潛變量(圖 1c 中的 L_i)是依賴于輸入的,這意味著它們的值在接收新 token 時會發(fā)生變化。由于其 RNN 的初始隱藏狀態(tài)(即潛變量)發(fā)生變化,Perceiver 因此需要從頭開始重新計算其 RNN,需要 O (NL) 的線性計算量,其中 N 是 token 的數(shù)量,L 是潛變量的數(shù)量。
將注意力視為一個多對多 RNN
針對這些局限性,作者建議開發(fā)一種基于注意力的模型,利用 RNN 公式的能力來執(zhí)行高效更新。為此,作者首先引入了一種高效的并行化方法,將注意力作為多對多 RNN 計算,即并行計算
的方法。為此,作者利用并行前綴掃描算法(見算法 1),這是一種通過關(guān)聯(lián)算子 ⊕ 從 N 個連續(xù)數(shù)據(jù)點計算 N 個前綴的并行計算方法。該算法可高效計算
回顧
,其中
,
為了高效計算
,可以通過并行掃描算法計算
和
,然后結(jié)合 a_k 和 c_k 計算
。
為此,作者提出了以下關(guān)聯(lián)算子⊕,該算子作用于形式為(m_A、u_A、w_A)的三元組,其中 A 是一組索引,
,
,
。并行掃描算法的輸入為
。該算法遞歸應(yīng)用算子 ⊕,其工作原理如下:
,其中,
,
。
在完成遞歸應(yīng)用算子后,算法輸出
。也被稱作
。結(jié)合輸出元組的最后兩個值,檢索
從而產(chǎn)生一種高效的并行方法,將注意力計算為多對多 RNN(圖 3)。
Aaren:[A] ttention [a] s a [re] current neural [n] etwork
Aaren 的接口與 Transformer 相同,即將 N 個輸入映射到 N 個輸出,而第 i 個輸出是第 1 到第 i 個輸入的聚合。此外,Aaren 還自然可堆疊,并且能夠計算每個序列 token 的單獨損失項。然而,與使用因果自注意力的 Transformers 不同,Aaren 使用上述計算注意力的方法作為多對多 RNN,使其更加高效。Aaren 形式如下:
與 Transformer 不同,在 Transformer 中查詢是輸入到注意力的 token 之一,而在 Aaren 中,查詢 token q 是在訓(xùn)練過程中通過反向傳播學(xué)習(xí)得到的。
下圖展示了一個堆疊 Aaren 模型的例子,該模型的輸入上下文 token 為 x_1:3,輸出為 y_1:3。值得注意的是,由于 Aaren 利用了 RNN 形式的注意力機(jī)制,堆疊 Aarens 也相當(dāng)于堆疊 RNN。因此,Aarens 也能夠高效地用新 token 進(jìn)行更新,即 y_k 的迭代計算僅需要常量計算,因為它僅依賴于 h_k-1 和 x_k。
基于 Transformer 的模型需要線性內(nèi)存(使用 KV 緩存時)并且需要存儲所有先前的 token ,包括中間 Transformer 層中的那些,但基于 Aaren 的模型只需要常量內(nèi)存,并且不需要存儲所有先前的 token ,這使得 Aarens 在計算效率上顯著優(yōu)于 Transformer。
實驗
實驗部分的目標(biāo)是比較 Aaren 和 Transformer 在性能和所需資源(時間和內(nèi)存)方面的表現(xiàn)。為了進(jìn)行全面比較,作者在四個問題上進(jìn)行了評估:強(qiáng)化學(xué)習(xí)、事件預(yù)測、時間序列預(yù)測和時間序列分類。
強(qiáng)化學(xué)習(xí)
作者首先比較了 Aaren 和 Transformer 在強(qiáng)化學(xué)習(xí)方面的表現(xiàn)。強(qiáng)化學(xué)習(xí)在機(jī)器人、推薦引擎和交通控制等交互式環(huán)境中很受歡迎。
表 1 中的結(jié)果表明,在所有 12 個數(shù)據(jù)集和 4 種環(huán)境中,Aaren 與 Transformer 的性能都不相上下。不過,與 Transformer 不同的是,Aaren 也是一種 RNN,因此能夠在持續(xù)計算中高效處理新的環(huán)境交互,從而更適合強(qiáng)化學(xué)習(xí)。
事件預(yù)測
接下來,作者比較了 Aaren 和 Transformer 在事件預(yù)測方面的表現(xiàn)。事件預(yù)測在許多現(xiàn)實環(huán)境中都很流行,例如金融(如交易)、醫(yī)療保健(如患者觀察)和電子商務(wù)(如購買)。
表 2 中的結(jié)果顯示,Aaren 在所有數(shù)據(jù)集上的表現(xiàn)都與 Transformer 相當(dāng)。Aaren 能夠高效處理新輸入,這在事件預(yù)測環(huán)境中尤為有用,因為在這種環(huán)境中,事件會以不規(guī)則流的形式出現(xiàn)。
時間序列預(yù)測
然后,作者比較了 Aaren 和 Transformer 在時間序列預(yù)測方面的表現(xiàn)。時間序列預(yù)測模型通常用在與氣候(如天氣)、能源(如供需)和經(jīng)濟(jì)(如股票價格)相關(guān)的領(lǐng)域。
表 3 中的結(jié)果顯示,在所有數(shù)據(jù)集上,Aaren 與 Transformer 的性能相當(dāng)。不過,與 Transformer 不同的是,Aaren 能高效處理時間序列數(shù)據(jù),因此更適合與時間序列相關(guān)的領(lǐng)域。
時間序列分類
接下來,作者比較了 Aaren 和 Transformer 在時間序列分類方面的表現(xiàn)。時間序列分類在許多重要的應(yīng)用中很常見,例如模式識別(如心電圖)、異常檢測(如銀行欺詐)或故障預(yù)測(如電網(wǎng)波動)。
從表 4 中可以看出,在所有數(shù)據(jù)集上,Aaren 與 Transformer 的表現(xiàn)不相上下。
分析
最后,作者比較了 Aaren 和 Transformer 所需的資源。
內(nèi)存復(fù)雜性:在圖 5(左)中,作者比較了 Aaren 和 Transformer(使用 KV 緩存)在推理時的內(nèi)存使用情況。可以看到,伴隨 KV 緩存技術(shù)的使用,Transformer 的內(nèi)存使用量呈線性增長。相比之下,Aaren 只使用恒定的內(nèi)存,無論 token 數(shù)量如何增長,因此它的效率要高得多。
時間復(fù)雜度:在圖 5(右圖)中,作者比較了 Aaren 和 Transformer(使用 KV 緩存)按順序處理一串 token 所需的累計時間。對于 Transformer,累計計算量是 token 數(shù)的二次方,即 O (1 + 2 + ... + N) = O (N^2 )。相比之下,Aaren 的累計計算量是線性的。在圖中,可以看到模型所需的累計時間也是類似的結(jié)果。具體來說,Transformer 所需的累計時間呈二次增長,而 Aaren 所需的累計時間呈線性增長。
參數(shù)數(shù)量:由于要學(xué)習(xí)初始隱藏狀態(tài) q,Aaren 模塊需要的參數(shù)略多于 Transformer 模塊。不過,由于 q 只是一個向量,因此差別不大。通過在同類模型中進(jìn)行實證測量,作者發(fā)現(xiàn) Transformer 使用了 3, 152, 384 個參數(shù)。相比之下,等效的 Aaren 使用了 3, 152, 896 個參數(shù),參數(shù)增加量僅為 0.016%—— 對于內(nèi)存和時間復(fù)雜性的顯著差異來說,這只是微不足道的代價。
本文轉(zhuǎn)自 機(jī)器之心 ,作者:機(jī)器之心
