想訓(xùn)練類Sora模型嗎?尤洋團(tuán)隊(duì)OpenDiT實(shí)現(xiàn)80%加速
作為 2024 開年王炸,Sora 的出現(xiàn)樹立了一個(gè)全新的追趕目標(biāo),每個(gè)文生視頻的研究者都想在最短的時(shí)間內(nèi)復(fù)現(xiàn) Sora 的效果。
根據(jù) OpenAI 披露的技術(shù)報(bào)告,Sora 的核心技術(shù)點(diǎn)之一是將視覺數(shù)據(jù)轉(zhuǎn)化為 patch 的統(tǒng)一表征形式,并通過 Transformer 和擴(kuò)散模型結(jié)合,展現(xiàn)了卓越的擴(kuò)展(scale)特性。在報(bào)告公布后,Sora 核心研發(fā)成員 William Peebles 和紐約大學(xué)計(jì)算機(jī)科學(xué)助理教授謝賽寧合著的論文《Scalable Diffusion Models with Transformers》就成了眾多研究者關(guān)注的重點(diǎn)。大家希望能以論文中提出的 DiT 架構(gòu)為突破口,探索復(fù)現(xiàn) Sora 的可行路徑。
最近,新加坡國立大學(xué)尤洋團(tuán)隊(duì)開源的一個(gè)名為 OpenDiT 的項(xiàng)目為訓(xùn)練和部署 DiT 模型打開了新思路。
OpenDiT 是一個(gè)易于使用、快速且內(nèi)存高效的系統(tǒng),專門用于提高 DiT 應(yīng)用程序的訓(xùn)練和推理效率,包括文本到視頻生成和文本到圖像生成。
項(xiàng)目地址:https://github.com/NUS-HPC-AI-Lab/OpenDiT
OpenDiT 方法介紹
OpenDiT 提供由 Colossal-AI 支持的 Diffusion Transformer (DiT) 的高性能實(shí)現(xiàn)。在訓(xùn)練時(shí),視頻和條件信息分別被輸入到相應(yīng)的編碼器中,作為DiT模型的輸入。隨后,通過擴(kuò)散方法進(jìn)行訓(xùn)練和參數(shù)更新,最終將更新后的參數(shù)同步至EMA(Exponential Moving Average)模型。推理階段則直接使用EMA模型,將條件信息作為輸入,從而生成對應(yīng)的結(jié)果。
圖源:https://www.zhihu.com/people/berkeley-you-yang
OpenDiT 利用了 ZeRO 并行策略,將 DiT 模型參數(shù)分布到多臺機(jī)器上,初步降低了顯存壓力。為了取得更好的性能與精度平衡,OpenDiT 還采用了混合精度的訓(xùn)練策略。具體而言,模型參數(shù)和優(yōu)化器使用 float32 進(jìn)行存儲,以確保更新的準(zhǔn)確性。在模型計(jì)算的過程中,研究團(tuán)隊(duì)為 DiT 模型設(shè)計(jì)了 float16 和 float32 的混合精度方法,以在維持模型精度的同時(shí)加速計(jì)算過程。
DiT 模型中使用的 EMA 方法是一種用于平滑模型參數(shù)更新的策略,可以有效提高模型的穩(wěn)定性和泛化能力。但是會額外產(chǎn)生一份參數(shù)的拷貝,增加了顯存的負(fù)擔(dān)。為了進(jìn)一步降低這部分顯存,研究團(tuán)隊(duì)將 EMA 模型分片,并分別存儲在不同的 GPU 上。在訓(xùn)練過程中,每個(gè) GPU 只需計(jì)算和存儲自己負(fù)責(zé)的部分 EMA 模型參數(shù),并在每次 step 后等待 ZeRO 完成更新后進(jìn)行同步更新。
FastSeq
在 DiT 等視覺生成模型領(lǐng)域,序列并行性對于有效的長序列訓(xùn)練和低延遲推理是必不可少的。
然而,DeepSpeed-Ulysses、Megatron-LM Sequence Parallelism 等現(xiàn)有方法在應(yīng)用于此類任務(wù)時(shí)面臨局限性 —— 要么是引入過多的序列通信,要么是在處理小規(guī)模序列并行時(shí)缺乏效率。
為此,研究團(tuán)隊(duì)提出了 FastSeq,一種適用于大序列和小規(guī)模并行的新型序列并行。FastSeq 通過為每個(gè) transformer 層僅使用兩個(gè)通信運(yùn)算符來最小化序列通信,利用 AllGather 來提高通信效率,并策略性地采用異步 ring 將 AllGather 通信與 qkv 計(jì)算重疊,進(jìn)一步優(yōu)化性能。
算子優(yōu)化
在 DiT 模型中引入 adaLN 模塊將條件信息融入視覺內(nèi)容,雖然這一操作對模型的性能提升至關(guān)重要,但也帶來了大量的逐元素操作,并且在模型中被頻繁調(diào)用,降低了整體的計(jì)算效率。為了解決這個(gè)問題,研究團(tuán)隊(duì)提出了高效的 Fused adaLN Kernel,將多次操作合并成一次,從而增加了計(jì)算效率,并且減少了視覺信息的 I/O 消耗。
圖源:https://www.zhihu.com/people/berkeley-you-yang
簡單來說,OpenDiT 具有以下性能優(yōu)勢:
1、在 GPU 上加速高達(dá) 80%,50%的內(nèi)存節(jié)省
- 設(shè)計(jì)了高效的算子,包括針對DiT設(shè)計(jì)的 Fused AdaLN,以及 FlashAttention、Fused Layernorm 和HybridAdam。
- 采用混合并行方法,包括 ZeRO、Gemini 和 DDP。對 ema 模型進(jìn)行分片也進(jìn)一步降低了內(nèi)存成本。
2、FastSeq:一種新穎的序列并行方法
- 專為類似 DiT 的工作負(fù)載而設(shè)計(jì),在這些應(yīng)用中,序列通常較長,但參數(shù)相比于 LLM 較小。
- 節(jié)點(diǎn)內(nèi)序列并行可節(jié)省高達(dá) 48% 的通信量。
- 打破單個(gè) GPU 的內(nèi)存限制,減少整體訓(xùn)練和推理時(shí)間。
3、易于使用
- 只需幾行代碼的修改,即可獲得巨大的性能提升。
- 用戶無需了解分布式訓(xùn)練的實(shí)現(xiàn)方式。
4、文本到圖像和文本到視頻生成完整 pipeline
- 研究人員和工程師可以輕松使用 OpenDiT pipeline 并將其應(yīng)用于實(shí)際應(yīng)用,而無需修改并行部分。
- 研究團(tuán)隊(duì)通過在 ImageNet 上進(jìn)行文本到圖像訓(xùn)練來驗(yàn)證 OpenDiT 的準(zhǔn)確性,并發(fā)布了檢查點(diǎn)(checkpoint)。
安裝與使用
要使用 OpenDiT,首先要安裝先決條件:
- Python >= 3.10
- PyTorch >= 1.13(建議使用 >2.0 版本)
- CUDA >= 11.6
建議使用 Anaconda 創(chuàng)建一個(gè)新環(huán)境(Python >= 3.10)來運(yùn)行示例:
conda create -n opendit pythnotallow=3.10 -y
conda activate opendit
安裝 ColossalAI:
git clone https://github.com/hpcaitech/ColossalAI.gitcd ColossalAI
git checkout adae123df3badfb15d044bd416f0cf29f250bc86
pip install -e .
安裝 OpenDiT:
git clone https://github.com/oahzxl/OpenDiTcd OpenDiT
pip install -e .
(可選但推薦)安裝庫以加快訓(xùn)練和推理速度:
# Install Triton for fused adaln kernel
pip install triton
# Install FlashAttention
pip install flash-attn
# Install apex for fused layernorm kernel
git clone https://github.com/NVIDIA/apex.gitcd apex
git checkout 741bdf50825a97664db08574981962d66436d16a
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-optinotallow=--cpp_ext" --config-settings "--build-optinotallow=--cuda_ext" ./--global-optinotallow="--cuda_ext" --global-optinotallow="--cpp_ext"
圖像生成
你可以通過執(zhí)行以下命令來訓(xùn)練 DiT 模型:
# Use script
bash train_img.sh# Use command line
torchrun --standalone --nproc_per_node=2 train.py \
--model DiT-XL/2 \
--batch_size 2
默認(rèn)禁用所有加速方法。以下是訓(xùn)練過程中一些關(guān)鍵要素的詳細(xì)信息:
- plugin: 支持 ColossalAI、zero2 和 ddp 使用的 booster 插件。默認(rèn)是 zero2,建議啟用 zero2。
- mixed_ precision:混合精度訓(xùn)練的數(shù)據(jù)類型,默認(rèn)是 fp16。
- grad_checkpoint: 是否啟用梯度檢查點(diǎn)。這節(jié)省了訓(xùn)練過程的內(nèi)存成本。默認(rèn)值為 False。建議在內(nèi)存足夠的情況下禁用它。
- enable_modulate_kernel: 是否啟用 modulate 內(nèi)核優(yōu)化,以加快訓(xùn)練過程。默認(rèn)值為 False,建議在 GPU < H100 時(shí)啟用它。
- enable_layernorm_kernel: 是否啟用 layernorm 內(nèi)核優(yōu)化,以加快訓(xùn)練過程。默認(rèn)值為 False,建議啟用它。
- enable_flashattn: 是否啟用 FlashAttention,以加快訓(xùn)練過程。默認(rèn)值為 False,建議啟用。
- sequence_parallel_size:序列并行度大小。當(dāng)設(shè)置值 > 1 時(shí)將啟用序列并行。默認(rèn)值為 1,如果內(nèi)存足夠,建議禁用它。
如果你想使用 DiT 模型進(jìn)行推理,可以運(yùn)行如下代碼,需要將檢查點(diǎn)路徑替換為你自己訓(xùn)練的模型。
# Use script
bash sample_img.sh# Use command line
python sample.py --model DiT-XL/2 --image_size 256 --ckpt ./model.pt
視頻生成
你可以通過執(zhí)行以下命令來訓(xùn)練視頻 DiT 模型:
# train with scipt
bash train_video.sh# train with command line
torchrun --standalone --nproc_per_node=2 train.py \
--model vDiT-XL/222 \
--use_video \
--data_path ./videos/demo.csv \
--batch_size 1 \
--num_frames 16 \
--image_size 256 \
--frame_interval 3
# preprocess
# our code read video from csv as the demo shows
# we provide a code to transfer ucf101 to csv format
python preprocess.py
使用 DiT 模型執(zhí)行視頻推理的代碼如下所示:
# Use script
bash sample_video.sh# Use command line
python sample.py \
--model vDiT-XL/222 \
--use_video \
--ckpt ckpt_path \
--num_frames 16 \
--image_size 256 \
--frame_interval 3
DiT 復(fù)現(xiàn)結(jié)果
為了驗(yàn)證 OpenDiT 的準(zhǔn)確性,研究團(tuán)隊(duì)使用 OpenDiT 的 origin 方法對 DiT 進(jìn)行了訓(xùn)練,在 ImageNet 上從頭開始訓(xùn)練模型,在 8xA100 上執(zhí)行 80k step。以下是經(jīng)過訓(xùn)練的 DiT 生成的一些結(jié)果:
損失也與 DiT 論文中列出的結(jié)果一致:
要復(fù)現(xiàn)上述結(jié)果,需要更改 train_img.py 中的數(shù)據(jù)集并執(zhí)行以下命令:
torchrun --standalone --nproc_per_node=8 train.py \
--model DiT-XL/2 \
--batch_size 180 \
--enable_layernorm_kernel \
--enable_flashattn \
--mixed_precision fp16
感興趣的讀者可以查看項(xiàng)目主頁,了解更多研究內(nèi)容。