SFT loss計(jì)算的那些坑,完美避開?。?!
?SFT 可以說是 LLM 的基本操作了,如果只是想把 SFT 跑起來是非常簡(jiǎn)單的,只需要構(gòu)造 input_ids 和 labels,然后就可以把訓(xùn)練跑起來。然而,這樣的訓(xùn)練效率實(shí)際上非常低。
所以在訓(xùn)練時(shí),通常有兩個(gè)加速方法:
- 多輪合并
- packing
無論是哪種方法,加速后都需要保證 loss 和原來是等價(jià)的。本文主要介紹這兩種加速方法,以及 loss 計(jì)算時(shí)遇到的問題。
1.多輪合并
假設(shè)我們有一個(gè)對(duì)話,其中 user 和 bot 交互了 3 輪,我們可以構(gòu)建三個(gè)樣本:
input_ids 就是對(duì)應(yīng)的 token id,labels 輸入部分(白色)使用 -100,輸出部分(綠色)使用 input_ids。
這樣計(jì)算的 loss 可以表示為:
其中 l_i 表示第 i 個(gè)樣本的 loss,n_i 表示第 i 個(gè)樣本輸出的 token 數(shù)量 (對(duì)應(yīng)綠色部分)。
這樣除了訓(xùn)練比較慢,沒有什么別的問題。因?yàn)椴煌瑯颖局g有很多重復(fù)計(jì)算的前綴,實(shí)際上這部分計(jì)算一次就行。
2.加速計(jì)算
如果將三輪三個(gè)樣本合并成一個(gè)樣本,可以嘗試這種構(gòu)造形式。
因?yàn)榇嬖?causal attention mask,所以每個(gè) token 只能看到前面的 token,計(jì)算上和之前是等價(jià)的。
但是這樣有一個(gè)坑:如果還是按照剛才的方式構(gòu)建 input_ids 和 labels (白色用-100,綠色用input_ids)loss 計(jì)算是有問題的。
pytorch CrossEntropyLoss 計(jì)算 loss 按照下面的方法,默認(rèn)是"mean"。
所以我們會(huì)得到這樣的 loss:
當(dāng)不同輪次的輸出長(zhǎng)度不同時(shí),這種 loss 和剛才的不等價(jià)。多輪對(duì)話中輸出較短的權(quán)重被降低了,輸出較長(zhǎng)的被提高了。所以結(jié)果就是短輸出的數(shù)據(jù)訓(xùn)練不夠充分。
3.Packing
假設(shè)我們有兩個(gè)對(duì)話,第一個(gè)是單輪對(duì)話,第二個(gè)是三輪對(duì)話。
正確的 loss:
其中 l_ij 表示第 i 個(gè)樣本第 j 輪對(duì)話的 loss,n_ij 同理。
問題:真實(shí)場(chǎng)景中的訓(xùn)練集文本長(zhǎng)度長(zhǎng)短不一,Padding 后矩陣非常稀疏,只有不到一半是有效計(jì)算。
加速計(jì)算:
將所有樣本拼接成一條,并且加入 attention mask 保證后面的樣本看不見前面的 token。
比如在 flash attention 中,可以調(diào)用 flash_attn_varlen_qkvpacked_func,并傳入 cu_seqlens 參數(shù)。
和之前一樣,如果不修改 loss 計(jì)算方法,packing 的樣本之間會(huì)存在因?yàn)殚L(zhǎng)度不同,導(dǎo)致訓(xùn)練不充分的問題。
4.正確方法
一般情況下,loss 計(jì)算會(huì)經(jīng)歷三次平均:
- micro batch 維度,分母是這個(gè) micro batch 中的所有 label 不是 -100 的 token 數(shù)
- DP 維度,分母是 DP size (和GPU數(shù)量相關(guān))
- 梯度累加維度,分母是梯度累加數(shù)
我們這里要做的就是禁用這三個(gè)平均,統(tǒng)一用這個(gè) global batch 的對(duì)話輪數(shù)作為分母。
在新版 megatron 框架中,開啟開關(guān) --calculate-per-token-loss 即可禁用 DP 和梯度累加的平均,然后修改 loss_func。
每個(gè) micro batch 都需要返回這個(gè) micro batch 的輪數(shù),最后框架會(huì)自動(dòng)將所有輪數(shù)求和,作為分母。對(duì)于分子,需要除以這個(gè)輪次的token 數(shù)。
正確實(shí)現(xiàn)代碼如下(loss_token_num, turn_num 是在構(gòu)建 data 的時(shí)候構(gòu)建的):
def loss_func(output_tensor, loss_mask, loss_token_num, turn_num):
losses = output_tensor.view(-1).float()
loss_mask = loss_mask.view(-1).float()
loss_token_num = loss_token_num.view(-1).float()
# label: [-100, -100, a, a, a, -100, b, b, -100, -100, c, c, c, -100, -100]
# loss_mask: [0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0]
# losses: [a0, a1, a2, a3, a4, b0, b1, b2, c0, c1, c2, c3, c4, d0, d1]
# losses * loss_mask = [0, 0, a2, a3, a4, 0, b1, b2, 0, 0, c2, c3, c4, 0, 0]
# loss_token_num: [3, 3, 3, 3, 3, 2, 2, 2, 3, 3, 3, 3, 3, 1, 1]
# losses * loss_mask / loss_token_num = [0, 0, a2/3, a3/3, a4/3, 0, b1/2, b2/2, 0, 0, c2/3, c3/3, c4/3, 0, 0]
# sum = 1/3 (a2 + a3 + a4) + 1/2 (b1 + b2) + 1/3 (c2 + c3 + c4)
loss = torch.sum(losses * loss_mask / loss_token_num)
loss_and_turn_num = torch.cat([loss.view(1), turn_num.view(1)])
# Reduce loss for logging.
loss_and_turn_num = loss_and_turn_num.clone().detach()
torch.distributed.all_reduce(loss_and_turn_num, group=mpu.get_data_parallel_group())
# 新版返回結(jié)構(gòu),開啟 calculate_per_token_loss 開關(guān)后,返回三個(gè)值
# 第一個(gè)是反向傳播實(shí)際使用的 loss, 所有 packing 的 loss 求和
# 第二個(gè)是 turn_num, 優(yōu)化器狀態(tài)更新時(shí)會(huì)使用對(duì)這個(gè)值求和然后縮放梯度
# 第三個(gè)是用于日志打印的 loss, 包含兩個(gè)值,第一個(gè)是所有 loss 求和作為分子,第二個(gè)是所有 turn_num 求和作為分母
return loss, turn_num, {"lm loss": (loss_and_turn_num[0], loss_and_turn_num[1])}
5.總結(jié)
在 SFT 時(shí),如果要加速,需要注意:
- 不同樣本之間是等價(jià)的;
- 不同輪次之間也是等價(jià)的。
在合并多輪 / packing 時(shí),需要修改 loss 計(jì)算方法,為每個(gè) token 設(shè)置正確的權(quán)重,并且關(guān)閉 DP / 梯度累加的平均。?
