想把半本《紅樓夢》搬進(jìn)ChatGPT輸入框?先把這個(gè)問題解決掉
過去兩年,斯坦福大學(xué) Hazy Research 實(shí)驗(yàn)室一直在從事一項(xiàng)重要的工作:增加序列長度。
他們有一種觀點(diǎn):更長的序列將開啟機(jī)器學(xué)習(xí)基礎(chǔ)模型的新時(shí)代 —— 模型可以從更長的上下文、多種媒體源、復(fù)雜的演示等中學(xué)習(xí)。
目前,這項(xiàng)研究已經(jīng)取得了新進(jìn)展。Hazy Research 實(shí)驗(yàn)室的 Tri Dao 和 Dan Fu 主導(dǎo)了 FlashAttention 算法的研究和推廣,他們證明了 32k 的序列長度是可能的,且在當(dāng)前這個(gè)基礎(chǔ)模型時(shí)代將得到廣泛應(yīng)用(OpenAI、Microsoft、NVIDIA 和其他公司的模型都在使用 FlashAttention 算法)。
- 論文地址:https://arxiv.org/abs/2205.14135
- 代碼地址:https://github.com/HazyResearch/flash-attention
正如 GPT4 的相關(guān)資料所指出的,它允許近 50 頁的文本作為上下文,而且像 Deepmind Gato 使用圖像作為上下文那樣實(shí)現(xiàn) tokenization/patching。
在這篇文章中,作者介紹了關(guān)于在高層級上增加序列長度的新方法,并提供了連接一組新原語的「橋梁」。
Transformer 變得越來越深,越來越寬,但在長序列上訓(xùn)練它們?nèi)匀缓芾щy。研究人員遇到的一個(gè)基本問題是,Transformer 的注意力層在序列長度方面是按二次方比例增長:就是說從 32k 長度增加到 64k 長度,成本不只增加 2 倍,而是增加了 4 倍。因此,這促使研究人員探索具有線性時(shí)間復(fù)雜度的序列長度模型。在 Hazy Research 實(shí)驗(yàn)室,這項(xiàng)工作從 Hippo 開始,然后是 S4、H3,再到現(xiàn)在的 Hyena。這些模型有可能處理數(shù)百萬、甚至十億級別的上下文長度。
FlashAttention 可以加速注意力并減少其內(nèi)存占用 —— 無需任何近似。「自從我們在 6 個(gè)月前發(fā)布 FlashAttention 以來,我們很高興看到許多組織和研究實(shí)驗(yàn)室采用 FlashAttention 來加速他們的訓(xùn)練和推理?!共┛椭袑懙?。
FlashAttention 是一種對注意力計(jì)算進(jìn)行重新排序并利用經(jīng)典技術(shù)(平鋪、重新計(jì)算)加快速度并將內(nèi)存使用從序列長度的二次減少到線性的算法。對于每個(gè)注意力頭,為了減少內(nèi)存讀 / 寫,F(xiàn)lashAttention 使用經(jīng)典的平鋪技術(shù)將查詢、鍵和值塊從 GPU HBM(其主內(nèi)存)加載到 SRAM(其快速緩存),計(jì)算關(guān)于該塊的注意力,并將輸出寫回 HBM。在大多數(shù)情況下,這種內(nèi)存讀 / 寫的減少帶來了顯著的加速(2-4 倍)。
接下來,讓我們看一下研究細(xì)節(jié)。
Long Range Arena 基準(zhǔn)和 S4
谷歌的研究人員在 2020 年推出了 Long Range Arena (LRA) 基準(zhǔn)測試,以評估不同模型處理長程依賴的能力。LRA 能夠測試一系列任務(wù),涵蓋多種不同的數(shù)據(jù)類型和模式,例如文本、圖像和數(shù)學(xué)表達(dá)式,序列長度可達(dá) 16K(Path-X:對已展開成像素的圖像進(jìn)行分類,沒有任何空間歸納偏置)。關(guān)于將 Transformer 擴(kuò)展到更長的序列方面已經(jīng)有很多出色的工作,但其中許多似乎會(huì)犧牲準(zhǔn)確性(如下圖所示)。請注意 Path-X 那一列:所有 Transformer 方法及其變體表現(xiàn)甚至不如隨機(jī)猜測。
現(xiàn)在讓我們認(rèn)識一下由 Albert Gu 主導(dǎo)研發(fā)的 S4。受到 LRA 基準(zhǔn)測試結(jié)果的啟發(fā),Albert Gu 想要找出如何更好地對長程依賴關(guān)系建模,在正交多項(xiàng)式和遞歸模型與卷積模型之間關(guān)系的長期研究基礎(chǔ)上,推出了 S4—— 一種基于結(jié)構(gòu)化狀態(tài)空間模型(SSMs)的新的序列模型。
很關(guān)鍵的一點(diǎn)是,SSM 在將長度為 N 的序列拓展到 2N 時(shí)的時(shí)間復(fù)雜度為,而不像注意力機(jī)制一樣呈平方級別增長!S4 成功地對 LRA 中的長程依賴進(jìn)行了建模,并成為首個(gè)在 Path-X 上獲得高于平均性能的模型(現(xiàn)在可以獲得 96.4%的準(zhǔn)確度?。W?S4 發(fā)布以來,許多研究人員在此基礎(chǔ)上發(fā)展和創(chuàng)新,出現(xiàn)了像 Scott Linderman 團(tuán)隊(duì)的 S5 模型、Ankit Gupta 的 DSS(以及 Hazy Research 實(shí)驗(yàn)室后續(xù)的 S4D)、Hasani 和 Lechner 的 Liquid-S4 等新模型。
另外,當(dāng) Hazy Research 發(fā)布 FlashAttention 時(shí),已經(jīng)能夠增加 Transformer 的序列長度。他們還發(fā)現(xiàn),僅通過將序列長度增加到 16K,Transformer 也能在 Path-X 上獲得不凡的表現(xiàn)(63%)。
建模方面的不足
但是 S4 在語言建模方面的質(zhì)量存在的差距高達(dá) 5% 的困惑度(對于上下文,這是 125M 模型和 6.7B 模型之間的差距)。為了縮小這一差距,研究人員研究了諸如聯(lián)想回憶之類的合成語言,以確定語言應(yīng)該具備哪些屬性。最終設(shè)計(jì)了 H3(Hungry Hungry Hippos):一個(gè)堆疊兩個(gè) SSM 的新層,并將它們的輸出與乘法門相乘。
使用 H3,Hazy Research 的研究人員替換了 GPT 式 Transformer 中的幾乎所有注意力層,并能夠在從 Pile 訓(xùn)練的 400B 規(guī)模的 token 時(shí),在困惑度和下游評估方面與 transformer 相媲美。
由于 H3 層建立在 SSM 上,因此在序列長度上,它的計(jì)算復(fù)雜度也以的速度增長。兩個(gè)注意力層使得整個(gè)模型的復(fù)雜度仍然是
,稍后會(huì)詳細(xì)討論這個(gè)問題。
當(dāng)然,Hazy Research 不是唯一考慮這個(gè)方向的人:GSS 也發(fā)現(xiàn)帶有門控的 SSM 可以與語言建模中的注意力很好地協(xié)同工作(這啟發(fā)了 H3),Meta 發(fā)布了 Mega 模型,它也將 SSM 和注意力結(jié)合起來,BiGS 模型則替換了 BERT-style 模型中的注意力,而 RWKV 一直在研究完全循環(huán)的方法。
新進(jìn)展:Hyena
根據(jù)前面的一系列工作,啟發(fā) Hazy Research 的研究人員開發(fā)了新的架構(gòu):Hyena。他們試圖擺脫 H3 中最后兩個(gè)注意力層,并獲得一個(gè)幾乎呈線性增長的模型,以適應(yīng)更長的序列長度。事實(shí)證明,兩個(gè)簡單的想法是找到答案的關(guān)鍵:
- 每個(gè) SSM 都可以看作是一個(gè)長度與輸入序列相同的卷積濾波器。因此,可以用一個(gè)大小等于輸入序列的卷積來替換 SSM,以獲得在相同計(jì)算量下更加強(qiáng)大的模型。具體來說,通過另一個(gè)小型神經(jīng)網(wǎng)絡(luò)來隱式地參數(shù)化卷積濾波器,這借鑒了關(guān)于神經(jīng)場文獻(xiàn)中的強(qiáng)大方法和 CKConv/FlexConv 的研究成果。此外,卷積可以在 O (NlogN) 的時(shí)間內(nèi)計(jì)算,其中 N 是序列長度,實(shí)現(xiàn)了近乎線性的擴(kuò)展;
- H3 中的門控行為可以概括為:H3 采用輸入的三個(gè)投影,并迭代地進(jìn)行卷積和應(yīng)用門控。在 Hyena 中,只需添加更多投影和更多的門,這有助于泛化到更具表現(xiàn)力的架構(gòu)并縮小與注意力的差距。
Hyena 首次提出了完全近線性時(shí)間卷積模型,它可以在困惑度和下游任務(wù)上與 Transformer 相匹配,并在實(shí)驗(yàn)中取得了很好的結(jié)果。并且在 PILE 的子集上訓(xùn)練了中小型模型,其表現(xiàn)與 Transformer 相媲美:
通過一些優(yōu)化(更多內(nèi)容見下文),在序列長度為 2K 時(shí),Hyena 模型的速度略慢于相同大小的 Transformer,但在更長的序列長度上會(huì)更快。
接下來仍需思考的是,究竟能將這些模型推廣到什么程度?是否能將它們擴(kuò)展到 PILE 的全尺寸(400B 個(gè) token)?如果結(jié)合 H3 和 Hyena 的思想精華,會(huì)發(fā)生什么,能走多遠(yuǎn)?
FFT 還是更基本的方法?
在所有這些模型中,一個(gè)常見的基本操作是 FFT,它是高效計(jì)算卷積的方式,只需要 O (NlogN) 的時(shí)間。然而,F(xiàn)FT 在現(xiàn)代硬件上的支持很差,因?yàn)楝F(xiàn)代硬件主流架構(gòu)是專用的矩陣乘法單元和 GEMMs(例如 NVIDIA GPU 上的張量核心)。
可以通過將 FFT 重寫為一系列矩陣乘法操作來縮小效率差距。研究小組的成員利用蝴蝶矩陣來探索稀疏訓(xùn)練,從而實(shí)現(xiàn)這個(gè)目標(biāo)。最近,Hazy Research 研究人員利用這個(gè)連接構(gòu)建了快速卷積算法,例如 FlashConv 和 FlashButterfly,通過使用蝴蝶分解將 FFT 計(jì)算轉(zhuǎn)化為一系列矩陣乘法操作。
此外,通過借鑒之前的工作,還能建立更深入的聯(lián)系:包括讓這些矩陣被學(xué)習(xí),這同樣需要相同的時(shí)間,但會(huì)增加額外的參數(shù)。研究人員已經(jīng)開始在一些小型數(shù)據(jù)集上探索這種聯(lián)系,并取得了初步成效。我們可以清楚地看到這種聯(lián)系可以帶來什么(比如,如何使其適用于語言模型):
這一擴(kuò)展值得更深入的探索:這個(gè)擴(kuò)展學(xué)習(xí)的是哪類轉(zhuǎn)換,它能讓你做什么?當(dāng)將它應(yīng)用于語言建模時(shí)會(huì)發(fā)生什么?
這些方向都是令人興奮的,接下來會(huì)是越來越長的序列和新的架構(gòu),讓我們能夠進(jìn)一步探索這個(gè)新領(lǐng)域。我們需要特別關(guān)注那些能夠受益于長序列模型的應(yīng)用,比如高分辨率成像、新的數(shù)據(jù)形式,能夠閱讀整本書的語言模型等等。想象一下,把整本書給語言模型閱讀,并讓它總結(jié)故事情節(jié),或者讓一個(gè)代碼生成模型基于你寫的代碼來生成新的代碼。這些可能的場景非常非常多,都是讓人感到非常興奮的事情。