自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

SFT loss計(jì)算的那些坑,完美避開?。?!

發(fā)布于 2024-12-11 10:48
瀏覽
0收藏

?SFT 可以說是 LLM 的基本操作了,如果只是想把 SFT 跑起來是非常簡(jiǎn)單的,只需要構(gòu)造 input_ids 和 labels,然后就可以把訓(xùn)練跑起來。然而,這樣的訓(xùn)練效率實(shí)際上非常低。

所以在訓(xùn)練時(shí),通常有兩個(gè)加速方法:

  • 多輪合并
  • packing

無論是哪種方法,加速后都需要保證 loss 和原來是等價(jià)的。本文主要介紹這兩種加速方法,以及 loss 計(jì)算時(shí)遇到的問題。

1.多輪合并

假設(shè)我們有一個(gè)對(duì)話,其中 user 和 bot 交互了 3 輪,我們可以構(gòu)建三個(gè)樣本:

SFT loss計(jì)算的那些坑,完美避開?。。?AI.x社區(qū)

input_ids 就是對(duì)應(yīng)的 token id,labels 輸入部分(白色)使用 -100,輸出部分(綠色)使用 input_ids。

這樣計(jì)算的 loss 可以表示為:

SFT loss計(jì)算的那些坑,完美避開?。。?AI.x社區(qū)

其中 l_i 表示第 i 個(gè)樣本的 loss,n_i 表示第 i 個(gè)樣本輸出的 token 數(shù)量 (對(duì)應(yīng)綠色部分)。

這樣除了訓(xùn)練比較慢,沒有什么別的問題。因?yàn)椴煌瑯颖局g有很多重復(fù)計(jì)算的前綴,實(shí)際上這部分計(jì)算一次就行。

2.加速計(jì)算

SFT loss計(jì)算的那些坑,完美避開?。。?AI.x社區(qū)

如果將三輪三個(gè)樣本合并成一個(gè)樣本,可以嘗試這種構(gòu)造形式。

因?yàn)榇嬖?causal attention mask,所以每個(gè) token 只能看到前面的 token,計(jì)算上和之前是等價(jià)的。

但是這樣有一個(gè)坑:如果還是按照剛才的方式構(gòu)建 input_ids 和 labels (白色用-100,綠色用input_ids)loss 計(jì)算是有問題的。

pytorch CrossEntropyLoss 計(jì)算 loss 按照下面的方法,默認(rèn)是"mean"。

SFT loss計(jì)算的那些坑,完美避開!??!-AI.x社區(qū)

所以我們會(huì)得到這樣的 loss:

SFT loss計(jì)算的那些坑,完美避開?。。?AI.x社區(qū)

當(dāng)不同輪次的輸出長(zhǎng)度不同時(shí),這種 loss 和剛才的不等價(jià)。多輪對(duì)話中輸出較短的權(quán)重被降低了,輸出較長(zhǎng)的被提高了。所以結(jié)果就是短輸出的數(shù)據(jù)訓(xùn)練不夠充分。

3.Packing

假設(shè)我們有兩個(gè)對(duì)話,第一個(gè)是單輪對(duì)話,第二個(gè)是三輪對(duì)話。

SFT loss計(jì)算的那些坑,完美避開?。?!-AI.x社區(qū)

正確的 loss:

SFT loss計(jì)算的那些坑,完美避開!??!-AI.x社區(qū)

其中 l_ij 表示第 i 個(gè)樣本第 j 輪對(duì)話的 loss,n_ij 同理。

問題:真實(shí)場(chǎng)景中的訓(xùn)練集文本長(zhǎng)度長(zhǎng)短不一,Padding 后矩陣非常稀疏,只有不到一半是有效計(jì)算。

加速計(jì)算:

SFT loss計(jì)算的那些坑,完美避開!?。?AI.x社區(qū)

將所有樣本拼接成一條,并且加入 attention mask 保證后面的樣本看不見前面的 token。

比如在 flash attention 中,可以調(diào)用 flash_attn_varlen_qkvpacked_func,并傳入 cu_seqlens 參數(shù)。

