分布式深度學(xué)習(xí)新進(jìn)展:讓“分布式”和“深度學(xué)習(xí)”真正深度融合
近年來,作為人工智能發(fā)展迅速的領(lǐng)域之一的深度學(xué)習(xí)在NLP、圖像識(shí)別、語音識(shí)別、機(jī)器翻譯等方面都取得了驚人的成果。但是,深度學(xué)習(xí)的應(yīng)用范圍卻日益受到數(shù)據(jù)量和模型規(guī)模的限制。如何才能高效地進(jìn)行深度學(xué)習(xí)模型訓(xùn)練?微軟亞洲研究院機(jī)器學(xué)習(xí)組主管研究員陳薇和她的團(tuán)隊(duì)基于對(duì)機(jī)器學(xué)習(xí)的完整理解,將分布式技術(shù)和深度學(xué)習(xí)緊密結(jié)合在一起,探索全新的真正合二為一“分布式深度學(xué)習(xí)”算法。
隨著大數(shù)據(jù)和高效計(jì)算資源的出現(xiàn),深度學(xué)習(xí)在人工智能的很多領(lǐng)域中都取得了重大突破。然而,面對(duì)越來越復(fù)雜的任務(wù),數(shù)據(jù)和深度學(xué)習(xí)模型的規(guī)模都變得日益龐大。例如,用來訓(xùn)練圖像分類器的有標(biāo)簽的圖像數(shù)據(jù)量達(dá)數(shù)百萬、甚至上千萬張。大規(guī)模訓(xùn)練數(shù)據(jù)的出現(xiàn)為訓(xùn)練大模型提供了物質(zhì)基礎(chǔ),因此近年來涌現(xiàn)出了很多大規(guī)模的機(jī)器學(xué)習(xí)模型,例如2015年微軟亞洲研究院開發(fā)出的擁有超過兩百億個(gè)參數(shù)的LightLDA主題模型。然而,當(dāng)訓(xùn)練數(shù)據(jù)詞表增大到成百上千萬時(shí),如果不做任何剪枝處理,深度學(xué)習(xí)模型可能會(huì)擁有上百億、甚至是幾千億個(gè)參數(shù)。
為了提高深度學(xué)習(xí)模型的訓(xùn)練效率,減少訓(xùn)練時(shí)間,我們普遍會(huì)采用分布式技術(shù)來執(zhí)行訓(xùn)練任務(wù)——同時(shí)利用多個(gè)工作節(jié)點(diǎn),分布式地、高效地訓(xùn)練出性能優(yōu)良的神經(jīng)網(wǎng)絡(luò)模型。分布式技術(shù)是深度學(xué)習(xí)技術(shù)的加速器,能夠顯著提高深度學(xué)習(xí)的訓(xùn)練效率、進(jìn)一步增大其應(yīng)用范圍。
深度學(xué)習(xí)的目標(biāo)是從數(shù)據(jù)中挖掘出規(guī)律,幫助我們進(jìn)行預(yù)測(cè)。深度學(xué)習(xí)算法的一般框架是,利用優(yōu)化算法迭代地最小化訓(xùn)練數(shù)據(jù)上的經(jīng)驗(yàn)風(fēng)險(xiǎn)。由于數(shù)據(jù)的統(tǒng)計(jì)性質(zhì)、優(yōu)化的收斂性質(zhì)、以及學(xué)習(xí)的泛化性質(zhì)在多機(jī)執(zhí)行時(shí)的靈活度更高,相比于其它的計(jì)算任務(wù),深度學(xué)習(xí)算法在并行化執(zhí)行過程中實(shí)際上并不需要計(jì)算節(jié)點(diǎn)通過通信嚴(yán)格地執(zhí)行單機(jī)版本算法。因而,當(dāng)“分布式”遇到“深度學(xué)習(xí)”,不應(yīng)只局限在對(duì)串行算法進(jìn)行多機(jī)實(shí)現(xiàn)以及底層實(shí)現(xiàn)方面的技術(shù),我們更應(yīng)該基于對(duì)機(jī)器學(xué)習(xí)的完整理解,將分布式和深度學(xué)習(xí)緊密結(jié)合在一起,結(jié)合深度學(xué)習(xí)的特點(diǎn),設(shè)計(jì)全新的真正合二為一的“分布式深度學(xué)習(xí)”算法。
圖1 分布式深度學(xué)習(xí)框架
分布式深度學(xué)習(xí)框架中,包括數(shù)據(jù)/模型切分、本地單機(jī)優(yōu)化算法訓(xùn)練、通信機(jī)制、和數(shù)據(jù)/模型聚合等模塊?,F(xiàn)有的算法一般采用隨機(jī)置亂切分的數(shù)據(jù)分配方式,隨機(jī)優(yōu)化算法(例如隨機(jī)梯度法)的本地訓(xùn)練算法,同步或者異步通信機(jī)制,以及參數(shù)平均的模型聚合方式。
結(jié)合深度學(xué)習(xí)算法的特點(diǎn),微軟亞洲研究院機(jī)器學(xué)習(xí)組重新設(shè)計(jì)/理解了這些模塊,我們?cè)诜植际缴疃葘W(xué)習(xí)領(lǐng)域主要做了三個(gè)方面的工作:第一個(gè)工作,針對(duì)異步機(jī)制中的梯度延遲問題,我們?yōu)樯疃葘W(xué)習(xí)設(shè)計(jì)了“帶有延遲補(bǔ)償?shù)漠惒剿惴?rdquo;;第二個(gè)工作,注意到神經(jīng)網(wǎng)絡(luò)的非凸性質(zhì),我們提出了比參數(shù)平均更加有效的集成聚合方式,并設(shè)計(jì)了“集成-壓縮”并行深度學(xué)習(xí)算法;第三個(gè)工作,我們首次分析了隨機(jī)置亂切分方式下分布式深度學(xué)習(xí)算法的收斂速率,為算法設(shè)計(jì)提供了理論指導(dǎo)。
DC-ASGD算法:補(bǔ)償異步通信中梯度的延遲
隨機(jī)梯度下降法(SGD)是目前最流行的深度學(xué)習(xí)的優(yōu)化算法之一,更新公式為:
公式 1
其中,wt為當(dāng)前模型,(xt, yt)為隨機(jī)抽取的數(shù)據(jù),g(wt; xt, yt)為(xt, yt)所對(duì)應(yīng)的經(jīng)驗(yàn)損失函數(shù)關(guān)于當(dāng)前模型wt的梯度,η為步長(zhǎng)/學(xué)習(xí)率。
假設(shè)系統(tǒng)中有多個(gè)工作節(jié)點(diǎn)并行地利用隨機(jī)梯度法優(yōu)化神經(jīng)網(wǎng)絡(luò)模型,同步和異步是兩種常用的通信同步機(jī)制。
同步隨機(jī)梯度下降法(Synchronous SGD)在優(yōu)化的每輪迭代中,會(huì)等待所有的計(jì)算節(jié)點(diǎn)完成梯度計(jì)算,然后將每個(gè)工作節(jié)點(diǎn)上計(jì)算的隨機(jī)梯度進(jìn)行匯總、平均并按照公式1更新模型。之后,工作節(jié)點(diǎn)接收更新之后的模型,并進(jìn)入下一輪迭代。由于Sync SGD要等待所有的計(jì)算節(jié)點(diǎn)完成梯度計(jì)算,因此好比木桶效應(yīng),Sync SGD的計(jì)算速度會(huì)被運(yùn)算效率最低的工作節(jié)點(diǎn)所拖累。
異步隨機(jī)梯度下降法(Asynchronous SGD)在每輪迭代中,每個(gè)工作節(jié)點(diǎn)在計(jì)算出隨機(jī)梯度后直接更新到模型上,不再等待所有的計(jì)算節(jié)點(diǎn)完成梯度計(jì)算。因此,異步隨機(jī)梯度下降法的迭代速度較快,也被廣泛應(yīng)用到深度神經(jīng)網(wǎng)絡(luò)的訓(xùn)練中。然而,Async SGD雖然快,但是用以更新模型的梯度是有延遲的,會(huì)對(duì)算法的精度帶來影響。什么是“延遲梯度”?我們來看下圖。
圖2 異步隨機(jī)梯度下降法
在Async SGD運(yùn)行過程中,某個(gè)工作節(jié)點(diǎn)Worker(m)在第t次迭代開始時(shí)獲取到模型的最新參數(shù)wt和數(shù)據(jù)(xt, yt),計(jì)算出相應(yīng)的隨機(jī)梯度gt,并將其返回并更新到全局模型w上。由于計(jì)算梯度需要一定的時(shí)間,當(dāng)這個(gè)工作節(jié)點(diǎn)傳回隨機(jī)梯度gt時(shí),模型wt已經(jīng)被其他工作節(jié)點(diǎn)更新了τ輪,變?yōu)榱藈t+τ。也就是說,Async SGD的更新公式為:
公式 2
對(duì)比公式1,公式2中對(duì)模型wt+τ上更新時(shí)所使用的隨機(jī)梯度是g(wt; xt, yt),相比SGD中應(yīng)該使用的隨機(jī)梯度g(wt+τ; xt+τ, yt+τ)產(chǎn)生了τ步的延遲。因而,我們稱Async SGD中隨機(jī)梯度為“延遲梯度”。
延遲梯度所帶來的最大問題是,由于每次用以更新模型的梯度并非是正確的梯度(請(qǐng)注意g(wt; xt, yt) ≠ g(wt+τ; xt+τ, yt+τ)),所以導(dǎo)致Async SGD會(huì)損傷模型的準(zhǔn)確率,并且這種現(xiàn)象隨著機(jī)器數(shù)量的增加會(huì)越來越嚴(yán)重。如下圖所示,隨著計(jì)算節(jié)點(diǎn)數(shù)目的增加,Async SGD的精度逐漸變差。
圖3 異步隨機(jī)梯度下降法的性能
那么,如何能讓異步隨機(jī)梯度下降法在保持訓(xùn)練速度的同時(shí),獲得更高的精度呢?我們?cè)O(shè)計(jì)了可以補(bǔ)償梯度延遲的DC-ASGD(Delay-compensated Async SGD)算法。
為了研究正確梯度g(wt+τ)和延遲梯度g(wt)之間的關(guān)系,我們將g(wt+τ)在wt處進(jìn)行泰勒展開:
其中,∇g(wt)為梯度的梯度,也就是損失函數(shù)的Hessian矩陣,H(g(wt))為梯度的Hessian矩陣。顯然,延遲梯度實(shí)則為真實(shí)梯度的零階近似,而其余各項(xiàng)造成了延遲。于是,一個(gè)自然的想法是,如果我們將所有的高階項(xiàng)都計(jì)算出來,就可以修正延遲梯度為準(zhǔn)確梯度了。然而,由于余項(xiàng)擁有無窮項(xiàng),所以無法被準(zhǔn)確計(jì)算。因此,我們選擇用上述公式中的一階項(xiàng)進(jìn)行延遲補(bǔ)償:
眾所周知,在現(xiàn)代的深度神經(jīng)網(wǎng)絡(luò)模型中有上百萬甚至更多的參數(shù),計(jì)算和存儲(chǔ)Hessian矩陣∇g(wt)成為了一件幾乎無法完成的事情。因此,尋找Hessian矩陣的一個(gè)良好近似是能否補(bǔ)償梯度延遲的關(guān)鍵。根據(jù)費(fèi)舍爾信息矩陣的定義,梯度的外積矩陣
是Hessian矩陣的一個(gè)漸近無偏估計(jì),因此我們選擇用G(wt)來近似估計(jì)Hessian矩陣。根據(jù)前人的研究,如果在神經(jīng)網(wǎng)絡(luò)模型中用Hessian矩陣的對(duì)角元來近似Hessian矩陣,在顯著降低運(yùn)算和存儲(chǔ)復(fù)雜度的同事還可以保持算法精度,于是我們采用diag(G(wt))作為Hessian矩陣的近似。為了進(jìn)一步降低近似的方差,我們使用一個(gè)(0,1]之間參數(shù)λ來對(duì)偏差和方差進(jìn)行調(diào)節(jié)。綜上,我們?cè)O(shè)計(jì)了如下帶有延遲補(bǔ)償?shù)漠惒诫S機(jī)梯度下降法(DC-ASGD),
其中,對(duì)延遲梯度g(wt)的補(bǔ)償項(xiàng)中只包含一階梯度信息,幾乎不增加計(jì)算和存儲(chǔ)代價(jià)。
我們?cè)贑IFAR10數(shù)據(jù)集和ImageNet數(shù)據(jù)集上對(duì)DC-ASGD算法進(jìn)行了評(píng)估,實(shí)驗(yàn)結(jié)果見以下兩圖。
圖4 DC-ASGD的訓(xùn)練/測(cè)試誤差_CIFAR-10
圖5 DC-ASGD的訓(xùn)練/測(cè)試誤差_ImageNet
可以觀察到,DC-ASGD算法與Async SGD算法相比,在相同的時(shí)間內(nèi)獲得的模型準(zhǔn)確率有顯著的提升,并且也高于Sync SGD,基本可以達(dá)到SGD相同的模型準(zhǔn)確率。
Ensemble-Compression算法:改進(jìn)非凸模型的聚合方法
參數(shù)平均是現(xiàn)有的分布式深度學(xué)習(xí)算法中非常普遍的模型聚合方法。如果損失函數(shù)關(guān)于模型參數(shù)是凸的,以下不等式成立:
其中,K為計(jì)算節(jié)點(diǎn)個(gè)數(shù),wk是局部模型,
為參數(shù)平均后的模型,(x, y)為任意樣本數(shù)據(jù)。該不等式的左端是平均模型所對(duì)應(yīng)的損失函數(shù),右端是各個(gè)局部模型的損失函數(shù)值的平均值??梢?,凸問題中參數(shù)平均可以保持模型的性能。
但是,對(duì)于非凸的神經(jīng)網(wǎng)絡(luò)模型,以上不等式將不再成立,因而平均模型的性能不再具有保證。這一點(diǎn)在實(shí)驗(yàn)上也得到了驗(yàn)證:如圖6所示,對(duì)于不同的交互頻率(尤其是較低頻的交互),參數(shù)平均通常會(huì)大幅度拉低訓(xùn)練精度,使得訓(xùn)練的過程極不穩(wěn)定。
圖6 基于參數(shù)平均的分布式算法訓(xùn)練曲線(DNN模型)
為了解決這個(gè)問題,我們提出用模型集成替代模型平均,作為分布式深度學(xué)習(xí)中的模型聚合方式。雖然神經(jīng)網(wǎng)絡(luò)的損失函數(shù)關(guān)于模型參數(shù)是非凸的,但是關(guān)于模型的輸出一般是凸的(比如深度學(xué)習(xí)中常用的交叉熵?fù)p失)。這時(shí),利用凸性可以得到如下不等式:
其中,不等式左側(cè)是集成(ensemble)模型的損失函數(shù)取值??梢?,對(duì)于非凸模型,集成模型可以保持性能。
然而,每經(jīng)過一次集成,神經(jīng)網(wǎng)絡(luò)模型的規(guī)模就會(huì)增加倍,從而出現(xiàn)模型規(guī)模爆炸的問題。那么,有沒有一種既能利用模型集成的優(yōu)點(diǎn),又能避免增大模型的方法呢?我們提出了一種同時(shí)基于模型集成和模型壓縮的模型聚合方法, 即集成-壓縮(ensemble-compression)方法。在每次集成之后,我們對(duì)集成所得的模型進(jìn)行一次壓縮。
算法具體分為三個(gè)步驟:
-
各個(gè)計(jì)算節(jié)點(diǎn)依照本地優(yōu)化算法訓(xùn)練和局部數(shù)據(jù)訓(xùn)練出局部模型;
-
計(jì)算節(jié)點(diǎn)之間相互通信局部模型得到集成模型,并對(duì)(一部分)局部數(shù)據(jù)標(biāo)注上集成模型對(duì)它們的輸出值;
-
利用模型壓縮技術(shù)(比如知識(shí)蒸餾),結(jié)合數(shù)據(jù)的再標(biāo)注信息,在每個(gè)工作節(jié)點(diǎn)上分別進(jìn)行模型壓縮,獲得與局部模型大小相同的新模型作為最終的聚合模型。為了進(jìn)一步節(jié)省計(jì)算量,可以將蒸餾的過程和本地模型訓(xùn)練的過程結(jié)合在一起。
這種集成-壓縮的聚合方法,既可以通過集成獲得性能提升,又可以在學(xué)習(xí)的迭代過程中保持全局模型的規(guī)模。CIFA-10和ImageNet上的實(shí)驗(yàn)結(jié)果也很好地驗(yàn)證了集成-壓縮聚合方法的有效性(見圖7和圖8)。當(dāng)工作節(jié)點(diǎn)之間相互通信的頻率較低時(shí),參數(shù)平均方法表現(xiàn)會(huì)很差,但模型集成-壓縮的方法卻依然能取得理想的效果。這是因?yàn)?strong>集成學(xué)習(xí)在子模型具有多樣性時(shí)效果更好,而較低的通信頻率會(huì)導(dǎo)致各個(gè)局部模型更加分散,多樣性更強(qiáng);同時(shí),較低的通信頻率意味著較低的通信代價(jià)。因此,模型集成-壓縮的方法更適用于網(wǎng)絡(luò)環(huán)境比較差的場(chǎng)景。
圖7 CIFAR數(shù)據(jù)集上對(duì)比各種分布式算法
圖8 ImageNet數(shù)據(jù)集上對(duì)比各種分布式算法
基于模型集成的分布式算法是一個(gè)比較新的研究領(lǐng)域,還存在很多沒有解決的問題。比如說當(dāng)工作節(jié)點(diǎn)非常多的時(shí)候或者本地模型本身就很大的時(shí)候,集成模型的規(guī)模會(huì)變得非常大,這會(huì)帶來較大的網(wǎng)絡(luò)開銷。另外,當(dāng)集成模型較大時(shí),模型壓縮也會(huì)變成一個(gè)較大的開銷。值得注意的是,在ICLR 2018上,Hinton等人提出的Co-distillation方法,盡管在動(dòng)機(jī)上和這個(gè)工作不同,但是其算法和這個(gè)工作非常相似。如何理解這些關(guān)聯(lián)和解決這些局限性將催生新的研究,感興趣的讀者可以對(duì)此進(jìn)行思考。
隨機(jī)重排下算法的收斂性分析:改進(jìn)分布式深度學(xué)習(xí)理論
最后,簡(jiǎn)單介紹一個(gè)我們最近在改進(jìn)分布式深度學(xué)習(xí)理論方面的工作。
分布式深度學(xué)習(xí)中常用的數(shù)據(jù)分配策略是隨機(jī)重排之后均等切分。具體來說,是將所有訓(xùn)練數(shù)據(jù)隨機(jī)地打亂順序,得到數(shù)據(jù)的一個(gè)重排列,然后將數(shù)據(jù)集按順序等分,并將每一份存放在計(jì)算節(jié)點(diǎn)上。在數(shù)據(jù)被過完一輪后,如果收集所有的局部數(shù)據(jù)并重復(fù)上述過程,一般被稱為“全局重排”,如果只是對(duì)局部數(shù)據(jù)進(jìn)行隨機(jī)重排,一般被稱為“局部重排”。
現(xiàn)有的絕大部分分布式深度學(xué)習(xí)理論中都假定數(shù)據(jù)是獨(dú)立同分布的。然而,基于Fisher-Yates算法的隨機(jī)重排實(shí)際上與無放回抽樣等價(jià),訓(xùn)練數(shù)據(jù)之間不再是獨(dú)立同分布的了。于是,每一輪所計(jì)算的隨機(jī)梯度也不再是精確梯度的無偏估計(jì),所以以往分布式隨機(jī)優(yōu)化算法的理論分析方法不再適用,已有的收斂性結(jié)果也不一定仍然成立。
我們利用Transductive Rademancher Complexity作為工具,給出了隨機(jī)梯度相對(duì)于精確梯度的偏差上界,證明了隨機(jī)重排下分布式深度學(xué)習(xí)算法的收斂性分析。
假設(shè)目標(biāo)函數(shù)光滑(不一定是凸函數(shù)),系統(tǒng)中有K個(gè)計(jì)算節(jié)點(diǎn),訓(xùn)練輪數(shù)(epoch)為S,總訓(xùn)練數(shù)據(jù)有n個(gè),考慮分布式SGD算法。
(1)如果采用全局隨機(jī)重排數(shù)據(jù)分配策略,算法的收斂速率為,其中非i.i.d.性質(zhì)帶來的額外誤差為
。因此,當(dāng)過數(shù)據(jù)的輪數(shù)遠(yuǎn)遠(yuǎn)小于訓(xùn)練樣本數(shù)目時(shí)(S ≪n),額外誤差的影響可以忽略??紤]到現(xiàn)有的分布式深度學(xué)習(xí)任務(wù)中,S ≪n是很容易滿足的,所以全局隨機(jī)重排不會(huì)影響分布式算法的收斂速率。
(2)如果采用局部重排策略數(shù)據(jù)分配策略,算法的收斂速率為
,其中非i.i.d.性質(zhì)帶來的更大的額外誤差
。原因是,由于隨機(jī)重排是本地局部進(jìn)行的,不同計(jì)算節(jié)點(diǎn)之間的數(shù)據(jù)沒有進(jìn)行交互,數(shù)據(jù)的差異性更大,隨機(jī)梯度的偏差也更大。當(dāng)過數(shù)據(jù)的輪數(shù)S≪n/K2時(shí),額外誤差的影響可以忽略。也就是說,使用局部重排數(shù)據(jù)分配策略時(shí),算法中過數(shù)據(jù)的輪數(shù)要收到計(jì)算節(jié)點(diǎn)數(shù)目的影響。如果計(jì)算節(jié)點(diǎn)數(shù)比較多,過的輪數(shù)不能太大。
目前,分布式深度學(xué)習(xí)領(lǐng)域的發(fā)展非常迅速,而以上工作只是我們研究組所做的一些初步探索。希望本文能夠讓更多的研究人員了解“分布式”需要與“深度學(xué)習(xí)”進(jìn)行深度融合,大家一起推動(dòng)分布式深度學(xué)習(xí)的新發(fā)展!
作者簡(jiǎn)介:
陳薇,微軟亞洲研究院機(jī)器學(xué)習(xí)組主管研究員,研究機(jī)器學(xué)習(xí)各個(gè)分支的理論解釋和算法改進(jìn),尤其關(guān)注深度學(xué)習(xí)、強(qiáng)化學(xué)習(xí)、分布式機(jī)器學(xué)習(xí)、博弈機(jī)器學(xué)習(xí)、排序?qū)W習(xí)等。陳薇于2011年加入微軟亞洲研究院,負(fù)責(zé)機(jī)器學(xué)習(xí)理論項(xiàng)目,先后在NIPS、ICML、AAAI、IJCAI等相關(guān)領(lǐng)域頂級(jí)國際會(huì)議和期刊上發(fā)表論文。
參考文獻(xiàn):
-
Shuxin Zheng, Qi Meng, Taifeng Wang, Wei Chen, Zhi-Ming Ma and Tie-Yan Liu, Asynchronous Stochastic Gradient Descent with Delay Compensation, ICML2017
-
Shizhao Sun, Wei Chen, Jiang Bian, Xiaoguang Liu and Tie-Yan Liu, Ensemble-Compression: A New Method for Parallel Training of Deep Neural Networks, ECML 2017
-
Qi Meng, Wei Chen, Yue Wang, Zhi-Ming Ma and Tie-Yan Liu, Convergence Analysis of Distributed Stochastic Gradient Descent with Shuffling, https://arxiv.org/abs/1709.10432