自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

微調(diào)大模型,AMD MI300X就夠了!跟著這篇博客微調(diào)Llama 3.1 405B,效果媲美H100

人工智能 新聞
為了優(yōu)化訓(xùn)練,在微調(diào) LLaMA 405B 模型,只計(jì)算 LoRA 參數(shù)的梯度,保持主模型參數(shù)不變。

隨著 AI 模型的參數(shù)量越來越大,對(duì)算力的需求也水漲船高。

比如最近,Llama-3.1 登上了最強(qiáng)開源大模型的寶座,但超大杯 405B 版本的內(nèi)存就高達(dá) 900 多 GB,這對(duì)算力構(gòu)成了更加苛刻的挑戰(zhàn)。

如何降低算力的使用成本和使用門檻,已經(jīng)成為許多公司尋求突破的關(guān)鍵。Felafax 就是其中的一家創(chuàng)業(yè)公司,致力于簡化 AI 訓(xùn)練集群的搭建流程。

Nikhil Sonti 和 Nikhin Sonti 創(chuàng)立了 Felafax,他們的口號(hào)是在構(gòu)建開源 AI 平臺(tái),為下一代 AI 硬件服務(wù),將機(jī)器學(xué)習(xí)的訓(xùn)練成本降低 30%。

與英偉達(dá)相比,AMD 的 GPU,尤其是 MI300X 系列,提供了更高的性價(jià)比,按每美元計(jì)算,其性能表現(xiàn)更為出色。

最近,F(xiàn)elafax 的聯(lián)合創(chuàng)始人 Nikhil Sonti 發(fā)布了一篇博客,詳細(xì)分享了如何通過 8 張 AMD MI300X GPU 和 JAX 微調(diào) LLaMA 3.1 405B 模型的方法,所有代碼現(xiàn)已開源。

圖片

Github 鏈接:https://github.com/felafax/felafax

機(jī)器之心對(duì)博客內(nèi)容進(jìn)行了不改變原意的編譯、整理,以下是博客內(nèi)容:

JAX 尤其適合非英偉達(dá)硬件

JAX 是一個(gè)強(qiáng)大的機(jī)器學(xué)習(xí)庫,結(jié)合了類似 NumPy 的 API、自動(dòng)微分功能以及 Google 的 XLA 編譯器。它在模型并行化方面提供了優(yōu)秀的 API,因此非常適合像 LLaMA 3.1 405B 這樣的超大模型訓(xùn)練。

在使用 AMD 硬件時(shí),JAX 有幾個(gè)明顯的優(yōu)勢:

  • 多硬件并行支持:JAX 采用 XLA(加速線性代數(shù))編譯器,將計(jì)算編譯為硬件無關(guān)的中間表示(HLO),這意味著同樣的 JAX 代碼無需修改便可高效運(yùn)行在不同硬件后端,包括 AMD GPU。
  • 獨(dú)立于底層硬件:XLA 編譯器的優(yōu)化策略是通用的,不針對(duì)某個(gè)特定的硬件平臺(tái)。這使得任何支持 XLA 的硬件設(shè)備(如 CPU、GPU、TPU)都能受益于這些優(yōu)化,獲得更好的性能表現(xiàn)。
  • 極高的適應(yīng)性:從 NVIDIA 轉(zhuǎn)移到 AMD(或其他硬件)時(shí),JAX 只需做極少的代碼改動(dòng)。而相較之下,PyTorch 與英偉達(dá)的 CUDA 生態(tài)系統(tǒng)緊密耦合,遷移過程相對(duì)復(fù)雜。

因此,JAX 成為了我們在非英偉達(dá)硬件上的最佳選擇。

拉取 Docker 鏡像:

docker pull rocm/jax:latest

啟動(dòng) Docker 容器:

# Pull the Docker Image:
docker pull rocm/jax:latest 


# Start the Docker Container:
docker run -it -w /workspace --device=/dev/kfd --device=/dev/dri --group-add video \ 
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G rocm/jax:latest


# Verify the Installation: 
python3 -c 'import jax; print(jax.devices())'

驗(yàn)證安裝

python3 -c 'import jax; print (jax.devices ())'

