Sample Packing 綜述:LLM 效果與效率的 Tradeoff
本文中我們通過(guò)幾篇論文來(lái)具體介紹 Sample Packing 相關(guān)的方案和對(duì)應(yīng)的各種問(wèn)題,比如 GraphCore 的PackedBert、Meta 的 In-Context-Pretraining、智譜 AI 的 LongAlign、Amazon 的 Fewer Truncations 以及 IBM 的 Packing with FlashAttention。
一、背景
上一篇文章(???Sample Packing:長(zhǎng)序列 LLM 訓(xùn)練的 Attention 問(wèn)題及優(yōu)化??)中我們簡(jiǎn)單介紹了 Sample Packing 相關(guān)的問(wèn)題和部分簡(jiǎn)單實(shí)驗(yàn)。本文中我們通過(guò)幾篇論文來(lái)具體介紹 Sample Packing 相關(guān)的方案和對(duì)應(yīng)的各種問(wèn)題,比如 GraphCore 的PackedBert、Meta 的 In-Context-Pretraining、智譜 AI 的 LongAlign、Amazon 的 Fewer Truncations 以及 IBM 的 Packing with FlashAttention。
二、方法
Sample Packing 可以看成是一個(gè)經(jīng)典的 Bin Packing 組合優(yōu)化(Combinatorial Optimization)問(wèn)題,核心思想是將一組不同體積的物品放入容量固定的箱子中,目標(biāo)是最小化所需箱子數(shù)量。
Bin Packing 是一個(gè)典型的 NP-Hard 問(wèn)題,因此往往需要關(guān)注其效率。對(duì)于 N 個(gè) Sample 的數(shù)據(jù)集,通常需要先排序,對(duì)應(yīng) O(N*logN) 的時(shí)間復(fù)雜度;然后 Packing,需要 O(N*logN) 的時(shí)間復(fù)雜度。不過(guò)由于 Sample 的長(zhǎng)度通常比較有限,可以采用計(jì)數(shù)排序的方式進(jìn)行優(yōu)化;此外,整個(gè)訓(xùn)練集通常有多個(gè)數(shù)據(jù)子集組成,為了保證每個(gè)數(shù)據(jù)子集具有不同的采樣權(quán)重,往往一個(gè) Sequence 中的 Sample 來(lái)自同一個(gè)數(shù)據(jù)子集,一個(gè)大的 Bin Packing 問(wèn)題也變成了多個(gè)小的 Bin Packing 問(wèn)題,可以進(jìn)一步降低復(fù)雜度。本文中我們就不再具體介紹 Bin Packing 的優(yōu)化問(wèn)題。
如果從物品順序角度考慮,Bin Packing 可以分為兩類:
- Online Bin Packing:物品按順序到達(dá),必須立即決定放入哪個(gè)箱子,無(wú)法預(yù)知后續(xù)物品的大小。對(duì)于 LLM 訓(xùn)練而言,Online 方式依然可以實(shí)現(xiàn) Batch 內(nèi)(窗口內(nèi))的亂序,也可以通過(guò)梯度累加增加 Batch 的大小。
- Offline Bin Packing:預(yù)先知道所有物品大小,可以全局排序。對(duì)于 LLM 訓(xùn)練而言,相當(dāng)于訓(xùn)練前就預(yù)先知道所有序列的長(zhǎng)度,并對(duì)所有 Sample 打包。
當(dāng)然,Bin Packing 方案也可能有不同的約束條件:
- 一次性 Packing:每個(gè)物品只能放進(jìn)一個(gè)箱子,不允許拆分物品。對(duì)于 LLM 而言,可以理解為 Sample 不能截?cái)唷?/li>
- 多維 Packing:物品和箱子不僅有體積,還有其他屬性,比如重量、形狀,需要同時(shí)滿足多維約束。對(duì)于 LLM,可以約束同一個(gè) Batch 內(nèi)不同 Sequence 的計(jì)算量盡量類似,以實(shí)現(xiàn)更好的負(fù)載均衡。
整體來(lái)說(shuō),對(duì)于 LLM 訓(xùn)練可以從數(shù)據(jù)分布和計(jì)算效率的角度考慮相關(guān)的方案,數(shù)據(jù)分布可能影響模型訓(xùn)練的效果,計(jì)算效率會(huì)影響訓(xùn)練的速度,往往需要綜合考量。
影響數(shù)據(jù)分布的幾個(gè)因素:
- 是否隨機(jī)采樣,比如排序機(jī)制可能引入分布的不一致。
- 是否交叉污染,比如 Packing 是否采用 Document Level 的 Mask(Block Diagonal Mask,也可以通過(guò) Position ID 區(qū)分)。
- 是否有 Sample 的截?cái)唷?/li>
影響計(jì)算效率的幾個(gè)因素:
- 是否 Padding,Padding 是否參與計(jì)算。
- Packing 是否采用 Document Level Mask。
- 是否存在負(fù)載不均衡問(wèn)題(主要是指稀疏度不同,計(jì)算量不同)。
三、GraphCore PackedBert
3.1 概述
GraphCore 在 2021 年 07 月的 [2107.02027] Efficient Sequence Packing without Cross-contamination: Accelerating Large Language Models without Impacting Performance 中已經(jīng)討論了 Sample Packing 導(dǎo)致的交叉污染(Cross Contamination)問(wèn)題。作者研究了新的 Packing 算法,并通過(guò)修改 Attention Mask 和 Position ID 來(lái)避免交叉污染,提升 Bert 模型的訓(xùn)練效率。
如下圖 Table 1 所示,不同的 Packing 方式有不同的 EFF(有效率),其中 Baseline 的 None 表示有 50% 左右的 Padding Token,而 SPFHP 和 NNLSHP(本文提出) 能獲得相對(duì)比較高的 EFF。
3.2 結(jié)果
如下圖 Figure 4 所示可以看出,需要相應(yīng)修改 Attention Mask 才能保證精度:
如下圖 Figure 5 所示,隨著 Accelerator 的增加,本文的方案能獲得比較穩(wěn)定的加速,非常接近理論速度(最上的藍(lán)線);而非 Padding 的方案可能隨著 Accelerator 的增加出現(xiàn)明顯的降速。
四、Meta In-Context-Pretraining
4.1 概述
Meta 作者在 [2310.10638] In-context Pretraining: Language Modeling Beyond Document Boundaries 也關(guān)注了 Sample Packing 的問(wèn)題。作者指出,之前的預(yù)訓(xùn)練流程在訓(xùn)練時(shí)將隨機(jī)的短文檔拼接起來(lái)形成輸入上下文,這些文檔之間沒(méi)有提供預(yù)測(cè)下一個(gè)文檔的信號(hào),導(dǎo)致計(jì)算效率不高。因此作者提出了新的方案:In-context Pretraining,通過(guò)改變文檔的順序,使得每個(gè)上下文包含相關(guān)的文檔,從而明確鼓勵(lì)模型跨文檔邊界進(jìn)行閱讀和推理。
4.2 方案
如下圖所示,在預(yù)訓(xùn)練前會(huì)計(jì)算文檔的相似性,在 Packing 時(shí)利用上這種相似性,保證 Sequence 中文檔盡可能相關(guān)。
4.3 結(jié)果
作者使用 CommonCrawl 數(shù)據(jù)集,預(yù)訓(xùn)練了 0.3B 到 7B 參數(shù)量的多個(gè)模型,并在多種任務(wù)上評(píng)估,包括上下文學(xué)習(xí)、閱讀理解、對(duì)先前上下文的忠實(shí)度、長(zhǎng)上下文推理和檢索增強(qiáng)等。與使用標(biāo)準(zhǔn)方法預(yù)訓(xùn)練的模型相比,In-context Pretraining 方法訓(xùn)練出的模型(ICLM)顯示出顯著的性能提升。
如下圖 Figure 3 所示,本文 ICLM 訓(xùn)練的模型獲得了更低的困惑度:
如下圖 Table 1 所示,基于 ICLM 訓(xùn)練的模型在下游 In-context Learning 任務(wù)上也獲得了更好的效果:
五、智譜 LongAlign
5.1 概述
智譜 AI 在 [2401.18058] LongAlign: A Recipe for Long Context Alignment of Large Language Models 中討論了部分 Sample Packing 相關(guān)問(wèn)題。如下圖 Figure 3 左圖所示,Sequence 的長(zhǎng)度各不相同,從 0 - 60K,如果采用 Naive Batching 方式,會(huì)導(dǎo)致明顯的 Bubble 問(wèn)題(雖然 NoPadding 技術(shù)可以避免重復(fù)計(jì)算,但是如果采用 Data Parallelism 方式,比較快的設(shè)備需要等待比較慢的設(shè)備計(jì)算完成)。為了解決效率和效果問(wèn)題,作者提出了 3 種解決方案:Packing、Loss Weighting 和 Sorted Batching。
5.2 Packing
如 Figure 3 右上圖所示,就是我們之前介紹的 Sample Packing:將不同的 Sample 拼接在一個(gè) Sequence 里,并且保證盡可能接近 Max Sequence Length,末尾的部分 Token 進(jìn)行 Padding。然后通過(guò) Block Diagonal Attention Mask 來(lái)區(qū)別不同的 Sample,以避免 Sample 之間的交叉污染,也就是 Document Level Attention。
PS:作者介紹,這里同樣是使用了 FlashAttention2 的 Varlen 特性。
5.3 Loss Weighting
假設(shè)訓(xùn)練時(shí)的 Batch Size 為 K,總共包含 M 個(gè) Sample,第 i 個(gè) Sample 的 Token 數(shù)為 Ni,則對(duì)應(yīng)的 Loss 如下圖所示:
然而,增加 Sample Packing 之后會(huì)引入一個(gè)問(wèn)題,如下圖所示,一個(gè) Sequence 中的不同 Sample 會(huì)被看成一個(gè) Sample 來(lái)計(jì)算損失。當(dāng)有些 Sample 比較長(zhǎng),其對(duì)應(yīng)的 Token 很多,那么這個(gè) Sample 對(duì) Loss 的貢獻(xiàn)就更大,模型可能會(huì)在訓(xùn)練時(shí)更傾向于優(yōu)化長(zhǎng) Sample 的表現(xiàn),進(jìn)而可能會(huì)導(dǎo)致對(duì)短 Sample 的學(xué)習(xí)有所欠缺。
為了解決這個(gè)問(wèn)題,作者提出 Loss Weighting,也就是對(duì)不同 Sample 的 Loss 加權(quán)。如下圖所示,保證其和上述公式(2)等價(jià)。作者聲稱可以在下游任務(wù)上帶來(lái) 10% 左右的效果提升。
PS:不過(guò)這里還會(huì)引入另外一個(gè)問(wèn)題,因?yàn)椴捎昧?Sample Packing,那么實(shí)際上不同 Step 中 Sample 的個(gè)數(shù)在不斷變化。比如一個(gè) Batch 里都是短 Sample,那么對(duì)應(yīng)的 M 會(huì)比較大;如果一個(gè) Batch 里都是長(zhǎng) Sample,相應(yīng)的 M 會(huì)比較小。這樣可能引入兩個(gè)問(wèn)題,1. 相當(dāng)于一個(gè) Batch Size 在不斷變化;2. 同樣一個(gè) Sample 可能會(huì)因?yàn)轫樞虻脑虮毁x予不同的權(quán)重。因此需要盡量保證Batch 中 Sample 的平均個(gè)數(shù)比較穩(wěn)定。
5.4 Sorted Batching
如上圖 Figure 3 右下圖所示,可以將所有 Sample 進(jìn)行排序,在組 Batch 時(shí)盡量保證一個(gè) Batch 中的 Sample 長(zhǎng)度相同(沒(méi)有 Packing)。這樣可以保證不同設(shè)備的計(jì)算盡可能的均衡。然而,這種方式也不可避免地引入不同 Batch 的數(shù)據(jù)分布的偏差,有些 Batch 都是長(zhǎng)序列,有些 Batch 都是短序列,對(duì)于 SGD 優(yōu)化來(lái)說(shuō)可能并不友好。不過(guò)作者發(fā)現(xiàn)這種方式可以顯著加快訓(xùn)練速度,而不會(huì)對(duì)效果產(chǎn)生明顯的負(fù)面影響。這可能是因?yàn)槭褂昧舜蟮奶荻壤奂樱∕icro Batch 中長(zhǎng)度類似,但整個(gè) Batch 中包含各種長(zhǎng)度的 Sample)。
PS:這種方式也就對(duì)應(yīng) Transformer 等工作中的 LengthGroupedSampler,如下圖所示,排序后可以有效降低無(wú)效計(jì)算(圖片參考 數(shù)據(jù)分組— XTuner 0.1.23 文檔)。
5.5 結(jié)果
如下圖 Figure 5 所示,作者在 8xA800 GPU 上進(jìn)行速度對(duì)比,可以看成,Packing 和 Sorted Batching 相比 Naive Batching 都有 2x-3x 的加速:
如下圖 Table 3 所示,作者基于 ChatGLM3-6B-64K 和 LLaMA-2-7B-64K 進(jìn)行了相關(guān)效果驗(yàn)證。可以看出,Loss Weighting 在 LongBench-Chat 上能帶來(lái) 5%-10% 的提升,但是在其他任務(wù)上并不明顯,并且這些方法看著都不是特別魯棒。
5.6 Packing 負(fù)載均衡
智譜 AI 在訓(xùn)練 GLM4(GLM Long: Scaling Pre-trained Model Contexts to Millions | by ChatGLM | Medium) 模型時(shí)進(jìn)一步解決了 Packing 帶來(lái)的負(fù)載均衡問(wèn)題。如下圖所示,雖然都 Packing 到了相同的長(zhǎng)度,但是由于其中的 Sample 個(gè)數(shù)、長(zhǎng)度不同,導(dǎo)致其稀疏度差距很大,計(jì)算量也相應(yīng)差距很大。如果它們?cè)诓煌脑O(shè)備上執(zhí)行,同樣會(huì)存在計(jì)算不均導(dǎo)致的 Bubble 問(wèn)題。
如下表所示,我們?cè)谏弦黄恼轮械膶?shí)驗(yàn)也能說(shuō)明這個(gè)問(wèn)題,隨著 Sequence 中 Sample 分布的不同,計(jì)算的耗時(shí)甚至可能差 10x:
如下圖所示,作者發(fā)現(xiàn)訓(xùn)練中每個(gè) Step 的時(shí)間存在較大幅度的波動(dòng),這種現(xiàn)象在短文本的 Packing SFT 中并不明顯。這是因?yàn)槎涛谋緯r(shí) Attention 的計(jì)算占比并不高,而超長(zhǎng)文本訓(xùn)練中會(huì)尤其明顯。
為了解決上述問(wèn)題,作者進(jìn)一步提出了 Sorted Packing。具體來(lái)說(shuō),作者在構(gòu)建 Batch 數(shù)據(jù)時(shí)考慮了計(jì)算復(fù)雜度,以確保每個(gè) Batch 中計(jì)算復(fù)雜度相似,從而減少 Bubble 時(shí)間。(PS:這里需要注意,計(jì)算復(fù)雜度不等于執(zhí)行速度,如果能針對(duì)預(yù)估計(jì)算速度來(lái)打包也許能獲得更優(yōu)的效果)
作者指出也使用了 layer accumulation 技術(shù)(PS:上述介紹的梯度累加?)來(lái)避免排序?qū)е碌钠脝?wèn)題。
六、Amazon Fewer Truncations
6.1 概述
在 [2404.10830] Fewer Truncations Improve Language Modeling 中,作者探討了數(shù)據(jù)截?cái)鄦?wèn)題對(duì)模型效果的影響。作者指出,截?cái)鄷?huì)損害數(shù)據(jù)完整性,從而阻礙模型學(xué)習(xí)基于完整上下文撰寫邏輯連貫且事實(shí)一致的內(nèi)容的能力。
為了解決這個(gè)問(wèn)題,作者提出了 Best-fit Packing,通過(guò)長(zhǎng)度感知組合優(yōu)化來(lái) Packing,可以完全消除不必要的截?cái)?,同時(shí)基本不影響訓(xùn)練效率。通過(guò)文本和代碼預(yù)訓(xùn)練的實(shí)驗(yàn)結(jié)果表明,提出的方法可以在閱讀理解上相對(duì)提升 4.7%,上下文跟隨提升 16.8%,程序生成提升 9.2%。此外,也可以有效減少 58.3% 的封閉域幻覺(jué)問(wèn)題。
PS:論文中對(duì)比實(shí)驗(yàn)時(shí) Baseline 的 Concatenation 方案中有截?cái)啵⑶覜](méi)有使用 Block Diagonal Mask;而 Best-Fit Packing 使用了 Block Diagonal Mask,且沒(méi)有截?cái)唷?/p>
6.2 方案
如下圖 Figure 1 所示為本文 Best-Fit Packing 與傳統(tǒng)方案的對(duì)比。
- 右圖所示為傳統(tǒng)方案:可以理解為將所有樣本排成一行,然后按照 Max Sequence Length 進(jìn)行切割,會(huì)導(dǎo)致大量 Sample 被截?cái)唷?/li>
- 左圖為本文的方案:首先將所有 Document 按照 Max Sequence Length 截?cái)?,然后使?Best-Fit Decreasing 算法來(lái)進(jìn)行組合優(yōu)化(Bin Packing 優(yōu)化)。?
如下圖 Table 2 所示,使用本文的 Best-Fit Packing 可以保證只增加不到 0.003% Sequence 數(shù)量,對(duì)訓(xùn)練效率的影響也就微乎其微。
6.3 結(jié)果
如下圖 Table 3 所示,作者訓(xùn)練了 3 個(gè)不同規(guī)模、序列長(zhǎng)度的模型:
如下圖 Table 4 所示,提出方案訓(xùn)練出的模型的閱讀理解能力有明顯提升:
如下圖 Table 9 所示,幻覺(jué)問(wèn)題也可以明顯降低:
如下圖 Table 10 所示,作者也針對(duì) Attention Mask 進(jìn)行了相關(guān)消融實(shí)驗(yàn),可以看出原始 Concatenation 方案加入 Block Diagonal Mask 后也有一定的提升??梢宰C明 Attention Mask 和截?cái)喽紩?huì)對(duì)效果有一定影響,不過(guò) Packing(避免截?cái)啵?的影響似乎更大一些。
七、IBM Packing with FlashAttention2
7.1 概述
在 [2407.09105] Enhancing Training Efficiency Using Packing with Flash Attention 中,作者總結(jié)了不同 Packing 策略、Mask 方式及與 FlashAttention 結(jié)合的優(yōu)勢(shì)。此外,作者也將相關(guān)工作提交到了 Huggingface Transformer 中,提供了新的 DataCollatorWithFlattening,具體可以參考:通過(guò)打包Flash Attention 來(lái)提升Hugging Face 訓(xùn)練效率。
7.2 相關(guān)方案
如下圖 Table 1 所示,作者分析了不同的 Packing 方案以及它們的影響,具體包含如下幾種方式:
- RandomSampling + Padding:最傳統(tǒng)的隨機(jī)采樣,然后 Padding 的方式。存在冗余計(jì)算,并且占比很高。
- GroupByLength+Padding:先排序,然后盡量保證每個(gè) Batch 中的序列長(zhǎng)度接近??梢詼p少 Padding 的占比。
- RandomSampling + PosID:隨機(jī)采樣,但是不 Padding,而是通過(guò) PosID 支持變長(zhǎng)序列。幾乎沒(méi)有冗余計(jì)算,但可能存在明顯的負(fù)載不均衡(計(jì)算量)。
- FixedLengthPacking:隨機(jī)采樣,隨機(jī) Packing,并且最后一個(gè)Sample 可能截?cái)?,保證填滿 Max Sequence Length。沒(méi)有區(qū)分不同 Sample,也就是 Causal Mask,沒(méi)有冗余計(jì)算,并且負(fù)載很均衡。
- FixedLengthPacking + PosID:相比FixedLengthPacking多了 PosID,也就是可以區(qū)分不同 Sample,對(duì)應(yīng) Block Diagonal Mask。但依然會(huì)存在末尾截?cái)?,并且可能?fù)載不均衡。
- MultiPack + PosID:使 Sequence 中的數(shù)據(jù)盡量接近 Batch 的 Max Sequence Length,降低 Sequence 中的長(zhǎng)度不均衡,可以參考GitHub - imoneoi/multipack_sampler: Multipack distributed sampler for fast padding-free training of LLMs。需要對(duì)數(shù)據(jù)進(jìn)行排序。
- SortedPacking + PosID:通過(guò)排序,使同一個(gè) Batch 中的計(jì)算復(fù)雜度盡量接近??梢员M可能降低計(jì)算負(fù)載不均衡問(wèn)題。
- RandomPacking + PosID:與FixedLengthPacking + PosID相比主要的區(qū)別就是最后一個(gè) Sample不截?cái)?,可能存在部?Bubble。?
7.3 結(jié)果
作者通過(guò)微調(diào)任務(wù)對(duì)比了一系列模型使用不同方案的效果和速度,其中 Max Sequence Length(msl)為 4096,每個(gè) GPU 的 Mini-Batch Size 為 4。對(duì)應(yīng)配置如下:
- no:表示最原始的RandomSampling + Padding,可以作為基線,冗余計(jì)算比較多,速度最慢,但是效果有保障。
- yes:表示FixedLengthPacking,存在交叉污染。
- flat:表示訓(xùn)練前 Offline Packing 好,也就是引入了排序,并且使用 PosID 實(shí)現(xiàn) Document Level Mask。
- mini:表示訓(xùn)練中 mini batch 的 Online Packing,和 Random 類似,并且使用 PosID 實(shí)現(xiàn) Document Level Mask。
如下圖所示(PS:這里只是部分結(jié)果,全量請(qǐng)參考論文,結(jié)論基本一致,可以看出 yes 和 flat 都會(huì)對(duì)精度(VLoss)有比較大的影響,但速度(Tok/s)確實(shí)快了很多,可以達(dá)到 Baseline 的 3x-4x;而 mini 可以在保證精度的情況實(shí)現(xiàn)實(shí)現(xiàn) 2x 左右加速。
作者使用 FLAN_20k 數(shù)據(jù)集在 Mistral-7B 上針對(duì)之前提到的幾種方案做了更多實(shí)驗(yàn)(PS:這里只是部分,gas 表示梯度累加次數(shù)),看起來(lái) Packing 的方案都能獲得不錯(cuò)的速度,但是精度的影響因素就比較多。但整體來(lái)說(shuō) RandomPacking + PosID 的方案還不錯(cuò),看著似乎也不能有太多的梯度累加。
八、參考鏈接
- ??https://arxiv.org/abs/2107.02027??
- ??https://arxiv.org/abs/2310.10638??
- ??https://arxiv.org/abs/2401.18058??
- ??https://xtuner.readthedocs.io/zh-cn/latest/acceleration/length_grouped_sampler.html#length-grouped-sampler??
- ??https://medium.com/@ChatGLM/glm-long-scaling-pre-trained-model-contexts-to-millions-caa3c48dea85??
- ??https://arxiv.org/abs/2404.10830??
- ??https://arxiv.org/abs/2407.09105??
- ??https://huggingface.co/blog/zh/packing-with-FA2??
- ??https://github.com/imoneoi/multipack_sampler??
本文轉(zhuǎn)載自 ??AI閑談??,作者: AI閑談
