突破百萬億參數(shù)規(guī)模,追求極致的效率和性價比:華人團(tuán)隊開源首個異構(gòu)并行推薦系統(tǒng)訓(xùn)練框架Persia
?個性化推薦是互聯(lián)網(wǎng)行業(yè)提升 DAU (Daily Active Users)和收入的核心技術(shù)手段。隨著深度學(xué)習(xí)的廣泛應(yīng)用,現(xiàn)代的推薦系統(tǒng)通過神經(jīng)網(wǎng)絡(luò)變相地「記住」用戶的行為習(xí)慣,從而精準(zhǔn)預(yù)測出用戶的喜好。在移動互聯(lián)網(wǎng)普及之后,用戶的行為數(shù)據(jù)呈現(xiàn)幾何級數(shù)增加,單位時間內(nèi)產(chǎn)生和收集的用戶行為數(shù)據(jù)更是極其龐大,因此需要更大的模型來對用戶的興趣編碼。
更大的數(shù)據(jù)規(guī)模意味著需要更大的模型容量,模型參數(shù)量從 5 年前的十億已經(jīng)迅速增長達(dá)到前段時間 Facebook 公開的十萬億參數(shù)規(guī)模。在這樣的趨勢下,更大規(guī)模的訓(xùn)練需求無疑將會成為下一個需要攻克的里程碑。
最近,由兩個華人團(tuán)隊聯(lián)合開源的訓(xùn)練框架 Persia 通過設(shè)計混合架構(gòu)并在 Google cloud 上成功地把模型規(guī)模又推向了一個新的量級 -- 百萬億參數(shù)量(需占用數(shù)百 T 的存儲),并能同時兼顧效率和精度。目前該框架已經(jīng)受邀集成進(jìn) Pytorch 生態(tài)圈 Pytorch Lightning。
- GitHub: https://github.com/PersiaML/PERSIA
- 論文: https://arxiv.org/abs/2111.05897
隨著模型的參數(shù)量隨著指數(shù)級別的增長,對于高性能的訓(xùn)練框架的需求也越來越迫切。傳統(tǒng)架構(gòu)在應(yīng)對越來越多參數(shù)量面前也顯得越來越力不從心。傳統(tǒng)架構(gòu)采用 CPU 的同構(gòu)并行機(jī)制,對應(yīng)的參數(shù)分布采用模型并行。
其最大的優(yōu)點是便于機(jī)器的水平擴(kuò)展以支持相對更大的模型,因此至今仍然在很多公司廣泛使用(雖然在各個公司有不同的命名方式,本文統(tǒng)稱為 mio 架構(gòu))。當(dāng)推薦模型從傳統(tǒng)的 Logistic regression 升級到基于 Deep Learning 的模型,且參數(shù)量急劇增大的時候,傳統(tǒng)的方案就顯得捉襟見肘 ---- 效率低下且難以兼顧精度。后來的進(jìn)化版通過引入 GPU 承擔(dān)了深度網(wǎng)絡(luò)部分的計算(本文稱為 mio + 架構(gòu)),由于采用的仍然是同構(gòu)的設(shè)計思路(只是把CPU機(jī)器換成了CPU帶GPU的機(jī)器),雖然能取得一定效率提升,部分緩解了效率和精度的矛盾,但是當(dāng)需要應(yīng)對不同規(guī)模的網(wǎng)絡(luò)結(jié)構(gòu)的時候,往往出現(xiàn)昂貴的 GPU 資源大量空閑的情況,導(dǎo)致性價比受損嚴(yán)重。
為了解決這兩個因模型規(guī)模不斷膨脹而帶來的難題,Persia 的核心設(shè)計思路如下:
- 采用異構(gòu)的架構(gòu)設(shè)計解決 GPU 資源利用率的問題。當(dāng) CPU 和 GPU 的配比綁定的時候,任何框架都難以同時保證在任何模型結(jié)構(gòu)下的資源利用率。因此 Persia 設(shè)計了一種靈活的異構(gòu)架構(gòu)來實現(xiàn)按模型需求分配資源,保證效率的前提下資源的充分利用,大幅提升了性價比。
- 采用同步和異步的混合訓(xùn)練模式同時兼顧效率和精度。傳統(tǒng)的方案中或是采用純同步的訓(xùn)練,或是采用純異步的訓(xùn)練。在模型越來越大、機(jī)器數(shù)量越來越多的情況下,同步的訓(xùn)練會導(dǎo)致機(jī)器之間相互等待,訓(xùn)練效率容易受損嚴(yán)重。而異步的訓(xùn)練方式雖然避免了機(jī)器之間相互等待,訓(xùn)練的效率顯著提高,但是隨著機(jī)器數(shù)量增加,模型的準(zhǔn)確率(Accuracy)會大幅下降。針對超大模型情景下這樣的挑戰(zhàn),Persia 設(shè)計了一種同步和異步 Hybrid 訓(xùn)練架構(gòu),集二者之長而避其短。并從理論和實踐兩個維度都驗證了 Persia 能同時達(dá)到同步訓(xùn)練的準(zhǔn)確率和異步的訓(xùn)練的效率。
這里簡要列出幾點 Persia 的特點:
- 原生支持 PyTorch 生態(tài):鑒于 PyTorch 極大地降低了研究人員定義模型的門檻,趨勢上在整個深度學(xué)習(xí)領(lǐng)域的占比越來越大,有別于已有的推薦訓(xùn)練框架(如 XDL,PaddlePaddle 等),Persia 決定基于 PyTorch 生態(tài)。用戶模型定義等操作可直接借助 PyTorch 實現(xiàn),因而即便是在研究領(lǐng)域最新最前沿模型(如 Transformer 等)也可直接調(diào)用,達(dá)到最大限度的靈活性與易用性。
- 高性能:在 Criteo 標(biāo)準(zhǔn)數(shù)據(jù)集上,相較其他流行的開源推薦模型訓(xùn)練框架,同樣資源條件下 Persia 可達(dá)到一倍以上的性能提升。Persia 支持 CPU-GPU 異構(gòu)訓(xùn)練,支持 GPU 與 GPU 直接通訊,顯著降低訓(xùn)練成本。
- 可擴(kuò)展性:Persia 在高達(dá) 100 萬億模型參數(shù)訓(xùn)練的 scale 下保持高訓(xùn)練效率。同時在多數(shù)場景能夠接近線性加速(投入 n 倍的資源量,訓(xùn)練效率提升接近 n 倍)。
- 工業(yè)級場景大規(guī)模驗證:Persia 為 Kubernetes 實現(xiàn)了定制化的 operator,支持云原生部署。并實現(xiàn)了各種容錯機(jī)制,經(jīng)過在線上生產(chǎn)環(huán)境穩(wěn)定運(yùn)行兩年以上的驗證。Persia 經(jīng)過多個億級 DAU 核心業(yè)務(wù)場景的實踐檢驗,取得了顯著的性能和業(yè)務(wù)指標(biāo)提升。
- 安全、故障易排查:Persia 由注重內(nèi)存安全、速度和并發(fā)性的 Rust 語言實現(xiàn),在編譯期就排除了大量的內(nèi)存安全問題。原生提供大量打點監(jiān)控,與 Grafana 完美結(jié)合,可自定義各類報警條件。同時基于 tracing 實現(xiàn)了分模塊、分層級的 log 輸出,使得實際場景中故障排查更加輕松。
- 靈活的特征處理:支持交叉特征等各種常見特征處理方式,且用戶通過 Python 腳本即可定義各種自定義特征處理模式。兼具靈活性與易用性。
- 線上線下一致性:離線訓(xùn)練和線上訓(xùn)練代碼統(tǒng)一,解決工程師常常需要花費(fèi)大量時間排查模型上線效果不一致等痛點問題。
Persia 設(shè)計思路
整體架構(gòu)
在推薦模型中,模型往往由下圖中的幾部分構(gòu)成:
- Embedding Layer: 用戶 id、item id 等 ID 類 feature 對應(yīng)的 Embedding 構(gòu)成的 Embedding 層。每個 id 對應(yīng)一個預(yù)設(shè)大小的向量(稱為 Embedding),由于 id 數(shù)量往往十分巨大,這些向量常常會占據(jù)整個模型體積的 99% 以上。
- Non-ID Type Features: 圖像信息、LDA 等實數(shù)向量特征。這部分將會與 id 對應(yīng)的 Embedding vector 組合在一起,輸入到 DNN 中預(yù)測點擊率等。
- Dense Neural Network (以下簡稱 NN): 這部分是一個神經(jīng)網(wǎng)絡(luò),接受 Embedding vector 和實數(shù)向量特征,輸出點擊率等希望預(yù)測的量。
這種推薦模型中, Embedding Layer 參數(shù)往往占模型體積的絕大部分,但 Embedding Layer 的計算量卻不大。而 NN 的參數(shù)量只占模型體積的很小部分,卻占了絕大部分計算量。這正對應(yīng)了:硬件上 CPU 的內(nèi)存較大,但算力較低,而 GPU 的顯存較小,但算力較高。
現(xiàn)有的訓(xùn)練框架雖然包含GPU算力,但是每個 GPU worker 都需要跟大量 PS 之間傳遞數(shù)據(jù)和模型,這常常會觸發(fā)通訊瓶頸,從而整個效率都被拖垮了。
因此,在 Persia 系統(tǒng)設(shè)計中,NN 被置于 GPU 顯存中,通過 GPU 進(jìn)行梯度計算。對于 NN 部分直接通過 GPU 與 GPU 之間的高效集合通訊同步,完全不經(jīng)過 PS。而 Embedding 則置于內(nèi)存中,通過 CPU 進(jìn)行計算。Persia 對于 PS 進(jìn)行兩層架構(gòu)設(shè)計 (Embedding PS, Embedding Worker,后文介紹),能夠在多數(shù)場景下進(jìn)一步降低 GPU worker 帶寬消耗,提升整體訓(xùn)練效率。
同步+異步混合訓(xùn)練
此外,現(xiàn)存系統(tǒng)往往采用全同步訓(xùn)練或全異步訓(xùn)練方式。在全同步訓(xùn)練中,所有 GPU worker 對一批數(shù)據(jù)進(jìn)行訓(xùn)練和模型更新,全部完成后再進(jìn)入下一批數(shù)據(jù)。在模型越來越大、機(jī)器數(shù)量越來越多的情況下,會導(dǎo)致機(jī)器之間相互等待、同步的時間大幅增加,難以在有限時間內(nèi)完成訓(xùn)練。這種情況下系統(tǒng)的訓(xùn)練過程如下圖中第一行 (Full Sync) 所示。在全異步訓(xùn)練中,每個 worker 獨(dú)立訓(xùn)練并更新 PS 參數(shù)。雖然 worker 之間不需要相互等待,訓(xùn)練的效率較高,但是隨著機(jī)器數(shù)量增加,每個 worker 上使用的模型的差異會變大,導(dǎo)致模型的訓(xùn)練效果大幅下降。這種情況下系統(tǒng)的訓(xùn)練過程如下圖中第二行 (Full Async) 所示。
針對這兩種方式的問題,Persia 設(shè)計了 Hybrid 訓(xùn)練架構(gòu),能夠在保證訓(xùn)練效果的同時,達(dá)到接近全異步的訓(xùn)練效率。推薦場景訓(xùn)練中的一個核心觀察是,Embedding 的更新非常稀疏,兩次更新之間往往交集很小,因此即使對 Embedding 做異步更新,對最終的訓(xùn)練結(jié)果影響也不大。而 NN 部分的更新則反之,每一次都會更新全部參數(shù),如果做異步訓(xùn)練,會導(dǎo)致訓(xùn)練結(jié)果的巨大差異。Persia 所提出的 Hybrid 訓(xùn)練方式,能夠?qū)?NN 部分同步訓(xùn)練,Embedding 部分異步訓(xùn)練。最終訓(xùn)練效率接近純異步訓(xùn)練的效率,同時模型效果保持和全同步訓(xùn)練一致。兼得兩方面的優(yōu)勢。這種情況下系統(tǒng)的訓(xùn)練過程如下圖中第三行 (Naive Hybrid) 所示。Persia 在此之上還對能夠并行執(zhí)行的通訊、計算操作進(jìn)行重疊,進(jìn)一步提升系統(tǒng)效率。最終系統(tǒng)的訓(xùn)練過程如下圖中第四行 (Persia) 所示:
理論保證
有別于現(xiàn)存系統(tǒng),Persia 對于 Hybrid 算法的設(shè)計給出了嚴(yán)格的理論保證。對于 Expectation of Loss 的優(yōu)化問題(比如推薦場景中最普遍的每個樣本對應(yīng)一個 loss 的場景):
其中 f(w) 代表整個數(shù)據(jù)集上的平均 loss,ξ 代表一個樣本,w 代表模型參數(shù),F(xiàn)(w; ξ) 代表樣本 ξ 上的 loss。模型訓(xùn)練的目標(biāo)是最小化整個數(shù)據(jù)集上的平均 loss。使用 Persia Hybrid 的訓(xùn)練方式,可以證明模型的收斂速度為:
其中 σ 為數(shù)據(jù)集方差,T 為迭代次數(shù),τ 為 GPU worker 數(shù)量,α 為 ID 類 feature 碰撞概率。其中前兩項為全同步訓(xùn)練的收斂速度,最后一項為 Hybrid 訓(xùn)練引入的誤差。在推薦場景中,因為 Embedding 的更新非常稀疏,碰撞概率 α 遠(yuǎn)小于 1,因此 Hybrid 收斂速度與全同步訓(xùn)練收斂速度幾乎完全一致,但因為同步開銷減少,每一步的訓(xùn)練執(zhí)行效率大幅提升。對于具體的理論證明,可以參考 Persia 的論文 [1]。
其他優(yōu)化
在算法創(chuàng)新的基礎(chǔ)上,為了發(fā)揮極致的性能。Persia 提供了大量的實現(xiàn)層優(yōu)化。比如:
- 所有 PS 服務(wù)通訊使用為訓(xùn)練場景優(yōu)化的 zero-copy Persia RPC 系統(tǒng),在訓(xùn)練場景下(特點是 payload 非常大,包括 Embedding 和梯度等等大量 tensor 數(shù)據(jù))性能遠(yuǎn)超傳統(tǒng) RPC 框架(如 gRPC、bRPC);
- GPU 之間通訊使用同為快手開源的 Bagua 訓(xùn)練加速框架,對 GPU 之間的集合通訊性能有顯著提升,并能通過梯度壓縮等算法進(jìn)一步降低通訊開銷;
- Embedding 在 PS 中通過特殊設(shè)計的數(shù)據(jù)結(jié)構(gòu) (Persia Embedding Array List) 存儲,大幅提升 PS 效率和模型存取效率。這包括運(yùn)行過程中無需動態(tài)申請新內(nèi)存,同時更好地利用 CPU cache 機(jī)制。支持 Embedding 逐出邏輯。模型保存和讀取過程簡化為對連續(xù)內(nèi)存的直接 Dump/Load 過程;
- 引入 Persia Embedding Worker 組件,將 Embedding Sum Pooling、處理原始數(shù)據(jù)等操作執(zhí)行后再發(fā)送給 GPU,大幅減少 GPU 帶寬占用;
- 原始數(shù)據(jù)處理為 Persia Compact Batch 格式,自帶 ID 去重和數(shù)據(jù)壓縮表征等性質(zhì),相比一般表示方式數(shù)據(jù)體積降低至 1/4,提升系統(tǒng)數(shù)據(jù)處理效率。
驗證和比較
測試選用 Alibaba-Ad,Avazu-Ad,Criteo-Ad 等多種開源數(shù)據(jù)集,訓(xùn)練效率整體有 8 倍以上提升:
Persia 可支持高達(dá) 100 萬億模型訓(xùn)練,并在模型規(guī)模變大時保持訓(xùn)練效率:
在資源量擴(kuò)大時,Persia 可以接近線性擴(kuò)展(投入 n 倍的資源量,訓(xùn)練效率提升接近 n 倍):
Persia 使用實例
使用 Persia 非常簡單,主要分為訓(xùn)練部署、模型定義、自定義數(shù)據(jù)集部分。
- 分布式部署:通過 Persia operator 可在 Kubernetes 集群上一鍵部署 PERSIA 任務(wù)
- 模型定義:直接使用 PyTorch
- 自定義數(shù)據(jù)集:自定義預(yù)處理邏輯,將結(jié)果通過 Persia 提供的 Python 工具包轉(zhuǎn)換成 Persia Compact Batch 即可
完整例子和更多場景,歡迎參考 Persia Tutorial 文檔(https://persiaml-tutorials.pages.dev/)。Persia 模型上線推理
Persia 訓(xùn)練的模型 Embedding 部分可通過線上部署 Embedding PS 和 Embedding Worker 直接提供服務(wù)。NN 部分為原生 PyTorch 模型,在 Persia Tutorial 中提供了通過 TorchServe 推理的簡單例子。用戶也可以通過原生 PyTorch 的各種工具,比如轉(zhuǎn)換成 TensorRT 模型,進(jìn)一步提升推理性能。?