訓(xùn)練使用了一個(gè)配備了 8 張 AMD MI300x GPU 的 AMD 節(jié)點(diǎn)。每張 MI300x 擁有 192GB 的 HBM3 內(nèi)存,性能表現(xiàn)與最新的英偉達(dá) H100 GPU 相比非常出色。

圖片

與英偉達(dá) H100 的比較,來源:TensorWave

訓(xùn)練 LLaMA 405B:性能與可擴(kuò)展性

使用 JAX,可以成功地在 AMD GPU 上訓(xùn)練 LLaMA 405B 模型。我們使用 LoRA 微調(diào),將所有模型權(quán)重和 LoRA 參數(shù)都設(shè)為 bfloat16,LoRA rank 設(shè)為 8,LoRA alpha 設(shè)為 16:

  • 模型大小:LLaMA 模型的權(quán)重占用了約 800GB 的顯存。
  • LoRA 權(quán)重 + 優(yōu)化器狀態(tài):大約占用了 400GB 的顯存。
  • 顯存總使用量:占總顯存的 77%,約 1200GB。
  • 限制:由于 405B 模型的規(guī)模過大,batch 大小和序列長度的空間有限,使用的 batch size 為 16,序列長度為 64。
  • JIT 編譯:由于空間限制,無法運(yùn)行 JIT 編譯版本;它可能需要比急切模式稍多的空間。
  • 訓(xùn)練速度:使用 JAX 急切模式,約為 35 tokens / 秒。
  • 內(nèi)存效率:穩(wěn)定在約 70% 左右。
  • 擴(kuò)展性:在 8 張 GPU 上,使用 JAX 的擴(kuò)展性接近線性。

由于硬件和顯存的限制,我們無法運(yùn)行 JIT 編譯版本的 405B 模型,整個(gè)訓(xùn)練過程是在 JAX 的急切模式下執(zhí)行的,因此還有很大的進(jìn)步空間。 

下圖中顯示了在一次微調(diào)訓(xùn)練步驟中,8 張 GPU 的顯存利用率和 rocm-smi 輸出:

GPU 利用率:

圖片

顯存利用率:

圖片

rocm-smi 輸出:

圖片

訓(xùn)練設(shè)置 

將 LLaMA 3.1 從 PyTorch 移植到 JAX 

圖片

此前,Nikhil Sonti 分享過如何將 LLaMA 3.1 從 PyTorch 移植到 JAX。他指出,目前 90% 的大型語言模型(LLM)都運(yùn)行在 NVIDIA GPU 上,但實(shí)際上還有一些同樣強(qiáng)大且性價(jià)比更高的替代方案。例如,在 Google TPU 上訓(xùn)練和部署 Llama 3.1 的成本比 NVIDIA GPU 低約 30%。

然而,支持非 NVIDIA 硬件的開發(fā)工具較為匱乏。Sonti 最初嘗試使用 PyTorch XLA 在 TPU 上訓(xùn)練 Llama 3.1,但過程并不順利。XLA 與 PyTorch 的集成不夠完善,缺少一些關(guān)鍵的庫(如 bitsandbytes 無法正常運(yùn)行),同時(shí)還遇到了一些難以解決的 HuggingFace 錯(cuò)誤。

為此,他決定調(diào)整策略,將 Llama 3.1 從 PyTorch 移植到 JAX,成功解決了這些問題。Sonti 還錄制了詳細(xì)的教程視頻,并開源了所有代碼:

圖片

  • 方法演示:https://dub.sh/felafax-demo
  • 代碼倉庫:https://github.com/felafax/felafax

加載模型,并把模型參數(shù)分片

處理像 LLaMA 405B 這樣的超大模型,需要在多個(gè)設(shè)備之間高效地進(jìn)行參數(shù)分片。以下是如何通過 JAX 實(shí)現(xiàn)這一點(diǎn)的。

在 JAX 中進(jìn)行參數(shù)分片

為了將巨大的 LLaMA 405B 模型高效地分布到 8 張 AMD GPU 上,需要使用 JAX 的設(shè)備網(wǎng)格(device mesh)功能。

部署代碼:https://github.com/felafax/felafax/blob/e2a96a0e207e1dc70effde099fe33a9e42a7d5cb/llama3_jax/trainer_engine/jax_utils.py#L69

JAX 的設(shè)備網(wǎng)格可以幫助我們把可用的設(shè)備組織成一個(gè)網(wǎng)格,讓我們可以指定如何把模型的參數(shù)和計(jì)算分配到不同的 GPU 上。

