微調(diào)大模型,AMD MI300X就夠了!跟著這篇博客微調(diào)Llama 3.1 405B,效果媲美H100
隨著 AI 模型的參數(shù)量越來越大,對(duì)算力的需求也水漲船高。
比如最近,Llama-3.1 登上了最強(qiáng)開源大模型的寶座,但超大杯 405B 版本的內(nèi)存就高達(dá) 900 多 GB,這對(duì)算力構(gòu)成了更加苛刻的挑戰(zhàn)。
如何降低算力的使用成本和使用門檻,已經(jīng)成為許多公司尋求突破的關(guān)鍵。Felafax 就是其中的一家創(chuàng)業(yè)公司,致力于簡化 AI 訓(xùn)練集群的搭建流程。
Nikhil Sonti 和 Nikhin Sonti 創(chuàng)立了 Felafax,他們的口號(hào)是在構(gòu)建開源 AI 平臺(tái),為下一代 AI 硬件服務(wù),將機(jī)器學(xué)習(xí)的訓(xùn)練成本降低 30%。
與英偉達(dá)相比,AMD 的 GPU,尤其是 MI300X 系列,提供了更高的性價(jià)比,按每美元計(jì)算,其性能表現(xiàn)更為出色。
最近,F(xiàn)elafax 的聯(lián)合創(chuàng)始人 Nikhil Sonti 發(fā)布了一篇博客,詳細(xì)分享了如何通過 8 張 AMD MI300X GPU 和 JAX 微調(diào) LLaMA 3.1 405B 模型的方法,所有代碼現(xiàn)已開源。
Github 鏈接:https://github.com/felafax/felafax
機(jī)器之心對(duì)博客內(nèi)容進(jìn)行了不改變原意的編譯、整理,以下是博客內(nèi)容:
JAX 尤其適合非英偉達(dá)硬件
JAX 是一個(gè)強(qiáng)大的機(jī)器學(xué)習(xí)庫,結(jié)合了類似 NumPy 的 API、自動(dòng)微分功能以及 Google 的 XLA 編譯器。它在模型并行化方面提供了優(yōu)秀的 API,因此非常適合像 LLaMA 3.1 405B 這樣的超大模型訓(xùn)練。
在使用 AMD 硬件時(shí),JAX 有幾個(gè)明顯的優(yōu)勢:
- 多硬件并行支持:JAX 采用 XLA(加速線性代數(shù))編譯器,將計(jì)算編譯為硬件無關(guān)的中間表示(HLO),這意味著同樣的 JAX 代碼無需修改便可高效運(yùn)行在不同硬件后端,包括 AMD GPU。
- 獨(dú)立于底層硬件:XLA 編譯器的優(yōu)化策略是通用的,不針對(duì)某個(gè)特定的硬件平臺(tái)。這使得任何支持 XLA 的硬件設(shè)備(如 CPU、GPU、TPU)都能受益于這些優(yōu)化,獲得更好的性能表現(xiàn)。
- 極高的適應(yīng)性:從 NVIDIA 轉(zhuǎn)移到 AMD(或其他硬件)時(shí),JAX 只需做極少的代碼改動(dòng)。而相較之下,PyTorch 與英偉達(dá)的 CUDA 生態(tài)系統(tǒng)緊密耦合,遷移過程相對(duì)復(fù)雜。
因此,JAX 成為了我們在非英偉達(dá)硬件上的最佳選擇。
拉取 Docker 鏡像:
docker pull rocm/jax:latest
啟動(dòng) Docker 容器:
# Pull the Docker Image:
docker pull rocm/jax:latest
# Start the Docker Container:
docker run -it -w /workspace --device=/dev/kfd --device=/dev/dri --group-add video \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G rocm/jax:latest
# Verify the Installation:
python3 -c 'import jax; print(jax.devices())'
驗(yàn)證安裝
python3 -c 'import jax; print (jax.devices ())'
訓(xùn)練使用了一個(gè)配備了 8 張 AMD MI300x GPU 的 AMD 節(jié)點(diǎn)。每張 MI300x 擁有 192GB 的 HBM3 內(nèi)存,性能表現(xiàn)與最新的英偉達(dá) H100 GPU 相比非常出色。
與英偉達(dá) H100 的比較,來源:TensorWave
訓(xùn)練 LLaMA 405B:性能與可擴(kuò)展性
使用 JAX,可以成功地在 AMD GPU 上訓(xùn)練 LLaMA 405B 模型。我們使用 LoRA 微調(diào),將所有模型權(quán)重和 LoRA 參數(shù)都設(shè)為 bfloat16,LoRA rank 設(shè)為 8,LoRA alpha 設(shè)為 16:
- 模型大小:LLaMA 模型的權(quán)重占用了約 800GB 的顯存。
- LoRA 權(quán)重 + 優(yōu)化器狀態(tài):大約占用了 400GB 的顯存。
- 顯存總使用量:占總顯存的 77%,約 1200GB。
- 限制:由于 405B 模型的規(guī)模過大,batch 大小和序列長度的空間有限,使用的 batch size 為 16,序列長度為 64。
- JIT 編譯:由于空間限制,無法運(yùn)行 JIT 編譯版本;它可能需要比急切模式稍多的空間。
- 訓(xùn)練速度:使用 JAX 急切模式,約為 35 tokens / 秒。
- 內(nèi)存效率:穩(wěn)定在約 70% 左右。
- 擴(kuò)展性:在 8 張 GPU 上,使用 JAX 的擴(kuò)展性接近線性。
由于硬件和顯存的限制,我們無法運(yùn)行 JIT 編譯版本的 405B 模型,整個(gè)訓(xùn)練過程是在 JAX 的急切模式下執(zhí)行的,因此還有很大的進(jìn)步空間。
下圖中顯示了在一次微調(diào)訓(xùn)練步驟中,8 張 GPU 的顯存利用率和 rocm-smi 輸出:
GPU 利用率:
顯存利用率:
rocm-smi 輸出:
訓(xùn)練設(shè)置
將 LLaMA 3.1 從 PyTorch 移植到 JAX
此前,Nikhil Sonti 分享過如何將 LLaMA 3.1 從 PyTorch 移植到 JAX。他指出,目前 90% 的大型語言模型(LLM)都運(yùn)行在 NVIDIA GPU 上,但實(shí)際上還有一些同樣強(qiáng)大且性價(jià)比更高的替代方案。例如,在 Google TPU 上訓(xùn)練和部署 Llama 3.1 的成本比 NVIDIA GPU 低約 30%。
然而,支持非 NVIDIA 硬件的開發(fā)工具較為匱乏。Sonti 最初嘗試使用 PyTorch XLA 在 TPU 上訓(xùn)練 Llama 3.1,但過程并不順利。XLA 與 PyTorch 的集成不夠完善,缺少一些關(guān)鍵的庫(如 bitsandbytes 無法正常運(yùn)行),同時(shí)還遇到了一些難以解決的 HuggingFace 錯(cuò)誤。
為此,他決定調(diào)整策略,將 Llama 3.1 從 PyTorch 移植到 JAX,成功解決了這些問題。Sonti 還錄制了詳細(xì)的教程視頻,并開源了所有代碼:
- 方法演示:https://dub.sh/felafax-demo
- 代碼倉庫:https://github.com/felafax/felafax
加載模型,并把模型參數(shù)分片
處理像 LLaMA 405B 這樣的超大模型,需要在多個(gè)設(shè)備之間高效地進(jìn)行參數(shù)分片。以下是如何通過 JAX 實(shí)現(xiàn)這一點(diǎn)的。
在 JAX 中進(jìn)行參數(shù)分片
為了將巨大的 LLaMA 405B 模型高效地分布到 8 張 AMD GPU 上,需要使用 JAX 的設(shè)備網(wǎng)格(device mesh)功能。
部署代碼:https://github.com/felafax/felafax/blob/e2a96a0e207e1dc70effde099fe33a9e42a7d5cb/llama3_jax/trainer_engine/jax_utils.py#L69
JAX 的設(shè)備網(wǎng)格可以幫助我們把可用的設(shè)備組織成一個(gè)網(wǎng)格,讓我們可以指定如何把模型的參數(shù)和計(jì)算分配到不同的 GPU 上。
在本文的設(shè)置中,需要?jiǎng)?chuàng)建一個(gè)形狀為(1, 8, 1)的網(wǎng)格,并將軸分別命名為數(shù)據(jù)并行(dp)、全分片數(shù)據(jù)并行(fsdp)和模型并行(mp)。然后,為模型的每個(gè)張量定義特定的分片規(guī)則,指定這些維度如何沿著這些網(wǎng)格軸進(jìn)行分片。
DEVICES = jax.devices ()
DEVICE_COUNT = len (DEVICES)
DEVICE_MESH = mesh_utils.create_device_mesh ((1, 8, 1))
MESH = Mesh (devices=DEVICE_MESH, axis_names=("dp", "fsdp", "mp"))
可視化分片
可以使用以下代碼來可視化分片結(jié)果,從而方便地驗(yàn)證分片規(guī)則是否按預(yù)期應(yīng)用。
jax.debug.visualize_array_sharding
分片規(guī)則
模型不同組件的分片規(guī)則如下所示:
- 參數(shù)如何分片:
參數(shù)要在 8 個(gè) GPU 之間分配。例如,LM head(lm_head/kernel)張量有兩個(gè)軸,按照 PS ("fsdp", "mp") 進(jìn)行分片。在本例中是 8 和 1,因此可以看到該張量在第一個(gè)軸上沿著 8 個(gè) GPU 被拆分。
- Non-Replicated 參數(shù):
沒有任何分片規(guī)范的參數(shù)會(huì)在所有設(shè)備上進(jìn)行復(fù)制。例如,層歸一化(attention_norm/kernel 和 ffn_norm/kernel)沒有設(shè)置分片規(guī)范,是 PS (None)。
應(yīng)用分片函數(shù)
在加載模型時(shí),使用以下分片函數(shù)逐步對(duì)模型權(quán)重進(jìn)行分片:
def make_shard_and_gather_fns (partition_specs):
def make_shard_fn (partition_spec):
out_sharding = NamedSharding (mesh, partition_spec)
def shard_fn (tensor):
return jax.device_put (tensor, out_sharding).block_until_ready ()
return shard_fn
shard_fns = jax.tree_util.tree_map (make_shard_fn, partition_specs)
return shard_fns
# Create shard functions based on partitioning rules
shard_fns = make_shard_and_gather_fns (partitioning_rules)
這使得我們能夠?qū)⒚總€(gè)參數(shù)放置在指定的設(shè)備上,并按照設(shè)定的分片進(jìn)行處理。
分片訓(xùn)練 Batch
最初,訓(xùn)練 Batch 是正常創(chuàng)建的,但在輸入模型之前,需要按照下面的代碼在 GPU 上進(jìn)行分片:
train_batch = jax.device_put ( train_batch,
NamedSharding (self.mesh, PS ("dp", "fsdp")))
在這里,我們指定訓(xùn)練 Batch 應(yīng)該在 "dp" 和 "fsdp" 軸上進(jìn)行分片,在本例中分別對(duì)應(yīng)于被分成 1 和 8 份,如果把結(jié)果可視化出來,如下所示:
分片前:
在調(diào)用 jax.device_put 之后:
加入 LoRA
LoRA 通過將權(quán)重更新分解為低秩矩陣,減少了可訓(xùn)練參數(shù)的數(shù)量,這對(duì)于微調(diào)大型模型特別有效。以下是在 AMD GPU 上微調(diào) Llama 3.1-405 的 LoRA 的要點(diǎn):
- 將 LoRA 參數(shù)(lora_a 和 lora_b)與主模型參數(shù)分開。
- 使用 jax.lax.stop_gradient (kernel) 來防止對(duì)主模型權(quán)重的更新。
- 使用 lax.dot_general 進(jìn)行快速、精確控制的矩陣運(yùn)算。
- LoRA 輸出在添加到主輸出之前會(huì)被縮放為 (self.lora_alpha/self.lora_rank)。
LoRADense 層
在此設(shè)定一個(gè)自定義的 LoRADense 層,該層集成了 LoRA 參數(shù):
class LoRADense (nn.Module):
features: int
lora_rank: int = 8
lora_alpha: float = 16.0
@nn.compact
def __call__(self, inputs: Any) -> Any:
# Original kernel parameter (frozen)
kernel = self.param ('kernel', ...)
y = lax.dot_general (inputs, jax.lax.stop_gradient (kernel), ...)
# LoRA parameters (trainable)
lora_a = self.variable ('lora_params', 'lora_a', ..., ...)
lora_b = self.variable ('lora_params', 'lora_b', ..., ...)
# Compute LoRA output
lora_output = lax.dot_general (inputs, lora_a.value, ...)
lora_output = lax.dot_general (lora_output, lora_b.value, ...)
# Combine original output with LoRA modifications
y += (self.lora_alpha/self.lora_rank) * lora_output
return y.astype (self.dtype)
分片 LoRA 參數(shù)
為了高效地在設(shè)備之間分配 LoRA 參數(shù),我們也通過 JAX 設(shè)定了分片規(guī)則,這確保了 LoRA 參數(shù)與主模型參數(shù)的分片一致,優(yōu)化了內(nèi)存使用和計(jì)算效率。
LoRA A matrices (lora_a)
LoRA A 矩陣(lora_a)
- 分片規(guī)則:PS ("fsdp", "mp")
- 可視化結(jié)果:如下圖所示,lora_a 參數(shù)被分片為 (8, 1),這意味著第一個(gè)軸在 8 個(gè)設(shè)備上進(jìn)行分片("fsdp" 軸),而第二個(gè)軸未進(jìn)行分片。
LoRA B 矩陣(lora_b)
- 分片規(guī)則:PS ("mp", "fsdp")
- 可視化結(jié)果:如下圖所示,lora_b 參數(shù)被分片為 (1, 8),這意味著第二個(gè)軸在 8 個(gè)設(shè)備上進(jìn)行分片(fsdp 軸),而第一個(gè)軸未進(jìn)行分片。
這種分片策略優(yōu)化了參數(shù)的分配,減少了通信開銷,并在訓(xùn)練過程中增強(qiáng)了并行性。它確保每個(gè)設(shè)備僅持有一部分 LoRA 參數(shù),使得大模型如 LLaMA 405B 的高效擴(kuò)展成為可能。
僅更新 LoRA 參數(shù)
為了優(yōu)化訓(xùn)練,在微調(diào) LLaMA 405B 模型,只計(jì)算 LoRA 參數(shù)的梯度,保持主模型參數(shù)不變。這個(gè)方法減少了內(nèi)存使用,并加速了訓(xùn)練,因?yàn)橹桓螺^少的參數(shù)。可以移步 GitHub 倉庫,查看實(shí)現(xiàn)細(xì)節(jié)。
在訓(xùn)練過程中,每一步都涉及將一批輸入數(shù)據(jù)通過模型進(jìn)行處理。由于只有 LoRA 參數(shù)是可訓(xùn)練的,因此模型的預(yù)測和計(jì)算的損失僅依賴于這些參數(shù),然后對(duì) LoRA 參數(shù)進(jìn)行反向傳播。只更新這些參數(shù)簡化了訓(xùn)練過程,使得在多個(gè) GPU 上高效微調(diào)像 LLaMA 405B 這樣的大型模型成為可能。
更多研究細(xì)節(jié),請參考原博客。