自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

大模型SFT暗藏大陷阱?梯度累計bug造成大范圍影響 原創(chuàng)

發(fā)布于 2024-11-5 13:15
瀏覽
0收藏

在LLM的訓練時,由于顯存不足以支撐起大batch訓練,通常大家都會采用一種策略:梯度累計(gradient accumulate)。這種方法允許模型在多個batch的梯度回傳累計并求均值之后,再更新一次權重。這樣做相當于模擬了一個更大的批量大小,而實際上并沒有一次性處理那么多數據。這樣做的好處是,它可以減少內存的使用,因為不需要一次性加載所有數據到GPU上,同時也可以享受等價大batch帶來的訓練的穩(wěn)定性和模型的泛化能力。

大模型SFT暗藏大陷阱?梯度累計bug造成大范圍影響-AI.x社區(qū)

但是近期大家發(fā)現(xiàn)了一個bug:對于幾乎所有使用了梯度累積策略的庫,包括Huggingface的一系列庫,都暗藏了一個bug,這個bug尤其在LLM的后訓練階段影響顯著:使用梯度累計并不一定等價于大batch訓練,會有非常明顯的精度損失!

???https://github.com/huggingface/trl/issues/2175???

大模型SFT暗藏大陷阱?梯度累計bug造成大范圍影響-AI.x社區(qū)

如同上述issue描述的情況,圖中bs表示batch size即梯度大小, gas表示 gradient accumulate step即多少次梯度回傳累計后再更新一次模型權重。

對于LLM訓練而言,不像圖像任務有batch norm的影響,理論上,梯度累計在應等同于全批量訓練,但實際發(fā)現(xiàn)loss并不匹配。研究者通過公式和實驗證明,罪魁禍首是開源庫中使用基于平均交叉熵loss求和后進行梯度累計的實現(xiàn)導致了bug,這在輸出等長的訓練任務中并不影響(這也是為什么在CV任務和LLM預訓練階段,梯度累計沒有發(fā)生明顯性能損失,因為輸出通常是等長的)。 梯度累積后,過度重視短輸出序列的loss,而忽略長輸出序列的loss。

這個bug的數學推導也非常簡單:

我們首先注意到交叉熵損失的計算方法如下:

大模型SFT暗藏大陷阱?梯度累計bug造成大范圍影響-AI.x社區(qū)

請注意,分母計算了未填充或未忽略(賦值為-100)的token的數量。首先,我們把它們設置為整個文檔的平均長度,以簡化我們的計算。

假設兩個batch的平均序列長度不等長,一個是m1,1個是m2,對于full batch情況:

大模型SFT暗藏大陷阱?梯度累計bug造成大范圍影響-AI.x社區(qū)

對于梯度累計情況:

大模型SFT暗藏大陷阱?梯度累計bug造成大范圍影響-AI.x社區(qū)

明顯看出在m1和m2不相等時,兩者是明顯不等價的。尤其是在其中一個序列長度明顯更長,另一個序列長度很短時,問題更加嚴重:比如m1=10,m2=1000時,會發(fā)現(xiàn)l2的loss大小會被壓縮,而l1的loss大小相對于full batch情況下會被嚴重放大。

這是因為不同batch的文本長度不同,導致的問題。在梯度累積中,我們需要將每個小批量梯度累積器按梯度累積步驟的數量進行縮放,以便我們得到期望的結果。

修復分母問題后重新實驗:

大模型SFT暗藏大陷阱?梯度累計bug造成大范圍影響-AI.x社區(qū)

現(xiàn)在確實等價了,所有的訓練損失曲線都匹配上了!分母就是罪魁禍首!這意味著簡單地對每個梯度累積步驟進行平均是錯誤的,相反,我們必須事先推導出分母。

目前,這個bug已經引起了廣泛關注,不少開源庫包括huggingface系列正在針對這個問題進行修復。如果近期遇到SFT效果不佳的問題,可以關注是否踩到了這個坑,短期不要使用梯度累計,或在修復后及時更新,使用新版梯度累計算法。


本文轉載自公眾號思源數據科學 作者:思源Source

原文鏈接:??https://mp.weixin.qq.com/s/Za62RV9BDrbuoMERzodCUA??


?著作權歸作者所有,如需轉載,請注明出處,否則將追究法律責任
收藏
回復
舉報
回復
相關推薦