在本文的設(shè)置中,需要?jiǎng)?chuàng)建一個(gè)形狀為(1, 8, 1)的網(wǎng)格,并將軸分別命名為數(shù)據(jù)并行(dp)、全分片數(shù)據(jù)并行(fsdp)和模型并行(mp)。然后,為模型的每個(gè)張量定義特定的分片規(guī)則,指定這些維度如何沿著這些網(wǎng)格軸進(jìn)行分片。

DEVICES = jax.devices () 
DEVICE_COUNT = len (DEVICES) 
DEVICE_MESH = mesh_utils.create_device_mesh ((1, 8, 1)) 
MESH = Mesh (devices=DEVICE_MESH, axis_names=("dp", "fsdp", "mp"))

可視化分片

可以使用以下代碼來可視化分片結(jié)果,從而方便地驗(yàn)證分片規(guī)則是否按預(yù)期應(yīng)用。

jax.debug.visualize_array_sharding

分片規(guī)則

模型不同組件的分片規(guī)則如下所示:

  • 參數(shù)如何分片:

參數(shù)要在 8 個(gè) GPU 之間分配。例如,LM head(lm_head/kernel)張量有兩個(gè)軸,按照 PS ("fsdp", "mp") 進(jìn)行分片。在本例中是 8 和 1,因此可以看到該張量在第一個(gè)軸上沿著 8 個(gè) GPU 被拆分。

  • Non-Replicated 參數(shù):

沒有任何分片規(guī)范的參數(shù)會(huì)在所有設(shè)備上進(jìn)行復(fù)制。例如,層歸一化(attention_norm/kernel 和 ffn_norm/kernel)沒有設(shè)置分片規(guī)范,是 PS (None)。

應(yīng)用分片函數(shù)

在加載模型時(shí),使用以下分片函數(shù)逐步對(duì)模型權(quán)重進(jìn)行分片:

def make_shard_and_gather_fns (partition_specs):
    def make_shard_fn (partition_spec):
        out_sharding = NamedSharding (mesh, partition_spec)
        def shard_fn (tensor):
            return jax.device_put (tensor, out_sharding).block_until_ready ()
        return shard_fn

    shard_fns = jax.tree_util.tree_map (make_shard_fn, partition_specs)
    return shard_fns

# Create shard functions based on partitioning rules
shard_fns = make_shard_and_gather_fns (partitioning_rules)

這使得我們能夠?qū)⒚總€(gè)參數(shù)放置在指定的設(shè)備上,并按照設(shè)定的分片進(jìn)行處理。

分片訓(xùn)練 Batch

最初,訓(xùn)練 Batch 是正常創(chuàng)建的,但在輸入模型之前,需要按照下面的代碼在 GPU 上進(jìn)行分片:

train_batch = jax.device_put ( train_batch, 
NamedSharding (self.mesh, PS ("dp", "fsdp")))

在這里,我們指定訓(xùn)練 Batch 應(yīng)該在 "dp" 和 "fsdp" 軸上進(jìn)行分片,在本例中分別對(duì)應(yīng)于被分成 1 和 8 份,如果把結(jié)果可視化出來,如下所示:

分片前:

圖片

在調(diào)用  jax.device_put 之后:

圖片

加入 LoRA

LoRA 通過將權(quán)重更新分解為低秩矩陣,減少了可訓(xùn)練參數(shù)的數(shù)量,這對(duì)于微調(diào)大型模型特別有效。以下是在 AMD GPU 上微調(diào) Llama 3.1-405 的 LoRA 的要點(diǎn):

  • 將 LoRA 參數(shù)(lora_a 和 lora_b)與主模型參數(shù)分開。
  • 使用 jax.lax.stop_gradient (kernel) 來防止對(duì)主模型權(quán)重的更新。
  • 使用 lax.dot_general 進(jìn)行快速、精確控制的矩陣運(yùn)算。
  • LoRA 輸出在添加到主輸出之前會(huì)被縮放為 (self.lora_alpha/self.lora_rank)。

LoRADense 層

在此設(shè)定一個(gè)自定義的 LoRADense 層,該層集成了 LoRA 參數(shù):

class LoRADense (nn.Module):
    features: int
    lora_rank: int = 8
    lora_alpha: float = 16.0
