用什么tricks能讓模型訓(xùn)練得更快?先了解下這個(gè)問題的第一性原理
每個(gè)人都想讓模型訓(xùn)練得更快,但是你真的找對(duì)方法了嗎?在康奈爾大學(xué)本科生、曾在 PyTorch 團(tuán)隊(duì)實(shí)習(xí)的 Horace He 看來,這個(gè)問題應(yīng)該分幾步解決:首先,你要知道為什么你的訓(xùn)練會(huì)慢,也就是說瓶頸在哪兒,其次才是尋找對(duì)應(yīng)的解決辦法。在沒有了解基本原理(第一性原理)之前就胡亂嘗試是一種浪費(fèi)時(shí)間的行為。
在這篇文章中,Horace He 從三個(gè)角度分析可能存在的瓶頸:計(jì)算、內(nèi)存帶寬和額外開銷,并提供了一些方式去判斷當(dāng)前處于哪一個(gè)瓶頸,有助于我們更加有針對(duì)性地加速系統(tǒng)。這篇文章得到了陳天奇等多位資深研究者、開發(fā)者的贊賞。
以下是原文內(nèi)容:
怎樣才能提高深度學(xué)習(xí)模型的性能?一般人都會(huì)選擇網(wǎng)上博客中總結(jié)的一些隨機(jī)技巧,比如「使用系統(tǒng)內(nèi)置的運(yùn)算算子,把梯度設(shè)置為 0,使用 PyTorch1.10.0 版本而不是 1.10.1 版本……」
在這一領(lǐng)域,當(dāng)代(特別是深度學(xué)習(xí))系統(tǒng)給人的感覺不像是科學(xué),反而更像煉丹,因此不難理解用戶為什么傾向于采用這種隨機(jī)的方法。即便如此,這一領(lǐng)域也有些第一性原理可以遵循,我們可以據(jù)此排除大量方法,從而使得問題更加容易解決。
比如,如果你的訓(xùn)練損失遠(yuǎn)低于測(cè)試損失,那么你可能遇到了「過擬合」問題,而嘗試著增加模型容量就是在浪費(fèi)時(shí)間。再比如,如果你的訓(xùn)練損失和你的驗(yàn)證損失是一致的,那對(duì)模型正則化就顯得不明智了。
類似地,你也可以把高效深度學(xué)習(xí)的問題劃分為以下三個(gè)不同的組成部分:
- 計(jì)算:GPU 計(jì)算實(shí)際浮點(diǎn)運(yùn)算(FLOPS)所花費(fèi)的時(shí)間;
- 內(nèi)存:在 GPU 內(nèi)傳輸張量所花費(fèi)的時(shí)間;
- 額外開銷:花在其它部分的時(shí)間。
在訓(xùn)練機(jī)器學(xué)習(xí)模型的時(shí)候,知道你遇到的是哪類問題非常關(guān)鍵,使模型高效的問題也是如此。例如,當(dāng)模型花費(fèi)大量時(shí)間進(jìn)行內(nèi)存到 GPU 的轉(zhuǎn)移的時(shí)候(也就是內(nèi)存帶寬緊張的時(shí)候),增加 GPU 的 FLOPS 就不管用。另一方面,如果你正在運(yùn)行大量的矩陣乘法運(yùn)算(也就是計(jì)算緊張的時(shí)候),將你的程序重寫成 C++ 去減輕額外開銷就不會(huì)管用。
所以,如果你想讓 GPU 絲滑運(yùn)行,以上三個(gè)方面的討論和研究就是必不可少的。
慘痛教訓(xùn)的背后有大量工程師保持 GPU 高效運(yùn)行。
注意:這個(gè)博客中的大多數(shù)內(nèi)容是基于 GPU 和 PyTorch 舉例子的,但這些原則基本是跨硬件和跨框架通用的。
計(jì)算
優(yōu)化深度學(xué)習(xí)系統(tǒng)的一個(gè)方面在于我們想要最大化用于計(jì)算的時(shí)間。你花錢買了 312 萬億次浮點(diǎn)數(shù)運(yùn)算,那你肯定希望這些都能用到計(jì)算上。但是,為了讓你的錢從你昂貴的矩陣乘法中得到回報(bào),你需要減少花費(fèi)在其他部分的時(shí)間。
但為什么這里的重點(diǎn)是最大化計(jì)算,而不是最大化內(nèi)存的帶寬?原因很簡單 —— 你可以減少額外開銷或者內(nèi)存消耗,但如果不去改變真正的運(yùn)算,你幾乎無法減少計(jì)算量。
與內(nèi)存帶寬相比,計(jì)算的增長速度增加了最大化計(jì)算利用率的難度。下表顯示了 CPU 的 FLOPS 翻倍和內(nèi)存帶寬翻倍的時(shí)間 (重點(diǎn)關(guān)注黃色一欄)。
一種理解計(jì)算的方式是把它想象成工廠。我們把指令傳達(dá)給我們的工廠(額外消耗),把原始材料送給它(內(nèi)存帶寬),所有這些都是為了讓工廠運(yùn)行得更加高效(計(jì)算)。
所以,如果工廠容量擴(kuò)展的速度高于我們提供給它原材料的速度,它就很難達(dá)到一個(gè)頂峰效率。
即使我們工廠容量(FLOP)翻倍,但帶寬跟不上,我們的性能也不能翻倍。
關(guān)于 FLOPS 還有一點(diǎn)要說,越來越多的機(jī)器學(xué)習(xí)加速器都有專門針對(duì)矩陣乘法的硬件配置,例如英偉達(dá)的「Tensor Cores」。
所以,你要是不做矩陣乘法的話,你只能達(dá)到 19.5 萬億次運(yùn)算,而不是 312 萬億次。注意,并不是只有 GPU 這么特殊,事實(shí)上 TPU 是比 GPU 更加專門化的計(jì)算模塊。
除了矩陣乘法以外,GPU 處理其他運(yùn)算時(shí)都比較慢,這一現(xiàn)象乍看上去似乎有問題:比如像是層歸一化或者激活函數(shù)的其它算子怎么辦呢?事實(shí)上,這些算子在 FLOPS 上僅僅像是矩陣乘法的舍入誤差一樣。例如,看看下表對(duì)于 BERT 中的不同算子類型占用的 FLOP 數(shù),其中的「Tensor Contraction」就是指矩陣乘法。
可以看到,非矩陣乘法運(yùn)算僅僅占所有運(yùn)算的 0.2%,所以即使它們的速度僅為矩陣乘法的 1/15 也沒什么問題。
事實(shí)上,歸一化運(yùn)算和逐點(diǎn)(pointwise)運(yùn)算使用的 FLOPS 僅為矩陣乘法的 1/250 和 1/700。那為什么非矩陣乘法運(yùn)算會(huì)遠(yuǎn)比它們應(yīng)該使用的運(yùn)行時(shí)間更多呢?
回到前文「工廠」的類比,罪魁禍?zhǔn)捉?jīng)常還是如何將原始材料運(yùn)到以及運(yùn)出工廠,換句話說,也就是「內(nèi)存帶寬」。
帶寬
帶寬消耗本質(zhì)上是把數(shù)據(jù)從一個(gè)地方運(yùn)送到另一個(gè)地方的花費(fèi),這可能是指把數(shù)據(jù)從 CPU 移動(dòng)到 GPU,從一個(gè)節(jié)點(diǎn)移動(dòng)到另一個(gè)節(jié)點(diǎn),甚至從 CUDA 的全局內(nèi)存移動(dòng)到 CUDA 的共享內(nèi)存。最后一個(gè)是本文討論的重點(diǎn),我們一般稱其為「帶寬消耗」或者「內(nèi)存帶寬消耗」。前兩者一般叫「數(shù)據(jù)運(yùn)輸消耗」或者「網(wǎng)絡(luò)消耗」,不在本文敘述范圍之內(nèi)。
還是回到「工廠」的類比。雖然我們?cè)诠S中從事實(shí)際的工作,但它并不適合大規(guī)模的存儲(chǔ)。我們要保證它的存儲(chǔ)是足夠高效的,并且能夠很快去使用(SRAM),而不是以量取勝。
那么我們?cè)谀睦锎鎯?chǔ)實(shí)際的結(jié)果和「原材料」呢?一般我們要有一個(gè)倉庫,那兒的地足夠便宜,并且有大量的空間(DRAM)。之后我們就可以在它和工廠之間運(yùn)送東西了(內(nèi)存帶寬)。
這種在計(jì)算單元之間移動(dòng)?xùn)|西的成本就是所謂的「內(nèi)存帶寬」成本。事實(shí)上,nvidia-smi 命令中出現(xiàn)的那個(gè)「內(nèi)存」就是 DRAM,而經(jīng)常讓人抓狂的「CUDA out of memory」說的就是這個(gè) DRAM。
值得注意的是:我們每執(zhí)行一次 GPU 核運(yùn)算都需要把數(shù)據(jù)運(yùn)出和運(yùn)回到我們的倉庫 ——DRAM。
現(xiàn)在想象一下,當(dāng)我們執(zhí)行一個(gè)一元運(yùn)算(如 torch.cos)的時(shí)候,我們需要把數(shù)據(jù)從倉庫(DRAM)運(yùn)送到工廠(SRAM),然后在工廠中執(zhí)行一小步計(jì)算,之后再把結(jié)果運(yùn)送回倉庫。運(yùn)輸是相當(dāng)耗時(shí)的,這種情況下,我們幾乎把所有的時(shí)間都花在了運(yùn)輸數(shù)據(jù),而不是真正的計(jì)算上。
因?yàn)槲覀冋阉械臅r(shí)間都花費(fèi)在內(nèi)存帶寬上,這種運(yùn)算也被稱作內(nèi)存限制運(yùn)算(memory-bound operation),它意味著我們沒有把大量時(shí)間花費(fèi)在計(jì)算上。
顯然,這并不是我們想要的。那我們能做什么呢?讓我們來看看算子序列長什么樣子。
一個(gè)逐點(diǎn)算子序列可能的樣子。
在全局內(nèi)存和計(jì)算單元之間來回傳輸數(shù)據(jù)的做法顯然不是最佳的。一種更優(yōu)的方式是:在數(shù)據(jù)工廠中一次性執(zhí)行完全部運(yùn)算再把數(shù)據(jù)傳回。
這就是算子融合(operator fusion)—— 深度學(xué)習(xí)編譯器中最重要的優(yōu)化。簡單地說,這種方法不會(huì)為了再次讀取而將數(shù)據(jù)寫入全局內(nèi)存,而是通過一次執(zhí)行多個(gè)計(jì)算來避免額外的內(nèi)存訪問。
例如,執(zhí)行 x.cos ().cos () 運(yùn)算,寫入內(nèi)存的方式需要 4 次全局讀寫。
x1 = x.cos() # Read from x in global memory, write to x1
x2 = x1.cos() # Read from x1 in global memory, write to x2
而算子融合只需要 2 次全局內(nèi)存讀寫,這樣就實(shí)現(xiàn)了 2 倍加速。
x2 = x.cos().cos() # Read from x in global memory, write to x2
但是這種做法也并不容易,需要一些條件。首先,GPU 需要知道執(zhí)行完當(dāng)前運(yùn)算后下一步會(huì)發(fā)生什么,因此無法在 PyTorch 的 Eager 模式(一次運(yùn)行一個(gè)運(yùn)算符)下進(jìn)行此優(yōu)化。其次,我們需要編寫 CUDA 代碼,這也不是一件簡單的事。
并不是所有的算子融合都像逐點(diǎn)算子那樣簡單。你可以將逐點(diǎn)算子融合到歸約(reduction)或矩陣乘法上。甚至矩陣乘法本身也可以被認(rèn)為是一種融合了廣播乘法(broadcasting multiply)和歸約的運(yùn)算。
任何 2 個(gè) PyTorch 算子都可以被融合,從而節(jié)省了讀取 / 寫入全局內(nèi)存的內(nèi)存帶寬成本。此外,許多現(xiàn)有編譯器通??梢詧?zhí)行「簡單」的融合(例如 NVFuser 和 XLA)。然而,更復(fù)雜的融合仍然需要人們手動(dòng)編寫,因此如果你想嘗試自己編寫自定義 CUDA 內(nèi)核,Triton 是一個(gè)很好的起點(diǎn)。
令人驚訝的是,融合后的 x.cos ().cos () 運(yùn)算將花費(fèi)幾乎與單獨(dú)調(diào)用 x.cos () 相同的時(shí)間。這就是為什么激活函數(shù)的成本幾乎是一樣的,盡管 gelu 顯然比 relu 包含更多的運(yùn)算。
因此,重新實(shí)現(xiàn) / 激活檢查點(diǎn)會(huì)產(chǎn)生一些有趣的結(jié)果。從本質(zhì)上講,進(jìn)行額外的重新計(jì)算可能會(huì)導(dǎo)致更少的內(nèi)存帶寬,從而減少運(yùn)行時(shí)間。因此,我們可以通過重新實(shí)現(xiàn)來減少內(nèi)存占用和運(yùn)行時(shí)間,并在 AOTAutograd 中構(gòu)建一個(gè)簡潔的 min-cut 優(yōu)化通道。
推理內(nèi)存帶寬成本
對(duì)于簡單的運(yùn)算,直接推理內(nèi)存帶寬是可行的。例如,A100 具有 1.5 TB / 秒的全局內(nèi)存帶寬,可以執(zhí)行 19.5 teraflops / 秒的計(jì)算。因此,如果使用 32 位浮點(diǎn)數(shù)(即 4 字節(jié)),你可以在 GPU 執(zhí)行 20 萬億次運(yùn)算的同時(shí)加載 4000 億個(gè)數(shù)字。
此外,執(zhí)行簡單的一元運(yùn)算(例如將張量 x2)實(shí)際上需要將張量寫回全局內(nèi)存。
因此直到執(zhí)行大約一百個(gè)一元運(yùn)算之前,更多的時(shí)間是花在了內(nèi)存訪問而不是實(shí)際計(jì)算上。
如果你執(zhí)行下面這個(gè) PyTorch 函數(shù):
def f(x: Tensor[N]):
for _ in range(repeat):
x = x * 2
return x
并使用融合編譯器對(duì)其進(jìn)行基準(zhǔn)測(cè)試,就可以計(jì)算每個(gè) repeat 值的 FLOPS 和內(nèi)存帶寬。增大 repeat 值是在不增加內(nèi)存訪問的情況下增加計(jì)算量的簡單方法 - 這也稱為增加計(jì)算強(qiáng)度 (compute intensity)。
具體來說,假設(shè)我們對(duì)這段代碼進(jìn)行基準(zhǔn)測(cè)試,首先要找出每秒執(zhí)行的迭代次數(shù);然后執(zhí)行 2N(N 是張量大?。┐蝺?nèi)存訪問和 N *repeat FLOP。因此,內(nèi)存帶寬將是 bytes_per_elem * 2 * N /itrs_per_second,而 FLOPS 是 N * repeat /itrs_per_second。
現(xiàn)在,讓我們繪制計(jì)算強(qiáng)度的 3 個(gè)函數(shù)圖象:運(yùn)行時(shí)間、flops 和內(nèi)存帶寬。
請(qǐng)注意,在執(zhí)行 64 次乘法之前,運(yùn)行時(shí)間根本不會(huì)顯著增加。這意味著在此之前主要受內(nèi)存帶寬的限制,而計(jì)算大多處于空閑狀態(tài)。
一開始 FLOPS 的值是 0.2 teraflops。當(dāng)我們將計(jì)算強(qiáng)度加倍時(shí),這個(gè)數(shù)字會(huì)線性增長,直到接近 9.75 teraflops 的峰值,一旦接近峰值 teraflops 就被認(rèn)為是「計(jì)算受限的」。
最后,可以看到內(nèi)存帶寬從峰值附近開始,隨著我們?cè)黾佑?jì)算強(qiáng)度開始下降。這正是我們所期待的,因?yàn)檫@說明執(zhí)行實(shí)際計(jì)算的時(shí)間越來越多,而不是訪問內(nèi)存。
在這種情況下,很容易看出何時(shí)受計(jì)算限制以及何時(shí)受內(nèi)存限制。repeat< 32 時(shí),內(nèi)存帶寬接近飽和,而未進(jìn)行充分的計(jì)算;repeat> 64 時(shí),計(jì)算接近飽和(即接近峰值 FLOPS),而內(nèi)存帶寬開始下降。
對(duì)于較大的系統(tǒng),通常很難說是受計(jì)算限制還是內(nèi)存帶寬限制,因?yàn)樗鼈兺ǔ0?jì)算限制和內(nèi)存限制兩方面的綜合原因。衡量計(jì)算受限程度的一種常用方法是計(jì)算實(shí)際 FLOPS 與峰值 FLOPS 的百分比。
然而,除了內(nèi)存帶寬成本之外,還有一件事可能會(huì)導(dǎo)致 GPU 無法絲滑運(yùn)行。
額外開銷
當(dāng)代碼把時(shí)間花費(fèi)在傳輸張量或計(jì)算之外的其他事情上時(shí),額外開銷(overhead)就產(chǎn)生了,例如在 Python 解釋器中花費(fèi)的時(shí)間、在 PyTorch 框架上花費(fèi)的時(shí)間、啟動(dòng) CUDA 內(nèi)核(但不執(zhí)行)所花費(fèi)的時(shí)間, 這些都是間接開銷。
額外開銷顯得重要的原因是現(xiàn)代 GPU 的運(yùn)算速度非??臁100 每秒可以執(zhí)行 312 萬億次浮點(diǎn)運(yùn)算(312TeraFLOPS)。相比之下 Python 實(shí)在是太慢了 ——Python 在一秒內(nèi)約執(zhí)行 3200 萬次加法。
這意味著 Python 執(zhí)行單次 FLOP 的時(shí)間,A100 可能已經(jīng)運(yùn)行了 975 萬次 FLOPS。
更糟糕的是,Python 解釋器甚至不是唯一的間接開銷來源,像 PyTorch 這樣的框架到達(dá) actual kernel 之前也有很多層調(diào)度。PyTorch 每秒大約能執(zhí)行 28 萬次運(yùn)算。如果使用微型張量(例如用于科學(xué)計(jì)算),你可能會(huì)發(fā)現(xiàn) PyTorch 與 C++ 相比非常慢。
例如在下圖中,使用 PyTorch 執(zhí)行單次添加,僅有一小塊圖是實(shí)際執(zhí)行計(jì)算的內(nèi)容,其他的部分都是純粹的額外開銷。
鑒于此,你可能會(huì)對(duì) PyTorch 成為主流框架的現(xiàn)象感到不解,而這是因?yàn)楝F(xiàn)代深度學(xué)習(xí)模型通常執(zhí)行大規(guī)模運(yùn)算。此外,像 PyTorch 這樣的框架是異步執(zhí)行的。因此,大部分框架開銷可以完全忽略。
如果我們的 GPU 算子足夠大,那么 CPU 可以跑在 GPU 之前(因此 CPU 開銷是無關(guān)緊要的)。另一方面,如果 GPU 算子太小,那么 GPU 將在 paperweight 上花費(fèi)大部分時(shí)間。
那么,如何判斷你是否處于這個(gè)問題中?由于額外開銷通常不會(huì)隨著問題的規(guī)模變化而變化(而計(jì)算和內(nèi)存會(huì)),所以最簡單的判斷方法是簡單地增加數(shù)據(jù)的大小。如果運(yùn)行時(shí)間不是按比例增加,應(yīng)該可以說遇到了開銷限制。例如,如果將批大小翻倍,但運(yùn)行時(shí)間僅增加 10%,則可能會(huì)受到開銷限制。
另一種方法是使用 PyTorch 分析器。如下圖,粉紅色塊顯示了 CPU 內(nèi)核與 GPU 內(nèi)核的匹配情況。
CPU 運(yùn)行地比 GPU 更超前。
另一方面,nvidia-smi 中的「GPU-Util」(不是「Volatile GPU-Util」)入口會(huì)測(cè)量實(shí)際運(yùn)行的 GPU 內(nèi)核的百分占比,所以這是另一種觀察是否遇到開銷限制的好方法。這種開銷是 PyTorch 等所有靈活的框架所具有的,本質(zhì)上都需要花費(fèi)大量時(shí)間來「弄清楚要做什么」。
這可能來自 Python(查找屬性或調(diào)度到正確的函數(shù))或 PyTorch 中的代碼。例如,當(dāng)你執(zhí)行 a + b 時(shí),需要執(zhí)行以下步驟:
- Python 需要在 a 上查找__add__調(diào)度到的內(nèi)容。
- PyTorch 需要確定張量的很多屬性(比如 dtype、device、是否需要 autograd)來決定調(diào)用哪個(gè)內(nèi)核。
- PyTorch 需要實(shí)際啟動(dòng)內(nèi)核。
從根本上說,這種開銷來自能夠在每個(gè)步驟中執(zhí)行不同運(yùn)算的靈活性。如果不需要這種靈活性,解決這種靈活性的一種方法是跟蹤它,例如使用 jit.trace、FX 或 jax.jit?;蛘?,可以換用 CUDA Graphs 之類的東西在更低的級(jí)別上執(zhí)行此運(yùn)算。
不幸的是,這是以失去靈活性為代價(jià)的。一種兩全其美的方法是,通過在 VM 級(jí)別進(jìn)行 introspect 來編寫更多符合「真實(shí)」的 JIT 的內(nèi)容。有關(guān)更多信息,可參閱 TorchDynamo (https://dev-discuss.pytorch.org/t/torchdynamo-an-experiment-in-dynamic-python-bytecode-transformation/361)。
總結(jié)
如果你想加速深度學(xué)習(xí)系統(tǒng),最重要的是了解模型中的瓶頸是什么,因?yàn)槠款i決定了適合加速該系統(tǒng)的方法是什么。
很多時(shí)候,我看到研究人員和其他對(duì)加速 PyTorch 代碼感興趣的人,會(huì)在不了解所處問題的情況下盲目嘗試。
當(dāng)然,另一方面,如果用戶需要考慮這些東西,也反映了框架的部分失敗。盡管 PyTorch 是一個(gè)活躍的關(guān)注領(lǐng)域,但 PyTorch 的編譯器或配置文件 API 并不是最容易使用的。
總而言之,我發(fā)現(xiàn)對(duì)系統(tǒng)基本原理的理解幾乎總是有用的,希望這對(duì)你也有用。