首個基于統(tǒng)計學(xué)的線性注意力機制ToST,高分拿下ICLR Spotlight
本文第一作者為加州大學(xué)伯克利分校三年級博士生吳梓陽,導(dǎo)師為馬毅教授。吳的主要研究方向為表征學(xué)習(xí)與多模態(tài)學(xué)習(xí)。該工作由多所學(xué)校與機構(gòu)的研究者共同完成,包括加州大學(xué)伯克利分校、賓夕法尼亞大學(xué)、密歇根大學(xué)、清華大學(xué)、憶生科技、香港大學(xué)、約翰·霍普金斯大學(xué)等。據(jù)悉,馬毅教授已受邀在今年四月的ICLR大會上就和此項成果相關(guān)的一系列白盒神經(jīng)網(wǎng)絡(luò)相關(guān)工作,進行為時一小時的主題報告(Keynote)。
Transformer 架構(gòu)在過去幾年中通過注意力機制在多個領(lǐng)域(如計算機視覺、自然語言處理和長序列任務(wù))中取得了非凡的成就。然而,其核心組件「自注意力機制」 的計算復(fù)雜度隨輸入 token 數(shù)量呈二次方增長,導(dǎo)致資源消耗巨大,難以擴展到更長的序列或更大的模型。
Token Statistics Transformer (ToST) 提出了一種新的注意力機制,它的時間復(fù)雜度是線性的。通過對序列特征的統(tǒng)計建模,ToST 提高了序列處理任務(wù)中的效率。文章探討了基于變分編碼率縮減(Variational Rate Reduction, VRR)的框架,并通過實驗驗證了其在不同任務(wù)中的性能,通過革新傳統(tǒng)注意力機制,解決了這些長期困擾 Transformer 架構(gòu)的效率瓶頸。
ToST 也作為 Spotlight 論文,入選了 ICLR 2025 大會。
- 論文標題:Token Statistics Transformer: Linear-Time Attention via Variational Rate Reduction
- 論文地址:https://arxiv.org/abs/2412.17810
- 項目主頁:https://robinwu218.github.io/ToST/
- 目前該工作已開源:https://github.com/RobinWu218/ToST
研究背景與動機
一直以來,自注意力機制依賴于對輸入 token 兩兩相似性的計算,這一過程雖然有效,但其資源開銷顯著;尤其當(dāng)輸入 token 數(shù)量極大時,傳統(tǒng)注意力機制(如 Transformer 中的全局注意力)在計算復(fù)雜度和內(nèi)存使用上的瓶頸問題愈發(fā)顯著。
為了應(yīng)對這一挑戰(zhàn),本文提出了一種基于統(tǒng)計學(xué)特征的注意力機制:Token Statistics Self-Attention (TSSA)。它通過避免兩兩相似性的計算,僅依賴于 token 特征的統(tǒng)計量,顯著降低了計算復(fù)雜度。
Token Statistics Transformer (ToST) 的架構(gòu)。Token Statistics Self-Attention (TSSA) 運算符通過對投影后的 token 進行行標量化變換,從而實現(xiàn)了線性復(fù)雜度。
核心方法
ToST 的核心方法是通過特定的概率分布函數(shù)對輸入序列進行建模,減少冗余信息并提取關(guān)鍵特征。具體包括:
1. 統(tǒng)計特征提取:對序列中的每個 token 提取其統(tǒng)計特征。
2. 變分編碼率縮減:利用 VRR 框架對特征進行壓縮,減少信息冗余。
3. 線性復(fù)雜度實現(xiàn):通過一系列優(yōu)化,其計算復(fù)雜度從 O (n2) 降低為 O (n)。
ToST 的方法概述。在 CRATE 的理論基礎(chǔ)上,ToST 通過幾何空間的結(jié)構(gòu)化特征實現(xiàn) token 分組和映射。
網(wǎng)絡(luò)架構(gòu)的推導(dǎo)
該團隊通過擴展先前的 CRATE 工作推導(dǎo)出網(wǎng)絡(luò)架構(gòu)。CRATE 顯示,一種 Transformer 風(fēng)格的架構(gòu)可以通過 "白盒" 架構(gòu)設(shè)計自然生成,其中網(wǎng)絡(luò)的每一層都旨在實現(xiàn)最大編碼率縮減目標 (MCR2) 的增量優(yōu)化步驟。
具體來說,該團隊推導(dǎo)了 MCR2 目標的一個新穎的變分形式,并表明通過對該變分目標進行展開梯度下降所得到的架構(gòu)會引入一種新的注意力模塊,稱為 Token Statistics Self-Attention (TSSA)。TSSA 擁有線性的計算和內(nèi)存復(fù)雜度,并從根本上不同于典型的注意力架構(gòu),其后者通過計算 token 之間的兩兩相似性來實現(xiàn)。
關(guān)鍵公式 MCR2 目標函數(shù)定義
技術(shù)細節(jié)
1. 線性時間注意力機制:Token Statistics Self-Attention (TSSA)
通過白盒設(shè)計方法(algorithmic unrolling),TSSA 從最大編碼率減少(Maximal Coding Rate Reduction, MCR2 )的變分形式中推導(dǎo)而來。
傳統(tǒng) Transformer 依賴于 pairwise 相似度計算,而 TSSA 則基于 token 特征的統(tǒng)計量構(gòu)建注意力機制,其計算復(fù)雜度從 O (n2) 降低為 O (n),內(nèi)存占用同樣顯著減少。
2. 創(chuàng)新性的網(wǎng)絡(luò)結(jié)構(gòu):Token Statistics Transformer (ToST)
ToST 通過將 TSSA 替代標準的自注意力模塊,不僅實現(xiàn)了顯著的效率提升,還增強了模型的可解釋性。
與傳統(tǒng)模型不同,ToST 架構(gòu)中的注意力操作基于統(tǒng)計量的低秩投影,通過減少不必要的計算路徑,大幅優(yōu)化了資源使用。
3. 理論支撐與數(shù)學(xué)推導(dǎo)
基于 MCR2 的變分形式,提出了一種新穎的壓縮項公式,可對大型矩陣進行有效的特征提取。
通過設(shè)計數(shù)據(jù)相關(guān)的低秩投影,TSSA 在保留關(guān)鍵信息的同時,消除了冗余方向。
實驗驗證與性能分析
實驗覆蓋了自然言語處理(NLP)、計算機視覺(CV)等多個領(lǐng)域的任務(wù),包括文本分類、機器翻譯、圖像識別等。結(jié)果表明,ToST 在保證模型性能的同時,大幅降低了計算資源消耗。
1. 計算和內(nèi)存的線性復(fù)雜度分析
實驗結(jié)果顯示,與現(xiàn)有的注意力機制相比,TSSA 的時間和內(nèi)存復(fù)雜度更低。具體而言,TSSA 的復(fù)雜度為 O (pn),顯著優(yōu)于傳統(tǒng) Transformer 的 O (n2)。
ToST 在計算時間和內(nèi)存使用上均隨序列長度實現(xiàn)線性擴展,使其顯著優(yōu)于標準 Transformer 的效率。如下:
復(fù)雜度分析對比
在 GPU 上評估的速度和內(nèi)存使用對比
2. 視覺任務(wù)性能分析
在 ImageNet-1k 等主流視覺數(shù)據(jù)集上的實驗表明,ToST 的性能可與傳統(tǒng) Transformer 架構(gòu)(如 ViT 和 XCiT)相媲美,同時顯著減少了模型參數(shù)量和計算開銷。
遷移學(xué)習(xí)實驗中,ToST 在 CIFAR、Oxford Flowers 等數(shù)據(jù)集上的表現(xiàn)進一步驗證了其在多種視覺任務(wù)中的適應(yīng)性。
結(jié)果展示了與傳統(tǒng) Transformer 相當(dāng)?shù)男阅?,同時在計算效率上顯著更高。
3. 長序列任務(wù)和語言建模
- 長序列任務(wù)
在長序列任務(wù)基準測試(如 Long-Range Arena)中,ToST 展現(xiàn)出優(yōu)異的長距離建模能力,其性能超越了現(xiàn)有 Transformer 變體。
- 語言建模
ToST 可以擴展并適用于多種任務(wù)場景,包括因果語言建模。針對語言建模,ToST 采用了一種因果版本的 TSSA,在多個數(shù)據(jù)集上實現(xiàn)了高效的預(yù)測能力。此外,即使在參數(shù)規(guī)模擴大的情況下,ToST 依然保持了優(yōu)異的時間和內(nèi)存效率。
NLP 任務(wù)中的表現(xiàn)
4. 有原理支持的模型設(shè)計
由于 ToST 是通過展開從學(xué)習(xí)目標中推導(dǎo)出來的,我們可以以有原理支持的方式逐層分析學(xué)習(xí)到的模型行為。
ToST 模型不同層次的 TSSA 輸出的變分壓縮項
5. 學(xué)習(xí)表示的可解釋性分析
ToST 通過統(tǒng)計量驅(qū)動的注意力機制,使每一層的注意力操作更加透明,便于解釋和分析。其分組機制展現(xiàn)了 token 特征在低維空間中的聚類效果,直觀反映了模型的決策過程。
ToST 在無需復(fù)雜的自監(jiān)督訓(xùn)練的情況下,自然生成了可解釋的注意力模式。
倒數(shù)第二個全局類注意力層中最后一個頭部的 [CLS] token 注意力圖的比較
在 TSSA 層中,可視化估計的隸屬矩陣 Π 的每一行(經(jīng)過重塑后)
可能對未來產(chǎn)生的影響
1. 大模型的高效化
隨著語言模型、生成模型和多模態(tài)模型規(guī)模的持續(xù)擴展,計算效率成為核心瓶頸。ToST 展示的統(tǒng)計量驅(qū)動注意力機制,為實現(xiàn)線性復(fù)雜度的大模型提供了可能性。
2. 推動 Transformer 的普適化應(yīng)用
高效的注意力機制使得 ToST 能夠更廣泛地應(yīng)用于資源受限場景,如邊緣計算、實時系統(tǒng)、嵌入式設(shè)備等。這為人工智能技術(shù)從中心化計算向分布式、邊緣化方向的發(fā)展奠定了基礎(chǔ)。
3. 多模態(tài)融合的可能性
ToST 的低復(fù)雜度機制為處理多模態(tài)長序列任務(wù)提供了新的技術(shù)框架,使未來多模態(tài)大模型在生成、分析和交互中的效率顯著提升。
4. 促進跨學(xué)科應(yīng)用
ToST 對數(shù)學(xué)理論與工程實現(xiàn)的有機結(jié)合,不僅在傳統(tǒng) AI 任務(wù)中表現(xiàn)突出,還可能推動其在新興領(lǐng)域(如量子計算、生物信息學(xué)和材料設(shè)計)中的應(yīng)用。
Token Statistics Transformer (ToST) 重塑了注意力機制,它不需要計算 token 之間的兩兩交互,而是基于投影后 token 特征的二階矩統(tǒng)計量構(gòu)建,其基于數(shù)據(jù)壓縮和表示學(xué)習(xí)的理論原則目標,為 Transformer 的發(fā)展開辟了新路徑。其基于統(tǒng)計特性的低復(fù)雜度設(shè)計,不僅優(yōu)化了現(xiàn)有架構(gòu)的性能,還為未來大模型的高效化、多模態(tài)融合和跨學(xué)科應(yīng)用提供了啟示。