字節(jié) TileLink:編譯生成高效的計算和通信 Overlap Kernel
一、背景
筆者之前的文章(萬字綜述 LLM 訓(xùn)練中的 Overlap 優(yōu)化:字節(jié) Flux 等 7 種方案)中詳細(xì)介紹過各種計算與通信 Overlap 的方案,這里進(jìn)一步介紹字節(jié)最近發(fā)表的 TileLink,其中提到的大部分工作已經(jīng)包含在我們之前的綜述中,建議優(yōu)先閱讀,比如 CoCoNet、Centauri、Flux 等。
對應(yīng)的論文:[2503.20313] TileLink: Generating Efficient Compute-Communication Overlapping Kernels using Tile-Centric Primitives [1]
二、摘要
大規(guī)模深度學(xué)習(xí)模型通常需要分布式系統(tǒng)以實(shí)現(xiàn)高效的訓(xùn)練與推理,分布式模型執(zhí)行的基礎(chǔ)構(gòu)建模塊是層內(nèi)并行算子。提升層內(nèi)并行算子性能的最有效方法在于實(shí)現(xiàn)計算與通信的 Overlap。這種 Overlap 可通過算子分解(Operator Decomposition)或 Kernel 融合(Fusion)兩種方式達(dá)成:
- Operator Decomposition 雖易于實(shí)現(xiàn),但性能往往欠佳。
- 將通信 Kernel 與計算 Kernel 相融合則需深厚的專業(yè)知識且易出錯。
本文中,作者提出 TileLink,旨在高效編譯并生成計算-通信 Overlap 執(zhí)行的 Kernel。TileLink 由前端(Frontend)和后端(Backend)構(gòu)成:
- 在前端,系統(tǒng)通過以 Tile 為中心的原語將通信與計算的設(shè)計空間解耦并建立關(guān)聯(lián)。
- 在后端,將這些原語轉(zhuǎn)換為底層指令,整合通信與計算組件以實(shí)現(xiàn) Overlap 執(zhí)行。
實(shí)驗(yàn)表明,TileLink 相較于非 Overlap 基線實(shí)現(xiàn)了 1.17x 至 20.76x 的加速,并在 GPU 上達(dá)到了與當(dāng)前最優(yōu) Overlap 執(zhí)行庫相當(dāng)?shù)男阅芩健?/p>
三、引言
3.1 北大 Centauri
北大在 [ASPLOS 24.04] Centauri: Enabling Efficient Scheduling for Communication-Computation Overlap in Large Model Training via Communication Partitioning [2] 中介紹了 Centauri 框架,其構(gòu)建了一個由三個固有抽象維度組成的切分空間:原語替換、拓?fù)涓兄M切分及工作負(fù)載切分。這些維度共同構(gòu)成了一個全面的優(yōu)化空間,用于高效 Overlap。為確定通信與計算的高效 Overlap,作者將混合并行訓(xùn)練中的調(diào)度任務(wù)分解為 OP、Layer 和模型三個層次。
如下圖 Figure 3 所示,Centauri 的工作流程包含兩個核心環(huán)節(jié):通信切分與層次調(diào)度。以 DP 與 FSDP 混合并行訓(xùn)練為例:
- 通信切分:通過考量三個基本維度,生成潛在切分空間,并為每種集合通信選擇高效策略。
- 層次調(diào)度:在上述全面但較大的切分空間下,優(yōu)化整圖的 Overlap 調(diào)度成為一項(xiàng)復(fù)雜的任務(wù),為了簡化復(fù)雜的調(diào)度任務(wù),作者將復(fù)雜的混合并行集合通信分解為三個層次,每個集合通信被分配至特定調(diào)度層級。各層級選取開銷較低的切分與調(diào)度方案,旨在實(shí)現(xiàn)整體優(yōu)化 Overlap 方案。
有一系列類似 Centauri 的算子分解方法,其核心是:將通信和計算 Kernel 拆分為更小規(guī)模的同構(gòu) Kernel,隨后將其分配到多個通信-計算 Kernel 對中。這些拆分后的小 Kernel 可被調(diào)度到不同的 Stream 上,使得通信 Kernel 和計算 Kernel 能同時對切分的數(shù)據(jù)分片進(jìn)行操作。
然而,類似上述算子分解的方法有一些局限性:
- 分解后的 Kernel 間的同步機(jī)制需要 Host 端介入,會在運(yùn)行中引入不可忽略的開銷。
- L2 Cache 利用率降低、資源量化效率不足,導(dǎo)致分解后的 Kernel 性能可能出現(xiàn)惡化。
這里的資源量化效率不足(Resource Quantization Inefficient)是指計算資源切分不均衡等導(dǎo)致的浪費(fèi),如下圖 Stream-K([2301.03598] Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU [3])中提到的問題:
3.2 字節(jié) Flux
字節(jié)在 [2406.06858] FLUX: Fast Software-based Communication Overlap On GPUs Through Kernel Fusion [4] 中提出 Flux,旨在通過依賴計算隱藏 GPU 間的通信時延。Flux 將通信和計算操作分解為更細(xì)粒度的操作,并進(jìn)一步融合成更大的 Kernel,從而在不損害 Kernel 效率的前提下有效隱藏通信。在融合 Kernel 的情況下,F(xiàn)lux 有望重疊高達(dá) 96% 的通信時間。
如下圖 Figure 5 展示 Flux 中 ReduceScatter 里 Overlap 與其他方案的差異。現(xiàn)有 Overlap 方案 Tm 理論上可能比原始方法 Tc 執(zhí)行得更快,但通常情況下,Tm 仍慢于原始 GEMM 操作時間 Tg。主要原因在于,將一個 GEMM Kernel 拆分為一系列較小的 GEMM Kernel 會降低 GPU GEMM 的執(zhí)行效率。GEMM 通常需要合理大小的矩陣才能充分利用 GPU 的計算能力。這些具有數(shù)據(jù)依賴性的小型 GEMM 操作序列進(jìn)一步阻礙了 GEMM Kernel 通過 GPU 多路復(fù)用技術(shù)并行運(yùn)行,因此,Tensor 并行度越高,GPU 上的 GEMM 效率越低。
相比之下,作者提出的技術(shù)不存在上述限制。作者的 Overlap 方案 Tf 能夠在極小開銷下實(shí)現(xiàn)與原始 GEMM 操作 Tg 相當(dāng)?shù)男阅?。其?xì)粒度分解策略完美契合現(xiàn)代 GPU 設(shè)計特性,即通過上下文切換的 Warp 和數(shù)百個在 SM 間并發(fā)活躍的 Warp 來隱藏延遲,如下圖 Figure 5 底部所示。最終,作者的方法在不影響 GEMM 計算效率的前提下,僅在執(zhí)行末尾引入少量通信開銷。
然而,雖然這種方式實(shí)現(xiàn)的 Kernel 效率很高,但是開發(fā)成本同樣很高,尤其是針對不同場景、模型可能需要開發(fā)特定的 Kernel。DeepSeek 可以做深度的 DeepEP、DualPipe 等優(yōu)化的一個前提就是其模型、硬件相對恒定,可以一勞永逸。
四、方案
4.1 概覽
本文工作主要聚焦于層內(nèi)并行,為了說明 TileLink 的優(yōu)勢,作者以 MLP 的 Tensor Parallelism(TP) 為例,如下圖 Figure 1 所示,其實(shí)現(xiàn)包含 AllGather + GEMM(AG+GEMM)與 GEMM+ ReduceScatter(GEMM + RS),其配置與 LLaMA-7B 一致:
如下圖 Table 2 所示,采用不同技術(shù)方案的性能進(jìn)行對比,其中 Non-Overlap 為直接使用 cuBLAS 和 NCCL 的無 Overlap 方案;Decomposition 則為采用算子分解技術(shù)。可以看出,Decomposition 是性能最差的,F(xiàn)usion 方案在 AG + GEMM 中最優(yōu),TileLink 在 GEMM + RS 中最優(yōu),同時 AG + GEMM 與 FLUX 性能接近(約達(dá) 99%)。同時,F(xiàn)LUX 需要 2000 行 CUDA 代碼,而 TileLink 僅需 200 行 Python 代碼,編程效率提升 10x。
PS:之前的 CoCoNet([ASPLOS 22] [2105.05720] Breaking the Computation and Communication Abstraction Barrier in Distributed Machine Learning Workloads [5]) 和 Dist-Einsum(Overlap Communication with Dependent Computation via Decomposition in Large Deep Learning Models | OpenReview [6]) 也可以生成 Overlap Kernel,但是其只能生成特定 Overlap Pattern 的算子,不夠靈活。
4.2 前端原語(Frontend Primitives)
4.2.1 解耦設(shè)計空間
設(shè)計計算+通信融合 Kernel 存在兩種方式:一種是將兩部分優(yōu)化選擇緊密耦合;另一種是解耦計算 Kernel 和 通信 Kernel 設(shè)計。本文的 TileLink 選擇后者,因?yàn)槠浣鈽?gòu)設(shè)計空間能為 Kernel 設(shè)計提供更多靈活性,從而可能獲得更優(yōu)性能。(PS:是否也有可能喪失聯(lián)合設(shè)計的優(yōu)勢,比如更加均衡的資源分配?)
解耦設(shè)計空間分為 3 個子空間:
- 分塊尺寸(Tile Size):如下圖 Figure 2a 所示,通信組件每次傳輸 128x128 的 Tile;計算組件每次處理 128x256 的 Tile。Tile 的大小與使用的處理核心數(shù)相關(guān),比如通信組件占用更多核心時使用較小的 Tile 可以更充分的利用全部核心資源;反之,核心數(shù)較少時更大的 Tile 更加高效。
- 分塊順序(Tile Order) :如下圖 Figure 2b 所示,通信組件可以與計算組件采用不同的分塊順序。分塊順序的選擇也存在權(quán)衡:若計算組件等待多個 Rank 的數(shù)據(jù)分塊,則能在處理更大數(shù)據(jù)塊時獲得更好的 Cache 效率,但可能等待時間變長;反之,僅等待單個 Rank 的數(shù)據(jù)分塊,可提早開始計算,但整體計算效率可能降低。圖中例子為通信采用 Ring 順序而每次迭代等待 2 個 Rank 數(shù)據(jù)。
- 資源映射(Resource Mapping):如下圖 Figure 2c 所示,通信與計算組件可映射到不同單元或相同單元。比如,如果通信組件使用 Copy Engine(DMA),可以避免與計算組件的資源沖突,但需要承擔(dān) Host 帶來的額外開銷;但是如果采用計算核心執(zhí)行數(shù)據(jù)拷貝,則可以消除 Host 開銷,但可能引發(fā)資源沖突,這適用于計算組件無法充分利用所有處理核心的場景。
4.2.2 Tile 為中心的基礎(chǔ)原語
解耦通信與計算的設(shè)計空間也會引入同步的挑戰(zhàn)。由于這兩個組件采用不同的分塊尺寸、分塊順序和資源映射方案,實(shí)現(xiàn)二者同步需要進(jìn)行復(fù)雜的底層編程并插入通信指令。以 GPU 為例,要求使用諸如 ld.global.acquire 和 red.release 等特殊指令。然而,這類指令的編程模型和代碼生成編譯器的工作機(jī)制存在根本性差異,現(xiàn)有編譯器普遍缺乏對內(nèi)存一致性模型的原生支持。
- ld.global.acquire:獲取語義,確保之后的操作不會提前執(zhí)行,確保讀取的變量是最新的,防止 CPU 或其他 GPU 線程的舊值污染數(shù)據(jù)。
- red.release:釋放語義,確保之前的寫入對其他線程可見,確保數(shù)據(jù)在此操作之前全部寫入,防止寫入亂序執(zhí)行。
- 這兩個指令通常用于同步機(jī)制,特別是生產(chǎn)者-消費(fèi)者、互斥鎖、信號量等場景,以保證不同 GPU 線程間的正確通信。
為解決上述問題,TileLink 提供了一套以 Tile 為中心的基礎(chǔ)原語。這些原語引入了內(nèi)存一致性(Memory Consistency)語義,并遵循編譯器采用的 Tile 級抽象,與現(xiàn)有框架提供的以算子為中心的原語形成顯著區(qū)別。如下圖 Figure 3 所示,TileLink 原語分為信號原語(Siganl Primitive)和數(shù)據(jù)原語(Data Primitive)兩大類,每類均包含 Device-side 原語和 Host-Side 原語兩個子類。
涉及的所有原語如下表 Table 3 所示:
4.2.3 信號原語
信號原語:旨在管理通信和計算之間的屏障,包括:
- producer(peer)_tile_notify:生產(chǎn)者或 Peer 通知
- consumer(peer)_tile_wait:消費(fèi)者或 Peer 等待
- rank_notify(wait):Rank 通知和等待
在 Device-side:
- producer_tile_notify 和 consumer_tile_wait 適用于生產(chǎn)者-消費(fèi)者關(guān)系,例如 AllGather 與 GEMM 運(yùn)算中各 Tile 的交互;
- peer_tile_notify 和 peer_tile_wait 主要用于跨不同 Rank 的同一算子 Tile,使用戶能夠構(gòu)建多樣化的 Tile 執(zhí)行順序。
在 Host-side:
- rank_notify 和 rank_wait 用于管理 Copy Engine 和計算核心間的同步屏障。當(dāng)通信任務(wù)映射至 Copy Engine 時,這些原語可有效協(xié)調(diào)通信與計算間的 Tile 執(zhí)行順序。如上圖 Figure 3a 所示。
Notify 原語需通過 Mode Argument 或 Rank argument 明確待通知的遠(yuǎn)端 Rank 范圍。TileLink 為 Mode Agrument 提供兩種選項(xiàng):p2p 和 broadcast。
- p2p 僅通知單個目標(biāo) Rank,其數(shù)值由給定 Tile 標(biāo)識(tile_id)在全局張量視圖中的偏移量計算得出;
- broadcast 則向所有 Rank 發(fā)送通知信號。
內(nèi)存一致性:在并行執(zhí)行過程中,不同進(jìn)程/線程執(zhí)行的內(nèi)存操作可能以非一致順序?qū)ζ渌M(jìn)程/線程可見。內(nèi)存一致性模型通過設(shè)定約束條件,確保各進(jìn)程/線程觀測到的操作順序不存在歧義。信號原語提供了嚴(yán)格的內(nèi)存一致性語義:
- 通知類原語具有釋放語義(release semantics),保證所有在 producer(peer)_tile_notify 和 rank_notify 之前的內(nèi)存訪問操作不得被重排到這些通知原語之后;
- 等待類原語則具有獲取語義(acquire semantics),確保所有在 consumer(peer)_tile_wait 和 rank_wait 之后的內(nèi)存訪問操作不得被重排到這些等待原語之前。
這種嚴(yán)格的內(nèi)存一致性約束在后端編譯階段同樣需要予以考慮。
4.2.4 數(shù)據(jù)原語
數(shù)據(jù)原語促進(jìn)了數(shù)據(jù)傳輸過程,主要包括 tile_push(pull)_data 和 rank_copy_data 兩類原語。這些原語精確控制著傳輸數(shù)據(jù)的資源映射與 Tile 大小。
- Device-side 的 tile_push(pull)_data 原語將通信映射至處理核心。
- Host-side 的 rank_copy_data 原語則將通信映射至 Copy Engine。
數(shù)據(jù)傳輸存在拉取(pull)與推送(push)兩種模式,各自適配不同的同步機(jī)制:
- 在 pull 模式下,生產(chǎn)者從所有其他 Rank 讀取數(shù)據(jù),并通過本地屏障通知其消費(fèi)者;
- 與之相反,push 模式允許生產(chǎn)者將本地數(shù)據(jù)寫入所有其他 Rank,同時向遠(yuǎn)端消費(fèi)者發(fā)送數(shù)據(jù)到達(dá)通知。
如上圖 Figure 3b 清晰展示了兩種模式的差異。模式選擇可能影響性能表現(xiàn),具體取決于數(shù)據(jù)形態(tài)、分塊策略及可用硬件資源等要素。值得注意的是,rank_copy_data 原語通過 P2P 復(fù)制技術(shù)支持雙模式運(yùn)行,其數(shù)據(jù)傳輸方向由源指針與目標(biāo)指針的排列順序顯式指定。
4.3 后端映射(Backend Mapping)
TileLink 后端負(fù)責(zé)將通信與計算組件共同編譯為底層設(shè)備代碼。為實(shí)現(xiàn)分布式系統(tǒng)的代碼生成,TileLink 采用了一種以計算單元為核心的映射技術(shù),該技術(shù)能夠?qū)⑼ㄐ拍K與計算模塊進(jìn)行關(guān)聯(lián)整合。
TileLink 采用以 Tile 為中心的映射方法,將前端原語編譯為底層代碼。以 Tile 為中心的映射包含三個組成部分:
- 形狀映射(shape mapping):將每個 tile_id 與特定的 Tensor Shape Tile 相關(guān)聯(lián)。
- Rank 映射(rank mapping):將每個 tile_id 與 Device Rank 相關(guān)聯(lián)。
- 通道映射(channel mapping):為每個tile_id 分配通信屏障(communication barrier)。
作者分別用 fS、fR、fC 表示這三種映射。根據(jù)工作負(fù)載類型的不同,應(yīng)采用不同的映射函數(shù)。作者將不同映射劃分為兩類:
靜態(tài)映射(Static Mapping):指可在編譯時確定的映射關(guān)系,通常用于數(shù)據(jù)分片策略固定的場景,例如 Tensor-Parallel MLP 和 Sequence-Parallel Self-Attention。作者采用仿射運(yùn)算(Affine Operation)處理靜態(tài)映射(此時 fS、fR、fC 均為仿射函數(shù))。以包含 R 個設(shè)備(每 Rank 對應(yīng) C 個通道/屏障)的系統(tǒng)上執(zhí)行 AllGather(pull 模式)+ GEMM(問題規(guī)模 M×N×K)為例:生產(chǎn)者 AllGather 操作的 Tile 尺寸為 Tmp × Tnp,輸入 Tensor 沿 M 維分片。給定生產(chǎn)者 Tile 的 tile_idp,其形狀范圍、源 Rank 及通道可通過以下公式計算。類似地,可以計算出從消費(fèi)者 tile_idc 到形狀范圍、 Rank 和通道的映射關(guān)系:
動態(tài)映射(Dynamic Mapping):是指在運(yùn)行時計算的映射關(guān)系,這對于具有動態(tài)數(shù)據(jù)分片需求的工作負(fù)載至關(guān)重要。例如,在 MoE 數(shù)據(jù)分片策略中,動態(tài)路由決定了數(shù)據(jù)分布,每個 tile 可能需要來自其他任意 Rank 的 Token。在編譯時無法確定需要從哪些 Rank 收集數(shù)據(jù)或在哪個通道等待屏障同步。因此,必須在運(yùn)行時計算這些映射關(guān)系。為支持動態(tài)映射,TileLink 將這些映射轉(zhuǎn)換為查找表,其值可在運(yùn)行時填充,而對這些查找表的訪問操作則在編譯時確定。從形式化角度來看,動態(tài)映射如下所示(其中 fS_low,fS_high,fR 和 fC 是查找表,其值在運(yùn)行中動態(tài)調(diào)整):
內(nèi)存一致性編譯:在后端編譯過程中,前端具有內(nèi)存一致性語義的原語被編譯為相應(yīng)的設(shè)備指令(如 ld.global.acquire 和 red.release)。然而,直接翻譯這些原語并不足以確保內(nèi)存一致性。對于大多數(shù)計算 Kernel,采用多級流水線技術(shù)來提升負(fù)載-計算平衡并優(yōu)化整體性能。將原始程序編譯為多級流水線版本需要進(jìn)行算子重排,在此過程中某些內(nèi)存訪問操作可能會意外地被重排至 TileLink 原語之前或之后。為解決這一問題,TileLink 在其原語與后續(xù) load/store 操作之間建立了嚴(yán)格的數(shù)據(jù)依賴關(guān)系,從而確保其原語能夠通過流水線處理階段被正確重排序和展開。
其他編譯優(yōu)化:除上述技術(shù)外,TileLink 還采用單設(shè)備優(yōu)化策略以實(shí)現(xiàn)高性能,該策略在已有研究中得到充分論證。優(yōu)化主要體現(xiàn)在內(nèi)存優(yōu)化與流水線優(yōu)化兩方面:
- 內(nèi)存優(yōu)化通過自動分配片上寄存器緩存和計算用共享存儲緩沖區(qū),對全局緩沖區(qū)的數(shù)據(jù)訪問進(jìn)行合并操作,并重構(gòu)共享存儲器訪問模式以避免存儲體沖突;
- 流水線優(yōu)化則通過重組數(shù)據(jù) load/store 操作與計算任務(wù),構(gòu)建多級流水線架構(gòu)。其中,本地數(shù)據(jù)拷貝被映射至專用異步引擎(如 GPU 的 TMA),而計算任務(wù)則被分配至高性能運(yùn)算單元(如 GPU 的 Tensor Core)。
4.4 Kernel 設(shè)計
為展示 TileLink 的靈活性與普適性,作者闡釋了如何為 GEMM + Ring ReduceScatter、AllGather + MoE 以及 AllGather KV + Self Attention 機(jī)制設(shè)計 Overlap 計算 Kernel。這三個案例具有代表性:它們分別采用了不同的分片順序(Ring 和 All2All)、不同的映射策略(靜態(tài)與動態(tài))以及不同的硬件資源(Device-side 和 Host-side)。
如下圖 Figure 4 展示了 GEMM + Ring ReduceScatter Kernel 的偽代碼實(shí)現(xiàn),該案例采用靜態(tài)映射策略,演示了生產(chǎn)者-消費(fèi)者和 P2P 雙向通信的編程范式。
- 其中計算與通信均采用 SM,分配了 20 個 SM 專用于通信(見第 1 行)。
- 生產(chǎn)者 GEMM 將部分計算結(jié)果存儲于本地 Tensor,并通過 producer_tile_notify 通知消費(fèi)者(第 9 行)。
- 消費(fèi)者 ReduceScatter 通過 consumer_tile_wait(第 16 行)等待生產(chǎn)者就緒。
- 一旦數(shù)據(jù)可用即執(zhí)行 local reduce 操作(第 20 行),并將部分結(jié)果通過 tile_push_data 傳遞給前序節(jié)點(diǎn)(第 24 行)。
- 節(jié)點(diǎn)間的信號控制通過 peer_tile_wait(第 19 行)和 peer_tile_notify(第 26 行)原語實(shí)現(xiàn)。
如下圖 Figure 5 展示了 AllGather + MoE 的偽代碼實(shí)現(xiàn)。
- 同樣采用 20 個 SM 處理通信任務(wù)(第 1 行)。值得注意的是,MoE 需要基于動態(tài)路由(輸入中的 topk_ids)為每個 token 選擇專家,必須采用動態(tài)映射。因此使用 table 數(shù)據(jù)結(jié)構(gòu)存儲形狀映射、Rank 映射及通道映射的查找表。所有相關(guān)原語均需以 table 為參數(shù),以確保 TileLink 能基于動態(tài)映射生成正確代碼。
- 此外,load 原語需要借助 table 中的形狀映射來收集當(dāng)前分片所需的正確 token(第 11 行)及其對應(yīng)的 topk_ids(第 12 行)。
如下圖 Figure 6 展示了 AllGather KV + Self Attention(序列并行)的偽代碼。本案例中通信操作通過 Copy Engine 實(shí)現(xiàn),采用 Host 原語來觸發(fā) Copy Engine。通信與計算分別在兩個獨(dú)立的流上執(zhí)行:
- 通信部分通過 rank_copy_data 原語完成,其分塊尺寸為 KV Cache 序列長度(S)除以總 Rank 數(shù)(WORLD_SIZE)。
- 計算部分則采用不同的分塊尺寸。通過基于分塊的 Kernel 映射機(jī)制,確保通信與計算環(huán)節(jié)間的屏障操作正確執(zhí)行。
4.5 實(shí)現(xiàn)
TileLink 基于 Triton,使用 Python 語言實(shí)現(xiàn)。作者在 Python 層面實(shí)現(xiàn)了以計算塊為中心的原語操作,從而擴(kuò)展了 Triton 的語言特性,而面向計算塊的映射機(jī)制則通過 Python 抽象語法樹(AST)轉(zhuǎn)換實(shí)現(xiàn)。其實(shí)現(xiàn)方案可輕松適配至 TVM、MLIR 等其他編譯器框架。
如下圖 Figure 7 所示,編譯器輸入為融合 TileLink 原語與 Triton 原生原語的純 Python 程序。通過特殊參數(shù) BlockChannel 為計算和通信提供以計算塊為核心的映射上下文,BlockChannel 封裝了分布式映射元數(shù)據(jù),包括當(dāng)前進(jìn)程 Rank、總 Rank 數(shù)、同步屏障配置及生產(chǎn)者/消費(fèi)者計算塊關(guān)系等。
- Python 程序經(jīng)解析生成 AST 后轉(zhuǎn)換為 Triton 中間表示(IR),在此過程中 BlockChannel 參數(shù)被分解,利用其內(nèi)嵌元數(shù)據(jù)構(gòu)建面向計算塊的映射關(guān)系,TileLink 原語則轉(zhuǎn)換為 Triton 的 ElementwiseInlineAsmOp 操作。
- 隨后 Triton IR 被進(jìn)一步降級為 Triton GPU IR 和 TileLink 新增的 Distributed IR,后者用于將通過 ElementwiseInlineAsmOp 表達(dá)的特殊指令轉(zhuǎn)換為 LLVM IR,最終編譯為適用于 NVIDIA GPU 的 PTX 代碼。
- 通過將 LLVM IR 轉(zhuǎn)換為目標(biāo)架構(gòu)特定的底層匯編,可支持更多后端硬件。
- 運(yùn)行時:
- 采用 NVSHMEM 初始化分布式執(zhí)行環(huán)境并分配共享內(nèi)存。
- 生成的代碼在所有進(jìn)程上啟動以執(zhí)行并發(fā)計算與通信。
- 運(yùn)行結(jié)束后正確釋放共享內(nèi)存空間。
五、評估
如下圖 Figure 8 所示,作者在 8xH00 集群上測試:
- 對 AG+GEMM 場景,Async-TP PyTorch 由于分解后的 GEMM 運(yùn)算規(guī)模過小無法充分占用設(shè)備資源,未能實(shí)現(xiàn)加速效果。FLUX 憑借高度優(yōu)化的實(shí)現(xiàn)取得了最高加速比(相較于 cuBLAS + NCCL 達(dá)1.34x)。TileLink 同樣實(shí)現(xiàn)了優(yōu)于 cuBLAS + NCCL 的加速效果(1.27x),達(dá)到 FLUX 性能的 94.5%。
- 對于 GEMM + ReduceScatter 場景,TileLink 展現(xiàn)出最佳性能:較 cuBLAS + NCCL 提升 1.25x,較 Async-TP PyTorch 提升 2.22x,較 FLUX 提升 1.28x。
如下圖 Figure 9 所示,MoE 層相較于 MLP 層復(fù)雜度顯著提升,在編譯階段需進(jìn)行動態(tài)映射。該層可分解為兩個核心部分:AG + Gather + Group GEMM 與 Group GEMM + Scatter + Topk Reduce + RS。這兩類算子可融合為 Group GEMM Kernel,vLLM 已實(shí)現(xiàn)此類融合運(yùn)算。
- 在第一部分:TileLink 憑借通信-計算 Overlap 優(yōu)化,在 vLLM 基礎(chǔ)上進(jìn)一步實(shí)現(xiàn) 1.51x 平均加速。
- 在第二部分:TileLink 相較 vLLM 獲得 1.31x 平均加速,較 CUTLASS + NCCL 組合提升 10.56x。
- 需特別指出,F(xiàn)LUX、Async-TP PyTorch 等現(xiàn)有庫均不支持 MoE 層 Overlap 執(zhí)行,而 TileLink 憑借靈活的原語體系與動態(tài)映射機(jī)制實(shí)現(xiàn)了該功能支持。
如下圖 Figure 10 所示,作者針對 16K 到 128K 序列長度的 Self Attention 機(jī)制進(jìn)行了評估。實(shí)驗(yàn)表明,在所有序列長度條件下,TileLink 方案相較 PyTorch 非 Overlap 實(shí)現(xiàn)(Torch)與RingAttention(RingAttn)均展現(xiàn)出穩(wěn)定的加速優(yōu)勢。經(jīng)量化分析,TileLink 平均可獲得 5.04x 于Torch、1.97x 于 RingAttn 的性能提升。
作者將 TileLink 集成至 PyTorch 框架,并在 H800 集群上對 8 種不同 LLM 進(jìn)行端到端性能評估。
- 首先在單節(jié)點(diǎn)(8×H800 GPU)環(huán)境下進(jìn)行測試,結(jié)果如下圖 Figure 11 左半部分所示。前五種為 Dense 模型,后三種為 MoE。其中 Qwen1.5 采用 MoE 共享專家機(jī)制,通過將 MLP 層與 MoE 層合并來實(shí)現(xiàn)共享專家支持。實(shí)驗(yàn)設(shè)置 Batch Size 為 4、序列長度 8192。結(jié)果表明, TileLink 相較 PyTorch 實(shí)現(xiàn)平均 1.32x 加速。Dense 模型平均加速比為 1.20x,與單層 MLP 加速效果一致——盡管 Self Attention 獲得顯著加速,但端到端性能仍由 MLP 層主導(dǎo)。MoE 模型平均加速比為 1.54x,低于單層 MoE 加速效果,因其 MLP 層與 MoE 層各占約 50% 執(zhí)行時間,最終加速比介于二者之間。
- 在多節(jié)點(diǎn)部署評估中,鑒于節(jié)點(diǎn)間帶寬限制,采用節(jié)點(diǎn)內(nèi) TP 與節(jié)點(diǎn)間 DP 的混合策略。雙節(jié)點(diǎn)(各 8×H800 GPU)測試結(jié)果與單節(jié)點(diǎn)基本一致(Batch 規(guī)模倍增),整體加速比為 1.29x,因節(jié)點(diǎn)間通信開銷略有下降。
六、參考鏈接
- ??https://arxiv.org/abs/2503.20313??
- ??https://dl.acm.org/doi/10.1145/3620666.3651379??
- ??https://arxiv.org/abs/2301.03598??
- ??https://arxiv.org/abs/2406.06858??
- ??https://arxiv.org/abs/2105.05720??
- ??https://openreview.net/forum?id=MIJtDiMUX9??
本文轉(zhuǎn)載自???AI閑談???,作者:AI閑談
