一文帶你了解MindSpore支持的萬億級參數(shù)超大模型關(guān)鍵技術(shù)!
前言
近來,增大模型規(guī)模成為了提升模型性能的主要手段。特別是NLP領(lǐng)域的自監(jiān)督預(yù)訓(xùn)練語言模型,規(guī)模越來越大,從GPT3的1750億參數(shù),到Switch Transformer的16000億參數(shù),又是一個數(shù)量級的增加。
模型規(guī)模的數(shù)量級的增大,雖然取得了一定程度的性能提升,甚至產(chǎn)生了某些意想不到的“神奇”效果(如GPT3),但其背后的計算開銷成了最大的問題,比如GPT3訓(xùn)練使用了萬級的GPU和數(shù)周的訓(xùn)練時間。如何既能利用超大規(guī)模的參數(shù)來提升模型表達(dá)和性能,又能控制計算量的較小的增加,成為了最主要的挑戰(zhàn)之一。以MoE為代表的動態(tài)神經(jīng)網(wǎng)絡(luò)技術(shù)被重點(diǎn)引入。大腦是典型的低能耗高效率的計算模式,稀疏激活是最重要的特性。除了巨型模型在訓(xùn)練推理特別是訓(xùn)練時的計算效能挑戰(zhàn)外,當(dāng)前巨型模型的訓(xùn)練優(yōu)化算法另一個更大的挑戰(zhàn)是(不在此處討論),BP算法是當(dāng)前最為可用的深度網(wǎng)絡(luò)優(yōu)化,但更理想的優(yōu)化算法需要高并行、優(yōu)化過程非對稱、并能夠在時空維度通過局部持續(xù)優(yōu)化完成整體優(yōu)化。
1. 傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)模型,前饋的時候,輸入的batch中,每一個樣本的處理,都將激活網(wǎng)絡(luò)中的每一個參數(shù)參與計算。
2. 條件計算最寬松的定義,指僅激活網(wǎng)絡(luò)中某些部分的一類算法。Conditional Computation refers to a class of algorithms that activate only some of the different parts in a network. 在具體某類條件計算實(shí)現(xiàn)中,條件選擇模式,可能按照輸入的batch中每sample獨(dú)立激活網(wǎng)絡(luò)不同部分,可能按照輸入數(shù)據(jù)空間上不同的部分(比如image不同區(qū)域或者channel),可能按照輸入數(shù)據(jù)時間上不同的部分(比如time series的不同slide window或者video的不同的frame。),可能按照目標(biāo)任務(wù)的不同每task獨(dú)立的,可能按照非可學(xué)習(xí)的固定的隨機(jī)分配不同的子網(wǎng)獨(dú)立計算。
3. 對不同的輸入(原始或者前層),按照一定條件,選擇性的執(zhí)行后續(xù)部分網(wǎng)絡(luò)的計算,這個技術(shù)下,有一些近似或相關(guān)的技術(shù),如:dynamic neural network(s), conditional computing, conditional activation, sparse activating, selective execution, mixture of experts (MoE), dynamic routing, …;強(qiáng)相關(guān)的一些模型比如 Switch Transformer等。
條件計算的分類(廣義)
1. 按照routing是否可學(xué)習(xí)可以分為:learnable routing conditional computation和 unlearnable routing conditional computation.
2. 按照activation是否不執(zhí)行non-activation計算,可以分為:hard conditional computation和soft conditional computation。對于hard-mode的條件計算,通過tensor挑選切分等操作,無論何種條件選擇模式,不需要激活的數(shù)據(jù)將完全不參與不激活的網(wǎng)絡(luò)部分的計算;soft-mode的條件計算,可能僅采取將相關(guān)數(shù)據(jù)置零等方式來避免產(chǎn)生計算效果,但還是和不需要激活網(wǎng)路部分實(shí)際執(zhí)行計算過程。
條件計算的主要優(yōu)勢
1. 計算有效,降低能耗:通過部分激活部分計算,以每樣本條件激活的條件計算為例,單個樣本只需要經(jīng)過整個SuperNet的一部分參與計算。
2. 更大網(wǎng)絡(luò),表達(dá)更強(qiáng):由于一處到多處的Route,各處(層)的Input被路由到不同的子網(wǎng)獨(dú)立計算,不同的輸入的相互在各層的表達(dá)相對獨(dú)立沒有影響,表達(dá)能力更強(qiáng),網(wǎng)絡(luò)可以更大,但表達(dá)效率降低了。
條件計算的網(wǎng)絡(luò)和計算形式
條件計算的網(wǎng)絡(luò)和計算形式比較靈活,部分構(gòu)建形式如:(此處省略具體模型和論文引用,參見: http:// intellabs.github.io/dis )
1. 按照CV等task的特點(diǎn),用多個獨(dú)立的CNN作為expert網(wǎng)絡(luò),按照task來獨(dú)立路由,尾部組合后給一個大網(wǎng)絡(luò)。
2. 使用更復(fù)雜的cascading等形式組合不同層級的不同的expert網(wǎng)絡(luò)。
3. 通過決策樹等方法做數(shù)據(jù)變換實(shí)現(xiàn)路由。
4. 通過可學(xué)習(xí)的網(wǎng)絡(luò)來選擇路由。其中策略學(xué)習(xí)的損失有多種構(gòu)建形式:直接使用分類等任務(wù)的主損失,對不同專家的重要性和負(fù)載構(gòu)建損失作為輔助損失等等。
條件計算的路由策略
1. non-learnable/hard-mode,通過某種確定性策略,如LSH等方式計算路由。
2. learnable-mode,通過可學(xué)習(xí)網(wǎng)絡(luò)計算路由。網(wǎng)絡(luò)規(guī)??纱罂尚?,簡單的可學(xué)習(xí)路由為單層權(quán)重:G(x) = P(X*W),G(x)為路由Gate函數(shù),X為輸入, W為通損失函數(shù)來度量的可學(xué)習(xí)路由權(quán)重,P為某種挑選函數(shù)(如topk, sort等),在實(shí)際實(shí)現(xiàn)中,X*W的輸入與權(quán)重計算結(jié)果可能作為后續(xù)網(wǎng)絡(luò)的輸入信息的一部分,不僅僅利用G(x)來選擇路由,則需要對X*W的結(jié)果做歸一化,更典型的形式則為:G(x)=P(N(X*W)),其中N為表達(dá)Normalization函數(shù),如Softmax。
條件計算的冗余策略
條件計算的冗余策略,可分為無冗余條件計算和冗余條件計算:
1. 無冗余條件計算可通過P(.)函數(shù)的實(shí)現(xiàn)如topk(k=1,…)來實(shí)現(xiàn);
2. 冗余條件計算,可以多種實(shí)現(xiàn)形式,可以通過P(.)函數(shù)的實(shí)現(xiàn)如topk(k=n,…),n>=2來實(shí)現(xiàn),也可以通過硬冗余模式,整個網(wǎng)絡(luò)中支持輸入的復(fù)制和多路計算實(shí)現(xiàn)。
條件計算的挑戰(zhàn)
1. 路由算法對模型質(zhì)量的影響無論輸入和路由權(quán)重作用的信息(X*W),是僅作為路由選擇并作為后續(xù)網(wǎng)絡(luò)單元的輸入,還是直接作為后續(xù)網(wǎng)絡(luò)單元的輸入的一部分,路由算法決定了輸入信息的處理流向,對模型的整體質(zhì)量都有很大影響。2. 路由(routing)/門(gate)的穩(wěn)定性隨機(jī)初始化的路由/門的權(quán)重,權(quán)重自身在不斷被訓(xùn)練調(diào)整;在前后層的網(wǎng)絡(luò)持續(xù)訓(xùn)練變化,同一樣本在訓(xùn)練的不同階段會被分派到不同的后續(xù)網(wǎng)絡(luò)單元中,這種動態(tài)變化過于劇烈,將嚴(yán)重影響整個網(wǎng)絡(luò)訓(xùn)練過程的穩(wěn)定性和收斂速度。3、路由的專家樣本重要性和負(fù)載的平衡性
訓(xùn)練階段,每專家和樣本批次中樣本的關(guān)聯(lián)度重要性,和每批次中樣本被均衡分派到不同專家的負(fù)載平衡性,這兩個指標(biāo)既相關(guān)又沖突。需要分別構(gòu)建損失函數(shù)作為輔助損失,來優(yōu)化這兩個指標(biāo)。在arxiv:1701.06538《Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer》做了相關(guān)討論。
關(guān)于條件計算/動態(tài)神經(jīng)網(wǎng)絡(luò)
關(guān)于條件計算/動態(tài)神經(jīng)網(wǎng)絡(luò),更多的信息在《Dynamic Neural Networks: A Survey》arxiv:2102.04906 ( http:// arxiv.org/abs/2102.0490 )一文中,作者對廣義的動態(tài)神經(jīng)網(wǎng)絡(luò),將各種動態(tài)網(wǎng)絡(luò)相關(guān)的技術(shù)按照實(shí)例級、時間級、空間級做了分類。
- 1. Instance-wise Dynamic NN:逐實(shí)例動態(tài),每樣本獨(dú)立激活不同的網(wǎng)絡(luò)和參數(shù)(MoE為這個方向)。Dynamic Architecture:Dynamic Depth、Dynamic Width、Dynamic Routing/MoE;Dynamic Parameter:Parameter Adjustment、Parameter Prediction、Dynamic Feature(s)
- 2. Spatial-wise Dynamic NN:空間級動態(tài):圖像等不同空間位置激活后續(xù)不同網(wǎng)絡(luò)和參數(shù)。(CNN等):Pixel Level、Region Level、Resolution Level
- 3. Temporal-wise Dynamic NN:時間級動態(tài):時序數(shù)據(jù)按時序維切分激活后續(xù)不同網(wǎng)絡(luò)和參數(shù)。(video-frames, text-sequence, time-series, stream, ...)Text-SequenceVideo-Frames
上述為該綜述論文對Dynamic NN的總體分類。
從超大規(guī)模網(wǎng)絡(luò)動態(tài)網(wǎng)絡(luò)技術(shù)支撐角度,高表達(dá)能力,低計算代價為主的來考慮分類,從兩個維度對動態(tài)網(wǎng)絡(luò)技術(shù)分類:
1. 按照在前饋計算時是否部分激活:
Hard-Dynamic:在前饋的時候,部分網(wǎng)絡(luò)絕對不激活參與計算
Soft-Dynamic:在前饋的時候,部分網(wǎng)絡(luò)經(jīng)過softmax等gate/route后,通過張量元素置零等方式,失去表達(dá)能力,但會參與計算。
2. 按照動態(tài)激活判定算法的輸入:
- 逐樣本級:(在輸入層)按照每樣本的實(shí)例來決定動態(tài)網(wǎng)絡(luò)的后續(xù)激活。
- 亞樣本級:(在輸入層)樣本內(nèi)時間/空間級激活不同的后續(xù)網(wǎng)絡(luò)單元。一般深度網(wǎng)絡(luò),不僅在輸入層會被選擇性激活執(zhí)行,在中間層也類似。
其中,智能平臺支持Hard-Dynamic逐樣本級的動態(tài)神經(jīng)網(wǎng)絡(luò),能比較自然的獲得網(wǎng)絡(luò)結(jié)構(gòu)大顆粒的稀疏激活,在超大模型中能實(shí)現(xiàn)訓(xùn)練和推理的高能效。
動態(tài)神經(jīng)網(wǎng)絡(luò)相比與靜態(tài)結(jié)構(gòu)的神經(jīng)網(wǎng)絡(luò),在相關(guān)研究中,從效能,表達(dá),泛化、魯棒,可解釋等方面做了大量對比研究。從智能平臺通過計算成本盡量低的支持超大規(guī)模網(wǎng)絡(luò)來提升模型性能的角度看,Efficiency和Representation最為重要:
1、Efficiency:靜態(tài)網(wǎng)絡(luò)“牽一發(fā)而動全身”,每一個樣本輸入整個網(wǎng)絡(luò)/所有參數(shù)都要響應(yīng),這對超大網(wǎng)絡(luò)來取得領(lǐng)先效果的模型能耗挑戰(zhàn)太大。
2、Representation: 參數(shù)量更大,表達(dá)容量更大;但MoE等結(jié)構(gòu)在深度網(wǎng)絡(luò)的各層特征的表達(dá)上,復(fù)用降低,每參數(shù)的表達(dá)效率更低。
實(shí)現(xiàn)策略
實(shí)現(xiàn)各種模型的帶有動態(tài)路由稀疏激活的超大規(guī)模參數(shù)版本,需要分模型研究和實(shí)現(xiàn)。
以Switch Transformer為例,其參數(shù)擴(kuò)展到部分在Transformer的FFN部分。其MoE化擴(kuò)展,如下圖:
(圖片來源:Switch Transformer論文)
可見,MoE化主要變化在需要Expert子網(wǎng)絡(luò)前后增加MoE相關(guān)的邏輯。本文主要介紹平臺上的實(shí)現(xiàn)。動態(tài)路由條件計算,主要包括四個步驟:路由計算、數(shù)據(jù)分派、獨(dú)立計算,結(jié)果合并。
1. 路由計算-Gate:根據(jù)輸入(可以為整個網(wǎng)絡(luò)的輸入,或者前面網(wǎng)絡(luò)單元/層的輸出),在路由單元完成計算,在以batch內(nèi)sample-wise的路由中,計算出每個樣本要分派的后續(xù)網(wǎng)絡(luò)路由(Mixture-of-Experts/MoE中的專家)。
2. 數(shù)據(jù)分派-Dispatch:從輸入的整體的Tensor中,按照路由計算的樣本-專家關(guān)系,收集合并出每個專家需要處理的Tensor。如果在固定expert-batch的設(shè)計中,要平衡每批訓(xùn)練中,分派到每個專家的樣本數(shù)和專家每輪訓(xùn)練最大容量,由于樣本輸入的隨機(jī)性,很難保證較為均勻的分派,對于低于最大容量的批次,對固定batch-size的要做pad,對于高于最大容量的樣本,可以采用延后重采樣等方式。為了維護(hù)正確的輸入輸出關(guān)系(Input/X – Label/Y)和訓(xùn)練是反向傳播的求導(dǎo)關(guān)系,實(shí)現(xiàn)中需要維護(hù)原始batch到每專家的sub-batch的index關(guān)系,在后來求導(dǎo)和結(jié)合合并時使用。
3. 獨(dú)立計算-Expert:并發(fā)(邏輯上可以先后)調(diào)用各個專家處理對應(yīng)的sub-batch。這也是智能平臺要支持的并發(fā)API之一。
4. 結(jié)果合并-Combine:合并每專家的結(jié)果tensor到整個batch的tensor,并按照數(shù)據(jù)分派索引,交換到原始輸入的順序。
在主流的深度學(xué)習(xí)智能平臺中,可以采用兩類主要的實(shí)現(xiàn)策略:
張量置零:對需要分派到不同的后續(xù)網(wǎng)絡(luò)單元(專家網(wǎng)絡(luò)子網(wǎng)等),對需要分派的專家拷貝若干份tensor,對于不應(yīng)輸入當(dāng)前專家處理的數(shù)據(jù)維度置零。該方式在保證置零計算邏輯正確的情況下,實(shí)現(xiàn)簡單,全張量操作,對平臺無特殊要求,適用于算法研究,僅體現(xiàn)條件計算前序數(shù)據(jù)被動態(tài)路由到不同的后續(xù)網(wǎng)絡(luò)單元,分析算法的效果。如果通過置零方式,該方法每個專家處理的tensor在batch維度大小是全batch,不能節(jié)省計算量和內(nèi)存使用量。
張量整理:對需要分派到不同的后續(xù)網(wǎng)絡(luò)單元(專家網(wǎng)絡(luò)子網(wǎng)等),對需要分派的專家拷貝若干份tensor,對于不應(yīng)輸入當(dāng)前專家處理的數(shù)據(jù)維度不保留。并維護(hù)好sample級的index在變換前后的對應(yīng)關(guān)系。在分布式友好的實(shí)現(xiàn)中,如果專家子網(wǎng)為單位被劃分到不同的計算節(jié)點(diǎn),那么專家網(wǎng)絡(luò)的實(shí)現(xiàn)最好從子網(wǎng)級的平臺對象繼承后實(shí)現(xiàn),比如:MindSpore中的mindspore.nn.Cell。詳細(xì)實(shí)現(xiàn)細(xì)節(jié)參見后續(xù)技術(shù)實(shí)現(xiàn)章節(jié)。
核心代碼
核心代碼:路由計算、數(shù)據(jù)分派、獨(dú)立計算,結(jié)果合并
參考代碼采用MindSpore示意實(shí)現(xiàn)。(注:import mindspore as ms)
Mixture of Experts的核心邏輯,對輸入I,經(jīng)過routing_network(最簡單*W即可),然后topk(若變種算法需要gate權(quán)重則需要softmax,否則可不),然后用tensor的操作(可按照batch)選擇出每個subnetwork/expert的張量。
為方便調(diào)試,采用了規(guī)模極小的非隨機(jī)的確定數(shù)值構(gòu)造輸入和路由權(quán)重,路由網(wǎng)絡(luò)采用簡單的X*W。
1、路由計算
當(dāng)上述輸入5行(僅3類,希望分派給3個專家)樣本,和Gate權(quán)重做矩陣乘后,可以明確算出每個樣本要分派的專家??梢杂胢atmul,也可以類似gates_weighted = einsum('bd,de->be', [data_inputs, gate_weights])第一輪矩陣乘的結(jié)果為:
輸入和權(quán)重乘法,在python中可以采用@,也可以采用matmul,也可以采用愛因斯坦求和簡記憶法函數(shù)einsum。當(dāng)是簡單的矩陣乘的時候,采用einsum在計算圖編譯的時候?qū)嶋H會拆分成多個算法,性能并不好;但當(dāng)輸入和權(quán)重超過2維,需要以batch維固定做路由計算的時候,使用einsum可以編程實(shí)現(xiàn)很簡單。
2、分派
條件計算的分派,主要邏輯是根據(jù)路由網(wǎng)絡(luò)的輸出,為每個樣本計算出top-k的專家。其實(shí)現(xiàn)可以通過topk函數(shù)實(shí)現(xiàn)。由于top選擇score可作為后續(xù)網(wǎng)絡(luò)單元的輸入信息(含路由的信息),所以一般要對路由輸出做softmax做歸一化。
按需計算1:all-N專家之間的歸一化權(quán)重 (please refer to #2) ,gates_weighted一樣,按照dim=-1做了歸一化而已其輸出為:
為batch中每個sample選擇Top-K個專家 這里為batch中每個的專家權(quán)重,可以從softmax-ed來top-k,也可以直接從gates_weighted來top-k;由于這里可能不做softmax或者延后,所以可gates_weighted,這里為batch中每個的專家序號
其輸出為:
接著:
按需計算2: top-n專家之間的歸一化權(quán)重
如何根據(jù)分派索引,從原始的輸入中,為每個專家提取出屬于該專家處理的tensor,在當(dāng)前的主流智能平臺,都沒有專門的算子,可以通過其他算子的組合來實(shí)現(xiàn)類似的效果。在MindSpore中,可以通過底層的C++實(shí)現(xiàn)算子,也可以通過Python中繼承Cell并實(shí)現(xiàn)bprob, 然后將原始 gate tensor中按照index組織到目標(biāo)輸出中。這里我們實(shí)現(xiàn)一個Dispatch類
3、獨(dú)立計算
直接并行調(diào)用后續(xù)的專家網(wǎng)絡(luò)。并行部分可以通過平臺來支持??梢酝ㄟ^特殊的函數(shù)或者annotation等標(biāo)識,也可以由平臺編譯時優(yōu)化為并行執(zhí)行。(在非動態(tài)路由條件計算的網(wǎng)絡(luò)模型中,一般不存在類似的優(yōu)化。)
4、合并
合并的邏輯相對簡單,先通過cat按照batch維度做拼接,然后構(gòu)造正確的zeros tensor用index_add按照索引將各個專家網(wǎng)絡(luò)的結(jié)果在保持input序合并到一起,做為該MoE模塊的輸出。
上述完成了整個MoE的完整計算過程。
代碼框架
我們按照上述基本動態(tài)路由條件計算的張量操作為主的邏輯,擴(kuò)展到一個完整的訓(xùn)練代碼框架中:
- class Dispatch(ms.nn.Cell): 實(shí)現(xiàn)路由中的分派邏輯
- class Combine(ms.nn.Cell): 實(shí)現(xiàn)路由中的組裝邏輯
- class Route(ms.nn.Cell): 完成整個動態(tài)路由邏輯,可以實(shí)現(xiàn)為相對通用的類
- class Expert(ms.nn.Cell): 平臺用戶自定義的專家網(wǎng)絡(luò)
- class Network(ms.nn.Cell): 平臺用戶自定義的大網(wǎng)絡(luò)
- class MSELoss(ms.nn.Cell):實(shí)現(xiàn)MSE損失,實(shí)現(xiàn)輔助損失的邏輯
- class OutputLossGraph(ms.nn.Cell):輸出infer和loss,PyNative模式單步
- class Dataset: 數(shù)據(jù)集類,僅滿足輸入shape和X-Y合理對應(yīng)關(guān)系,僅僅示例def train( …): 訓(xùn)練入口
條件計算實(shí)現(xiàn)技術(shù)點(diǎn)
1、動態(tài)路由
- 不可學(xué)習(xí)路由
如使用LSH (locality sensitive hashing)做路由:在整個可學(xué)習(xí)網(wǎng)絡(luò)的前端,使用LSH來分派樣本,這樣可以避免LSH部分求導(dǎo)問題;如果在網(wǎng)絡(luò)中間增加LSH模塊,需要通過梯度估計完成確定性算法部分梯度傳遞。
- 可學(xué)習(xí)路由
簡單的做法,定義gate_weights為可學(xué)習(xí)Parameter,對于二維的張量,通過python@或者matmul等完成權(quán)重路由計算;如果是更高維度的張量,且需固定batch維,einsum('bd*,*de->b*e')的形式完成計算。
2、topk和softmax的前后關(guān)系
在G_1(x)=softmax(topk(X*W)))和G_2(x)=topk(softmax(X*W)))兩類Gate實(shí)現(xiàn)中,
將softmax置于Topk前后,對top-k的選擇不變;當(dāng)需要將G_*作為后序網(wǎng)絡(luò)輸入的一部分,即將路由權(quán)重信息作為后續(xù)網(wǎng)絡(luò)輸入信息,則需要考慮:需要all-N專家之間的歸一化權(quán)重,則softmax置于top-k之前;否則softmax置于top-k之后,來計算top-N專家之間的歸一化權(quán)重。
3、如何每專家在批次處理中平衡
按照每樣本的路由權(quán)重求和,即對batch單個樣本被分配的1+個export的重要性和權(quán)重求和,計算出importance;按照每樣本的路由權(quán)重中非0的求和,計算出有負(fù)載的專家來求得load。將coefficient_of_variation(importance) + coefficient_of_variation(load)作為auxiliary_loss參與優(yōu)化,來平衡importance和load。變異系數(shù)(Coefficient of Variation)是用于無量綱度量數(shù)據(jù)的離散程度,越離散在此處表示均衡性越差,需要向更小優(yōu)化。
在Transformer等多層(多處)MoE的模型中,將多組auxiliary_loss聯(lián)合作為auxiliary_loss, 在加dominated_loss之后即可。