大模型千卡訓(xùn)練總結(jié)!難點(diǎn)到底在哪里?
本文經(jīng)自動駕駛之心公眾號授權(quán)轉(zhuǎn)載,轉(zhuǎn)載請聯(lián)系出處。
最近看到知乎一個(gè)回答,把千卡訓(xùn)練的難度吹上天了。但其實(shí)真正用過千卡就會發(fā)現(xiàn)也就那么幾個(gè)點(diǎn)。于是想寫一篇文章簡單講講。
本文將包括3個(gè)部分:首先我們將討論千卡訓(xùn)練的難題,以及應(yīng)該在什么時(shí)候使用千卡訓(xùn)練;接著,我們將討論如何在一千張卡上開始訓(xùn)練,如何讓他達(dá)到近乎線性的性能提升;最后我們將展開討論一些千卡訓(xùn)練當(dāng)中仍然懸而未決(至少對于開源社區(qū)來說)的問題。
為什么千卡訓(xùn)練是困難的?
千卡訓(xùn)練和八卡訓(xùn)練的區(qū)別是—顯卡多了一百多倍。
這意味著什么呢?
- 通信時(shí)間增加
- 故障概率增加
這倆問題都很好理解。
時(shí)間上,PyTorch內(nèi)部支持NCCL/Gloo/MPI三個(gè)通信后端(請務(wù)必使用NCCL。其中AllReduce操作會會根據(jù)具體硬件配置走Ring AllReduce和Tree AllReduce。Ring的時(shí)間復(fù)雜度是,Tree的時(shí)間復(fù)雜度是 。就算是理論上128節(jié)點(diǎn)也比單節(jié)點(diǎn)慢至少七倍,實(shí)踐當(dāng)中跨節(jié)點(diǎn)通訊要遠(yuǎn)比單節(jié)點(diǎn)慢得多。
故障上,一個(gè)節(jié)點(diǎn)出問題的概率是p,128個(gè)節(jié)點(diǎn)就是1-(1-p)^128。也就是說如果一個(gè)操作在一個(gè)訓(xùn)練當(dāng)中的出錯(cuò)概率是1%,那么在128節(jié)點(diǎn)當(dāng)中的出錯(cuò)概率就是72.37%。
此外,隨著規(guī)模的增大,許多問題都會變得難以忍受。比如數(shù)據(jù)增強(qiáng)要花0.1s,一億條數(shù)據(jù)就是278個(gè)小時(shí)(當(dāng)然這只是胡拆的一個(gè)數(shù)字,實(shí)際有各種機(jī)制所以不會有這么大影響。
因此,錢多燒手并不是使用千卡訓(xùn)練的理由。閑得蛋疼可能是,但你得多蛋疼才能想出這么折磨自己的idea?
千卡訓(xùn)練解決的問題是大模型&大數(shù)據(jù)問題。如果你的訓(xùn)練時(shí)間沒有超過8192GPU日,那么你絕對不需要一千張顯卡。
看到這里,絕大多數(shù)人已經(jīng)可以關(guān)掉這篇文章了。除非你的模型和數(shù)據(jù)都以B(十億)來作為計(jì)量單位。當(dāng)然如果你正在廁所里手機(jī)沒電想看點(diǎn)兒東西解悶兒的話(雖然我很懷疑是否會有人把他打出來……那么可以繼續(xù)往下看
如何使用一千張卡訓(xùn)練?
如何提高計(jì)算效率?
這件事情其實(shí)是一個(gè)case by case的事情。因?yàn)橥ㄐ?、?jì)算速度啥的受硬件影響更多。而每一個(gè)集群的硬件拓?fù)涠际遣灰粯拥?。同樣是A100集群,我全DGX節(jié)點(diǎn),每一張A100都是SXM接口并配一塊兒專屬的IB網(wǎng)卡。你一個(gè)小破普惠服務(wù)器插8張PCI-E A100,IB卡一個(gè)節(jié)點(diǎn)只給一張。那咱倆遇到的問題就完全不是一個(gè)問題。
因此,要討論如何提高訓(xùn)練效率、減少訓(xùn)練耗時(shí),我們首先要了解訓(xùn)練耗時(shí)在哪里。那么,一個(gè)訓(xùn)練步的耗時(shí)在哪里呢?需要謹(jǐn)記,沒有profile的優(yōu)化是沒有意義的。
你可能會說,forward backward sync。很好,這說明你了解PyTorch的基本流程。不過現(xiàn)實(shí)當(dāng)中要復(fù)雜得多。
- dataset讀取數(shù)據(jù),構(gòu)建輸出
- dataloader collate數(shù)據(jù),進(jìn)行數(shù)據(jù)預(yù)處理
- 模型forward計(jì)算輸出
- loss compute
- 模型backward計(jì)算梯度
- 模型sync梯度
- 優(yōu)化器step更新權(quán)重
- 打印log
當(dāng)然這是可以無限細(xì)分下去的,但一般這些就夠了。需要注意的是,除了4-7的耗時(shí)是真耗時(shí),其他都需要通過異步操作來蓋掉。這也是我們的優(yōu)化目標(biāo)。
異步執(zhí)行在PyTorch的dataloader、CUDA和分布式當(dāng)中都存在。前者可以通過設(shè)置num_workers和prefetch_count為0來關(guān)閉,后兩者可以通過cuda.synchornize和dist.barrier來執(zhí)行手動同步。在profile時(shí),我們需要首先需要測整個(gè)step的時(shí)長。然后再在每次測量前執(zhí)行手動同步來計(jì)算每個(gè)部分的時(shí)長。如果前者的總耗時(shí)等于后者4-7的耗時(shí)之和,那么通常不需要執(zhí)行任何操作。但這種情況在千卡操作中幾乎不可能發(fā)生。
第6步通信往往需要耗費(fèi)大量時(shí)間。因此,我們還需要進(jìn)一步優(yōu)化通信。
以下內(nèi)容是對《PyTorch Distributed: Experiences on Accelerating Data Parallel Training》論文的概括,有感興趣的同學(xué)建議通讀并背誦全文。
Paper: https://arxiv.org/abs/2006.15704
計(jì)算-通信重疊
在PyTorch當(dāng)中,梯度的通信和反向傳播是交疊進(jìn)行的。也就是說,每完成一層的梯度計(jì)算,都會立即觸發(fā)當(dāng)前層的同步。實(shí)現(xiàn)起來也很簡單,每個(gè)進(jìn)程在完成自己第k層的梯度計(jì)算后都會觸發(fā)一個(gè)鉤子來給計(jì)數(shù)器+1s。當(dāng)計(jì)數(shù)器達(dá)到進(jìn)程數(shù)是開火進(jìn)行梯度通信。有很多同學(xué)在計(jì)算梯度過程中遇到過RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one.錯(cuò)誤,這就是因?yàn)橛械哪K沒有參與計(jì)算loss,導(dǎo)致梯度同步卡住了。需要注意,當(dāng)find_unused_parameters=True時(shí),PyTorch分布式使用nn.Module.__init__當(dāng)中定義sub-module的反向順序來作為梯度桶的構(gòu)建順序。因此,確保模塊定義和調(diào)用的順序一致對于高效訓(xùn)練來說很重要。
梯度合桶
盡管理論上來說,同步發(fā)生的越及時(shí),重合度越高,性能越好。但實(shí)際上每次發(fā)起通信都是有開銷的。因此,現(xiàn)實(shí)當(dāng)中梯度同步并不是越多越好越快越好。為此,PyTorch引入了梯度合桶機(jī)制,通過把多個(gè)Tensor裝在一個(gè)桶里再通信桶來減少通信次數(shù)從而減少總耗時(shí)。合桶的Buffer Size等等參數(shù)往往需要針對硬件和模型來調(diào)整從而取得最好的通信效果。PyTorch的默認(rèn)參數(shù)是從0.x時(shí)代祖?zhèn)飨聛淼?,這一參數(shù)通常都需要調(diào)節(jié)。
梯度累加
當(dāng)你做完所有操作之后,驚喜的發(fā)現(xiàn)TMD怎么同步時(shí)間還是單節(jié)點(diǎn)的好幾倍。這其實(shí)是正常情況……實(shí)際上超過256卡的訓(xùn)練想要把通信蓋掉就是一件不可能的事情。你說老師我看FB論文說他們256卡就是線性提升啊…那這里不得不提的一個(gè)策略就是梯度累加了。梯度累加會執(zhí)行k次forward+backward之后再執(zhí)行優(yōu)化器步進(jìn)。這有很多好處,首先對于大模型batch size通常不能開多大,梯度累加可以提升等效batch size。其次累加期間的backward不需要通信梯度,加快了訓(xùn)練速度。
少即是快
Python是一種很慢的代碼。當(dāng)然你說JIT trace+torch.compile有提升我也不反對,但對于最高效率來說,只有必須要存在的代碼和不存在的代碼兩種。
抱抱臉的Transformers就是一個(gè)反例。兩個(gè)sub-Module就能寫完的TransformerLayer他們硬是能寫出來一堆…偏偏他們還信奉Single Model File Policy……我尋思你這完全不考慮繼承的封這么多層是要搞雞毛啊?正例反而是PyTorch……(笑死,我竟然會夸臉書代碼寫得好。具體來說就是nn.functional當(dāng)中的各種實(shí)現(xiàn)。你會發(fā)現(xiàn)他們第一行往往是handle_torch_func。熟悉Python裝飾器的小伙汁通常要問了,為啥這里不用個(gè)裝飾器統(tǒng)一一下?因?yàn)檠b飾器會引入額外的函數(shù)調(diào)用,額外的函數(shù)調(diào)用就是額外的開銷。
因此,如果你想確保最高的效率,寫一個(gè)簡單的訓(xùn)練代碼和模型代碼非常重要。畢竟,1%的效率提升,節(jié)省的可能是數(shù)百個(gè)GPU日。
如何平穩(wěn)訓(xùn)練
這一段當(dāng)中中咱們只討論你能控制的問題。
捕捉不致命的異常
故障率高的問題其實(shí)很好解決。在訓(xùn)練當(dāng)中,大部分異常都是非致命異常,捉住他們就好了。https://danling.org/utils/decorators/#danling.utils.decorators.catch 是我之前寫的一個(gè)裝飾器,它的作用就是catch異常,然后調(diào)回調(diào)函數(shù)(默認(rèn)當(dāng)然就是把錯(cuò)誤打印到log里)。所有你需要做的只是使用它來裝飾非fatal的操作。
在實(shí)際應(yīng)用當(dāng)中,我們遇到的最常見的問題是存ckpt寫滿了磁盤(不準(zhǔn)笑,從商湯到深勢再到上海AI Lab,這個(gè)問題在哪兒都有出現(xiàn)。咱也不知道為啥肯買那么多顯卡但不肯多插點(diǎn)兒硬盤,咱也不敢問)。catch住所有保存操作,如果你有閑心可以在回調(diào)里刪一下之前的ckpt。沒嫌心的話…大不了重訓(xùn)一次嘛(逃。第二常見的問題,你猜對了……存log寫滿了硬盤……所以所有l(wèi)ogging操作也都是要catch的。這就是為啥我都用tmux然后開很長的緩存窗口,總是能搶救一些log出來的。
咳咳,說點(diǎn)兒正經(jīng)的。任何聯(lián)網(wǎng)操作都是需要catch的,常見的聯(lián)網(wǎng)操作主要包括從ceph讀取數(shù)據(jù)和…寫log到遠(yuǎn)程(逃。其他就沒啥了吧,我見過有大哥嘗試恢復(fù)OOM的,但效果似乎不是很好,至少我自己沒用過。簡單來說,唯一不應(yīng)捕捉的錯(cuò)誤是集群炸了。
那有的大兄弟就說了,集群沒爆炸,但是有兩張卡突然掉了咋辦。這個(gè)咱第三部分再討論。
過程也很重要
有用過[丹靈]http://danling.org的同學(xué)可能比較熟悉。丹靈其他地方都很輕量,唯獨(dú)實(shí)驗(yàn)管理這里寫的很復(fù)雜?,F(xiàn)代丹靈會將創(chuàng)建一個(gè)三個(gè)級別的實(shí)驗(yàn)?zāi)夸?,project/experiment-run/timestamp。其中project是用戶給出的,experiment和run分別是通過代碼版本和配置計(jì)算出來的,timestamp就是運(yùn)行開始的時(shí)間。也就是說,如果代碼和配置是完全一樣的,丹靈就會認(rèn)為這是同一個(gè)運(yùn)行。在設(shè)置中打開auto_resum就會自動找最新的一個(gè)檢查點(diǎn)(這就是為啥最后一級要用時(shí)間戳)來加載。其實(shí)微軟用的amlt更好用,他甚至還會創(chuàng)建一個(gè)代碼的diff文件夾來幫助你回憶當(dāng)初代碼修改了些啥。
收斂,收斂,收斂
模型訓(xùn)著訓(xùn)著發(fā)散了幾乎是每個(gè)訓(xùn)大模型的人都會遇到的問題。輸出和loss只要有nan果斷丟掉。梯度先clip by value再clip by norm都是常規(guī)操作。哦對了,還有初始化……關(guān)于大模型收斂性的論文有一堆,此處不再贅述。
比更大,還更大,再更大
彈性訓(xùn)練
實(shí)際上當(dāng)你的訓(xùn)練超過2048個(gè)GPU日時(shí),在整個(gè)訓(xùn)練過程當(dāng)中發(fā)生單個(gè)GPU甚至單個(gè)節(jié)點(diǎn)下線是再正常不過的事情了。
PyTorch在1.10就引入了torchelastic彈性訓(xùn)練機(jī)制,用過的都罵娘。等下,讓我先罵一遍,呸。ok咱們繼續(xù)吧。
我印象當(dāng)中在微軟的最后一輪面試當(dāng)中被問到了這個(gè)問題:如何設(shè)計(jì)一個(gè)彈性分布式系統(tǒng)。
我的回答很教科書。每k分鐘,系統(tǒng)會做一次AllReduce來統(tǒng)計(jì)存活進(jìn)程數(shù),然后選舉出一個(gè)主進(jìn)程。主進(jìn)程會計(jì)算好每個(gè)進(jìn)程的rank和local rank進(jìn)行broadcast。所有進(jìn)程每次forward開始時(shí)向主進(jìn)程發(fā)送一個(gè)心跳包來匯報(bào)狀態(tài)。主進(jìn)程會根據(jù)心跳包來確定這一個(gè)step參與同步的機(jī)器有多少。但很可惜,2024年了。還是沒人去寫。
大小梯度同步
我一直認(rèn)為梯度同步不應(yīng)該以GPU/進(jìn)程為單位。而應(yīng)該分為大同步(節(jié)點(diǎn)間同步)和小同步(節(jié)點(diǎn)內(nèi)同步)。小同步可以更高頻的進(jìn)行,大同步則可以更慢的執(zhí)行。這樣不僅能提高實(shí)際的梯度同步頻率,降低同步總耗時(shí),并且還能天然的去結(jié)合小batch和大batch訓(xùn)練的優(yōu)點(diǎn)—節(jié)點(diǎn)內(nèi)小batch關(guān)注個(gè)體,節(jié)點(diǎn)間大batch關(guān)注整體。