美團(tuán) Flash Communication:LLM 推理的 AllReduce 通信優(yōu)化
一、背景
前段時(shí)間的文章里我們剛剛介紹過(guò)兩個(gè)對(duì) LLM 分布式推理場(chǎng)景中 AllReduce 的優(yōu)化工作,一個(gè)是 NVIDIA TensorRT-LLM 中的 MultiShot 無(wú)損優(yōu)化,另一個(gè)是 Recogni 提出的基于量化壓縮實(shí)現(xiàn)的 AllReduce 加速方案。本文中我們繼續(xù)介紹美團(tuán)新發(fā)表的 AllReduce 量化壓縮優(yōu)化方案。
對(duì)應(yīng)的論文為:[2412.04964] Flash Communication: Reducing Tensor Parallelization Bottleneck for Fast Large Language Model Inference [1]
二、摘要
隨著 LLM 規(guī)模的不斷增長(zhǎng),快速推理所需的分布式解決方案往往要利用多維并行性,將計(jì)算負(fù)載分散至 GPU 集群的多個(gè)設(shè)備上。然而,此方法往往會(huì)引入顯著的通信開(kāi)銷,尤其在帶寬受限的設(shè)備上(比如沒(méi)有 NVLink 或者跨機(jī)的情況)。
本文中作者提出 Flash Communication,一種新穎的低比特壓縮技術(shù),旨在緩解推理過(guò)程中 Tensor Parallelism(TP)的通信瓶頸。作者在多種最新的 LLM 上進(jìn)行的廣泛實(shí)驗(yàn),驗(yàn)證了該方法的有效性。該方法可以將節(jié)點(diǎn)內(nèi)通信速度提升 3 倍,并將首 Token 時(shí)間縮短 2 倍,同時(shí)幾乎不犧牲模型精度。
PS:上述的結(jié)論其實(shí)有點(diǎn)夸大,INT4 可以實(shí)現(xiàn)上述速度,但精度損失還是有點(diǎn)大的;而 INT8 可以保持精度,但加速比又沒(méi)這么多。
三、引言
3.1 硬件拓?fù)?/h3>
作者論文中評(píng)估主要采用了兩種機(jī)型,一種是 8 x L40 GPU 節(jié)點(diǎn),如下圖 Figure 12 所示。每個(gè)節(jié)點(diǎn)上有 8 個(gè) L40 GPU,每 2 個(gè) GPU 在一個(gè) PCIe Switch 下,沒(méi)有 NVLink + NVSwitch,并且每個(gè)節(jié)點(diǎn)只有 1 個(gè) 100 Gbps 的 NIC。因此:
- 如果節(jié)點(diǎn)內(nèi)的 TP 通信都需要走 PCIe 鏈路,如果是不同 CPU Socket 下的 GPU 通信,還需要通過(guò) CPU 之間的 UPI,因此通信效率可能比較低。
- 如果節(jié)點(diǎn)間通信,則必須通過(guò)節(jié)點(diǎn)的 NIC,最糟糕的情況是左側(cè)紅框和右側(cè)紅框的 GPU 組成 TP 組進(jìn)行通信。
- PS:本文中作者并沒(méi)有涉及節(jié)點(diǎn)間通信,甚至 L40 上的 TP=8 的 8 GPU 通信都沒(méi)有。?
而 A100 節(jié)點(diǎn)類似下圖所示,節(jié)點(diǎn)內(nèi)有 8 個(gè) A100 GPU,這些 GPU 通過(guò) NVLink + NVSwitch 實(shí)現(xiàn)全互聯(lián),任何兩個(gè) GPU 之間的通信帶寬都可以達(dá)到 600 GB/s。不過(guò)作者介紹節(jié)點(diǎn)間通信帶寬是 200 Gbps,那么可能節(jié)點(diǎn)上就沒(méi)有紅框中的 NIC,只有藍(lán)框中的 200 Gbps NIC。(PS:論文中也不涉及節(jié)點(diǎn)間通信)
此外,A100 GPU 有 108 個(gè) SM,而 L40 GPU 有 108 個(gè) SM。
3.2 ReduceScatter + AllGather
我們?cè)谥暗奈恼轮性敿?xì)介紹過(guò) AllReduce,這里再簡(jiǎn)單陳述一下。對(duì)于常見(jiàn)的基于 Ring 的 AllReduce 實(shí)現(xiàn)中,通常將一個(gè) AllReduce 操作拆分為一個(gè) ReduceScatter 和一個(gè) AllGather 操作,如下圖所示:
具體的 ReduceScatter 操作如下,每個(gè)設(shè)備(GPU)發(fā)送一部分?jǐn)?shù)據(jù)給下一個(gè)設(shè)備,同時(shí)接收上一個(gè)設(shè)備的數(shù)據(jù)并累加。這個(gè)過(guò)程執(zhí)行 K-1 步,ReduceScatter 后每個(gè)設(shè)備都包含一部分?jǐn)?shù)據(jù)的 Sum:
具體的 AllGather 操作如下,每個(gè)設(shè)備(GPU)將其持有的部分結(jié)果發(fā)送給下一個(gè)設(shè)備,同時(shí)接收上一個(gè)設(shè)備的部分結(jié)果,逐步匯集完整的結(jié)果,同樣需要 K-1 步。AllGather 后,每個(gè)設(shè)備都包含全量的數(shù)據(jù):
NVIDIA 在 3x Faster AllReduce with NVSwitch and TensorRT-LLM MultiShot | NVIDIA Technical Blog [2] 中并沒(méi)有介紹 ReduceScatter 的優(yōu)化,不過(guò)在我們推測(cè)其可能采用了下述的優(yōu)化方式:具體來(lái)說(shuō),Ring ReduceScatter 可以等效為一個(gè) All2All 操作實(shí)現(xiàn)數(shù)據(jù)的重排,然后在 Local 進(jìn)行 Reduce 操作(或者 NVSwitch 上進(jìn)行 Reduce 操作)。此過(guò)程只有一個(gè) All2All 的整體通信操作,雖然實(shí)際上與 Ring 實(shí)現(xiàn)的方式的通信量和計(jì)算量沒(méi)有變化,但可以避免 K-1 個(gè) Ring Step 的同步,進(jìn)而可以有效降低時(shí)延。
3.3 TP 推理
如下圖 Figure 3 所示,對(duì)于 LLaMA 模型推理,其一個(gè) Transformer Layer 需要 2 次 AllReduce 通信,不過(guò)需要 Attention 以及 FFN 都采用先列切再行切的方式。以 80 層的 LLaMA 3 70B 模型為例,一次 Forward 需要 180 次 AllReduce 通信。
3.4 延遲分析
如下圖 Figure 2 所示,作者在 4*L40 GPU 上測(cè)量了 LLaMA-3-70B 模型在不同序列長(zhǎng)度下各個(gè)部分的開(kāi)銷(Batch Size 為 8),可以看出,序列越長(zhǎng),AllReduce 的通信占比越大,在 4K 序列長(zhǎng)度時(shí) AllReduce 通信開(kāi)銷為 18% 左右,在序列長(zhǎng)度達(dá)到 32K 時(shí),通信開(kāi)銷占到 40% 左右。
在 A100 GPU 上雖然有 NVLink+NVSwitch 互聯(lián),最大的通信開(kāi)銷依然可以達(dá)到 20%(PS:不過(guò)作者這里沒(méi)有提供詳細(xì)的數(shù)據(jù))。
四、方案
4.1 量化挑戰(zhàn)
為了在準(zhǔn)確性與時(shí)延之間達(dá)成最佳平衡,作者選擇采用低比特量化技術(shù)。如下圖 Figure 4 所示,可觀察到,在大 Block 下進(jìn)行逐 Token 量化會(huì)導(dǎo)致 C4 困惑度的性能急劇下降,非對(duì)稱量化(Asym)相對(duì)較好,不過(guò)依然下降明顯,因此細(xì)粒度量化是必要的。
然而,作者發(fā)現(xiàn)在此情境下應(yīng)用低比特激活量化并非易事。因此,作者計(jì)算了 LLaMA-3-8B 模型在激活量化前后的層級(jí)均方誤差(MSE)來(lái)研究量化的敏感性。如下圖 Figure 5 左圖所示,下投影 dproj 的量化難度遠(yuǎn)高于輸出投 影oproj。
此外,All-Reduce 中 Reduce-Scatter 和 All-Gather 操作對(duì)應(yīng)的量化難度也各不相同,如下圖 Figure 5 右圖所示。這一現(xiàn)象符合預(yù)期,因?yàn)?Reduce-Scatter 前的量化僅引入舍入誤差,而在 All-Gather 中,則同時(shí)包含舍入誤差和累積誤差。作為替代方案,可以在 All-Gather 操作前采用更高精度的量化以提升準(zhǔn)確性。
4.2 通信算法
鑒于上述問(wèn)題,作者設(shè)計(jì)了一種兩步量化策略以替代傳統(tǒng)的 Ring AllReduce 方法,稱為 Flash AllReduce,如下圖 Figure 6 所示。該策略與 TP 的結(jié)合如上圖 Figure 3 所示。
如下圖 Figure 6 展示了本文 Flash Communication 的通信原理:
- 首先,將每個(gè) GPU 上的激活值按 Rank 的數(shù)量進(jìn)行劃分。
- 在激活值上進(jìn)行細(xì)粒度量化后,執(zhí)行All2All 通信(與我們猜測(cè)的 TRT-LLM 的 MultiShot 實(shí)現(xiàn)類似),使得每個(gè)設(shè)備接收其規(guī)約所需的計(jì)算負(fù)載。當(dāng)然,接收后也需要反量化操作。
- 在設(shè)備內(nèi)完成 Reduce,不涉及通信操作。
- 對(duì)得到的結(jié)果再次進(jìn)行量化以加速傳輸。然后進(jìn)行AllGather以匯總所有結(jié)果,并在每個(gè)設(shè)備上進(jìn)行反量化以恢復(fù)浮點(diǎn)數(shù)值。?
具體的算法過(guò)程也可以參考如下圖 Algorithm 1:
4.3 Kernel 設(shè)計(jì)
為了提升效率,作者開(kāi)發(fā)了一個(gè)融合的 Flash AllReduce Kernel,以囊括上述所有集合通信操作及量化操作。如下圖 Table 1 所示,相比 Ring AllReduce 操作,F(xiàn)lash AllReduce 將量化-反量化步驟從 N 次減少到 2 次,Reduce-Gather 步驟從 N-1 次縮減到 1 次。盡管總體數(shù)據(jù)個(gè)數(shù)保持不變,但每一份數(shù)據(jù)均被量化到較低位數(shù),從而大幅減少了傳輸?shù)臄?shù)據(jù)大小。
快速細(xì)粒度量化:每個(gè)節(jié)點(diǎn)的總通信量(個(gè)數(shù)) M 被劃分為 T 個(gè) Chunk 進(jìn)行傳輸。給定 Chunk 大小 C,如下圖 Figure 7 展示了 GPU 線程如何并行組織以處理 Chunk 信息。一個(gè) Chunk 被分割成 N 個(gè) Block,每個(gè) Block 對(duì)應(yīng) 32 個(gè) Warp,其中每個(gè) Warp 由 32 個(gè) Thread 組成,每個(gè) Thread 可處理 8 個(gè) FP16 元素。以采用 128 組大小的非對(duì)稱量化為例,使用 16 個(gè)線程對(duì)每組 128 個(gè)元素進(jìn)行量化。具體而言,利用 CUDA API 函數(shù) __shfl_xor_sync 通過(guò)迭代交換這些 Warp Thread 間的信息,高效實(shí)現(xiàn) Max/Min 歸約。
快速通信。不再使用 All2All 原語(yǔ),而是利用 CUDA Runtime API 中的 GPU Peer Direct Memory 訪問(wèn)來(lái)傳輸量化后的數(shù)據(jù)量,在此過(guò)程中,能夠直接從不同 Rank 獲取數(shù)據(jù),顯著提升通信速度。
快速反量化。一旦接收到量化后的數(shù)據(jù),需要將其反量化為 FP16 以進(jìn)行 Reduce Sum。由于單純的 INT4 到 FP16 轉(zhuǎn)換會(huì)產(chǎn)生開(kāi)銷,作者采用了 [2211.10017] Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production [3] 中的反量化布局。為了在線協(xié)調(diào)其順序,作者還采用了來(lái)自 LMDeploy 的快速 INT4 打包,如下圖 Figure 8 所示:
- 給定兩個(gè) 32 位無(wú)符號(hào)整數(shù) U0 和 U1,它們分別持有 4 個(gè) INT4 量化的激活值(每個(gè)存儲(chǔ)在 8 位中的低 4 位)用于傳輸。
- 首先執(zhí)行右移 12 位操作,然后對(duì)其自身進(jìn)行按位或運(yùn)算。
- 隨后,使用 CUDA Math API __byte_perm 從這兩個(gè)整數(shù)中選擇目標(biāo)位。通過(guò)這種方式,可以方便的按順序打包 8 個(gè) 4 位整數(shù)進(jìn)行反量化。
- 接下來(lái),應(yīng)用 lop3.b32 對(duì)打包變量執(zhí)行邏輯操作(0xF0 & 0xCC)| 0xAA,應(yīng)用掩碼 0x000F000F 和 0x64006400,然后減去 0x64006400,這有效地表示了 FP16 中的 W1 和 W0。
- 通過(guò)改變剩余 INT4 整數(shù)的掩碼,可以迭代進(jìn)行反量化。?
INT6 量化:鑒于在 All-Gather 之前進(jìn)行低比特量化會(huì)導(dǎo)致更大的損失,這里作者選擇采用 INT8 位寬,同時(shí)保持 ReduceSum 的 INT4 位寬,從而有效構(gòu)建了一個(gè) INT6 解決方案。INT6 配置在性能與通信效率之間達(dá)到了很好的平衡。
五、實(shí)驗(yàn) & 結(jié)果
5.1 實(shí)驗(yàn)配置
實(shí)驗(yàn)在前述的 L40 和 A100 GPU 進(jìn)行,對(duì)應(yīng)的輸入 Token 為 1024,輸出為 64,基線為 FP16 通信。
5.2 精度對(duì)比
5.2.1 FP16 Weight 實(shí)驗(yàn)
如下圖所示,作者針對(duì) LLaMA-2 和 LLaMA-3 系列模型使用 FP16 Weight 進(jìn)行了評(píng)估,可以看出,大部分情況下 Asym INT8 的損失都很小,基本無(wú)損(紅框);Asym INT6(INT8 + IINT4)在 LLaMA-2 損失較小,在 LLaMA-3 損失稍微有點(diǎn)大;而 LLaMA-3 的 INT4 方案損失比較大,這也與 [2411.04330] Scaling Laws for Precision [4] 的結(jié)論相符,LLaMA-3 用了更多訓(xùn)練數(shù)據(jù),相應(yīng)也更難量化):
PS:需要說(shuō)明的是,上述中我們沒(méi)有使用論文表格是因?yàn)檎撐闹谐霈F(xiàn)了一些錯(cuò)誤(v2 已修正),上述表格中:
- AVG 列:作者論文中計(jì)算的均值,如下圖 Table 2 所示,此結(jié)果計(jì)算有誤。
- New_AVG 列:我們自己根據(jù)表格中相關(guān)數(shù)據(jù)計(jì)算的均值。
- INT8_Weight_AVG:來(lái)自下述 Table 3 中對(duì)應(yīng) INT8 Weight 推理的均值??梢钥闯?INT8 Weight 的均值也和我們計(jì)算的 FP16 Weight 的結(jié)果均值接近,符合預(yù)期。?
5.2.2 INT8 Weight 實(shí)驗(yàn)
如下圖 Table 3 所示,作者同樣針對(duì) LLaMA-2 和 LLaMA-3 系列模型使用 INT8 Weight 進(jìn)行了評(píng)估,和上述 FP16 Weight 結(jié)論基本類似:
5.3 Flash AllReduce vs Ring AllReduce
在集成了一系列優(yōu)化技術(shù)后,F(xiàn)lash AllReduce 的速度顯著由于 Ring AllReduce。如下圖 Figure 10 所示,作者展示了通信量在 64MB - 1GB 時(shí)的通信時(shí)延??梢钥闯?,其 INT4 版本最高可以實(shí)現(xiàn) 3.18x 的 Kernel 加速,而 INT6 在速度和精度之間取得了不錯(cuò)的平衡(PS:需要注意的是,實(shí)際推理過(guò)程中通信量可能沒(méi)有這么大)。
如下圖 Figure 11 所示,作者也展示了不同 SM 數(shù)量對(duì)通信效率的影響。在通信量較小時(shí),較少的 SM 數(shù)量更為有利,因?yàn)檫@可以減少 Kernel 啟動(dòng)和 Block 間同步的開(kāi)銷。然而,隨著通信量增大,計(jì)算需求增加,也很有必要使用更多 SM。配置 48 個(gè) SM 可以在通信與計(jì)算之間達(dá)到了更佳的平衡。
5.4 時(shí)延和吞吐
如下圖 Figure 9 所示,作者也基于 LLaMA-3-8B 和 LLaMA-3-70B 模型在 L40 和 A100 上測(cè)量了 TTFT 的時(shí)延,可以看出,在 L40 上 TP=4 最多可以獲得 2.06x 的加速(對(duì)應(yīng)的 INT4,INT8 只有 1.42x);而在 A100 上 TP=8 最多可以獲得 1.19x 加速(對(duì)應(yīng)的 INT4,INT8 只有 1.1x)
如下圖 Figure 13 所示,在 L40 上 TP=2 的加速會(huì)更小一些:
PS:此外,LLM Inference 在 Prefill 階段的 AllReduce 通信量比較大,而在 Decoding 階段的 AllReduce 通信量比較小,作者并沒(méi)有進(jìn)行相關(guān)對(duì)比實(shí)驗(yàn)。
六、參考鏈接
- ???https://arxiv.org/abs/2412.04964???
- ???https://developer.nvidia.com/blog/3x-faster-allreduce-with-nvswitch-and-tensorrt-llm-multishot/???
- ???https://arxiv.org/abs/2211.10017???
- ???https://arxiv.org/abs/2411.04330???
本文轉(zhuǎn)載自 ??AI閑談??,作者: AI閑談
