ICLR 2025 Spotlight | 慕尼黑工業(yè)大學(xué)&北京大學(xué):邁向無沖突訓(xùn)練的ConFIG方法
本文由慕尼黑工業(yè)大學(xué)與北京大學(xué)聯(lián)合團(tuán)隊(duì)撰寫。第一作者劉強(qiáng)為慕尼黑工業(yè)大學(xué)博士生。第二作者楚夢(mèng)渝為北京大學(xué)助理教授,專注于物理增強(qiáng)的深度學(xué)習(xí)算法,以提升數(shù)值模擬的靈活性及模型的準(zhǔn)確性和泛化性。通訊作者 Nils Thuerey 教授(慕尼黑工業(yè)大學(xué))長(zhǎng)期研究深度學(xué)習(xí)與物理模擬,尤其是流體動(dòng)力學(xué)模擬的結(jié)合,并曾因高效流動(dòng)特效模擬技術(shù)獲奧斯卡技術(shù)獎(jiǎng)。目前,其團(tuán)隊(duì)重點(diǎn)關(guān)注可微物理模擬及物理應(yīng)用中的先進(jìn)生成式模型。
在深度學(xué)習(xí)的多個(gè)應(yīng)用場(chǎng)景中,聯(lián)合優(yōu)化多個(gè)損失項(xiàng)是一個(gè)普遍的問題。典型的例子包括物理信息神經(jīng)網(wǎng)絡(luò)(Physics-Informed Neural Networks, PINNs)、多任務(wù)學(xué)習(xí)(Multi-Task Learning, MTL)和連續(xù)學(xué)習(xí)(Continual Learning, CL)。然而,不同損失項(xiàng)的梯度方向往往相互沖突,導(dǎo)致優(yōu)化過程陷入局部最優(yōu)甚至訓(xùn)練失敗。
目前,主流的方法通常通過調(diào)整損失權(quán)重來緩解沖突。例如在物理信息神經(jīng)網(wǎng)絡(luò)中,許多研究從數(shù)值剛度、損失的收斂速度差異和神經(jīng)網(wǎng)絡(luò)的初始化角度提出了許多權(quán)重方法。然而,盡管這些方法聲稱具有更高的解的精度,但目前對(duì)于最優(yōu)的加權(quán)策略尚無共識(shí)。
針對(duì)這一問題,來自慕尼黑工業(yè)大學(xué)和北京大學(xué)的聯(lián)合研究團(tuán)隊(duì)提出了 ConFIG(Conflict-Free Inverse Gradients,無沖突逆梯度)方法,為多損失項(xiàng)優(yōu)化提供了一種穩(wěn)定、高效的優(yōu)化策略。ConFIG 提供了一種優(yōu)化梯度,能夠防止由于沖突導(dǎo)致優(yōu)化陷入某個(gè)特定損失項(xiàng)的局部最小值。ConFIG 方法可以在數(shù)學(xué)上證明其收斂特性并具有以下特點(diǎn):
- 最終更新梯度
與所有損失項(xiàng)的優(yōu)化梯度均不沖突。
在每個(gè)特定損失梯度上的投影長(zhǎng)度是均勻的,可以確保所有損失項(xiàng)以相同速率進(jìn)行優(yōu)化。
長(zhǎng)度可以根據(jù)損失項(xiàng)之間的沖突程度自適應(yīng)調(diào)整。
此外,ConFIG 方法還引入了一種基于動(dòng)量的變種。通過計(jì)算并緩存每個(gè)損失項(xiàng)梯度的動(dòng)量,可以避免在每次訓(xùn)練迭代中計(jì)算所有損失項(xiàng)的梯度。結(jié)果表明,基于動(dòng)量的 ConFIG 方法在顯著降低訓(xùn)練成本的同時(shí)保證了優(yōu)化的精度。
想深入了解 ConFIG 的技術(shù)細(xì)節(jié)?我們已經(jīng)為你準(zhǔn)備好了完整的論文、項(xiàng)目主頁和代碼倉庫!
- 論文地址:https://arxiv.org/abs/2408.11104
- 項(xiàng)目主頁:https://tum-pbs.github.io/ConFIG/
- GitHub: https://github.com/tum-pbs/ConFIG
ConFIG: 無沖突逆梯度方法
目標(biāo):給定個(gè)損失函數(shù)
,其對(duì)應(yīng)梯度為
。我們希望找到一個(gè)優(yōu)化方向
,使其滿足:
。即所有損失項(xiàng)在該方向上都能減少,從而避免梯度沖突。
無沖突優(yōu)化區(qū)間
假設(shè)存在一個(gè)無沖突更新梯度,我們可以引入一個(gè)新的矢量。由于
是一個(gè)無沖突梯度,
應(yīng)為一個(gè)正向分量矢量。同樣地,我們也可以預(yù)先定義一個(gè)正向分量矢量
,然后直接通過矩陣的逆運(yùn)算求得無沖突更新梯度
,即
。通過給定不同的正向分量矢量
,我們得到由一系列不同
組成的無沖突優(yōu)化區(qū)間。
確定唯一優(yōu)化梯度
盡管通過簡(jiǎn)單求逆可以獲得一個(gè)無沖突更新區(qū)間,我們需要進(jìn)一步確定唯一的無沖突梯度用于優(yōu)化。在 ConFIG 方法中,我們從方向和幅度兩個(gè)方面進(jìn)一步限定了最終用于優(yōu)化更新的梯度:
- 具體優(yōu)化方向:相比于直接求解梯度矩陣的逆,ConFIG 方法求解了歸一化梯度矩陣的逆,即
,其中
表示第
個(gè)梯度向量的單位向量??梢宰C明,變換后
矢量的每個(gè)分量代表了每個(gè)梯度
與最終更新梯度
之間的余弦相似度。因此,通過設(shè)定
分量的不同值可以直接控制最終更新梯度對(duì)于每個(gè)損失梯度的優(yōu)化速率。在 ConFIG 中,
被設(shè)定為單位矢量以確保每個(gè)損失具有相同的優(yōu)化強(qiáng)度從而避免某些損失項(xiàng)的優(yōu)化被忽略。
- 優(yōu)化梯度大小:此外,ConFIG 方法還根據(jù)梯度沖突程度調(diào)整步長(zhǎng)。當(dāng)梯度方向較一致時(shí),加快更新;當(dāng)梯度沖突嚴(yán)重時(shí),減小更新幅度:
, 其中
為每個(gè)梯度與最終更新方向之間的余弦相似度。
ConFIG 方法獲得最終無沖突優(yōu)化方向的計(jì)算過程可以總結(jié)為:
原論文中給出了上述 ConFIG 更新收斂性的嚴(yán)格證明。同時(shí),我們還可以證明只要參數(shù)空間的維度大于損失項(xiàng)的個(gè)數(shù),ConFIG 運(yùn)算中的逆運(yùn)算總是可行的。
M-ConFIG: 結(jié)合動(dòng)量加速訓(xùn)練
ConFIG 方法引入了矩陣的逆運(yùn)算,這將帶來額外的計(jì)算成本。然而與計(jì)算每個(gè)損失的梯度帶來的計(jì)算成本,其并不顯著。在包括 ConFIG 在內(nèi)的基于梯度的方法中,總是需要額外的反向傳播步驟獲得每個(gè)梯度相對(duì)于訓(xùn)練參數(shù)的梯度。這使得基于梯度的方法的計(jì)算成本顯著高于標(biāo)準(zhǔn)優(yōu)化過程和基于權(quán)重的方法。為此,我們引入了 M-ConFIG 方法,使用動(dòng)量加速優(yōu)化:
- 使用梯度的動(dòng)量(指數(shù)移動(dòng)平均)代替梯度進(jìn)行 ConFIG 運(yùn)算。
- 在每次優(yōu)化迭代中,僅對(duì)一個(gè)或部分損失進(jìn)行反向傳播以更新動(dòng)量。其它損失項(xiàng)的動(dòng)量采用之前迭代步的歷史值。
在實(shí)際應(yīng)用中,M-ConFIG 的計(jì)算成本往往低于標(biāo)準(zhǔn)更新過程或基于權(quán)重的方法。這是由于反向傳播一個(gè)子損失往往要比反向傳播總損失
更快。這在物理信息神經(jīng)網(wǎng)絡(luò)中尤為明顯,因?yàn)檫吔缟系牟蓸狱c(diǎn)通常遠(yuǎn)少于計(jì)算域內(nèi)的采樣點(diǎn)。在我們的實(shí)際測(cè)試中,M-ConFIG 的平均計(jì)算成本為基于權(quán)重方法的 0.56 倍。
結(jié)果:更快的收斂,更優(yōu)的預(yù)測(cè)
物理信息神經(jīng)網(wǎng)絡(luò)
在物理信息神經(jīng)網(wǎng)絡(luò)中,用神經(jīng)網(wǎng)絡(luò)的自動(dòng)微分來近似偏微分方程的時(shí)空間導(dǎo)數(shù)。偏微分方程的殘差項(xiàng)與邊界條件和初始條件被視作不同的損失項(xiàng)在訓(xùn)練過程中進(jìn)行聯(lián)合優(yōu)化。我們?cè)诙鄠€(gè)經(jīng)典的物理神經(jīng)信息網(wǎng)絡(luò)中測(cè)試了 ConFIG 方法的表現(xiàn)。
結(jié)果顯示,在相同訓(xùn)練迭代次數(shù)下,ConFIG 方法是唯一一個(gè)相比于標(biāo)準(zhǔn) Adam 方法始終獲得正向提升的方法。對(duì)每個(gè)損失項(xiàng)變化的單獨(dú)分析表明,ConFIG 方法在略微提高 PDE 訓(xùn)練殘差的同時(shí)大幅降低了邊界和初始條件損失
,實(shí)現(xiàn)了 PDE 訓(xùn)練精度的整體提升。
相同迭代步數(shù)下不同方法在 PINNs 測(cè)試中相比于 Adam 優(yōu)化器的相對(duì)性能提升
不同損失項(xiàng)隨著訓(xùn)練周期的變化情況
在實(shí)際應(yīng)用中,相同訓(xùn)練時(shí)間下的模型準(zhǔn)確性可能更為重要。M-ConFIG 方法通過使用動(dòng)量近似梯度帶來的運(yùn)算速度提升可以使其充分發(fā)揮潛力。在相同訓(xùn)練時(shí)間內(nèi),M-ConFIG 方法的測(cè)試結(jié)果優(yōu)于其他所有方法,甚至高于常規(guī)的 ConFIG 方法。
此外,我們還在最具有挑戰(zhàn)性的三維 Beltrami 流動(dòng)中進(jìn)一步延長(zhǎng)訓(xùn)練時(shí)間來更加深入地了解 M-ConFIG 方法的性能。結(jié)果表明,M-ConFIG 方法并非僅在優(yōu)化初始階段帶來顯著的性能改善,而是在整個(gè)優(yōu)化過程中都持續(xù)改善優(yōu)化的過程。
相同訓(xùn)練時(shí)間下不同方法在 PINNs 測(cè)試中相比于 Adam 優(yōu)化器的相對(duì)性能提升
三維 Beltrami 流動(dòng)案例中預(yù)測(cè)誤差隨著訓(xùn)練時(shí)間的變化
多任務(wù)學(xué)習(xí)
我們還測(cè)試了 ConFIG 方法在多任務(wù)學(xué)習(xí)(MTL)方面的表現(xiàn)。我們采用經(jīng)典的 CelebA 數(shù)據(jù)集,其包含 20 萬張人臉圖像并標(biāo)注了 40 種不同的面部二元屬性。對(duì)每張人像面部屬性的學(xué)習(xí)是一個(gè)非常有挑戰(zhàn)的 40 項(xiàng)損失的多任務(wù)學(xué)習(xí)。
實(shí)驗(yàn)結(jié)果表明,ConFIG 方法或 M-ConFIG 方法在平均 F1 分?jǐn)?shù)、平均排名
中均表現(xiàn)最佳。其中,對(duì)于 M-ConFIG 方法,我們?cè)谝淮蔚懈?30 個(gè)動(dòng)量而不僅更新一個(gè)動(dòng)量。這是因?yàn)楫?dāng)任務(wù)數(shù)量增加時(shí),單個(gè)動(dòng)量更新時(shí)間的間隔較長(zhǎng),歷史動(dòng)量信息難以準(zhǔn)確捕捉梯度的變化。動(dòng)量信息的滯后會(huì)逐漸抵消 M-ConFIG 方法更高訓(xùn)練效率帶來的性能提升。
在我們的測(cè)試中,當(dāng)任務(wù)數(shù)量等于 10 時(shí),M-ConFIG 方法在相同訓(xùn)練時(shí)間下的性能就已經(jīng)弱于 ConFIG 方法。增加單次迭代過程中的動(dòng)量更新次數(shù)可以顯著緩解這種性能下降。在標(biāo)準(zhǔn)的 40 任務(wù) CelebA 訓(xùn)練中將動(dòng)量更新次數(shù)提升到 20 時(shí),M-ConFIG 方法的性能已經(jīng)接近 ConFIG 方法,而訓(xùn)練時(shí)間僅為 ConFIG 方法的 56%。當(dāng)更新步數(shù)達(dá)到 30 時(shí),其性能甚至可以優(yōu)于 ConFIG 方法。
ConFIG 方法在 CelebA 人臉屬性數(shù)據(jù)集中的表現(xiàn)
結(jié)論
在本研究中,我們提出了 ConFIG 方法來解決不同損失項(xiàng)之間的訓(xùn)練沖突。ConFIG 方法通過確保最終更新梯度與每個(gè)子梯度之間的正點(diǎn)積來確保無沖突學(xué)習(xí)。此外,我們還發(fā)展了一種基于動(dòng)量的方法,用交替更新的動(dòng)量代替梯度,顯著提升了訓(xùn)練效率。ConFIG 方法有望為眾多包含多個(gè)損失項(xiàng)的深度學(xué)習(xí)任務(wù)帶來巨大的性能提升。