GPT-4的32k輸入框還是不夠用?Unlimiformer把上下文長度拉到無限長
Transformer 是時(shí)下最強(qiáng)大的 seq2seq 架構(gòu)。預(yù)訓(xùn)練 transformer 通常具有 512(例如 BERT)或 1024 個(gè)(例如 BART)token 的個(gè)上下文窗口,這對于目前許多文本摘要數(shù)據(jù)集(XSum、CNN/DM)來說是足夠長的。
但 16384 并不是生成所需上下文長度的上限:涉及長篇敘事的任務(wù),如書籍摘要(Krys-′cinski et al.,2021)或敘事問答(Kociskyet al.,2018),通常輸入超過 10 萬個(gè) token。維基百科文章生成的挑戰(zhàn)集(Liu*et al.,2018)包含超過 50 萬個(gè) token 的輸入。生成式問答中的開放域任務(wù)可以從更大的輸入中綜合信息,例如回答關(guān)于維基百科上所有健在作者的文章的聚合屬性的問題。圖 1 根據(jù)常見的上下文窗口長度繪制了幾個(gè)流行的摘要和問答數(shù)據(jù)集的大??;最長的輸入比 Longformer 的上下文窗口長 34 倍以上。
在這些超長輸入的情況下,vanilla transformer 無法進(jìn)行縮放,因?yàn)樵⒁饬C(jī)制具有平方級的復(fù)雜度。長輸入 transformer 雖然比標(biāo)準(zhǔn) transformer 更高效,但仍需要大量的計(jì)算資源,這些資源隨著上下文窗口大小的增加而增加。此外,增加上下文窗口需要用新的上下文窗口大小從頭開始重新訓(xùn)練模型,計(jì)算上和環(huán)境上的代價(jià)都不小。
在「Unlimiformer: Long-Range Transformers with Unlimited Length Input」一文中,來自卡內(nèi)基梅隆大學(xué)的研究者引入了 Unlimiformer。這是一種基于檢索的方法,這種方法增強(qiáng)了預(yù)訓(xùn)練的語言模型,以在測試時(shí)接受無限長度的輸入。
論文鏈接:https://arxiv.org/pdf/2305.01625v1.pdf
Unlimiformer 可以被注入到任何現(xiàn)有的編碼器 - 解碼器 transformer 中,能夠處理長度不限的輸入。給定一個(gè)長的輸入序列,Unlimiformer 可以在所有輸入 token 的隱藏狀態(tài)上構(gòu)建一個(gè)數(shù)據(jù)存儲。然后,解碼器的標(biāo)準(zhǔn)交叉注意力機(jī)制能夠查詢數(shù)據(jù)存儲,并關(guān)注前 k 個(gè)輸入 token。數(shù)據(jù)存儲可以存儲在 GPU 或 CPU 內(nèi)存中,能夠次線性查詢。
Unlimiformer 可以直接應(yīng)用于經(jīng)過訓(xùn)練的模型,并且可以在沒有任何進(jìn)一步訓(xùn)練的情況下改進(jìn)現(xiàn)有的 checkpoint。Unlimiformer 經(jīng)過微調(diào)后,性能會得到進(jìn)一步提高。本文證明,Unlimiformer 可以應(yīng)用于多個(gè)基礎(chǔ)模型,如 BART(Lewis et al.,2020a)或 PRIMERA(Xiao et al.,2022),且無需添加權(quán)重和重新訓(xùn)練。在各種長程 seq2seq 數(shù)據(jù)集中,Unlimiformer 不僅在這些數(shù)據(jù)集上比 Longformer(Beltagy et al.,2020b)、SLED(Ivgi et al.,2022)和 Memorizing transformers(Wu et al.,2021)等強(qiáng)長程 Transformer 表現(xiàn)更好,而且本文還發(fā)現(xiàn) Unlimiform 可以應(yīng)用于 Longformer 編碼器模型之上,以進(jìn)行進(jìn)一步改進(jìn)。
Unlimiformer 技術(shù)原理
由于編碼器上下文窗口的大小是固定的,Transformer 的最大輸入長度受到限制。然而,在解碼過程中,不同的信息可能是相關(guān)的;此外,不同的注意力頭可能會關(guān)注不同類型的信息(Clark et al.,2019)。因此,固定的上下文窗口可能會在注意力不那么關(guān)注的 token 上浪費(fèi)精力。
在每個(gè)解碼步驟中,Unlimiformer 中每個(gè)注意力頭都會從全部輸入中選擇一個(gè)單獨(dú)的上下文窗口。通過將 Unlimiformer 查找注入解碼器來實(shí)現(xiàn):在進(jìn)入交叉注意力模塊之前,該模型在外部數(shù)據(jù)存儲中執(zhí)行 k 最近鄰 (kNN) 搜索,在每個(gè)解碼器層中的每個(gè)注意力頭中選一組 token 來參與。
編碼
為了將比模型的上下文窗口長度更長的輸入序列進(jìn)行編碼,本文按照 Ivgi et al. (2022) 的方法對輸入的重疊塊進(jìn)行編碼 (Ivgi et al. ,2022),只保留每個(gè) chunk 的輸出的中間一半,以確保編碼過程前后都有足夠的上下文。最后,本文使用 Faiss (Johnson et al., 2019) 等庫對數(shù)據(jù)存儲中的編碼輸入進(jìn)行索引(Johnson et al.,2019)。
檢索增強(qiáng)的交叉注意力機(jī)制
在標(biāo)準(zhǔn)的交叉注意力機(jī)制中,transformer 的解碼器關(guān)注編碼器的最終隱狀態(tài),編碼器通常截?cái)噍斎耄H對輸入序列中的前 k 個(gè) token 進(jìn)行編碼。
本文不是只關(guān)注輸入的這前 k 個(gè) token,對于每個(gè)交叉注意頭,都檢索更長的輸入系列的前 k 個(gè)隱狀態(tài),并只關(guān)注這前 k 個(gè)。這樣就能從整個(gè)輸入序列中檢索關(guān)鍵字,而不是截?cái)嚓P(guān)鍵字。在計(jì)算和 GPU 內(nèi)存方面,本文的方法也比處理所有輸入 token 更便宜,同時(shí)通常還能保留 99% 以上的注意力性能。
圖 2 顯示了本文對 seq2seq transformer 架構(gòu)的更改。使用編碼器對完整輸入進(jìn)行塊編碼,并將其存儲在數(shù)據(jù)存儲中;然后,解碼時(shí)查詢編碼的隱狀態(tài)數(shù)據(jù)存儲。kNN 搜索是非參數(shù)的,并且可以被注入到任何預(yù)訓(xùn)練的 seq2seq transformer 中,詳情如下。
實(shí)驗(yàn)結(jié)果
長文檔摘要
表 3 顯示了長文本(4k 及 16k 的 token 輸入)摘要數(shù)據(jù)集中的結(jié)果。
在表 4 的訓(xùn)練方法中,Unlimiformer 能夠在各項(xiàng)指標(biāo)上達(dá)到最優(yōu)。
書籍摘要
表 5 顯示了在書籍摘要上的結(jié)果??梢钥吹?,基于 BARTbase 和 PRIMERA,應(yīng)用 Unlimiformer 都能取得一定的改進(jìn)效果。