DeepSeek開源第三彈:V3/R1訓(xùn)練推理關(guān)鍵秘籍,核心代碼僅300行
開源周的第三天,DeepSeek把訓(xùn)練推理V3/R1背后的“動(dòng)力”給亮出來了——
DeepGEMM:一個(gè)FP8 GEMM(通用矩陣乘法)庫(kù),支持密集(dense)和混合專家(MoE)矩陣乘法運(yùn)算。
圖片
我們先來簡(jiǎn)單了解一下GEMM。
GEMM,即通用矩陣乘法,是線性代數(shù)中的基本運(yùn)算,是科學(xué)計(jì)算、機(jī)器學(xué)習(xí)、深度學(xué)習(xí)等領(lǐng)域中“??汀保彩窃S多高性能計(jì)算任務(wù)的核心。
但由于它的計(jì)算量往往都比較大,所以GEMM的性能優(yōu)化是至關(guān)重要的一點(diǎn)。
而DeepSeek這次開源的DeepGEMM,依舊是保持了“高性能+低成本”的特性,亮點(diǎn)如下:
- 高性能:在Hopper架構(gòu)的GPU上,DeepGEMM能夠?qū)崿F(xiàn)高達(dá)1350+FP8 TFLOPS的性能。
- 簡(jiǎn)潔性:核心邏輯僅約 300 行代碼,但性能卻優(yōu)于專家調(diào)優(yōu)的內(nèi)核。
- 即時(shí)編譯(JIT):采用完全即時(shí)編譯的方式,這意味著它可以在運(yùn)行時(shí)動(dòng)態(tài)生成優(yōu)化的代碼,從而適應(yīng)不同的硬件和矩陣大小。
- 無重依賴:這個(gè)庫(kù)設(shè)計(jì)得非常輕量級(jí),沒有復(fù)雜的依賴關(guān)系,可以讓部署和使用變得簡(jiǎn)單。
- 支持多種矩陣布局:支持密集矩陣布局和兩種 MoE 布局,這使得它能夠適應(yīng)不同的應(yīng)用場(chǎng)景,包括但不限于深度學(xué)習(xí)中的混合專家模型。
簡(jiǎn)單來說,DeepGEMM主要用于加速深度學(xué)習(xí)中的矩陣運(yùn)算,特別是在大規(guī)模模型訓(xùn)練和推理中,它特別適用于需要高效計(jì)算資源的場(chǎng)景,能夠顯著提升計(jì)算效率。
很多網(wǎng)友們對(duì)這次的開源都比較“買單”,有人將DeepGEMM比作數(shù)學(xué)界的超級(jí)英雄,認(rèn)為它比飛快的計(jì)算器還要快,比多項(xiàng)式方程還要強(qiáng)大。
圖片
也有人將DeepGEMM的發(fā)布比喻為量子態(tài)穩(wěn)定到一個(gè)新的現(xiàn)實(shí),稱贊其即時(shí)編譯的干凈利落。
圖片
當(dāng)然……也有人開始擔(dān)心起自己手上的英偉達(dá)股票了……
圖片
深入了解DeepGEMM
DeepGEMM是一個(gè)專門為實(shí)現(xiàn)簡(jiǎn)潔高效的FP8通用矩陣乘法(GEMMs)而打造的庫(kù),它還具備細(xì)粒度縮放功能,這一設(shè)計(jì)源于DeepSeek V3。
它既能處理普通的通用矩陣乘法,也能支持MoE分組的通用矩陣乘法。
這個(gè)庫(kù)是用CUDA編寫的,安裝的時(shí)候不需要編譯,因?yàn)樗鼤?huì)在運(yùn)行時(shí)通過一個(gè)輕量級(jí)的即時(shí)編譯(JIT)模塊來編譯所有的內(nèi)核程序。
目前,DeepGEMM只支持英偉達(dá)的Hopper張量核心。
為了解決FP8張量核心在計(jì)算累積時(shí)不夠精確的問題,它采用了CUDA核心的兩級(jí)累積(提升)方法。
雖然DeepGEMM借鑒了CUTLASS和CuTe里的一些理念,但并沒有過度依賴它們的模板或代數(shù)運(yùn)算。
相反,這個(gè)庫(kù)設(shè)計(jì)得很簡(jiǎn)潔,只有一個(gè)核心內(nèi)核函數(shù),代碼量大概300行左右。
這使得它成為一個(gè)簡(jiǎn)潔易懂的資源,方便大家學(xué)習(xí)Hopper架構(gòu)下的FP8矩陣乘法和優(yōu)化技術(shù)。
盡管其設(shè)計(jì)輕巧,但DeepGEMM的性能可以匹配或超過各種矩陣形狀的專家調(diào)優(yōu)庫(kù)。
那么具體性能如何呢?
團(tuán)隊(duì)在H800上使用NVCC 12.8測(cè)試了DeepSeek-V3/R1推理中可能使用的所有形狀(包括預(yù)填充和解碼,但沒有張量并行)。
下面這張圖展示的是用于密集模型的普通DeepGEMM的性能:
圖片
從測(cè)試結(jié)果來看,DeepGEMM計(jì)算性能最高可達(dá)1358 TFLOPS,內(nèi)存寬帶最高可達(dá)2668 GB/s。
加速比方面,與基于CUTLASS 3.6的優(yōu)化實(shí)現(xiàn)相比,最高可達(dá)2.7倍。
再來看下DeepGEMM支持MoE模型的連續(xù)布局(contiguous layout)的性能:
圖片
以及支持MoE模型掩碼布局(masked layout)的性能是這樣的:
圖片
如何使用?
要想使用DeepGEMM,需先注意一下幾個(gè)依賴項(xiàng),包括:
- 必須支持Hopper架構(gòu)的GPU,sm_90a。
- Python 3.8及以上。
- CUDA 12.3及以上(推薦12.8)。
- PyTorch 2.1及以上。
- CUTLASS 3.6及以上
Development代碼如下:
# Submodule must be cloned
git clone --recursive git@github.com:deepseek-ai/DeepGEMM.git
# Make symbolic links for third-party (CUTLASS and CuTe) include directories
python setup.py develop
# Test JIT compilation
python tests/test_jit.py
# Test all GEMM implements (normal, contiguous-grouped and masked-grouped)
python tests/test_core.py
安裝代碼如下:
python setup.py install
在上述步驟之后,您的Python項(xiàng)目中導(dǎo)入deep_gemm即可。
在接口方面,對(duì)于普通的DeepGEMM,可調(diào)用deep_gemm.gemm_fp8_fp8_bf16_nt函數(shù),支持NT格式(非轉(zhuǎn)置LHS和轉(zhuǎn)置RHS)。
對(duì)于分組的DeepGEMM,連續(xù)布局情況下是m_grouped_gemm_fp8_fp8_bf16_nt_contiguous;掩碼布局情況下是m_grouped_gemm_fp8_fp8_bf16_nt_masked。
DeepGEMM還提供設(shè)置最大SM數(shù)量、獲取TMA對(duì)齊大小等工具函數(shù);支持環(huán)境變量,如DG_NVCC_COMPILER、DG_JIT_DEBUG等。
除此之外,DeepSeek團(tuán)隊(duì)還提供了幾種優(yōu)化的方式,包括:
- JIT設(shè)計(jì):所有內(nèi)核在運(yùn)行時(shí)編譯,無需安裝時(shí)編譯;支持動(dòng)態(tài)選擇最優(yōu)塊大小和流水線階段。
- 細(xì)粒度縮放:通過CUDA核心兩層累加解決FP8精度問題;支持非2的冪次方塊大小,優(yōu)化SM利用率。
- FFMA SASS交錯(cuò):通過修改SASS指令的yield和reuse位,提高性能。
圖片
感興趣的小伙伴可以戳文末GitHub鏈接查看詳情哦~
One More Thing
英偉達(dá)這幾天的股票……嗯……一直再跌:
圖片
不過在北京時(shí)間27日凌晨,英偉達(dá)2025財(cái)年第四季度業(yè)績(jī)報(bào)告也即將出爐,我們可以期待一下它的表現(xiàn)~