更通用、有效,螞蟻自研優(yōu)化器WSAM入選KDD Oral
深度神經(jīng)網(wǎng)絡(luò)(DNNs)的泛化能力與極值點的平坦程度密切相關(guān),因此出現(xiàn)了 Sharpness-Aware Minimization (SAM) 算法來尋找更平坦的極值點以提高泛化能力。本文重新審視 SAM 的損失函數(shù),提出了一種更通用、有效的方法 WSAM,通過將平坦程度作為正則化項來改善訓(xùn)練極值點的平坦度。通過在各種公開數(shù)據(jù)集上的實驗表明,與原始優(yōu)化器、SAM 及其變體相比,WSAM 在絕大多數(shù)情形都實現(xiàn)了更好的泛化性能。WSAM 在螞蟻內(nèi)部數(shù)字支付、數(shù)字金融等多個場景也被普遍采用并取得了顯著效果。該文被 KDD '23 接收為 Oral Paper。
- 論文地址:https://arxiv.org/pdf/2305.15817.pdf
- 代碼地址:https://github.com/intelligent-machine-learning/dlrover/tree/master/atorch/atorch/optimizers
隨著深度學(xué)習(xí)技術(shù)的發(fā)展,高度過參數(shù)化的 DNNs 在 CV 和 NLP 等各種機器學(xué)習(xí)場景下取得了巨大的成功。雖然過度參數(shù)化的模型容易過擬合訓(xùn)練數(shù)據(jù),但它們通常具有良好的泛化能力。泛化的奧秘受到越來越多的關(guān)注,已成為深度學(xué)習(xí)領(lǐng)域的熱門研究課題。
最近的研究表明,泛化能力與極值點的平坦程度密切相關(guān),即損失函數(shù)“地貌”中平坦的極值點可以實現(xiàn)更小的泛化誤差。Sharpness-Aware Minimization (SAM) [1] 是一種用于尋找更平坦極值點的技術(shù),是當(dāng)前最有前途的技術(shù)方向之一。它廣泛應(yīng)用于各個領(lǐng)域,如 CV、NLP 和 bi-level learning,并在這些領(lǐng)域明顯優(yōu)于原先最先進的方法。
為了探索更平坦的最小值,SAM 定義損失函數(shù) L 在 w 處的平坦程度如下:
GSAM [2] 證明了 是局部極值點 Hessian 矩陣最大特征值的近似,表明
確實是平坦(陡峭)程度的有效度量。然而
只能用于尋找更平坦的區(qū)域而不是最小值點,這可能導(dǎo)致?lián)p失函數(shù)收斂到損失值依然很大的點(雖然周圍區(qū)域很平坦)。因此,SAM 采用
,即
作為損失函數(shù)。它可以視為在
和
之間尋找更平坦的表面和更小損失值的折衷方案,在這里兩者被賦予了同等的權(quán)重。
本文重新思考了 的構(gòu)建,將
視為正則化項。我們開發(fā)了一個更通用、有效的算法,稱為 WSAM(Weighted Sharpness-Aware Minimization),其損失函數(shù)加入了一個加權(quán)平坦度項
作為正則項,其中超參數(shù)
控制了平坦度的權(quán)重。在方法介紹章節(jié),我們演示了如何通過
來指導(dǎo)損失函數(shù)找到更平坦或更小的極值點。我們的關(guān)鍵貢獻可以總結(jié)如下。
- 我們提出 WSAM,將平坦度視為正則化項,并在不同任務(wù)之間給予不同的權(quán)重。我們提出一個“權(quán)重解耦”技術(shù)來處理更新公式中的正則化項,旨在精確反映當(dāng)前步驟的平坦度。當(dāng)基礎(chǔ)優(yōu)化器不是 SGD 時,如 SGDM 和 Adam,WSAM 在形式上與 SAM 有顯著差異。消融實驗表明,這種技術(shù)在大多數(shù)情況下可以提升效果。
- 我們在公開數(shù)據(jù)集上驗證了 WSAM 在常見任務(wù)中的有效性。實驗結(jié)果表明,與 SAM 及其變體相比,WSAM 在絕大多數(shù)情形都有著更好的泛化性能。
預(yù)備知識
SAM 是解決由公式(1)定義的 的極小極大最優(yōu)化問題的一種技術(shù)。
首先,SAM 使用圍繞 w 的一階泰勒展開來近似內(nèi)層的最大化問題,即、
其次,SAM 通過采用 的近似梯度來更新 w ,即
其中第二個近似是為了加速計算。其他基于梯度的優(yōu)化器(稱為基礎(chǔ)優(yōu)化器)可以納入 SAM 的通用框架中,具體見Algorithm 1。通過改變 Algorithm 1 中的 和
,我們可以獲得不同的基礎(chǔ)優(yōu)化器,例如 SGD、SGDM 和 Adam,參見 Tab. 1。請注意,當(dāng)基礎(chǔ)優(yōu)化器為 SGD 時,Algorithm 1 回退到 SAM 論文 [1] 中的原始 SAM。
方法介紹
WSAM 的設(shè)計細節(jié)
在此,我們給出的正式定義,它由一個常規(guī)損失和一個平坦度項組成。由公式(1),我們有
其中 。當(dāng)
=0 時,
退化為常規(guī)損失;當(dāng)
=1/2 時,
等價于
;當(dāng)
>1/2 時,
更注重平坦度,因此與 SAM 相比更容易找到具有較小曲率而非較小損失值的點;反之亦然。
包含不同基礎(chǔ)優(yōu)化器的 WSAM 的通用框架可以通過選擇不同的 和
來實現(xiàn),見 Algorithm 2。例如,當(dāng)
和
時,我們得到基礎(chǔ)優(yōu)化器為 SGD 的 WSAM,見 Algorithm 3。在此,我們采用了一種“權(quán)重解耦”技術(shù),即
平坦度項不是與基礎(chǔ)優(yōu)化器集成用于計算梯度和更新權(quán)重,而是獨立計算(Algorithm 2 第 7 行的最后一項)。這樣,正則化的效果只反映了當(dāng)前步驟的平坦度,而沒有額外的信息。為了進行比較,Algorithm 4 給出了沒有“權(quán)重解耦”(稱為 Coupled-WSAM)的 WSAM。例如,如果基礎(chǔ)優(yōu)化器是 SGDM,則 Coupled-WSAM 的正則化項是平坦度的指數(shù)移動平均值。如實驗章節(jié)所示,“權(quán)重解耦”可以在大多數(shù)情況下改善泛化表現(xiàn)。
Fig. 1 展示了不同取值下的 WSAM 更新過程。當(dāng)
<1/2 時,
介于
和
之間,并隨著
增大逐漸偏離
。
簡單示例
為了更好地說明 WSAM 中 γ 的效果和優(yōu)勢,我們設(shè)置了一個二維簡單示例。如 Fig. 2 所示,損失函數(shù)在左下角有一個相對不平坦的極值點(位置:(-16.8, 12.8),損失值:0.28),在右上角有一個平坦的極值點(位置:(19.8, 29.9),損失值:0.36)。損失函數(shù)定義為: ,這里
是單變量高斯模型與兩個正態(tài)分布之間的 KL 散度,即
,其中
和
。
我們使用動量為 0.9 的 SGDM 作為基礎(chǔ)優(yōu)化器,并對 SAM 和 WSAM 設(shè)置=2 。從初始點 (-6, 10) 開始,使用學(xué)習(xí)率為 5 在 150 步內(nèi)優(yōu)化損失函數(shù)。SAM 收斂到損失值更低但更不平坦的極值點,
=0.6的 WSAM 也類似。然而,
=0.95 使得損失函數(shù)收斂到平坦的極值點,說明更強的平坦度正則化發(fā)揮了作用。
實驗
我們在各種任務(wù)上進行了實驗,以驗證 WSAM 的有效性。
圖像分類
我們首先研究了 WSAM 在 Cifar10 和 Cifar100 數(shù)據(jù)集上從零開始訓(xùn)練模型的效果。我們選擇的模型包括 ResNet18 和WideResNet-28-10。我們使用預(yù)定義的批大小在 Cifar10 和 Cifar100 上訓(xùn)練模型,ResNet18 和 WideResNet-28-10 分別為 128,256。這里使用的基礎(chǔ)優(yōu)化器是動量為 0.9 的 SGDM。按照 SAM [1] 的設(shè)置,每個基礎(chǔ)優(yōu)化器跑的 epoch 數(shù)是 SAM 類優(yōu)化器的兩倍。我們對兩種模型都進行了 400 個 epoch 的訓(xùn)練(SAM 類優(yōu)化器為 200 個 epoch),并使用 cosine scheduler 來衰減學(xué)習(xí)率。這里我們沒有使用其他高級數(shù)據(jù)增強方法,例如 cutout 和 AutoAugment。
對于兩種模型,我們使用聯(lián)合網(wǎng)格搜索確定基礎(chǔ)優(yōu)化器的學(xué)習(xí)率和權(quán)重衰減系數(shù),并將它們保持不變用于接下來的 SAM 類優(yōu)化器實驗。學(xué)習(xí)率和權(quán)重衰減系數(shù)的搜索范圍分別為 {0.05, 0.1} 和 {1e-4, 5e-4, 1e-3}。由于所有 SAM 類優(yōu)化器都有一個超參數(shù)(鄰域大小),我們接下來在 SAM 優(yōu)化器上搜索最佳的
并將相同的值用于其他 SAM 類優(yōu)化器。
的搜索范圍為 {0.01, 0.02, 0.05, 0.1, 0.2, 0.5}。最后,我們對其他 SAM 類優(yōu)化器各自獨有的超參進行搜索,搜索范圍來自各自原始文章的推薦范圍。對于 GSAM [2],我們在 {0.01, 0.02, 0.03, 0.1, 0.2, 0.3} 范圍內(nèi)搜索。對于 ESAM [3],我們在 {0.4, 0.5, 0.6} 范圍內(nèi)搜索
,在 {0.4, 0.5, 0.6} 范圍內(nèi)搜索
,在 {0.4, 0.5, 0.6} 范圍內(nèi)搜索
。對于 WSAM,我們在 {0.5, 0.6, 0.7, 0.8, 0.82, 0.84, 0.86, 0.88, 0.9, 0.92, 0.94, 0.96} 范圍內(nèi)搜索
。我們使用不同的隨機種子重復(fù)實驗 5 次,計算了平均誤差和標準差。我們在單卡 NVIDIA A100 GPU 上進行實驗。每個模型的優(yōu)化器超參總結(jié)在 Tab. 3 中。
Tab. 2 給出了在不同優(yōu)化器下,ResNet18、WRN-28-10 在 Cifar10 和 Cifar100 上測試集的 top-1 錯誤率。相比基礎(chǔ)優(yōu)化器,SAM 類優(yōu)化器顯著提升了效果,同時,WSAM 又顯著優(yōu)于其他 SAM 類優(yōu)化器。
ImageNet 上的額外訓(xùn)練
我們進一步在 ImageNet 數(shù)據(jù)集上使用 Data-Efficient Image Transformers 網(wǎng)絡(luò)結(jié)構(gòu)進行實驗。我們恢復(fù)了一個預(yù)訓(xùn)練的 DeiT-base checkpoint,然后繼續(xù)訓(xùn)練三個 epoch。模型使用批大小 256 進行訓(xùn)練,基礎(chǔ)優(yōu)化器為動量 0.9 的 SGDM,權(quán)重衰減系數(shù)為 1e-4,學(xué)習(xí)率為 1e-5。我們在四卡 NVIDIA A100 GPU 重復(fù)跑 5 次并計算平均誤差和標準差。
我們在 {0.05, 0.1, 0.5, 1.0,? , 6.0} 中搜索 SAM 的最佳。最佳的
=5.5 被直接用于其他 SAM 類優(yōu)化器。之后,我們在{0.01, 0.02, 0.03, 0.1, 0.2, 0.3}中搜索 GSAM 的最佳
,并在 0.80 到 0.98 之間以 0.02 的步長搜索WSAM 的最佳
。
模型的初始 top-1 錯誤率為 18.2%,在進行了三個額外的 epoch 之后,錯誤率如 Tab. 4 所示。我們沒有發(fā)現(xiàn)三個 SAM-like 優(yōu)化器之間有明顯的差異,但它們都優(yōu)于基礎(chǔ)優(yōu)化器,表明它們可以找到更平坦的極值點并具有更好的泛化能力。
標簽噪聲的魯棒性
如先前的研究 [1, 4, 5] 所示,SAM 類優(yōu)化器在訓(xùn)練集存在標簽噪聲時表現(xiàn)出良好的魯棒性。在這里,我們將 WSAM 的魯棒性與 SAM、ESAM 和 GSAM 進行了比較。我們在 Cifar10 數(shù)據(jù)集上訓(xùn)練 ResNet18 200 個 epoch,并注入對稱標簽噪聲,噪聲水平為 20%、40%、60% 和 80%。我們使用具有 0.9 動量的 SGDM 作為基礎(chǔ)優(yōu)化器,批大小為 128,學(xué)習(xí)率為 0.05,權(quán)重衰減系數(shù)為 1e-3,并使用 cosine scheduler 衰減學(xué)習(xí)率。針對每個標簽噪聲水平,我們在 {0.01, 0.02, 0.05, 0.1, 0.2, 0.5} 范圍內(nèi)對 SAM 進行網(wǎng)格搜索,確定通用的值。然后,我們單獨搜索其他優(yōu)化器特定的超參數(shù),以找到最優(yōu)泛化性能。我們在 Tab. 5 中列出了復(fù)現(xiàn)我們結(jié)果所需的超參數(shù)。我們在 Tab. 6 中給出了魯棒性測試的結(jié)果,WSAM 通常比 SAM、ESAM 和 GSAM 都具有更好的魯棒性。
探索幾何結(jié)構(gòu)的影響
SAM 類優(yōu)化器可以與 ASAM [4] 和 Fisher SAM [5] 等技術(shù)相結(jié)合,以自適應(yīng)地調(diào)整探索鄰域的形狀。我們在 Cifar10 上對 WRN-28-10 進行實驗,比較 SAM 和 WSAM 在分別使用自適應(yīng)和 Fisher 信息方法時的表現(xiàn),以了解探索區(qū)域的幾何結(jié)構(gòu)如何影響 SAM 類優(yōu)化器的泛化性能。
除了和
之外的參數(shù),我們復(fù)用了圖像分類中的配置。根據(jù)先前的研究 [4, 5],ASAM 和 Fisher SAM 的
通常較大。我們在 {0.1, 0.5, 1.0,…, 6.0} 中搜索最佳的
,ASAM 和 Fisher SAM 最佳的
均為 5.0。之后,我們在 0.80 到 0.94 之間以 0.02 的步長搜索 WSAM 的最佳
,兩種方法最佳
均為 0.88。
令人驚訝的是,如 Tab. 7 所示,即使在多個候選項中,基準的 WSAM 也表現(xiàn)出更好的泛化性。因此,我們建議直接使用具有固定的基準 WSAM 即可。
消融實驗
在本節(jié)中,我們進行消融實驗,以深入理解 WSAM 中“權(quán)重解耦”技術(shù)的重要性。如WSAM 的設(shè)計細節(jié)所述,我們將不帶“權(quán)重解耦”的 WSAM 變體(算法 4)Coupled-WSAM 與原始方法進行比較。
結(jié)果如 Tab. 8 所示。Coupled-WSAM 在大多數(shù)情況下比 SAM 產(chǎn)生更好的結(jié)果,WSAM 在大多數(shù)情況下進一步提升了效果,證明“權(quán)重解耦”技術(shù)的有效性。
極值點分析
在這里,我們通過比較 WSAM 和 SAM 優(yōu)化器找到的極值點之間的差異,進一步加深對 WSAM 優(yōu)化器的理解。極值點處的平坦(陡峭)度可通過 Hessian 矩陣的最大特征值來描述。特征值越大,越不平坦。我們使用 Power Iteration 算法來計算這個最大特征值。
Tab. 9 顯示了 SAM 和 WSAM 優(yōu)化器找到的極值點之間的差異。我們發(fā)現(xiàn),vanilla 優(yōu)化器找到的極值點具有更小的損失值但更不平坦,而 SAM 找到的極值點具有更大的損失值但更平坦,從而改善了泛化性能。有趣的是,WSAM 找到的極值點不僅損失值比 SAM 小得多,而且平坦度十分接近 SAM。這表明,在尋找極值點的過程中,WSAM 優(yōu)先確保更小的損失值,同時盡量搜尋到更平坦的區(qū)域。
超參敏感性
與 SAM 相比,WSAM 具有一個額外的超參數(shù),用于縮放平坦(陡峭)度項的大小。在這里,我們測試 WSAM 的泛化性能對該超參的敏感性。我們在 Cifar10 和 Cifar100 上使用 WSAM 對 ResNet18 和 WRN-28-10 模型進行了訓(xùn)練,使用了廣泛的
取值。如 Fig. 3 所示,結(jié)果表明 WSAM 對超參
的選擇不敏感。我們還發(fā)現(xiàn),WSAM 的最優(yōu)泛化性能幾乎總是在 0.8 到 0.95 之間。