和之前一樣,如果不修改 loss 計(jì)算方法,packing 的樣本之間會(huì)存在因?yàn)殚L(zhǎng)度不同,導(dǎo)致訓(xùn)練不充分的問題。

4.正確方法

一般情況下,loss 計(jì)算會(huì)經(jīng)歷三次平均:

  • micro batch 維度,分母是這個(gè) micro batch 中的所有 label 不是 -100 的 token 數(shù)
  • DP 維度,分母是 DP size (和GPU數(shù)量相關(guān))
  • 梯度累加維度,分母是梯度累加數(shù)

我們這里要做的就是禁用這三個(gè)平均,統(tǒng)一用這個(gè) global batch 的對(duì)話輪數(shù)作為分母。

在新版 megatron 框架中,開啟開關(guān) --calculate-per-token-loss 即可禁用 DP 和梯度累加的平均,然后修改 loss_func。

每個(gè) micro batch 都需要返回這個(gè) micro batch 的輪數(shù),最后框架會(huì)自動(dòng)將所有輪數(shù)求和,作為分母。對(duì)于分子,需要除以這個(gè)輪次的token 數(shù)。

正確實(shí)現(xiàn)代碼如下(loss_token_num, turn_num 是在構(gòu)建 data 的時(shí)候構(gòu)建的):

def loss_func(output_tensor, loss_mask, loss_token_num, turn_num):
    losses = output_tensor.view(-1).float()
    loss_mask = loss_mask.view(-1).float()
    loss_token_num = loss_token_num.view(-1).float()
    # label: [-100, -100, a, a, a, -100, b, b, -100, -100, c, c, c, -100, -100]
    # loss_mask: [0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0]
    # losses: [a0, a1, a2, a3, a4, b0, b1, b2, c0, c1, c2, c3, c4, d0, d1]
    # losses * loss_mask = [0, 0, a2, a3, a4, 0, b1, b2, 0, 0, c2, c3, c4, 0, 0]
    # loss_token_num: [3, 3, 3, 3, 3, 2, 2, 2, 3, 3, 3, 3, 3, 1, 1]
    # losses * loss_mask / loss_token_num = [0, 0, a2/3, a3/3, a4/3, 0, b1/2, b2/2, 0, 0, c2/3, c3/3, c4/3, 0, 0]
    # sum = 1/3 (a2 + a3 + a4) + 1/2 (b1 + b2) + 1/3 (c2 + c3 + c4)
    loss = torch.sum(losses * loss_mask / loss_token_num)
    loss_and_turn_num = torch.cat([loss.view(1), turn_num.view(1)])
    # Reduce loss for logging.
    loss_and_turn_num = loss_and_turn_num.clone().detach()
    torch.distributed.all_reduce(loss_and_turn_num, group=mpu.get_data_parallel_group())
    # 新版返回結(jié)構(gòu),開啟 calculate_per_token_loss 開關(guān)后,返回三個(gè)值
    # 第一個(gè)是反向傳播實(shí)際使用的 loss, 所有 packing 的 loss 求和
    # 第二個(gè)是 turn_num, 優(yōu)化器狀態(tài)更新時(shí)會(huì)使用對(duì)這個(gè)值求和然后縮放梯度
    # 第三個(gè)是用于日志打印的 loss, 包含兩個(gè)值,第一個(gè)是所有 loss 求和作為分子,第二個(gè)是所有 turn_num 求和作為分母
    return loss, turn_num, {"lm loss": (loss_and_turn_num[0], loss_and_turn_num[1])}

5.總結(jié)

在 SFT 時(shí),如果要加速,需要注意:

  • 不同樣本之間是等價(jià)的;
  • 不同輪次之間也是等價(jià)的。

在合并多輪 / packing 時(shí),需要修改 loss 計(jì)算方法,為每個(gè) token 設(shè)置正確的權(quán)重,并且關(guān)閉 DP / 梯度累加的平均。?

本文轉(zhuǎn)載自??丁師兄大模型??,作者:Ethan Yan ????

標(biāo)簽
收藏
回復(fù)
舉報(bào)
回復(fù)
相關(guān)推薦