阿里云通義大模型新技術(shù):MoE模型訓(xùn)練專(zhuān)家平衡的關(guān)鍵細(xì)節(jié)
本周,在阿里云通義千問(wèn) Qwen 團(tuán)隊(duì)提交的一篇論文中,研究人員發(fā)現(xiàn)了目前最熱門(mén)的 MoE(混合專(zhuān)家模型)訓(xùn)練中存在的一個(gè)普遍關(guān)鍵問(wèn)題,并提出一種全新的方法——通過(guò)輕量的通信將局部均衡放松為全局均衡,使得 MoE 模型的性能和專(zhuān)家特異性都得到了顯著的提升。
- 論文:《Demons in the Detail: On Implementing Load Balancing Loss for Training Specialized Mixture-of-Expert Models》
- 論文鏈接:https://arxiv.org/abs/2501.11873
MoE 模型訓(xùn)練中的關(guān)鍵問(wèn)題
混合專(zhuān)家模型(MoEs)通過(guò)路由機(jī)制動(dòng)態(tài)并稀疏地激活模型參數(shù),使得能高效地增大模型參數(shù)規(guī)模。基于 TopK 機(jī)制的稀疏激活會(huì)在訓(xùn)練中會(huì)遇到專(zhuān)家激活不均衡的問(wèn)題:少數(shù)被頻繁選擇的專(zhuān)家會(huì)被優(yōu)化得更多,進(jìn)一步使得這些專(zhuān)家被更頻繁地選擇,最終導(dǎo)致只選擇少數(shù)專(zhuān)家,造成剩余專(zhuān)家的冗余。因此,MoE 在訓(xùn)練中需要引入額外輔助的負(fù)載均衡損失(load balance loss,LBL)來(lái)鼓勵(lì)專(zhuān)家的選擇趨于均衡。
目前主流 MoE 訓(xùn)練框架中實(shí)現(xiàn)的 LBL 的優(yōu)化目標(biāo)是局部(micro-batch)的負(fù)載均衡,這使得模型需要將一個(gè)micro-batch的輸入都均勻分配給不同的專(zhuān)家。然而,一個(gè)micro-batch的輸入往往只來(lái)自個(gè)別領(lǐng)域,局部負(fù)載均衡會(huì)讓模型將每個(gè)領(lǐng)域的輸入都均勻分配。這種均勻分配會(huì)阻礙某些專(zhuān)家更多處理特定領(lǐng)域的數(shù)據(jù),也即阻礙專(zhuān)家出現(xiàn)領(lǐng)域?qū)哟蔚姆只卣?。我們發(fā)現(xiàn),將局部的負(fù)載均衡放松到全局的負(fù)載均衡,能顯著增強(qiáng)專(zhuān)家的特異化并提高模型性能。
背景
混合專(zhuān)家(Mixture-of-Experts,MoE)是一種高效的在訓(xùn)練時(shí)擴(kuò)展模型參數(shù)規(guī)模的技術(shù)。通常,一個(gè)MoE層由一個(gè)路由器(通常是一個(gè)線性層)和一組專(zhuān)家組成(對(duì)于Transformer的模型,每個(gè)專(zhuān)家是一個(gè)前饋神經(jīng)網(wǎng)絡(luò))。給定一個(gè)輸入,只有部分專(zhuān)家會(huì)被激活,然后它們的輸出會(huì)根據(jù)路由器分配的權(quán)重進(jìn)行聚合。具體來(lái)說(shuō):
負(fù)載均衡損失
負(fù)載均衡損失是訓(xùn)練 MoE 網(wǎng)絡(luò)中的一種重要正則化技術(shù),其核心思想是鼓勵(lì)所有專(zhuān)家的均衡激活。它可以通過(guò)以下公式計(jì)算:
其中, 是專(zhuān)家 的激活頻率, 是分配給專(zhuān)家 的平均路由分?jǐn)?shù)。
然而,大多數(shù)現(xiàn)有的MoE訓(xùn)練框架(例如Megatron-core)實(shí)現(xiàn)的是局部(micro-batch)層次的均衡,這意味著在每個(gè) micro-batch 內(nèi)計(jì)算 LBL ,然后在全局(global-batch)層次上進(jìn)行平均,即:
其中 為 micro-batch 數(shù), 是在第 個(gè) micro-batch 上計(jì)算的負(fù)載均衡損失, 為在第 個(gè) micro-batch 上統(tǒng)計(jì)出的激活頻率和路由分?jǐn)?shù)。
我們關(guān)注的關(guān)鍵點(diǎn)是,如果一個(gè) micro-batch 中的數(shù)據(jù)不夠多樣化,這種實(shí)現(xiàn)方式可能會(huì)阻礙專(zhuān)家的特異化。例如,假設(shè)一個(gè) micro-batch 中只包含代碼數(shù)據(jù),上述負(fù)載均衡損失仍然會(huì)推動(dòng)路由器將這些代碼輸入均勻分配給所有專(zhuān)家。而理想狀況下,處理代碼數(shù)據(jù)的專(zhuān)家網(wǎng)絡(luò)應(yīng)該對(duì)代碼數(shù)據(jù)有更高的激活頻率。在訓(xùn)練基于 MoE 的大型語(yǔ)言模型時(shí),這種情況更常見(jiàn):一個(gè)較小的 micro-batch (通常為 1)中的數(shù)據(jù)通常來(lái)自同一領(lǐng)域。這在一定程度上解釋了為什么當(dāng)前大多數(shù)基于 MoE 的大語(yǔ)言模型中都沒(méi)有觀察到明顯的領(lǐng)域?qū)哟蔚膶?zhuān)家特異化。
這一缺點(diǎn)促使我們將當(dāng)前局部均衡的方法想辦法擴(kuò)展到全局(global-batch)均衡。
從局部均衡到全局均衡
得得益于 LBL 計(jì)算的格式,我們可以通過(guò)通信不同節(jié)點(diǎn)的 來(lái)將局部 轉(zhuǎn)化為全局的 :1)在所有 micro-batch 之間同步專(zhuān)家選擇頻率 ;2)在每個(gè)GPU上計(jì)算負(fù)載均衡損失;3)在所有 micro-batch 之間聚合損失。具體來(lái)說(shuō):
其中 是對(duì)全局統(tǒng)計(jì)的激活頻率和門(mén)控分?jǐn)?shù),第一個(gè)等式為 的計(jì)算方式,第二個(gè)等式為全局路由分?jǐn)?shù)可以由局部路由分?jǐn)?shù)平均而來(lái),第三個(gè)等式表示用全局激活頻率參與局部計(jì)算后再平均聚合等價(jià)于全局均衡損失。因?yàn)?nbsp; 只是一個(gè)專(zhuān)家數(shù)大小的向量,即使是在全局通信的情況下也不會(huì)帶來(lái)明顯的開(kāi)銷(xiāo)。此外由于 LBL 的計(jì)算與模型其它部分的計(jì)算相對(duì)獨(dú)立,還可以用計(jì)算掩蓋等策略進(jìn)一步消除同步 的通信開(kāi)銷(xiāo)。
此外,對(duì)于需要梯度積累的情景,我們還提出了緩存機(jī)制來(lái)累積各個(gè)積累步統(tǒng)計(jì)的專(zhuān)家激活頻率,使得計(jì)算節(jié)點(diǎn)較少、只進(jìn)行一次通信達(dá)到的均衡范圍有限的情況下,也能逐漸近似全局統(tǒng)計(jì)的激活頻率。
擴(kuò)大均衡的范圍帶來(lái)穩(wěn)定的提升
我們?cè)谌N參數(shù)規(guī)模(3.4B 激活 0.6B, 15B 激活 2.54B,43B 激活 6.6B)下分別訓(xùn)練了 120B 和 400B tokens,對(duì)比了不同的均衡范圍(Balance BSZ)對(duì)模型性能的影響。所有模型都使用了細(xì)粒度專(zhuān)家、共享專(zhuān)家及 dropless 策略(專(zhuān)家不會(huì)拋棄超過(guò)容量的tokens)??梢钥吹?,將均衡范圍從一般框架實(shí)現(xiàn)的 4,8 或者 16 增大到 128 以上后模型在 Benchmark 指標(biāo)和 PPL 都有明顯提升。
我們?cè)?3.4B 激活 0.6B 的模型訓(xùn)練 400B tokens 到設(shè)置上進(jìn)一步對(duì)比了模型效果隨著均衡范圍的變化,可以看到 balance BSZ 從 2 到 128 模型的 PPL 在快速降低,在 128 后逐漸飽和。目前主流 MoE 框架中即使是進(jìn)行了機(jī)內(nèi)通信,對(duì)于較大的模型 balance BSZ 也一般在 8 到 16 的,這進(jìn)一步體現(xiàn)了我們通信方法的意義。
分析實(shí)驗(yàn)
假設(shè)驗(yàn)證
前文提到,這篇工作的出發(fā)點(diǎn)是在一個(gè) micro-batch 中,數(shù)據(jù)的來(lái)源較為單一的,進(jìn)而導(dǎo)致 MoE 模型需要將類(lèi)似來(lái)源的數(shù)據(jù)均勻分配到所有expert上,我們改進(jìn)了這一點(diǎn)進(jìn)而得到了提升。
然而,我們也可以假設(shè) global batch 是因?yàn)槭褂昧烁嗟?token 來(lái)統(tǒng)計(jì) expert 激活頻率進(jìn)而減少了方差,使得負(fù)載均衡損失更加穩(wěn)定,進(jìn)而提升訓(xùn)練洗哦啊過(guò)。位了更加嚴(yán)謹(jǐn)?shù)貙?duì)比這兩種假設(shè),我們引入了一種對(duì)比的實(shí)驗(yàn)設(shè)置:Shffuled batch balance, 即我們從global batch中隨機(jī)抽取一個(gè)子集(這個(gè)子集的大小等于micro batch的大小)統(tǒng)計(jì)專(zhuān)家激活頻率,進(jìn)而計(jì)算負(fù)載均衡損失。Shuffled batch balance 和 micro-batch balance擁有相同的token數(shù)目,和 global-batch balance擁有相同的token分布。
我們發(fā)現(xiàn),shuffled batch balance 和 global batch balance 的表現(xiàn)幾乎一致,都顯著好于 micro batch balance。說(shuō)明,引入 global-batch 獲得提升的首要原因是在一個(gè)更加通用、多樣的 token 集合上計(jì)算損失。進(jìn)而驗(yàn)證了我們的出發(fā)點(diǎn)和假設(shè)。
添加少量局部均衡損失
能提高模型效率
只使用全局均衡會(huì)導(dǎo)致局部均衡狀況有所降低,這會(huì)一定程度影響 MoE 的計(jì)算效率。我們進(jìn)一步實(shí)驗(yàn)了在主要使用全局均衡的情況下,在訓(xùn)練過(guò)程中添加局部均衡(默認(rèn)實(shí)現(xiàn)的 LBL,損失權(quán)重為全局 LBL 的 1%)限制對(duì)于模型性能和效率的影響??梢钥吹?,添加局部均衡能提升模型的速度(每個(gè)更新步耗時(shí)從 1.64秒提升到1.59秒),同時(shí)模型的效果也幾乎不受影響。
同期相關(guān)工作以及討論
已有工作 GRIN 也提出了 Global Load Balance Loss Adaptations,然而更多將這一均衡方法作為訓(xùn)練框架只使用張量并行、不使用專(zhuān)家并行的優(yōu)勢(shì)。GRIN 中并沒(méi)有從 specialization 或是對(duì)模型 performance 影響等方面討論使用 Global Load Balance 的動(dòng)機(jī),也沒(méi)有展示單一使用 Global Load Balance 的影響。
Wang et al. 提出在基于MoE的大語(yǔ)言模型訓(xùn)練中,負(fù)載均衡損失和語(yǔ)言模型損失如同杠桿一樣需要權(quán)衡,因?yàn)閮烧叩膬?yōu)化目標(biāo)并不一致。因此,他們提出了一種基于專(zhuān)家選擇頻率更新的偏差項(xiàng)(bais term),在不改變路由分?jǐn)?shù)的情況下平衡專(zhuān)家選擇,從而去掉了用來(lái)輔助訓(xùn)練的負(fù)載均衡損失(auxiliary-loss free)?;趯?zhuān)家選擇頻率更新的偏置項(xiàng),以在不改變路由評(píng)分的情況下平衡專(zhuān)家選擇。但是,他們沒(méi)有比較該方法在專(zhuān)家選擇頻率是根據(jù) micro-batch 計(jì)算和根據(jù) global-batch 計(jì)算時(shí)的性能差異。
這項(xiàng)工作也被應(yīng)用到 deepseek-v3 的訓(xùn)練中。deepseek-v3 的技術(shù)報(bào)告(同期工作)中強(qiáng)調(diào)了這項(xiàng)技術(shù)的專(zhuān)家選擇頻率是基于 global-batch 進(jìn)行計(jì)算,并在小規(guī)模上討論了基于global batch 使用 LBL 的結(jié)果,也發(fā)現(xiàn)這兩種方法結(jié)果相似。
而我們的工作不僅在大規(guī)模上系統(tǒng)驗(yàn)證了這種方法的有效性,還詳細(xì)析了均衡范圍對(duì)性能的影響,并消融證明了 global-batch 是通過(guò)納入更多樣化的領(lǐng)域信息從而顯著提性能。
結(jié)論
我們回顧了目前 MoE 訓(xùn)練框架中均衡損失,發(fā)現(xiàn)目前的實(shí)現(xiàn)方式會(huì)將所有來(lái)自相同領(lǐng)域的局部輸入都均勻分配,限制了專(zhuān)家的分化。通過(guò)輕量的通信將局部均衡放松為全局均衡,MoE 模型的性能和專(zhuān)家特異性都得到了顯著的提升。我們認(rèn)為這一進(jìn)展解決了現(xiàn)有MoE訓(xùn)練中的一個(gè)關(guān)鍵問(wèn)題,為MoE模型的優(yōu)化提供了新的視角,并有助于構(gòu)建更加可解釋的模型。盡管我們的實(shí)驗(yàn)主要集中在基于語(yǔ)言的任務(wù)上,我們希望我們的工作能夠?yàn)樵诓煌I(lǐng)域訓(xùn)練更大規(guī)模、更有效的 MoE 模型提供幫助。