騰訊混元、北大發(fā)現(xiàn)Scaling law「浪涌現(xiàn)象」,解決學(xué)習(xí)率調(diào)參難題
過去十年間,基于隨機(jī)梯度下降(SGD)的深度學(xué)習(xí)模型在許多領(lǐng)域都取得了極大的成功。與此同時各式各樣的 SGD 替代品也如雨后春筍般涌現(xiàn)。在這些眾多替代品中,Adam 及其變種最受追捧。無論是 SGD,還是 Adam,亦或是其他優(yōu)化器,最核心的超參數(shù)非 Learning rate 莫屬。因此如何調(diào)整好 Leanring rate 是煉丹師們從一開始就必學(xué)的技能。
從直覺上講,影響 Learning rate 取值的重要因素是 Batch size。不知你在學(xué)習(xí)煉丹術(shù)時,是否遇到或者思考過入如下問題:
- 我的 Batch size 增加一倍,Learning rate 該怎么調(diào)整?
- 網(wǎng)上有說 Batch size 和 Learning rate 是線性放縮,也有說是平方根放縮,到底該按照哪個調(diào)整?
- 為什么我按照網(wǎng)上說的經(jīng)驗關(guān)系調(diào)整之后效果反而變差了?
針對上述問題,騰訊混元聯(lián)合北京大學(xué)基于現(xiàn)有科研基礎(chǔ)和實(shí)際業(yè)務(wù)需求,在進(jìn)行了大量理論分析和實(shí)驗驗證后發(fā)布了關(guān)于 Batch size 和 Learning rate 放縮關(guān)系的調(diào)參指南:
- 論文:Surge Phenomenon in Optimal Learning Rate and Batch Size Scaling
- 論文地址:https://arxiv.org/pdf/2405.14578
1. 當(dāng)使用 SGD 風(fēng)格的優(yōu)化器時,應(yīng)當(dāng)采用 OpenAI 2018 年給出的結(jié)論(https://arxiv.org/pdf/1812.06162):
2. 但是當(dāng)使用 Adam 風(fēng)格的優(yōu)化器時,需要按照如下放縮規(guī)律:
其中和 B 分別代表 Learning rate 和 Batch size,而
與 OpenAI 2020 年 Scaling law 論文(https://arxiv.org/pdf/2001.08361)中的
對應(yīng)。從上面結(jié)論不難發(fā)現(xiàn),當(dāng)
時,社區(qū)中廣為流傳的線性放縮和平方根放縮在一定范圍內(nèi)都是正確的,并且分別對應(yīng)使用 SGD 風(fēng)格和 Adam 風(fēng)格優(yōu)化器的情況。
一、居然要降低學(xué)習(xí)率?
如果仔細(xì)觀察 Adam 風(fēng)格優(yōu)化器放縮規(guī)律的表達(dá)式子會發(fā)現(xiàn),當(dāng) Batch size 超過后,隨著 Batch size 增加最優(yōu)的 Learning rate 反而是下降的!這樣的結(jié)論似乎有點(diǎn)反常,但是仔細(xì)思考之后又覺得是合理的。首先我們回顧一下 Adam 的更新形式,梯度的一階動量除以二階動量的平方根:
(更詳細(xì)的討論參考原文中的附錄 A)。與 SGD 直接采用 G 進(jìn)行參數(shù)更新相比,將更快的進(jìn)入飽和區(qū)間,例如,假設(shè) G 的均值是正實(shí)數(shù),隨著 Batch size 增加
估計為正數(shù)時,再增加估計的準(zhǔn)確度對
的結(jié)果也毫無影響了。因此當(dāng) Batch size 超過
時,增加的信息不足以抵消
帶來的噪聲影響,從而導(dǎo)致此次的更新不再那么確信,以至于需要降低學(xué)習(xí)率。
二、觀察到的下降區(qū)間
為了檢驗理論的正確性,需要從實(shí)驗中觀察到最優(yōu)學(xué)習(xí)率的 “下降區(qū)間”。既然從上一節(jié)的分析中發(fā)現(xiàn),使用 Adam 優(yōu)化器時 Batch size 超過就會導(dǎo)致最優(yōu)學(xué)習(xí)率下降,那么只要確定出
取值,然后在通過網(wǎng)格搜索打點(diǎn)觀察就可以了。雖然從形式上
計算很困難,但是幸運(yùn)的是基于 OpenAI 關(guān)算于訓(xùn)練時間和樣本效率的定量結(jié)論中我們可以估算出
的取值(更詳細(xì)的討論參考原文中的附錄 G)。
上面展示了 CNN 在 FashionMNIST 上的學(xué)習(xí)率 “下降區(qū)間”。左圖為通過 OpenAI 定量公式估算的(左圖直線斜率的負(fù)數(shù),右圖紅色豎直虛線),右圖中黃色五角星代表不同 Batch size 下的最優(yōu) Learning rate 取值,青色實(shí)線為我們的理論預(yù)估曲線。
以及 Resnet18 在 TinyImagenet,和 DistilGPT2 在 Eli5Category 上也觀察到了類似現(xiàn)象。
三、浪涌現(xiàn)象
前面我們從理論和實(shí)驗上都發(fā)現(xiàn)了,在使用 Adam 風(fēng)格優(yōu)化器時最優(yōu)學(xué)習(xí)率曲線就像一朵 “浪花” 一樣隨著 Batch size 增加會先升高后下降。同時結(jié)合 OpenAI scaling law 的結(jié)論,隨著訓(xùn)練進(jìn)行會逐漸變大。我們理論預(yù)測并實(shí)驗證明了隨著訓(xùn)練進(jìn)行 “浪花” 逐漸向著大 Batch size 方向涌動:
四、理論發(fā)現(xiàn)
前面討論過 Adam 風(fēng)格的優(yōu)化器在進(jìn)行參數(shù)更新時采用類似的形式。雖然此形式看起來很簡單,但是由于推導(dǎo)過程涉及到對更新量均值和方差的考量,所以我們在處理的時候做了一個假設(shè)和一個近似:
1. 假設(shè)每個樣本的參數(shù) i 的梯度服從均值為,方差為
的高斯分布
2. 通過 sigmoid-style 函數(shù)對高斯誤差函數(shù)進(jìn)行數(shù)值近似
當(dāng)時,完整的 Scaling law 形式近似為:
其中
,H 為海森矩陣。
當(dāng)時:
表明,Batch size 無限大時最優(yōu)學(xué)習(xí)率趨于一個飽和值。
五、應(yīng)用
我們在騰訊 Angel 大模型訓(xùn)練框架中集成了上述理論成果,并在騰訊混元大模型訓(xùn)練任務(wù)中對理論進(jìn)行進(jìn)一步驗證,未來將服務(wù)于各種大模型訓(xùn)練場景。
感謝閱讀,更多詳細(xì)內(nèi)容,請參考原文。