Sample Packing:長(zhǎng)序列 LLM 訓(xùn)練的 Attention 問(wèn)題及優(yōu)化
一、背景
之前看過(guò)部分 Megatron-LM 的源碼,也詳細(xì)分析過(guò)對(duì)應(yīng)的 Dataset 和 DataLoader,想當(dāng)然的認(rèn)為在 LLM 預(yù)訓(xùn)練時(shí)會(huì)使用 Document Level 的 Mask,也就是常說(shuō)的 Sample Packing 技術(shù)。最近我們?cè)谧鲩L(zhǎng)序列訓(xùn)練相關(guān)工作時(shí)發(fā)現(xiàn)并非如此,并且出現(xiàn)了一些很奇怪的性能問(wèn)題,因此重新看了相關(guān)工作,并進(jìn)行了部分實(shí)驗(yàn)。
Sample Packing 中有很多可以討論的技術(shù)點(diǎn),比如 Attention 的實(shí)現(xiàn)和優(yōu)化,Sample 的組合及負(fù)載均衡問(wèn)題(有點(diǎn)類似調(diào)度問(wèn)題)以及不同方案對(duì)效果的影響等。我們這里只是先簡(jiǎn)單介紹一下相關(guān)問(wèn)題和實(shí)驗(yàn),后續(xù)會(huì)進(jìn)一步探索更多工作,比如 Document Level 的 Mask 到底對(duì)預(yù)訓(xùn)練效果影響有多大,對(duì) Attention 進(jìn)行優(yōu)化還能帶來(lái)多少提升,如何設(shè)計(jì)一個(gè)比較好的 Packing 策略等?
相關(guān)工作可以參考我們之前的文章:
- ??大規(guī)模分布式 AI 模型訓(xùn)練系列——序列并行??
- ??LLM 預(yù)訓(xùn)練語(yǔ)料、預(yù)處理和數(shù)據(jù)集索引、加載總結(jié)??
- ??LLaMA 3 技術(shù)報(bào)告解讀:全面梳理 LLM 相關(guān)技術(shù)棧??
二、Dataset + Dataloader
之前的文章(???LLM 預(yù)訓(xùn)練語(yǔ)料、預(yù)處理和數(shù)據(jù)集索引、加載總結(jié)??)中詳細(xì)介紹過(guò) Megatron-LM(DeepSpeed-Megatron)中預(yù)訓(xùn)練 Dataset 的存儲(chǔ)格式和 Dataloader 的加載、混合方式。簡(jiǎn)單說(shuō)來(lái),預(yù)訓(xùn)練通常包含很多不同的數(shù)據(jù)集,每個(gè)數(shù)據(jù)集又包含許多 Document。為了提升訓(xùn)練效率,在實(shí)際訓(xùn)練的時(shí)候一個(gè) Sample(Sequence)里面可能會(huì)包含多個(gè)不同的 Document(Sample Packing)。比如 8K 的預(yù)訓(xùn)練 Sequence Length,則一個(gè) Sample 可以包含 8 個(gè) 1K 的 Document。
如下圖所示,簡(jiǎn)單展示了 Megatron-LM 中如何 Packing 多個(gè) Document,實(shí)際上就是一個(gè)多級(jí)的索引。需要說(shuō)明的是,這里其實(shí)會(huì)引入很多隨機(jī)讀操作,會(huì)極大影響讀的性能。不過(guò)一般 LLM 計(jì)算代價(jià)都很高,這里也往往不會(huì)導(dǎo)致瓶頸。
三、Attention Mask
對(duì)于單個(gè) Document 而言,Decoder Only 的 GPT 模型具有 Causal 特性,也就是每個(gè) Token 不能看到之后的 Token,因此在實(shí)際訓(xùn)練中需要添加 Attention Mask。如下圖所示,這種情況下 Attention Mask 是一個(gè)標(biāo)準(zhǔn)的下三角矩陣(Causal Mask),也就是綠色部分為 1,其他部分為 0:
如果一個(gè) Sample 里包含多個(gè)樣本,則 Attention Mask 矩陣需要變成如下圖所示的塊對(duì)角矩陣形式(Block Diagonal Mask)。比如 Sequence Length 為 16,4 個(gè) Document 的長(zhǎng)度分別為 3,4,5,4,則對(duì)應(yīng) Attention Mask 矩陣如下圖所示,對(duì)角線上的 4 個(gè)矩陣(紅框)都是標(biāo)準(zhǔn)的下三角矩陣。按照這種方式可以保證和 4 個(gè) Document 單獨(dú)作為 Sample 訓(xùn)練是等價(jià)的:
四、Reset Attention Mask
4.1 是否需要
那么在實(shí)際使用中是否需要嚴(yán)格按照 Block Diagonal Mask 的方式使用呢?答案是否定的,比如 Megatron-LM 可以通過(guò) reset_attention_mask 來(lái)控制是使用 Block Diagonal Mask 還是標(biāo)準(zhǔn)的 Causal Mask,默認(rèn)值為 False。很多模型在預(yù)訓(xùn)練時(shí)也會(huì)采用默認(rèn)配置,即使用 Causal Mask。
在浪潮的 Yuan-1.0 報(bào)告(“源1.0”大模型技術(shù)白皮書(shū))中有提到,為了避免不同 Document 之間的相互干擾而將 reset_attention_mask 設(shè)置為 True,也就是 Block Diagonal Mask:
在 Meta 的 LLaMA 3.1 技術(shù)報(bào)告([2407.21783] The Llama 3 Herd of Models)中也提到,在 LLaMA 3.1 模型的預(yù)訓(xùn)練中會(huì)打開(kāi)這個(gè)配置。不過(guò)作者也做了說(shuō)明,對(duì)于 8K Sequence Length 的預(yù)訓(xùn)練而言,對(duì)模型最終的效果影響不大,對(duì)長(zhǎng)序列的 Continuous PreTraining 影響比較大:
在 [2402.08268] World Model on Million-Length Video And Language With Blockwise RingAttention 中作者提出了“世界模型”,為了提升超長(zhǎng)序列的訓(xùn)練效率,作者采用了 Sample Packing 的策略,并且做了相關(guān)消融實(shí)驗(yàn)。如下圖 Table 10 所示,采用 Naive Packing(不對(duì) Attention Mask 特殊處理)相比使用了 Block Diagonal Mask 的 LWM 的性能會(huì)差很多:
PS:當(dāng)然,目前還沒(méi)有更多有關(guān)預(yù)訓(xùn)練中是否 reset_attention_mask 的消融實(shí)驗(yàn),我們后續(xù)會(huì)進(jìn)行相關(guān)測(cè)試。此外,如果采用絕對(duì)位置編碼,Position-id 也需要相應(yīng)的調(diào)整,在 Megatron-LM 中對(duì)應(yīng) reset_position_id 選項(xiàng)。
4.2 性能問(wèn)題
如下圖為 Megatron-LM/megatron/core/datasets/gpt_dataset.py 中 reset_attention_mask 的實(shí)現(xiàn)方式,首先會(huì)將 attention_mask 初始化為標(biāo)準(zhǔn)的 Causal Mask 形式,然后從第二個(gè) Document 開(kāi)始,將之前的 mask 置為 0:
具體來(lái)說(shuō)如下圖所示,初始是一個(gè)標(biāo)準(zhǔn)的 Causal Mask 矩陣,然后會(huì)將 4x3、5x(3+4) 和 4x(3+4+5) 的區(qū)域依次置為 0,之后會(huì)變成 Block Diagonal Mask:
實(shí)際上我們已經(jīng)知道這里是標(biāo)準(zhǔn)的 Block Diagonal Mask,可以使用 torch.block_diag() 快速創(chuàng)建。實(shí)測(cè)當(dāng)序列比較長(zhǎng)時(shí)(比如 32K),兩種方式速度可能會(huì)差幾十倍,導(dǎo)致 reset_attention_mask 可能成為訓(xùn)練瓶頸:
除此之外,當(dāng)序列非常長(zhǎng)時(shí),Attention Mask 也會(huì)占據(jù)很大的存儲(chǔ)空間,為了計(jì)算效率,往往會(huì)使用整型而不是 Bool 類型。假設(shè)以 int8 存儲(chǔ),32K 序列長(zhǎng)度對(duì)應(yīng)的 Mask 大小為 32K * 32K = 1GB,128K 時(shí)更是高達(dá) 16GB。為了避免顯存浪費(fèi),其實(shí)不必將其拼成大的 Block Diagonal Mask,而保留幾個(gè)小的 Causal Mask 即可。
五、Attention 優(yōu)化
5.1 FlashAttention
當(dāng)前 LLM 預(yù)訓(xùn)練基本都會(huì)使用 FlashAttention,其對(duì) Casual Mask 的方式進(jìn)行了優(yōu)化,如下圖所示,假設(shè) 16x16 的 Attention Mask,在計(jì)算時(shí)按照 4x4 分塊,則可以將其分為 3 種情況:
- 有些塊對(duì)應(yīng)的 Mask 都是 0(紅框右上部分,比如藍(lán)框),無(wú)需再計(jì)算。
- 有些塊中部分 Mask 為 0,部分 Mask 為 1(紅框),需要相應(yīng)特殊處理。
- 有些塊對(duì)應(yīng)的 Mask 都是 1(紅框左下部分,比如黃框),全部計(jì)算即可。?
對(duì)于上述 Block Diagonal Mask,依然可以使用 Causal Mask 的方式計(jì)算,不過(guò)會(huì)導(dǎo)致大量的無(wú)效計(jì)算。幸運(yùn)的是,F(xiàn)lashAttention V2 支持可變序列長(zhǎng)度(Varlen)的 Batching Attention 計(jì)算,可以避免 Padding 導(dǎo)致的無(wú)效計(jì)算。因此也就可以借用這種機(jī)制來(lái)對(duì) Block Diagonal Mask 進(jìn)行解構(gòu),重新分解為多個(gè) Causal Mask 分別計(jì)算,可以避免很多無(wú)效計(jì)算。如下圖所示,可以將其看成 4 個(gè)獨(dú)立的 Attention 計(jì)算,具體可以參考 FlashAttention Github 上的相關(guān)討論:How to implement example packing with flash_attn v2? · Issue #654 · Dao-AILab/flash-attention · GitHub 和 Will attention_mask be extended to 3D? (concatenate short samples for efficient training) · Issue #432 · Dao-AILab/flash-attention · GitHub。
在 GLM-4([2406.12793] ChatGLM: A Family of Large Language Models from GLM-130B to GLM-4 All Tools)中也應(yīng)用了 Sample Packing 方案,并且同樣使用了 Block Diagonal Mask 機(jī)制來(lái)區(qū)分不同的 Document。并且作者也是基于 FlashAttention 的 Varlen 功能來(lái)實(shí)現(xiàn)。
5.2 Pytorch FlashAttention
Pytorch 的 scaled_dot_product_attention 提供了高效的 Attention 實(shí)現(xiàn),也集成了 FlashAttention2 的實(shí)現(xiàn),然而其不支持上述的可變序列長(zhǎng)度的功能,導(dǎo)致針對(duì) Block Diagonal Mask 場(chǎng)景時(shí)會(huì)存在大量的重復(fù)計(jì)算。
此外,我們?cè)谥暗奈恼轮幸捕啻翁岬剑?dāng)序列比較短時(shí),Attention 部分計(jì)算的占比并不是特別大,因此其中的冗余計(jì)算可能對(duì)整體訓(xùn)練速度影響不大;但當(dāng)序列比較長(zhǎng)時(shí),Attention 部分計(jì)算的占比會(huì)越來(lái)越大,冗余計(jì)算可能會(huì)對(duì)訓(xùn)練速度有比較大的影響,也就需要對(duì)其進(jìn)行優(yōu)化。
5.3 FlexAttention
Pytorch 在 2.5.0 版本引入了 FlexAttention(FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention),可以很容易支持各種 Attention Mask 變種,比如標(biāo)準(zhǔn) Causal Mask、Sliding Window + Causal、Prefix Mask 以及 Document Mask(Block Diagonal Mask)等,相比 FlashAttention 也更加的靈活。
我們基于 FlexAttention 進(jìn)行了相關(guān)測(cè)試,以驗(yàn)證使用 Block Diagonal Mask 的性能優(yōu)勢(shì)。首先以兩個(gè) 16K Document 拼接為一個(gè) 32K Sample 為例,Attention Mask 大概是如下圖所示方式,對(duì)應(yīng)的稀疏度為 74.80%(整個(gè) Mask 中 0 的占比):
如下圖所示我們?cè)?H100 GPU 上進(jìn)行的 Attention 相關(guān)性能測(cè)試。可以看出, Pytorch 的 Causal + FlashAttention2 方式確實(shí)可以達(dá)到非常高的 TFLOPS,明顯高于 FlexAttention。然而,因?yàn)?FlexAttention 中避免了很多無(wú)效計(jì)算,實(shí)際的 Forward 和 Backward 時(shí)間反而更短:
當(dāng)然,也并不意外著 FlexAttention 總是更優(yōu)的,還和 Sample 中 Document 長(zhǎng)度有關(guān)。如下圖所示為相應(yīng)測(cè)試結(jié)果,32K 表示 Sample 中只有一個(gè) Document,2K + 30K 表示 Sample 中有 2 個(gè) Document,一個(gè)長(zhǎng)度 2K,一個(gè)長(zhǎng)度 30K。從下圖基本上可以得出這樣一個(gè)結(jié)論:當(dāng) Sample 中最長(zhǎng)的 Document 的長(zhǎng)度 <= Sequence Length/2 時(shí),使用 FlexAttention 可能會(huì)帶來(lái)更大的收益:
那么為什么“最長(zhǎng)的 Document 的長(zhǎng)度 <= Sequence Length/2”時(shí)會(huì)有收益呢?其實(shí)可以簡(jiǎn)單從稀疏度的角度考慮:假設(shè) a1 + a2 + a3 + ... + an = S,并且 0 < a1 <= a2 <= a3 <= ... <= an <= S/2,那么可以用數(shù)學(xué)歸納法得出 (a1)^2 + (a2)^2 + (a3)^2 + ... + (an)^2 <= S^2/2。也就是說(shuō),最長(zhǎng)的 Document 的長(zhǎng)度 <= Sequence Length/2 時(shí),稀疏度會(huì) >= 75%(還要考慮 Causal 特性),相應(yīng)的 FlashAttention 中至少有一半的冗余計(jì)算。
因此,我們也需要充分考慮在長(zhǎng)文本訓(xùn)練過(guò)程中短文本的占比,極端情況下訓(xùn)練數(shù)據(jù)全部是超長(zhǎng)文本,每個(gè) Sample 中都只有一個(gè) Document,Block Diagonal Mask 會(huì)退化為 Causal Mask。不過(guò)有些時(shí)候?yàn)榱吮苊饽P统霈F(xiàn)災(zāi)難性遺忘,也會(huì)混合一些短文本數(shù)據(jù),或者高質(zhì)量的預(yù)訓(xùn)練數(shù)據(jù),不可避免的會(huì)出現(xiàn)冗余計(jì)算的問(wèn)題。
5.4 Sequence Parallel
我們?cè)谥暗男蛄胁⑿形恼拢???大規(guī)模分布式 AI 模型訓(xùn)練系列——序列并行??)中也提到過(guò),針對(duì)長(zhǎng)序列場(chǎng)景通常會(huì)采用 RingAttention 和 USP 等,然而不管是 RingAttention 還是其 LoadBalance 版本(如下圖 Figure 3 所示)等都沒(méi)有太多討論 Sample Packing 的情況。對(duì)于 Block Diagonal Mask 場(chǎng)景,其相應(yīng)的優(yōu)化,LoadBalance 策略也可能需要對(duì)應(yīng)調(diào)整:
在 [2402.08268] World Model on Million-Length Video And Language With Blockwise RingAttention 中作者(也是 RingAttention 的作者)聲稱針對(duì) Block Diagonal Mask 場(chǎng)景對(duì) RingAttention 進(jìn)行相關(guān)優(yōu)化,但并沒(méi)有對(duì)比優(yōu)化前后訓(xùn)練速度的提升。
PS:整體來(lái)說(shuō),在各種序列并行技術(shù)中更好的兼容 Block Diagonal Mask 場(chǎng)景又會(huì)有更多的挑戰(zhàn),我們留作后續(xù)介紹。
六、參考鏈接
- ??https://www.inspur.com/lcjtww/resource/cms/article/2526910/2726086/2022082918565451491.pdf??
- ??https://arxiv.org/abs/2407.21783??
- ??https://arxiv.org/abs/2402.08268??
- ??https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/datasets/gpt_dataset.py??
- ??https://github.com/Dao-AILab/flash-attention/issues/654??
- ??https://github.com/Dao-AILab/flash-attention/issues/432??
- ??https://arxiv.org/abs/2406.12793??
- ??https://pytorch.org/blog/flexattention/??
本文轉(zhuǎn)載自 ??AI閑談??,作者: AI閑談
