極長(zhǎng)序列、極快速度:面向新一代高效大語(yǔ)言模型的LASP序列并行
從國(guó)際頂流 GPT-4 128K、Claude 200K 到國(guó)內(nèi)「當(dāng)紅炸子雞」支持 200 萬(wàn)字上下文的 Kimi Chat,大語(yǔ)言模型(LLM)在長(zhǎng)上下文技術(shù)上不約而同地卷起來(lái)了。當(dāng)全世界最聰明的頭腦都在卷一件事的時(shí)候,這件事的重要性和難度就自然不言自明。
極長(zhǎng)的上下文可以極大拓展大模型的生產(chǎn)力價(jià)值。隨著 AI 的普及,用戶(hù)已經(jīng)不再滿(mǎn)足于調(diào)戲大模型幾個(gè)腦筋急轉(zhuǎn)彎,用戶(hù)開(kāi)始渴望利用大模型來(lái)真正提高生產(chǎn)力。畢竟從前花一周憋出來(lái)的 PPT,現(xiàn)在只需要喂給大模型一串提示詞和幾份參考文檔就分分鐘生成出來(lái),打工人誰(shuí)能不愛(ài)呢?
新型高效序列建模方法比如:Lightning Attention (TransNormerLLM), State Space Modeling (Mamba), Linear RNN (RWKV, HGRN, Griffin) 等最近成為炙手可熱的研究方向。研究人員渴望通過(guò)改造已經(jīng) 7 歲高齡的 Transformer 架構(gòu),獲得性能與之旗鼓相當(dāng),但復(fù)雜度僅為線(xiàn)性的新型架構(gòu)。這類(lèi)方法專(zhuān)注于模型架構(gòu)設(shè)計(jì),并提供了基于 CUDA 或 Triton 的硬件友好實(shí)現(xiàn),使其能夠像 FlashAttention 一樣在單卡 GPU 內(nèi)部高效計(jì)算。
與此同時(shí),另一個(gè)長(zhǎng)序列訓(xùn)練的殺手锏:序列并行獲得了越來(lái)越多的關(guān)注。通過(guò)把長(zhǎng)序列在序列維度切分為多個(gè)等分短序列,再將短序列分散至不同 GPU 卡并行訓(xùn)練,再輔以卡間通信便達(dá)到了序列并行訓(xùn)練的效果。從最早出現(xiàn)的 Colossal-AI 序列并行、到 Megatron 序列并行、再到 DeepSpeed Ulysses、以及近期的 Ring Attention,研究人員不斷設(shè)計(jì)更加優(yōu)雅高效的通信機(jī)制以提升序列并行的訓(xùn)練效率。當(dāng)然這些已知方法全部是為傳統(tǒng)注意力機(jī)制設(shè)計(jì)的,本文中我們稱(chēng)之為 Softmax Attention。這些方法也已經(jīng)有各路大神做了精彩分析,本文不過(guò)多探討。
那么問(wèn)題來(lái)了。如何讓新型高效序列建模方法實(shí)現(xiàn)序列并行,從而跨越單卡 GPU 顯存限制實(shí)現(xiàn)真正意義的無(wú)限序列長(zhǎng)度(當(dāng)然你得有無(wú)限 GPU)的高效大語(yǔ)言模型訓(xùn)練,成為了一個(gè)開(kāi)放的問(wèn)題。已經(jīng)成熟的序列并行方法如 DeepSpeed Ulysses, Megatron-SP 當(dāng)然可以應(yīng)用在線(xiàn)性序列建模方法上,但以 Softmax Attention 為設(shè)計(jì)藍(lán)本的它們注定天生不是最優(yōu)解。
- 論文標(biāo)題:Linear Attention Sequence Parallelism
- 論文地址:https://arxiv.org/abs/2404.02882
- LASP代碼地址:https://github.com/OpenNLPLab/LASP
本文即將介紹的 LASP 便應(yīng)運(yùn)而生。來(lái)自上海人工智能實(shí)驗(yàn)室的研究人員提出了 Linear Attention Sequence Parallelism (LASP) 方法以充分利用 Linear Attention 的線(xiàn)性右乘特性實(shí)現(xiàn)高效的序列并行計(jì)算。在 128 卡 A100 80G GPU、TransNormerLLM 1B 模型、FSDP backend 的配置下,LASP 可以最高將序列長(zhǎng)度擴(kuò)展至 4096K,即 4M。與成熟的序列并行方法相比,LASP 可訓(xùn)練的最長(zhǎng)序列長(zhǎng)度是 Megatron-SP 的 8 倍、DeepSpeed Ulysses 的 4 倍,速度則分別快了 136% 和 38%。
值得注意的是,雖然方法的名字包含 Linear Attention,LASP 并不局限于 Linear Attention 方法,而是可以廣泛應(yīng)用于包括 Lightning Attention (TransNormerLLM), State Space Modeling (Mamba), Linear RNN (RWKV, HGRN, Griffin) 等在內(nèi)的線(xiàn)性序列建模方法。
LASP 方法介紹
為了充分理解 LASP 的思路,讓我們先回顧下傳統(tǒng) Softmax Attention 的計(jì)算公式:O=softmax ((QK^T)⊙M) V,其 Q, K, V, M, O 分別為 Query, Key, Value, Mask 和 Output 矩陣,這里的 M 在單向任務(wù)(如 GPT)中是一個(gè)下三角的全 1 矩陣,在雙向任務(wù)(如 BERT)中則可以忽略,即雙向任務(wù)沒(méi)有 Mask 矩陣。我們下面將 LASP 拆為四點(diǎn)進(jìn)行解釋?zhuān)?/p>
Linear Attention 原理
Linear Attention 可以視為 Softmax Attention 一種變體。Linear Attention 去除了計(jì)算成本高昂的 Softmax 算子,Attention 的計(jì)算公式可以寫(xiě)為 O=((QK^T)⊙M) V 的簡(jiǎn)潔形式。但由于單向任務(wù)中 Mask 矩陣 M 的存在,使得該形式依然只能進(jìn)行左乘計(jì)算(即先計(jì)算 QK^T),從而不能獲得 O (N) 的線(xiàn)性復(fù)雜度。但對(duì)于雙向任務(wù),由于沒(méi)有 Mask 矩陣的存在,其計(jì)算公式可以進(jìn)一步簡(jiǎn)化為 O=(QK^T) V。Linear Attention 的巧妙之處在于,僅僅利用簡(jiǎn)單的矩陣乘法結(jié)合律,其計(jì)算公式就可以進(jìn)一步轉(zhuǎn)化為:O=Q (K^T V),這種計(jì)算形式被稱(chēng)之為右乘,可見(jiàn) Linear Attention 在這種雙向任務(wù)中可以達(dá)到誘人的 O (N) 復(fù)雜度!
LASP 數(shù)據(jù)分發(fā)
LASP 首先將長(zhǎng)序列數(shù)據(jù)從序列維度切分為多個(gè)等分的子序列,再將子序列分散發(fā)送至序列并行通信組內(nèi)的所有 GPU,使得每張 GPU 上各有一段子序列,以供后續(xù)序列并行的計(jì)算使用。
LASP 核心機(jī)制
隨著 decoder-only 的類(lèi) GPT 形式的模型逐漸成為 LLM 的事實(shí)標(biāo)準(zhǔn),LASP 的設(shè)計(jì)充分考慮了單向 Casual 任務(wù)的場(chǎng)景。由切分后子序列 Xi 計(jì)算而來(lái)的便是按照序列維度切分的 Qi, Ki, Vi,每一個(gè)索引 i 對(duì)應(yīng)一個(gè) Chunk 和一個(gè) Device(即一張 GPU)。由于 Mask 矩陣的存在,LASP 作者巧妙地將各個(gè) Chunk 對(duì)應(yīng)的 Qi, Ki, Vi 區(qū)分為兩種,即:Intra-Chunk 和 Inter-Chunk。其中 Intra-Chunk 為 Mask 矩陣分塊后對(duì)角線(xiàn)上的 Chunk,可以認(rèn)為仍然有 Mask 矩陣的存在,依然需要使用左乘;Inter-Chunk 則為 Mask 矩陣非對(duì)角線(xiàn)上的 Chunk,可以認(rèn)為沒(méi)有 Mask 矩陣的存在,可以使用右乘;顯然,當(dāng)切分的 Chunk 越多時(shí),對(duì)角線(xiàn)上的 Chunk 占比越少,非對(duì)角線(xiàn)上的 Chunk 占比越多,可以利用右乘實(shí)現(xiàn)線(xiàn)性復(fù)雜度 Attention 計(jì)算的 Chunk 就越多。其中,對(duì)于右乘的 Inter-Chunk 的計(jì)算,前向計(jì)算時(shí)每個(gè)設(shè)備需要使用點(diǎn)對(duì)點(diǎn)通信 Recive 上一個(gè)設(shè)備的 KV,并 Send 自己的更新后的 KV 給下一個(gè)設(shè)備。反向計(jì)算時(shí)則正好相反,只是 Send 和 Recive 的對(duì)象變?yōu)榱?KV 的梯度 dKV。其中前向計(jì)算過(guò)程如下圖所示:
LASP 代碼實(shí)現(xiàn)
為了提高 LASP 在 GPU 上的計(jì)算效率,作者對(duì) Intra-Chunk 和 Inter-Chunk 的計(jì)算分別進(jìn)行了 Kernel Fusion,并將 KV 和 dKV 的更新計(jì)算也融合到了 Intra-Chunk 和 Inter-Chunk 計(jì)算中。另外,為了在反向傳播過(guò)程中避免重新計(jì)算激活 KV,作者選擇在前向傳播計(jì)算后立即將其存儲(chǔ)在 GPU 的 HBM 中。在隨后的反向傳播過(guò)程中,LASP 直接訪(fǎng)問(wèn) KV 以供使用。需要注意的是,存儲(chǔ)在 HBM 中的 KV 大小為 d x d,完全不受序列長(zhǎng)度 N 的影響。當(dāng)輸入序列長(zhǎng)度 N 較大時(shí),KV 的內(nèi)存占用變得微不足道。在單張 GPU 內(nèi)部,作者實(shí)現(xiàn)了由 Triton 實(shí)現(xiàn)的 Lightning Attention 以減少 HBM 和 SRAM 之間的 IO 開(kāi)銷(xiāo),從而加速單卡 Linear Attention 計(jì)算。
想要了解更多細(xì)節(jié)的讀者,可以閱讀論文中的 Algorithm 2(LASP 前向過(guò)程)和 Algorithm 3(LASP 反向過(guò)程),以及文中詳細(xì)的推導(dǎo)過(guò)程。
通信量分析
LASP 算法中需要注意前向傳播需要在每個(gè) Linear Attention 模塊層進(jìn)行 KV 激活的通信。通信量為 Bd^2/h,其中 B 是 batch 大小,h 是頭數(shù)。相比之下,Megatron-SP 在每個(gè) Transformer 層中的兩個(gè) Layer Norm 層之后分別使用了一次 All-Gather 操作,并在 Attention 和 FFN 層之后分別使用了一次 Reduce-Scatter 操作,這導(dǎo)致其通信量為 2BNd + 4BNd/T,其中 T 為序列并行維度。DeepSpeed-Ulysses 使用了 All-to-All 集合通信操作來(lái)處理每個(gè) Attention 模塊層的輸入 Q, K, V 和輸出 O,導(dǎo)致通信量為 4BNd/T。三者的通信量對(duì)比如下表所示。其中 d/h 是頭維度,通常設(shè)置為 128。在實(shí)際應(yīng)用中,當(dāng) N/T>=32 時(shí),LASP 便能夠?qū)崿F(xiàn)最低的理論通信量。此外,LASP 的通信量不受序列長(zhǎng)度 N 或子序列長(zhǎng)度 C 的影響,這對(duì)于跨大型 GPU 集群的極長(zhǎng)序列并行計(jì)算是一個(gè)巨大的優(yōu)勢(shì)。
Data-Sequence 混合并行
數(shù)據(jù)并行(即 Batch-level 的數(shù)據(jù)切分)已經(jīng)是分布式訓(xùn)練的常規(guī)操作,在原始數(shù)據(jù)并行(PyTorch DDP)的基礎(chǔ)上,已經(jīng)進(jìn)化出了更節(jié)省顯存的切片式數(shù)據(jù)并行,從最初的 DeepSpeed ZeRO 系列到 PyTorch 官方支持的 FSDP,切片式數(shù)據(jù)并行已經(jīng)足夠成熟并被越來(lái)越多用戶(hù)使用。LASP 作為 Sequence-level 的數(shù)據(jù)切分方法,可以能夠和包括 PyTorch DDP, Zero-1/2/3, FSDP 在內(nèi)的各種數(shù)據(jù)并行方法兼容使用。這對(duì) LASP 的使用者來(lái)說(shuō)無(wú)疑是好消息。
精度實(shí)驗(yàn)
在 TransNormerLLM (TNL) 和 Linear Transformer 上的實(shí)驗(yàn)結(jié)果表明,LASP 作為一種系統(tǒng)優(yōu)化方法能夠和各種 DDP backends 結(jié)合,并均能達(dá)到與 Baseline 持平的性能。
可擴(kuò)展性實(shí)驗(yàn)
得益于高效的通信機(jī)制設(shè)計(jì),LASP 可以輕松擴(kuò)展至上百卡 GPU,并保持很好的可擴(kuò)展性。
速度對(duì)比實(shí)驗(yàn)
與成熟的序列并行方法 Megatron-SP 和 DeepSpeed-Ulysses 對(duì)比,LASP 可訓(xùn)練的最長(zhǎng)序列長(zhǎng)度是 Megatron-SP 的 8 倍、DeepSpeed-Ulysses 的 4 倍,速度則分別快了 136% 和 38%。
結(jié)語(yǔ)
為了方便大家試用,作者已經(jīng)提供了一個(gè)即裝即用的 LASP 代碼實(shí)現(xiàn),無(wú)需下載數(shù)據(jù)集和模型,只需 PyTorch 分分鐘體驗(yàn) LASP 的極長(zhǎng)極快序列并行能力。
代碼傳送門(mén):https://github.com/OpenNLPLab/LASP
本文轉(zhuǎn)自 機(jī)器之心 ,作者:機(jī)器之心
原文鏈接:??https://mp.weixin.qq.com/s/wPJsmgSAYgh3Si2eZ0_HzA??
