MemLong:用于長文本建模的記憶增強檢索
?一、結(jié)論寫在前面
論文標(biāo)題:MemLong: Memory-Augmented Retrieval for Long Text Modeling
論文鏈接:https://arxiv.org/pdf/2408.16967
LLMs在各個領(lǐng)域的最新進展取得了顯著的成功。然而,由于注意力機制的二次時間和空間復(fù)雜性以及生成過程中鍵值緩存的內(nèi)存消耗不斷增加,處理長上下文仍然是LLMs的一個重大挑戰(zhàn)。
論文提出了MemLong,一種高效且輕量化的方法,用于擴展大型語言模型(LLMs)的上下文窗口。其核心思想是將過去的上下文和知識存儲在一個不可訓(xùn)練的記憶庫中,并進一步利用這些存儲的嵌入來檢索塊級別的鍵值(K-V)對,以輸入到模型中。MemLong適用于任何僅解碼器的預(yù)訓(xùn)練語言模型,通過結(jié)合(1)一個額外的記憶檢索組件用于記憶和檢索,以及( 2 )一個檢索因果注意力模塊用于整合局部和記憶信息。MemLong的記憶和檢索過程如圖1(b)所示。在生成過程中,超出模型最大處理長度的文本被存儲為上下文信息在記憶庫中。隨后,在長文檔中生成最近的一個文本塊時,論文使用檢索器顯式地檢索過去的相關(guān)信息,通過索引對齊獲取額外的上下文信息。
MemLong提供了幾個優(yōu)勢:(1)分布一致性:與之前在信息存儲到記憶中時經(jīng)歷分布偏移的模型不同,MemLong確保緩存信息的分布保持一致。(2)訓(xùn)練高效:論文凍結(jié)模型的底層,只需微調(diào)上層,這大大減少了計算成本。在論文的實驗中,在0.5B個token上微調(diào)一個3B參數(shù)版本的MemLong僅需八塊3090 GPU運行八小時。(3)廣泛的上下文窗口:由于只需記憶單層的K-V對,MemLong能夠在單塊3090 GPU上輕松擴展上下文窗口至80k個token。
大量實驗表明,與其他領(lǐng)先的LLMs相比,MemLong在多個方面表現(xiàn)出優(yōu)越的性能。在多個長上下文語言建模數(shù)據(jù)集上,MemLong優(yōu)于OpenLLaMA和其他基于檢索的模型。在檢索增強的上下文學(xué)習(xí)任務(wù)中,MemLong比OpenLLaMA提高了多達10.2個百分點。
二、論文的簡單介紹
2.1 論文的背景
由于傳統(tǒng)的注意力機制的二次方時間和空間復(fù)雜性,擴展上下文長度具有挑戰(zhàn)性,這為涉及長序列任務(wù)的應(yīng)用帶來了顯著的限制,例如長文檔摘要和多輪對話。因此,LLMs通常被期望具備長時間的工作能力(即長上下文LLMs),以有效應(yīng)對這些苛刻的場景。
圖1:檢索增強生成(RAG)和MemLong的記憶檢索流程示意圖。(a) 當(dāng)檢索信息的長度超過模型處理能力時,RAG甚至可能降低生成性能(黃色)。(b) 論文的方法利用外部檢索器獲取歷史信息,然后將其作為K-v對傳遞給模型,而不是以文本形式傳遞。
為了解決計算瓶頸問題,已經(jīng)進行了大量努力。第一類工作側(cè)重于通過采用稀疏注意力操作來減少普通注意力機制的計算量。盡管這些方法可以將計算復(fù)雜度降低到大約 O(n),但通常會以模型容量為代價。因此,一些工作轉(zhuǎn)向了記憶選擇。這些方法,作為token級別的記憶選擇,可能導(dǎo)致語義信息的截斷。另一類最近的工作是檢索增強語言建模。這些工作通常引入檢索機制來增強模型處理長文本的能力。
然而,這些方法有幾個缺點。首先,由于訓(xùn)練過程中模型參數(shù)的變化,存儲在記憶中的信息可能會經(jīng)歷分布偏移。其次,這些方法通常需要重新訓(xùn)練,這在大型模型時代是不切實際的。最后,這些模型往往傾向于以犧牲預(yù)訓(xùn)練模型的原始能力為代價來處理長文本輸入。為了解決先前研究的局限性,論文提出了以下問題:論文能否利用檢索器的顯式檢索能力來近似模型內(nèi)部的隱式檢索過程?
2.2預(yù)備知識
2.2.1 任務(wù)定義
語言模型旨在定義一系列token的概率分布,從而有效地預(yù)測給定語言中序列的可能性。給定一個序列 x_1, ..., x_n,標(biāo)準(zhǔn)方法建模其概率為 p ( x_1, ..., x_n )= sum pθ ( x_i | x< i ),其中 x< i , := x_1, ..., x_i-1 是 x_i 之前的token序列。
與標(biāo)準(zhǔn)語言建模目標(biāo)不同,論文不僅使用當(dāng)前上下文進行下一個token預(yù)測,還利用外部檢索獲取相關(guān)信息并在模型的上層進行知識融合。具體來說,給定一個由 l個token組成的序列,每個塊的大小為 ν=l/τ,論文將序列劃分為 τ個不重疊的塊,記為 C=( c_1, ..., cν )。相應(yīng)地,其文本形式被劃分為ν個文本塊,記為 T=( t_1, ..., t_ν)。在每一步中,論文在下層對 c_i 進行因果語言建模,而在上層,論文對 t_i 進行細(xì)粒度可控檢索以融合額外信息。完成此操作后,論文的語言建模目標(biāo)變?yōu)?/p>
其中 R( t_i) 表示檢索 t_i 所在塊的鄰近塊。
2.2.2 模塊與操作定義
如圖 2 所示,Ret-Mem 模塊由一個檢索器(Retriever)和一個記憶組件(Memory)組成,用于信息交換。首先,論文將記憶組件定義為M,將檢索器定義為 R,以及它們對應(yīng)的操作M(.) 和 R(.)。此外,論文指定模型的維度為 d_model,檢索器的維度作為 d_ret 。記憶模塊包括兩個部分:K-V 對和相應(yīng)的表示嵌入。鍵和值的維度表示為R^d_model,而嵌入的維度表示為R^d_ret。必須強調(diào)的是,實際的檢索過程涉及表示塊的嵌入,而不是 K-V 對。檢索器本質(zhì)上是一個預(yù)訓(xùn)練的密集嵌入器,具有出色的表示能力。MemLong 使用它將每個塊編碼為表示嵌入。由于它為每個塊生成一維表示向量,即使內(nèi)存大小很大,內(nèi)存占用也保持最小。
圖 2 :MemLong 的一個示例:在較低層,模型保持靜態(tài),對整個塊 c_i執(zhí)行因果語言建模,隨后,c_i 以嵌入和 K-V 對的形式緩存。最后,上層被微調(diào)以協(xié)調(diào)檢索偏好并整合檢索到的內(nèi)容。
2.3 MemLong
2.3.1 概述
如圖 2 所示,每個步驟涉及一個塊 c_i 的輸入,其中該塊的原始文本為 t_i。在模型凍結(jié)的較低層,對整個 c_i 應(yīng)用標(biāo)準(zhǔn)因果注意力。對于較低層的最后一層,論文稱之為記憶層。在每次遍歷記憶層后,執(zhí)行兩個關(guān)鍵操作。第一個操作是檢索,由紅線表示,其中 t_i 用于獲取最相關(guān)的 K-V 對。第二個操作,由藍線表示,涉及緩存獲得的 K-V 對及其相關(guān)塊表示。在模型的上層,檢索到的 K-V 對與當(dāng)前輸入上下文整合,隨后調(diào)整模型參數(shù)以校準(zhǔn)檢索參考。后續(xù)部分將探討 MemLong 框架的各個方面及其細(xì)節(jié),包括檢索器和動態(tài)內(nèi)存管理 (Dynamic Memory Management),注意力重構(gòu) ( Attention Reformulation ),以及使用 MemLong 進行推理 (Inference with MemLong)。
2.3.2 檢索器與動態(tài)內(nèi)存管理
論文提供了一個關(guān)于檢索過程和內(nèi)存管理動態(tài)的全面解釋。
檢索過程。鑒于論文的目標(biāo)是用顯式檢索取代基于K-V對的傳統(tǒng)kNN檢索,論文旨在在可行的情況下預(yù)取所需信息,然后再緩存模型輸入。具體來說,對于每個潛在的查詢塊c^q=c_i及其對應(yīng)的文本塊t^q=t_i,論文首先將其傳遞給檢索器,然后獲得一個表示嵌入r^q = R ( t^q )。隨后,論文使用這個表示嵌入對\mathcal{M}中的嵌入進行檢索,以獲取所需的k個塊級索引。論文計算檢索表示r^q與存儲在內(nèi)存M中的嵌入之間的余弦相似度。最后,論文得到c^q的top-k索引z^q =Top K{Cos( r^q ) }。
記憶過程。記憶過程同步存儲來自記憶層的K-V對以及先前計算的用于檢索的表示嵌入,確保\mathrm{K-V}對的索引與其表示嵌入準(zhǔn)確對應(yīng)(見圖2,右側(cè),藍線)。對于每個可能的塊記憶c^m=c_i及其對應(yīng)的文本塊t^m=t_i,論文將記憶過程分為兩部分:第一部分詳細(xì)說明如何緩存K-V對,第二部分解釋如何存儲相應(yīng)的表示。首先,論文將c^m輸入到MemLong中,并從記憶層獲取輸出。論文的記憶操作非常高效,因為它僅涉及存儲檢索所需的表示r^m=r^q,從而避免了冗余。在所有塊對檢索完成后,記憶操作——表示為M ( k, v ; r^m )——同步更新記憶,包括鍵值對及其對應(yīng)的表示。
動態(tài)記憶更新。當(dāng)記憶溢出時,論文使用計數(shù)器智能更新記憶。在論文的實驗中,論文保留最新10%的記憶內(nèi)容,因其可能具有相關(guān)性,丟棄最舊的10%,因其可能已過時,并根據(jù)檢索頻率優(yōu)先處理中間的80%,刪除最少訪問的條目,直到記憶使用量降至50%。這種選擇性修剪平衡了時效性和相關(guān)性,保留了有價值的信息并刪除了不太相關(guān)的數(shù)據(jù)。
圖3:檢索因果注意力示意圖。局部因果注意力應(yīng)用于最近的上下文,而通過檢索方法獲得的塊級K-V對由于其歷史性質(zhì)而能夠?qū)崿F(xiàn)雙向注意力,而不會發(fā)生信息泄露。
2.3.3注意力重構(gòu)
在模型的可訓(xùn)練上層中,論文重構(gòu)了注意力機制以融合長期記憶。如圖3所示,與傳統(tǒng)的Transformer解碼層使用多頭注意力不同,論文提出了一種檢索因果注意力機制,將其擴展為聯(lián)合注意力機制,并提出了一種長期記憶融合過程,使得每個token既能關(guān)注局部上下文,也能關(guān)注具有完整和連續(xù)語義的塊級過去上下文。下一層的隱藏狀態(tài)H^l計算如下:
為了避免訓(xùn)練初期檢索注意力分?jǐn)?shù)o_m的干擾,論文采用了多注意力機制,遵循LLaMA-adapter
最后,論文將V和V連接起來得到H^l:
2.3.4 使用MemLong進行推理
當(dāng)MemLong接收到超過長度的輸入時,論文將其視為兩個部分:前綴p和主體。論文將分別描述在推理階段對長輸入的編碼和長輸出的生成。當(dāng)MemLong接收到長輸入時,它首先將前綴分成多個不重疊的塊,并從其內(nèi)存層計算,這確保了參與注意力的token數(shù)量等于塊大小,遠(yuǎn)小于輸入長度。需要注意的是,每個塊是相互關(guān)聯(lián)的(例如,第t個塊需要處理前t-1個塊的)。
第二步是根據(jù)塊級檢索表示選擇與主內(nèi)容最相關(guān)的k個塊,并獲取它們的關(guān)鍵和值表示。在此之后,對于上層檢索層,檢索的注意力窗口相當(dāng)于k * τ,這也小于輸入長度。最后,高效地執(zhí)行長度受限的因果注意和檢索注意。
2.4 實驗
論文在需要內(nèi)存中長上下文處理的各項任務(wù)上評估論文提出的MemLong模型:(a) 長上下文語言建模和檢索增強語言建模;(b) 能夠處理內(nèi)存中大量演示示例的可擴展上下文學(xué)習(xí)。
2.4.1 實現(xiàn)細(xì)節(jié)
訓(xùn)練細(xì)節(jié)。論文使用OpenLLaMA-3B作為預(yù)訓(xùn)練的骨干LLM,采用旋轉(zhuǎn)位置編碼(rotation position coding)。由于硬件限制,論文選擇使用LoRA技術(shù)訓(xùn)練論文的模型。骨干LLM具有L=26, H=32, d=100的架構(gòu)。除非另有說明,論文使用第13層作為內(nèi)存層,[14, 18, 22, 26]層作為檢索增強層。檢索增強適應(yīng)的訓(xùn)練僅在0.5B個token上迭代,序列長度為1024。Mem-Long的可訓(xùn)練參數(shù)來自14到26層。論文利用slimpajama數(shù)據(jù)集采樣作為論文的訓(xùn)練語料庫。
位置重映射。在生成過程中,M中檢索到的塊級別K-V有多個。由于每一步檢索的不確定性,論文需要將位置嵌入重映射到檢索到的塊。與之前的工作(Tworkowski et al., 2024)相同,局部上下文(最多2048個token)接收標(biāo)準(zhǔn)旋轉(zhuǎn)位置編碼,而內(nèi)存鍵則被編碼為在局部上下文窗口中具有位置0。
2.4.2 長上下文語言建模
論文首先在長上下文語言建?;鶞?zhǔn)上評估MemLong,以評估其基本的語言建模能力。由于K-V緩存提供了顯著的背景和上下文信息,MemLong能夠快速檢索相關(guān)的K-V緩存并充分利用它,從而在長上下文建模任務(wù)中增強模型的性能。
數(shù)據(jù)集。論文在四個廣泛的文本基準(zhǔn)數(shù)據(jù)集上對論文的模型進行了評估:英語書籍PG-19和BookCorpus,維基百科文章Wikitext-103,以及數(shù)學(xué)論文Proof-Pile。實驗結(jié)果表明,所有數(shù)據(jù)集上的困惑度都有顯著改善。論文的模型在從1024到32768個token的不同長度范圍內(nèi)進行了測試。通過利用外部檢索器和內(nèi)存,論文的模型在所有數(shù)據(jù)集上展示了顯著的性能提升,且內(nèi)存開銷最小。
設(shè)置。按照(Yen et al., 2024),論文計算每個序列最后2048個token的困惑度。此實驗設(shè)置旨在驗證不同檢索器大小對模型整體性能的影響。為了實現(xiàn)高效的細(xì)粒度檢索,論文使用faiss工具包在GPU上構(gòu)建精確搜索索引,以存儲文本塊的表示嵌入并執(zhí)行高效檢索。對于MemLong,論文將token分割并放入finetune-length = 1024的M中,用于進一步檢索。
基線。在論文的實驗中,論文采用OpenLLaMA-3B模型作為基線。為了確保公平比較,論文使用相同的LoRA配置并微調(diào)了模型在相同數(shù)量的slimpajama數(shù)據(jù)集上的數(shù)據(jù)。此外,論文比較了LongLLaMA-3B,該模型使用Focused Transformer(FoT)方法和5B token進行了微調(diào)。為了進行更全面的比較,論文還測試了兩個7B模型:LLaMA-2-7B和LongLoRA-7B-32K,以及兩個位置編碼模型:Yarn-7b-128k和Phi3-128k。
表1:不同上下文窗口擴展模型在PG19、Proof-pile、BookCorpus、Wikitext-103上的滑動窗口困惑度。所有實驗都在一塊3090 24GB GPU上進行。LongLLaMA-3B和MemLong-3Btoken為表示在沒有Memory的情況下評估,LongLLaMA-3Btoken表示在無限Memory的情況下評估。論文還評估了MemLong在4K/32K Memory場景下的表現(xiàn)。"- / 6.95"表示模型在單GPU上導(dǎo)致內(nèi)存不足(OOM)錯誤,而在雙GPU上則產(chǎn)生相應(yīng)結(jié)果。*
結(jié)果。結(jié)果如表1所示。論文采用困惑度(PPL)作為語言模型的評估指標(biāo)。較低的PPL表示更強的語言建模能力。與兩個完全微調(diào)的模型相比,OpenLLaMA-3B和LLaMA-2-7B,論文的模型在測試長度在其預(yù)訓(xùn)練限制內(nèi)(OpenLLaMA-3B為2048,LLaMA-2-7B為4096)時,在多個數(shù)據(jù)集上表現(xiàn)出相當(dāng)?shù)男阅堋?/p>
然而,一旦測試長度超過這些預(yù)訓(xùn)練限制,論文的模型即使在微調(diào)長度1024和預(yù)訓(xùn)練長度2048之后,仍能繼續(xù)降低困惑度,展示了其優(yōu)越的泛化能力。
相比之下,OpenLLaMA-3B和LLaMA-2-7B模型無法泛化到超出其預(yù)訓(xùn)練長度的輸入,并且由于注意力機制的二次復(fù)雜性,表現(xiàn)出顯著增加的內(nèi)存開銷。論文還與LongLoRA進行了比較。盡管LongLoRA中提出的Shifted Sparse Attention顯著減少了內(nèi)存使用,但它也削弱了模型在短文本上的性能。
相比之下,LongLLaMA由于其內(nèi)存使用無限增長,在測試長度變得過長時也會遇到OOM問題。位置編碼模型具有強大的泛化能力。然而,這種方法的性能只能保證長距離生成性能不下降。與這些方法相比,MemLong利用外部檢索器處理更長的輸入token,并實現(xiàn)了更好的困惑度改進。同時,由于高存儲效率,MemLong可以有效控制GPU的使用,避免OOM問題。
2.4.3 上下文學(xué)習(xí)
傳統(tǒng)的上下文學(xué)習(xí)(ICL)將少量的非參數(shù)化示例與查詢一同輸入模型。然而,這些方法通常受限于模型的輸入長度。在本實驗中,由于MemLong可以將示例以參數(shù)化形式存儲在其記憶中,論文主要研究MemLong是否能有效利用其記憶中存儲的知識以增強其突現(xiàn)能力。結(jié)果如表2所示。
與僅依賴非參數(shù)化知識的OpenLLaMA相比,在相同數(shù)量的上下文示例下,MemLong可以利用其記憶中存儲的額外示例。隨著記憶中示例數(shù)量的增加,性能進一步提升或保持穩(wěn)定。在與LongLLaMA的對比分析中,論文觀察到在相同的保留內(nèi)存示例條件下,論文的模型在大多數(shù)數(shù)據(jù)集上表現(xiàn)優(yōu)于LongLLaMA。值得注意的是,與LongLLaMA相比,論文的模型在訓(xùn)練參數(shù)(2億對比0.3億)和微調(diào)數(shù)據(jù)量(0.5億對比5億)方面顯著減少。這凸顯了論文模型在利用外部檢索器進行信息獲取方面的效率,展示了在資源大幅減少的情況下,能夠更有效地綜合和利用知識的能力。
2.5 Ablation Study
2.5.1 訓(xùn)練設(shè)置
在訓(xùn)練階段,論文探討了不同檢索層對模型的影響,并考察了論文的方法是否能充分解決MemTrm中討論的分布偏移問題。如前所述,論文的方法為分布偏移提供了一種低成本的解決方案。如圖4所示,棕色線(圖片頂部的線條;訓(xùn)練方法類似于MemTrm,微調(diào)模型的所有參數(shù),并且在內(nèi)存層之后的所有層都參與檢索)在性能和擬合速度方面明顯劣于論文所有其他方法(即使是設(shè)置最不合理的方法)。論文將在稍后分析推理階段的性能。
2.5.2 推理性能
圖 4:訓(xùn)練階段PPL的變化程度。y軸的指標(biāo)為PPL。論文主要關(guān)注訓(xùn)練參數(shù)和檢索層。
Q1:記憶長度是否影響模型的性能?如圖5所示,論文對同一模型在不同記憶大小下的性能進行了檢查,結(jié)果表明記憶容量與模型效率之間存在明顯的相關(guān)性。趨勢表明,記憶大小的增加會逐漸提升性能。此外,在記憶大小為65536時,模型能力經(jīng)歷了一個顯著的飛躍。這表明,雖然擴展記憶提供了實質(zhì)性的好處,但其有效性存在一個實際的上限,這可能受到數(shù)據(jù)分布細(xì)微差別的影響。
Q2:論文需要引入多少層額外的記憶信息?如圖4所示(粉色線)和表3(RPL+TH)中展示的,當(dāng)檢索層的數(shù)量設(shè)置為[13,17,21,25]時,模型表現(xiàn)最佳。根據(jù)經(jīng)驗認(rèn)為,如果在模型的所有上層都引入檢索信息,會導(dǎo)致模型對局部上下文的注意力下降。因此,以適當(dāng)?shù)拈g隔選擇檢索層實際上可以增強模型的能力。
表3:不同的檢索層會影響MemLong的性能。token為的MemLong表示在沒有Memory的情況下進行評估。所有使用Memory的方法的大小設(shè)置為32768。RA表示跨所有上層檢索;TA表示不凍結(jié)參數(shù)的訓(xùn)練;RP表示跨較少上層檢索,RPL表示跨更少上層檢索。
圖 5:在不同記憶大小下評估不同數(shù)據(jù)集。在每個子圖中,除記憶大小外,所有參數(shù)均相同。
