如果想要在某個(gè)模型基礎(chǔ)上做全參數(shù)微調(diào),需要多少顯存?
全參數(shù)微調(diào)(Full Parameter Fine-Tuning)的顯存需求取決于多個(gè)因素,包括模型的大小、數(shù)據(jù)的批量大?。˙atch Size)、優(yōu)化器的狀態(tài)存儲(chǔ)以及是否使用混合精度訓(xùn)練等。以下是一個(gè)詳細(xì)的分析:
模型參數(shù)大小
模型參數(shù)顯存占用:模型的每個(gè)參數(shù)在顯存中占用一定的空間。通常,單精度浮點(diǎn)數(shù)(FP32)占用4字節(jié),半精度浮點(diǎn)數(shù)(FP16)占用2字節(jié)。
計(jì)算公式:
模型參數(shù)顯存=模型參數(shù)數(shù)量×每個(gè)參數(shù)占用的字節(jié)數(shù)
示例:
如果模型有1.5億個(gè)參數(shù)(如BERT-Base),使用FP32精度,顯存占用為:
梯度存儲(chǔ)
在反向傳播中,每個(gè)參數(shù)的梯度也需要存儲(chǔ)在顯存中。
計(jì)算公式:
梯度顯存=模型參數(shù)數(shù)量×每個(gè)參數(shù)占用的字節(jié)數(shù)
示例:
對(duì)于上述BERT-Base模型(FP32),梯度顯存占用為:
優(yōu)化器狀態(tài)
常用的優(yōu)化器(如Adam)會(huì)為每個(gè)參數(shù)存儲(chǔ)額外的狀態(tài)(如動(dòng)量和方差估計(jì))。
不同優(yōu)化器的狀態(tài)倍數(shù)如下:
AdamW (2 states): 8 Bytes per parameter
AdamW (bitsandbytes Quantized): 2 Bytes per parameter
SGD (1 state): 4 Bytes per parameter
計(jì)算公式:
優(yōu)化器狀態(tài)顯存=模型參數(shù)數(shù)量×每個(gè)參數(shù)占用的字節(jié)數(shù)×優(yōu)化器狀態(tài)倍數(shù)
示例:
對(duì)于BERT-Base模型(FP32),優(yōu)化器狀態(tài)顯存占用為:
激活值和臨時(shí)變量
在前向和反向傳播過程中,網(wǎng)絡(luò)的激活值(中間層輸出)和臨時(shí)變量也會(huì)占用顯存。
估算公式:
激活值顯存≈模型參數(shù)數(shù)量×每個(gè)參數(shù)占用的字節(jié)數(shù)×2
示例:
對(duì)于BERT-Base模型(FP32),激活值顯存占用為:
批量大?。˙atch Size)
批量大小會(huì)顯著影響顯存占用。每個(gè)樣本的輸入、輸出和中間激活值都需要存儲(chǔ)。
估算公式:
Batch Size顯存=Batch Size×(輸入大小+輸出大小+中間激活值大小)
示例:
假設(shè)輸入為512個(gè)token的文本,每個(gè)token的嵌入維度為768(BERT-Base),Batch Size為32,則輸入顯存占用為:
總結(jié)公式
綜合以上因素,全參數(shù)微調(diào)的顯存需求估算公式為:
總顯存需求=(模型參數(shù)顯存+梯度顯存+優(yōu)化器狀態(tài)顯存+激活值顯存)×精度倍數(shù)+Batch Size顯存
示例:BERT-Base全參數(shù)微調(diào)(FP32)
- 模型參數(shù)顯存:600MB
- 梯度顯存:600MB
- 優(yōu)化器狀態(tài)顯存:1200MB
- 激活值顯存:1200MB
- Batch Size顯存:假設(shè)為100MB(根據(jù)輸入大小和Batch Size估算)
最終總顯存需求:
600+600+1200+1200+100=3700MB≈3.7GB