解密萬億參數(shù)M6模型預(yù)訓(xùn)練背后的分布式框架Whale
最近,阿里云PAI團隊和達摩院智能計算實驗室一起發(fā)布“低碳版”巨模型M6,大幅降低萬億參數(shù)超大模型訓(xùn)練能耗。借助我們自研的Whale框架僅使用480卡GPU,即訓(xùn)練出了規(guī)模達人類神經(jīng)元10倍的萬億參數(shù)多模態(tài)大模型M6,與傳統(tǒng)海外公司實現(xiàn)萬億參數(shù)規(guī)模相比,能耗降低超八成、效率提升近11倍。
M6是國內(nèi)首個實現(xiàn)商業(yè)化落地的多模態(tài)大模型。M6擁有超越傳統(tǒng)AI的認知和創(chuàng)造能力,擅長繪畫、寫作、問答,在電商、制造業(yè)、文學(xué)藝術(shù)等諸多領(lǐng)域擁有廣泛應(yīng)用前景。
這里來為大家介紹支持萬億參數(shù)模型訓(xùn)練的Whale框架設(shè)計。
一、模型發(fā)展趨勢和挑戰(zhàn)
1.模型發(fā)展趨勢
隨著深度學(xué)習的火爆,模型的參數(shù)規(guī)模也增長迅速,OpenAI數(shù)據(jù)顯示:
- 2012年以前,模型計算耗時每2年增長一倍,和摩爾定律保持一致;
- 2012年后,模型計算耗時每3.4個月翻一倍,遠超硬件發(fā)展速度。
近一年模型參數(shù)規(guī)模飛速增長,谷歌、英偉達、阿里、智源研究院都發(fā)布了萬億參數(shù)模型,有大廠也發(fā)布了百億、千億參數(shù)模型。同時,隨著模型參數(shù)規(guī)模增大,模型效果也在逐步提高,Nvidia測試Bert模型不同參數(shù)規(guī)模,發(fā)現(xiàn)模型困惑度隨模型參數(shù)規(guī)模增加而降低。
Google在GShard paper中也發(fā)現(xiàn)MoETransformer 模型參數(shù)規(guī)模越大,翻譯質(zhì)量越高。
2.大模型訓(xùn)練的挑戰(zhàn)
大模型帶來模型效果提升的同時,也為訓(xùn)練框架帶來更大的挑戰(zhàn),例如當我們要訓(xùn)練一個萬億規(guī)模的模型時會面臨如下挑戰(zhàn):
- 訓(xùn)練難:
- GPU顯存已經(jīng)不夠存放模型副本,數(shù)據(jù)并行已經(jīng)不能滿足需求;
- 需要框架提供新的并行策略,協(xié)同多GPU能力來存放和訓(xùn)練模型;
- 如何給用戶提供簡潔、易用的接口,讓用戶能很容易實現(xiàn)分布式版模型;
- 超大規(guī)模模型對計算效率、通信效率都帶來很大挑戰(zhàn),如何提高計算和通信效率;
- 下游任務(wù)如何對接,如何支持批量預(yù)測和在線推理需求;
- 成本高:
- 以萬億模型為例,模型參數(shù)有4TB大小、梯度也有4TB,加上optimizer states和active tensor,顯存需求巨大;
- 業(yè)界訓(xùn)練同等規(guī)模模型需要的資源:英偉達 3072 A100、谷歌 2048 TPU v3,成本太高很難落地;
- 如何降本增效,使用更少的資源,更快的訓(xùn)練收斂;
當前已經(jīng)有一些分布式訓(xùn)練框架,例如:Horovod、Tensorflow Estimator、PyTorch DDP等支持數(shù)據(jù)并行,Gpipe、PipeDream、PipeMare等支持流水并行,Mesh Tensorflow、FlexFlow、OneFlow、MindSpore等支持算子拆分,但這些框架還有一些不足:
- 模式單一:很多框架只支持部分并行策略,不能完全支持各種混合并行;
- 接入門檻高:用戶實現(xiàn)模型分布式版本難度大、成本高,需要有領(lǐng)域?qū)<医?jīng)驗才能實現(xiàn)高效的分布式并行策略;
- 遷移代價大:不同分布式框架并行化實現(xiàn)割裂,不同框架有各自定義的DSL,當用戶要切換并行策略時,需要學(xué)習各種接口,重新改寫模型;
- 性能不理想:部分框架實現(xiàn)未考慮集群物理環(huán)境;
為了應(yīng)對當前分布式訓(xùn)練的挑戰(zhàn),我們研發(fā)了分布式訓(xùn)練框架Whale,主要目標是:
- 統(tǒng)一多種并行策略:在一個框架中支持各種并行策略以及這些策略的各種組合;
- 簡潔易用的接口:用戶只需添加幾行annotation即可完成并行策略的配置,模型代碼不需要改動;
- 高效的訓(xùn)練框架:結(jié)合硬件資源、網(wǎng)絡(luò)拓撲和模型進行協(xié)同優(yōu)化,打造高效分布式訓(xùn)練框架;
二、PAI自研Whale框架
1.Whale架構(gòu)
我們推出統(tǒng)一多種并行策略的高性能分布式訓(xùn)練框架Whale,從如下角度來應(yīng)對分布式訓(xùn)練的挑戰(zhàn):
- 將不同并行化策略進行統(tǒng)一抽象、封裝,在一套分布式訓(xùn)練框架中支持多種并行策略;
- 基于Tensorflow設(shè)計一套分布式并行接口,完全兼容Tensorflow,用戶僅僅只需添加幾行annotation就可以實現(xiàn)豐富的分布式并行策略;
- 結(jié)合模型結(jié)構(gòu)和網(wǎng)絡(luò)拓撲進行調(diào)度和通信優(yōu)化,提供高效的分布式訓(xùn)練能力。
Whale框架如下圖所示,主要分4個模塊:
- API:提供簡潔易用接口,讓用戶組合使用各種混合并行策略;
- Whale IR:將并行策略轉(zhuǎn)成內(nèi)部表達,通過TaskGraph、Multi-Dimension、VirtualDevices抽象來表達各種并行策略;
- Whale Engine:基于WhaleIR,通過圖編輯工具來構(gòu)建分布式執(zhí)行圖;
- Runtime:將分布式執(zhí)行圖轉(zhuǎn)成TFGraph,再調(diào)用TF 的Runtime來執(zhí)行;
2.Whale簡介易用接口
Whale提供簡潔易用的接口來描述各種并行策略,主要的原語:
- cluster:配置Virtual Device的劃分方法
- replica:數(shù)據(jù)并行
- stage:劃分TaskGraph
- pipeline:流水并行
- split:算子拆分
用這些接口可以組合各種并行策略,例如:
- 數(shù)據(jù)并行:
- 流水并行:
流水并行+數(shù)據(jù)并行:
更多并行策略示例:
3.Whale訓(xùn)練流程
使用Whale進行分布式訓(xùn)練流程:
- 并行策略配置:
- 使用Whale API來為模型配置并行策略,只需添加幾行annotation,無需修改模型代碼,方法如 2.2節(jié) 所示;
- 可以將模型劃分為多個TaskGraph,TaskGraph支持配置多個并行策略,每個TaskGraph可以配置不同的并行策略;
- 虛擬資源劃分:
- 按并行策略來劃分Virtual Device,每個TaskGraph對應(yīng)一個Virtual Device;
- 按GPU資源和網(wǎng)絡(luò)topo來為Virtual Device選擇Physical Device;
- 分布式執(zhí)行圖:
- 基于并行策略和資源分配信息,使用圖編輯工具來編輯執(zhí)行圖(圖拷貝、拆分、插入通信節(jié)點等),生成最終的分布式執(zhí)行圖;
- 調(diào)用TF的runtime來執(zhí)行分布式Graph;
三、萬億M6模型預(yù)訓(xùn)練
萬億模型的算力需求非常大,為了降低算力需求,Whale中實現(xiàn)了MoE(Mixture-of-Experts)結(jié)構(gòu),MoE的主要特點是稀疏激活,使用Gating(Router)來為輸入選擇Top k的expert進行計算(k常用取值1、2),從而大大減少算力需求。
Whale中實現(xiàn)了MoE(Mixture-of-Experts) layer,并支持專家并行,將experts拆分到多個Devices上,降低單個Device的顯存和算力需求。同時數(shù)據(jù)并行有利于提升訓(xùn)練的并發(fā)度,因此采用數(shù)據(jù)并行+專家并行組合的混合并行策略來訓(xùn)練M6模型:MoElayer采用專家并行,其他layer采用數(shù)據(jù)并行。
Whale中提供簡潔易用的接口來進行模型的混合并行訓(xùn)練,只需要增加幾行annotation來配置并行策略,模型本身不需要任何修改。M6模型采用數(shù)據(jù)并行+專家并行的策略,只需要增加如下圖的annotation:
同時為了節(jié)約訓(xùn)練資源,提高訓(xùn)練效率,Whale中提供各種優(yōu)化技術(shù):
顯存優(yōu)化:
- Auto Gradient Checkpoint,自動選擇最優(yōu)checkpoint節(jié)點,節(jié)約activation的顯存;
- Group-wise Apply,優(yōu)化Optimizer Apply階段的顯存;
- CPU Offload技術(shù),優(yōu)化Optimizer status和Weight的顯存;
- 通信池化,控制通信的數(shù)據(jù)塊大小和并發(fā),節(jié)約通信的顯存;
計算、通信加速:
- 采用DP+EP混合并行策略,降低算力需求;
- 采用分組融合通信、半精度通信、拓撲感知的All2All通信算子等技術(shù)來提高通信效率;
- 結(jié)合混合精度、編譯優(yōu)化等技術(shù)提高訓(xùn)練效率;
- 借助Whale框架,首次在480 V100 上,3天內(nèi)完成萬億M6模型的預(yù)訓(xùn)練。相比此前英偉達使用3072 A100 GPU實現(xiàn)萬億參數(shù)、谷歌使用2048 TPU實現(xiàn)1.6萬億參數(shù)大模型,此次達摩院僅使用480卡V100 32G GPU就實現(xiàn)了萬億模型M6,節(jié)省算力資源超80%,且訓(xùn)練效率提升近11倍。
四、結(jié)語
模型參數(shù)規(guī)模已越來越大,大模型已成為發(fā)展趨勢,為解決超大模型訓(xùn)練的挑戰(zhàn),我們自研Whale框架,將不同并行化策略進行統(tǒng)一抽象、封裝,在一套分布式訓(xùn)練框架中支持多種并行策略。Whale提供簡潔易用的接口,用戶只需添加幾行annotation即可實現(xiàn)各種并行策略,不需要對模型本身進行修改。同時我們結(jié)合硬件資源、網(wǎng)絡(luò)topo、模型進行軟硬件協(xié)同優(yōu)化,提供高效分布式訓(xùn)練框架。
通過Whale框架,我們用480 V100 GPU卡訓(xùn)練萬億規(guī)模模型,并在3天內(nèi)完成模型訓(xùn)練收斂,為超大規(guī)模模型訓(xùn)練落地提供了可能,后續(xù)我們會進一步完善Whale框架,從更大規(guī)模、更快速度、更高性價比3個維度去擴展Whale框架的能力。同時也會推動Whale能力在更多業(yè)務(wù)場景落地,讓技術(shù)能力到產(chǎn)品能力的轉(zhuǎn)變。