大模型SFT暗藏大陷阱?梯度累計bug造成大范圍影響 原創(chuàng)
在LLM的訓練時,由于顯存不足以支撐起大batch訓練,通常大家都會采用一種策略:梯度累計(gradient accumulate)。這種方法允許模型在多個batch的梯度回傳累計并求均值之后,再更新一次權重。這樣做相當于模擬了一個更大的批量大小,而實際上并沒有一次性處理那么多數據。這樣做的好處是,它可以減少內存的使用,因為不需要一次性加載所有數據到GPU上,同時也可以享受等價大batch帶來的訓練的穩(wěn)定性和模型的泛化能力。
但是近期大家發(fā)現(xiàn)了一個bug:對于幾乎所有使用了梯度累積策略的庫,包括Huggingface的一系列庫,都暗藏了一個bug,這個bug尤其在LLM的后訓練階段影響顯著:使用梯度累計并不一定等價于大batch訓練,會有非常明顯的精度損失!
???https://github.com/huggingface/trl/issues/2175???
如同上述issue描述的情況,圖中bs表示batch size即梯度大小, gas表示 gradient accumulate step即多少次梯度回傳累計后再更新一次模型權重。
對于LLM訓練而言,不像圖像任務有batch norm的影響,理論上,梯度累計在應等同于全批量訓練,但實際發(fā)現(xiàn)loss并不匹配。研究者通過公式和實驗證明,罪魁禍首是開源庫中使用基于平均交叉熵loss求和后進行梯度累計的實現(xiàn)導致了bug,這在輸出等長的訓練任務中并不影響(這也是為什么在CV任務和LLM預訓練階段,梯度累計沒有發(fā)生明顯性能損失,因為輸出通常是等長的)。 梯度累積后,過度重視短輸出序列的loss,而忽略長輸出序列的loss。
這個bug的數學推導也非常簡單:
我們首先注意到交叉熵損失的計算方法如下:
請注意,分母計算了未填充或未忽略(賦值為-100)的token的數量。首先,我們把它們設置為整個文檔的平均長度,以簡化我們的計算。
假設兩個batch的平均序列長度不等長,一個是m1,1個是m2,對于full batch情況:
對于梯度累計情況:
明顯看出在m1和m2不相等時,兩者是明顯不等價的。尤其是在其中一個序列長度明顯更長,另一個序列長度很短時,問題更加嚴重:比如m1=10,m2=1000時,會發(fā)現(xiàn)l2的loss大小會被壓縮,而l1的loss大小相對于full batch情況下會被嚴重放大。
這是因為不同batch的文本長度不同,導致的問題。在梯度累積中,我們需要將每個小批量梯度累積器按梯度累積步驟的數量進行縮放,以便我們得到期望的結果。
修復分母問題后重新實驗:
現(xiàn)在確實等價了,所有的訓練損失曲線都匹配上了!分母就是罪魁禍首!這意味著簡單地對每個梯度累積步驟進行平均是錯誤的,相反,我們必須事先推導出分母。
目前,這個bug已經引起了廣泛關注,不少開源庫包括huggingface系列正在針對這個問題進行修復。如果近期遇到SFT效果不佳的問題,可以關注是否踩到了這個坑,短期不要使用梯度累計,或在修復后及時更新,使用新版梯度累計算法。
本文轉載自公眾號思源數據科學 作者:思源Source
原文鏈接:??https://mp.weixin.qq.com/s/Za62RV9BDrbuoMERzodCUA??
