PyTorch內(nèi)存優(yōu)化的十種策略總結(jié):在有限資源環(huán)境下高效訓(xùn)練模型
在大規(guī)模深度學(xué)習(xí)模型訓(xùn)練過(guò)程中,GPU內(nèi)存容量往往成為制約因素,尤其是在訓(xùn)練大型語(yǔ)言模型(LLM)和視覺(jué)Transformer等現(xiàn)代架構(gòu)時(shí)。由于大多數(shù)研究者和開(kāi)發(fā)者無(wú)法使用配備海量GPU內(nèi)存的高端計(jì)算集群,因此掌握有效的內(nèi)存優(yōu)化技術(shù)變得尤為關(guān)鍵。本文將系統(tǒng)性地介紹多種內(nèi)存優(yōu)化策略,這些技術(shù)組合應(yīng)用可使模型訓(xùn)練的內(nèi)存消耗降低近20倍,同時(shí)不會(huì)損害模型性能和預(yù)測(cè)準(zhǔn)確率。以下大部分技術(shù)可以相互結(jié)合,以獲得更顯著的內(nèi)存效率提升。
1、自動(dòng)混合精度訓(xùn)練
混合精度訓(xùn)練是降低內(nèi)存占用的基礎(chǔ)且高效的方法,它充分利用16位(FP16)和32位(FP32)浮點(diǎn)格式的優(yōu)勢(shì)。
混合精度訓(xùn)練的核心思想是在大部分計(jì)算中使用較低精度執(zhí)行數(shù)學(xué)運(yùn)算,從而減少內(nèi)存帶寬和存儲(chǔ)需求,同時(shí)在計(jì)算的關(guān)鍵環(huán)節(jié)保持必要的精度。通過(guò)對(duì)激活值和梯度采用FP16格式,這些張量的內(nèi)存占用可減少約50%。然而某些特定的層或操作仍需要FP32格式以避免數(shù)值不穩(wěn)定。PyTorch對(duì)自動(dòng)混合精度(AMP)的原生支持大大簡(jiǎn)化了實(shí)現(xiàn)過(guò)程。
混合精度訓(xùn)練 與 低精度訓(xùn)練 有本質(zhì)區(qū)別
關(guān)于混合精度訓(xùn)練是否會(huì)影響模型準(zhǔn)確率的問(wèn)題,答案是否?;旌暇扔?xùn)練通過(guò)精心設(shè)計(jì)的計(jì)算流程保持了計(jì)算精度。
混合精度訓(xùn)練原理
混合精度訓(xùn)練通過(guò)結(jié)合16位(FP16)和32位(FP32)浮點(diǎn)格式來(lái)保持計(jì)算準(zhǔn)確性。使用16位精度計(jì)算梯度可顯著加快計(jì)算速度并減少內(nèi)存消耗,同時(shí)維持與32位分辨率相當(dāng)?shù)慕Y(jié)果質(zhì)量。這種方法在計(jì)算資源受限的環(huán)境中尤為有效。
"混合精度"一詞更準(zhǔn)確地描述了這一過(guò)程,因?yàn)椴⒎撬袇?shù)和操作都轉(zhuǎn)換為16位格式。實(shí)際訓(xùn)練過(guò)程中,32位和16位操作交替執(zhí)行,形成混合精度計(jì)算流程。
如上圖所示,該過(guò)程首先將權(quán)重轉(zhuǎn)換為低精度(FP16)以加速計(jì)算,然后計(jì)算梯度,接著將梯度轉(zhuǎn)回高精度(FP32)以確保數(shù)值穩(wěn)定性,最后使用這些適當(dāng)縮放的梯度更新原始權(quán)重。通過(guò)這種方式,混合精度訓(xùn)練可提高訓(xùn)練效率的同時(shí)保持網(wǎng)絡(luò)的整體精度和穩(wěn)定性。
使用torch.cuda.amp.autocast()可輕松實(shí)現(xiàn)混合精度訓(xùn)練,示例代碼如下:
import torch
from torch.cuda.amp import autocast, GradScaler
# Assume your model and optimizer have been defined elsewhere.
model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()
for data, target in data_loader:
optimizer.zero_grad()
# Enable mixed precision
with autocast():
output = model(data)
loss = loss_fn(output, target)
# Scale the loss and backpropagate
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
2、低精度訓(xùn)練
除了混合精度訓(xùn)練,我們還可以嘗試使用完整的16位低精度格式進(jìn)行訓(xùn)練。由于16位浮點(diǎn)數(shù)的表示范圍限制,這種方法可能導(dǎo)致NaN值出現(xiàn)。為解決這一問(wèn)題,研究人員開(kāi)發(fā)了多種專用浮點(diǎn)格式。其中,Brain Floating Point(BF16)是Google為此專門(mén)開(kāi)發(fā)的一種廣受歡迎的格式。與標(biāo)準(zhǔn)FP16相比,BF16提供了更大的動(dòng)態(tài)范圍,能夠表示極大和極小的數(shù)值,使其更適合于深度學(xué)習(xí)應(yīng)用中可能遇到的多樣化數(shù)值情況。盡管較低精度可能在某些計(jì)算中影響精確度或?qū)е律崛胝`差,但在大多數(shù)深度學(xué)習(xí)應(yīng)用場(chǎng)景中,這種影響對(duì)模型性能的影響極小。
雖然BF16最初是為T(mén)PU設(shè)計(jì)的,但現(xiàn)在大多數(shù)現(xiàn)代GPU(Nvidia Ampere架構(gòu)及更高版本)也支持這種格式??梢酝ㄟ^(guò)以下方法檢查GPU是否支持BF16格式:
import torch
print(torch.cuda.is_bf16_supported()) # should print True
3、梯度檢查點(diǎn)
即便采用混合精度和低精度訓(xùn)練,大型模型在前向傳播過(guò)程中產(chǎn)生的大量中間張量仍會(huì)消耗大量?jī)?nèi)存。梯度檢查點(diǎn)(Gradient Checkpointing)技術(shù)通過(guò)在前向傳播過(guò)程中選擇性地僅存儲(chǔ)部分中間結(jié)果來(lái)解決這一問(wèn)題。在反向傳播過(guò)程中,系統(tǒng)會(huì)重新計(jì)算缺失的中間值,這雖然增加了計(jì)算成本,但可以顯著降低內(nèi)存需求。
通過(guò)戰(zhàn)略性地選擇需要設(shè)置檢查點(diǎn)的層,可以通過(guò)動(dòng)態(tài)重新計(jì)算激活值而非存儲(chǔ)它們來(lái)減少內(nèi)存使用。對(duì)于具有深層架構(gòu)的模型,中間激活值通常占據(jù)了內(nèi)存消耗的主要部分,此時(shí)這種權(quán)衡尤為有效。以下是梯度檢查點(diǎn)的實(shí)現(xiàn)示例:
import torch
from torch.utils.checkpoint import checkpoint
def checkpointed_segment(input_tensor):
# This function represents a portion of your model
# which will be recomputed during the backward pass.
# You can create a custom forward pass for this segment.
return model_segment(input_tensor)
# Instead of a conventional forward pass, wrap the segment with checkpoint.
output = checkpoint(checkpointed_segment, input_tensor)
采用此方法,在多數(shù)情況下可將激活值所需的內(nèi)存減少40-50%。盡管反向傳播現(xiàn)在包含額外的計(jì)算開(kāi)銷,但當(dāng)GPU內(nèi)存成為限制因素時(shí),這種權(quán)衡通常是合理的。
4、使用梯度累積降低批量大小
在嘗試上述方法后,一個(gè)自然的問(wèn)題是:
為何不直接減小批量大?。?/span>
雖然這確實(shí)是最直接的方法,但通常使用較小批量大小會(huì)導(dǎo)致預(yù)測(cè)性能下降。簡(jiǎn)單減小批量大小雖然能顯著降低內(nèi)存消耗,但往往會(huì)對(duì)模型準(zhǔn)確率產(chǎn)生不良影響。
如何在這兩者之間取得平衡?
梯度累積(Gradient Accumulation)正是為解決這一問(wèn)題而設(shè)計(jì)的技術(shù)。它允許在訓(xùn)練過(guò)程中虛擬增加批量大小,其核心原理是為較小的批量計(jì)算梯度,并在多次迭代中累積這些梯度(通常通過(guò)求和或平均),而不是在每個(gè)批次后立即更新模型權(quán)重。一旦累積的梯度達(dá)到目標(biāo)"虛擬"批量大小,才使用這些累積的梯度更新模型參數(shù)。
然而需要注意,這種技術(shù)的主要缺點(diǎn)是顯著增加了訓(xùn)練時(shí)間。
5、張量分片和分布式訓(xùn)練
對(duì)于即使應(yīng)用上述優(yōu)化后仍無(wú)法在單個(gè)GPU上容納的超大模型,完全分片數(shù)據(jù)并行(Fully Sharded Data Parallel, FSDP)技術(shù)提供了解決方案。FSDP將模型參數(shù)、梯度和優(yōu)化器狀態(tài)分片到多個(gè)GPU上,這不僅使得訓(xùn)練超大模型成為可能,還能通過(guò)更合理地分配通信開(kāi)銷提高訓(xùn)練效率。
FSDP不是在每個(gè)GPU上維護(hù)完整的模型副本,而是將模型參數(shù)分配到多個(gè)可用設(shè)備上。在執(zhí)行前向或反向傳播時(shí),系統(tǒng)僅將相關(guān)分片加載到內(nèi)存中。這種分片機(jī)制顯著降低了單個(gè)設(shè)備的內(nèi)存需求,與前述技術(shù)結(jié)合使用,在某些情況下可實(shí)現(xiàn)高達(dá)10倍的內(nèi)存降低效果。
FSDP可通過(guò)以下方式實(shí)現(xiàn):
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
# Initialize your model and ensure it is on the correct device.
model = MyLargeModel().cuda()
# Wrap the model in FSDP for sharded training across GPUs.
fsdp_model = FSDP(model)
6、高效的數(shù)據(jù)加載
內(nèi)存優(yōu)化中常被忽視的一個(gè)方面是數(shù)據(jù)加載效率。雖然大部分優(yōu)化關(guān)注點(diǎn)集中在模型內(nèi)部結(jié)構(gòu)和計(jì)算過(guò)程,但低效的數(shù)據(jù)處理同樣可能造成不必要的瓶頸,影響內(nèi)存利用和計(jì)算速度。作為經(jīng)驗(yàn)法則,當(dāng)處理數(shù)據(jù)加載器時(shí),應(yīng)始終啟用Pinned Memory和配置適當(dāng)?shù)腗ultiple Workers,如下所示:
from torch.utils.data import DataLoader
# Create your dataset instance and then the DataLoader with pinned memory enabled.
train_loader = DataLoader(
dataset,
batch_size=64,
shuffle=True,
num_workers=4, # Adjust based on your CPU capabilities
pin_memory=True # Enables faster host-to-device transfers
)
7、使用原地操作
在處理張量時(shí),如果不謹(jǐn)慎管理,每個(gè)操作都可能創(chuàng)建新的張量對(duì)象。原地操作(In-place Operations)通過(guò)直接修改現(xiàn)有張量而非分配新張量,有助于減少內(nèi)存碎片和總體內(nèi)存占用。這種方式減少了臨時(shí)內(nèi)存分配,在迭代訓(xùn)練循環(huán)中尤為重要。示例如下:
import torch
x = torch.randn(100, 100, device='cuda')
y = torch.randn(100, 100, device='cuda')
# Using in-place addition
x.add_(y) # Here x is modified directly instead of creating a new tensor
8、激活和參數(shù)卸載
對(duì)于極大規(guī)模模型,即使應(yīng)用了所有上述技術(shù),由于大量中間激活值的存在,仍可能達(dá)到GPU內(nèi)存限制。激活和參數(shù)卸載(Activation and Parameter Offloading)技術(shù)通過(guò)將部分中間數(shù)據(jù)移動(dòng)到CPU內(nèi)存,為GPU內(nèi)存提供額外的緩解。
這種方法通過(guò)戰(zhàn)略性地將部分激活值和/或參數(shù)臨時(shí)卸載到主機(jī)內(nèi)存(CPU),僅在GPU內(nèi)存中保留關(guān)鍵計(jì)算所需的數(shù)據(jù)。雖然DeepSpeed、Fabric等專用框架可自動(dòng)管理這種數(shù)據(jù)移動(dòng),但也可以按如下方式實(shí)現(xiàn)自定義卸載邏輯:
def offload_activation(tensor):
# Move tensor to CPU to save GPU memory
return tensor.cpu()
def process_batch(data):
# Offload some activations explicitly
intermediate = model.layer1(data)
intermediate = offload_activation(intermediate)
intermediate = intermediate.cuda() # Move back when needed
output = model.layer2(intermediate)
return output
9、使用更精簡(jiǎn)的優(yōu)化器
各種優(yōu)化器在內(nèi)存消耗方面存在顯著差異。例如,廣泛使用的Adam優(yōu)化器為每個(gè)模型參數(shù)維護(hù)兩個(gè)額外狀態(tài)參數(shù)(動(dòng)量和方差),這意味著更多的內(nèi)存消耗。將Adam替換為無(wú)狀態(tài)優(yōu)化器(如SGD)可將參數(shù)數(shù)量減少近2/3,這在處理LLM等大型模型時(shí)尤為重要。
標(biāo)準(zhǔn)SGD的缺點(diǎn)是收斂特性較差。為彌補(bǔ)這一點(diǎn),可引入余弦退火學(xué)習(xí)率調(diào)度器以實(shí)現(xiàn)更好的收斂效果。實(shí)現(xiàn)示例:
# instead of this
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
# use this
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
num_steps = NUM_EPOCHS * len(train_loader)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=num_steps)
這種優(yōu)化可在保持模型準(zhǔn)確率達(dá)到約97%(取決于具體應(yīng)用)的同時(shí),顯著改善峰值內(nèi)存消耗。
10、進(jìn)階優(yōu)化技術(shù)
除上述基礎(chǔ)技術(shù)外,以下高級(jí)策略可進(jìn)一步優(yōu)化GPU內(nèi)存使用,充分發(fā)揮硬件潛能:
內(nèi)存分析和緩存管理
精確測(cè)量是有效優(yōu)化的前提。PyTorch提供了多種實(shí)用工具用于監(jiān)控GPU內(nèi)存使用情況:
import torch
# print a detailed report of current GPU memory usage and fragmentation
print(torch.cuda.memory_summary(device=None, abbreviated=False))
# free up cached memory that's no longer needed by PyTorch
torch.cuda.empty_cache()
使用TorchScript進(jìn)行JIT編譯
PyTorch的即時(shí)編譯器(JIT)能夠?qū)ython模型轉(zhuǎn)換為經(jīng)過(guò)優(yōu)化的、可序列化的TorchScript程序。這種轉(zhuǎn)換通過(guò)優(yōu)化內(nèi)核啟動(dòng)和減少運(yùn)行時(shí)開(kāi)銷,可帶來(lái)內(nèi)存和性能的雙重提升:
import torch
# Suppose `model` is an instance of your PyTorch network.
scripted_model = torch.jit.script(model)
# Now, you can run the scripted model just like before.
output = scripted_model(input_tensor)
這種編譯方式可顯著優(yōu)化模型運(yùn)行效率。
自定義內(nèi)核融合
編譯的另一項(xiàng)重要優(yōu)勢(shì)是能夠?qū)⒍鄠€(gè)操作融合到單個(gè)計(jì)算內(nèi)核中。內(nèi)核融合有助于減少內(nèi)存讀寫(xiě)操作,提高總體計(jì)算吞吐量:
使用torch.compile()進(jìn)行動(dòng)態(tài)內(nèi)存分配
進(jìn)一步利用編譯技術(shù),JIT編譯器可通過(guò)編譯時(shí)優(yōu)化改進(jìn)動(dòng)態(tài)內(nèi)存分配效率。結(jié)合跟蹤和計(jì)算圖優(yōu)化技術(shù),這種方法可在大型模型和Transformer架構(gòu)中實(shí)現(xiàn)更顯著的內(nèi)存和性能優(yōu)化。
總結(jié)
在GPU和云計(jì)算資源成本高昂的環(huán)境下,最大化利用現(xiàn)有計(jì)算資源至關(guān)重要。對(duì)于希望在有限計(jì)算資源條件下訓(xùn)練或微調(diào)大型模型(如LLM或視覺(jué)Transformer)的研究者和開(kāi)發(fā)者而言,掌握上述優(yōu)化技術(shù)尤為重要。本文介紹的這些策略代表了研究人員和專業(yè)人士在資源受限條件下進(jìn)行高效模型訓(xùn)練的常用方法。