@nn.compact
def __call__(self, inputs: Any) -> Any:
# Original kernel parameter (frozen)
        kernel = self.param ('kernel', ...)
        y = lax.dot_general (inputs, jax.lax.stop_gradient (kernel), ...)
# LoRA parameters (trainable)
        lora_a = self.variable ('lora_params', 'lora_a', ..., ...)
        lora_b = self.variable ('lora_params', 'lora_b', ..., ...)
# Compute LoRA output
        lora_output = lax.dot_general (inputs, lora_a.value, ...)
        lora_output = lax.dot_general (lora_output, lora_b.value, ...)
# Combine original output with LoRA modifications
        y += (self.lora_alpha/self.lora_rank) * lora_output




        return y.astype (self.dtype)

分片 LoRA 參數(shù)

為了高效地在設(shè)備之間分配 LoRA 參數(shù),我們也通過 JAX 設(shè)定了分片規(guī)則,這確保了 LoRA 參數(shù)與主模型參數(shù)的分片一致,優(yōu)化了內(nèi)存使用和計(jì)算效率。

LoRA A matrices (lora_a)

LoRA A 矩陣(lora_a)

  • 分片規(guī)則:PS ("fsdp", "mp")
  • 可視化結(jié)果:如下圖所示,lora_a 參數(shù)被分片為 (8, 1),這意味著第一個(gè)軸在 8 個(gè)設(shè)備上進(jìn)行分片("fsdp" 軸),而第二個(gè)軸未進(jìn)行分片。

圖片

LoRA B 矩陣(lora_b)

  • 分片規(guī)則:PS ("mp", "fsdp")
  • 可視化結(jié)果:如下圖所示,lora_b 參數(shù)被分片為 (1, 8),這意味著第二個(gè)軸在 8 個(gè)設(shè)備上進(jìn)行分片(fsdp 軸),而第一個(gè)軸未進(jìn)行分片。

圖片

這種分片策略優(yōu)化了參數(shù)的分配,減少了通信開銷,并在訓(xùn)練過程中增強(qiáng)了并行性。它確保每個(gè)設(shè)備僅持有一部分 LoRA 參數(shù),使得大模型如 LLaMA 405B 的高效擴(kuò)展成為可能。

僅更新 LoRA 參數(shù) 

為了優(yōu)化訓(xùn)練,在微調(diào) LLaMA 405B 模型,只計(jì)算 LoRA 參數(shù)的梯度,保持主模型參數(shù)不變。這個(gè)方法減少了內(nèi)存使用,并加速了訓(xùn)練,因?yàn)橹桓螺^少的參數(shù)。可以移步 GitHub 倉庫,查看實(shí)現(xiàn)細(xì)節(jié)。

在訓(xùn)練過程中,每一步都涉及將一批輸入數(shù)據(jù)通過模型進(jìn)行處理。由于只有 LoRA 參數(shù)是可訓(xùn)練的,因此模型的預(yù)測和計(jì)算的損失僅依賴于這些參數(shù),然后對(duì) LoRA 參數(shù)進(jìn)行反向傳播。只更新這些參數(shù)簡化了訓(xùn)練過程,使得在多個(gè) GPU 上高效微調(diào)像 LLaMA 405B 這樣的大型模型成為可能。

更多研究細(xì)節(jié),請參考原博客。

責(zé)任編輯:張燕妮 來源: 機(jī)器之心
相關(guān)推薦

2024-07-24 13:58:25

2024-08-16 14:00:00

2024-08-02 14:53:00

2024-07-24 13:18:17

2023-06-07 08:22:59

LLM微調(diào)技術(shù)

2024-04-15 12:50:00

大型語言模型ReFT

2023-06-14 12:08:51

2023-08-13 07:44:18

GPU模型英偉達(dá)

2023-06-28 21:47:54

2024-12-25 13:33:18

2024-07-23 09:20:35

2024-09-09 07:46:16

2025-04-10 07:59:51

2024-04-29 06:46:50

2023-10-20 17:53:05

2024-12-30 00:01:00

多模態(tài)大模型Python

2024-09-06 13:00:29

2024-07-24 09:20:45

2024-03-25 08:00:00

點(diǎn)贊
收藏

51CTO技術(shù)棧公眾號(hào)