Meta探索大模型記憶層,擴(kuò)展至1280億個(gè)參數(shù),優(yōu)于MoE
預(yù)訓(xùn)練語言模型通常在其參數(shù)中編碼大量信息,并且隨著規(guī)模的增加,它們可以更準(zhǔn)確地回憶和使用這些信息。對于主要將信息編碼為線性矩陣變換權(quán)重的密集深度神經(jīng)網(wǎng)絡(luò)來說,參數(shù)大小的擴(kuò)展直接與計(jì)算和能量需求的增加相關(guān)。語言模型需要學(xué)習(xí)的一個(gè)重要信息子集是簡單關(guān)聯(lián)。雖然前饋網(wǎng)絡(luò)原則上(給定足夠的規(guī)模)可以學(xué)習(xí)任何函數(shù),但使用聯(lián)想記憶(associative memory)會更高效。
記憶層(memory layers)使用可訓(xùn)練的鍵值查找機(jī)制向模型添加額外的參數(shù),而不會增加 FLOP。從概念上講,稀疏激活的記憶層補(bǔ)充了計(jì)算量大的密集前饋層,提供了廉價(jià)地存儲和檢索信息的專用容量。
最近,Meta 的一項(xiàng)新研究使記憶層超越了概念驗(yàn)證,證明了它們在大型語言模型(LLM)擴(kuò)展中的實(shí)用性。
- 論文標(biāo)題:Memory Layers at Scale
- 論文地址:https://arxiv.org/pdf/2412.09764
- 項(xiàng)目地址:https://github.com/facebookresearch/memory
在下游任務(wù)中,通過改進(jìn)的記憶層增強(qiáng)的語言模型的性能優(yōu)于計(jì)算預(yù)算兩倍以上的密集模型,以及在計(jì)算和參數(shù)相當(dāng)?shù)膶<一旌希∕oE)模型。
這項(xiàng)工作表明,當(dāng)記憶層得到充分改進(jìn)和擴(kuò)展時(shí),可以用于增強(qiáng)密集神經(jīng)網(wǎng)絡(luò),從而帶來巨大的性能提升。通過用記憶層替換一個(gè)或多個(gè) transformer 層的前饋網(wǎng)絡(luò)(FFN)來實(shí)現(xiàn)這一點(diǎn)(保持其他層不變)。這些優(yōu)勢在各種基本模型大?。◤?1.34 億到 80 億參數(shù))和內(nèi)存容量(最多 1280 億參數(shù))中都是一致的。這意味著存儲容量實(shí)現(xiàn)了兩個(gè)數(shù)量級的飛躍。
記憶增強(qiáng)架構(gòu)
可訓(xùn)練的記憶層類似于注意力機(jī)制。給定一個(gè)查詢,一組鍵
,以及值
。輸出是值的軟組合,根據(jù) q 和相應(yīng)鍵之間的相似性進(jìn)行加權(quán)。
在使用時(shí),記憶層與注意力層之間存在兩個(gè)區(qū)別。
- 首先,記憶層中的鍵和值是可訓(xùn)練參數(shù),而不是激活參數(shù);
- 其次,記憶層在鍵和值的數(shù)量方面通常具有更大的規(guī)模,因此稀疏查詢和更新是必需的。
該研究將鍵-值對的數(shù)量擴(kuò)展到數(shù)百萬。在這種情況下,只有 top-k 最相似的鍵和相應(yīng)的值被輸出。一個(gè)簡單的記憶層可以用下面的等式來描述:
其中,I 是一組指標(biāo),,輸出
。
擴(kuò)展記憶層
擴(kuò)展記憶層時(shí)面臨的一個(gè)瓶頸是「查詢 - 鍵」檢索機(jī)制。簡單的最近鄰搜索需要比較每一對查詢 - 鍵,這對于大型記憶來說很快就變得不可行。雖然可以使用近似向量相似性技術(shù),但當(dāng)鍵正在不斷訓(xùn)練并需要重新索引時(shí),將它們整合起來是一個(gè)挑戰(zhàn)。相反,本文采用了可訓(xùn)練的「product-quantized」鍵。
并行記憶。記憶層是記憶密集型的,主要是由于可訓(xùn)練參數(shù)和相關(guān)優(yōu)化器狀態(tài)的數(shù)量龐大導(dǎo)致的。該研究在多個(gè) GPU 上并行化嵌入查找和聚合,記憶值在嵌入維度上進(jìn)行分片。在每個(gè)步驟中,索引都從進(jìn)程組中收集,每個(gè) worker 進(jìn)行查找,然后將嵌入的部分聚合到分片中。此后,每個(gè) worker 收集與其自身索引部分相對應(yīng)的部分嵌入。該過程如圖 2 所示。
共享記憶。深度網(wǎng)絡(luò)在不同層上以不同的抽象級別對信息進(jìn)行編碼。向多個(gè)層添加記憶可能有助于模型以更通用的方式使用其記憶。與以前的工作相比,該研究在所有記憶層中使用共享記憶參數(shù)池,從而保持參數(shù)數(shù)量相同并最大化參數(shù)共享。
該研究通過引入具有 silu 非線性的輸入相關(guān)門控來提高記憶層的訓(xùn)練性能。等式 (1) 中的輸出變?yōu)椋?/span>
其中 silu (x) = x sigmoid (x),⊙是元素的乘法(參見圖 3)。
實(shí)驗(yàn)及結(jié)果
首先,該研究固定記憶大小,并與密集基線以及參數(shù)大致匹配的 MOE 和 PEER 模型進(jìn)行比較。
從表 1 中我們可以看出,Memory 模型比密集基線模型有了大幅改進(jìn),在 QA 任務(wù)上的表現(xiàn)通常與密集參數(shù)數(shù)量為其兩倍的模型相當(dāng)。
Memory+ (有 3 個(gè)記憶層)比 Memory 有了進(jìn)一步的改進(jìn),其性能通常介于計(jì)算能力高出其 2 到 4 倍的密集模型之間。
對于相同數(shù)量的參數(shù),PEER 架構(gòu)的表現(xiàn)與 Memory 模型相似,但落后于 Memory+。MOE 模型的表現(xiàn)遠(yuǎn)不及 Memory 變體。
圖 4 顯示了不同大小的 Memory、MOE 和密集模型在 QA 任務(wù)上的擴(kuò)展性能。
圖 1 表明 Memory+ 模型的實(shí)際 QA 性能隨著記憶大小的增加而不斷的增加。
在 6400 萬個(gè)鍵(1280 億個(gè)記憶參數(shù))下,1.3B Memory 模型的性能接近 Llama2 7B 模型,后者使用了 10 倍以上的 FLOPs(見表 2)。
最后,本文在 8B 基礎(chǔ)模型和 4096^2 個(gè)記憶值的基礎(chǔ)上 (64B 記憶參數(shù))擴(kuò)展了 Memory+ 模型,表 2 報(bào)告了結(jié)果,發(fā)現(xiàn)記憶增強(qiáng)模型的表現(xiàn)明顯優(yōu)于密集基線。