測試時領(lǐng)域適應(yīng)的魯棒性得以保證,TRIBE在多真實場景下達到SOTA
測試時領(lǐng)域適應(yīng)(Test-Time Adaptation)的目的是使源域模型適應(yīng)推理階段的測試數(shù)據(jù),在適應(yīng)未知的圖像損壞領(lǐng)域取得了出色的效果。然而,當(dāng)前許多方法都缺乏對真實世界場景中測試數(shù)據(jù)流的考慮,例如:
- 測試數(shù)據(jù)流應(yīng)當(dāng)是時變分布(而非傳統(tǒng)領(lǐng)域適應(yīng)中的固定分布)
- 測試數(shù)據(jù)流可能存在局部類別相關(guān)性(而非完全獨立同分布采樣)
- 測試數(shù)據(jù)流在較長時間里仍表現(xiàn)全局類別不平衡
近日,華南理工、A*STAR 和港中大(深圳)團隊通過大量實驗證明,這些真實場景下的測試數(shù)據(jù)流會對現(xiàn)有方法帶來巨大挑戰(zhàn)。該團隊認為,最先進方法的失敗首先是由于不加區(qū)分地根據(jù)不平衡測試數(shù)據(jù)調(diào)整歸一化層造成的。
為此,研究團隊提出了一種創(chuàng)新的平衡批歸一化層 (Balanced BatchNorm Layer),以取代推理階段的常規(guī)批歸一化層。同時,他們發(fā)現(xiàn)僅靠自我訓(xùn)練(ST)在未知的測試數(shù)據(jù)流中進行學(xué)習(xí),容易造成過度適應(yīng)(偽標(biāo)簽類別不平衡、目標(biāo)域并非固定領(lǐng)域)而導(dǎo)致在領(lǐng)域不斷變化的情況下性能不佳。
因此,該團隊建議通過錨定損失 (Anchored Loss) 對模型更新進行正則化處理,從而改進持續(xù)領(lǐng)域轉(zhuǎn)移下的自我訓(xùn)練,有助于顯著提升模型的魯棒性。最終,模型 TRIBE 在四個數(shù)據(jù)集、多種真實世界測試數(shù)據(jù)流設(shè)定下穩(wěn)定達到 state-of-the-art 的表現(xiàn),并大幅度超越已有的先進方法。研究論文已被 AAAI 2024 接收。
論文鏈接:https://arxiv.org/abs/2309.14949
代碼鏈接:https://github.com/Gorilla-Lab-SCUT/TRIBE
引言
深度神經(jīng)網(wǎng)絡(luò)的成功依賴于將訓(xùn)練好的模型推廣到 i.i.d. 測試域的假設(shè)。然而,在實際應(yīng)用中,分布外測試數(shù)據(jù)的魯棒性,如不同的照明條件或惡劣天氣造成的視覺損壞,是一個需要關(guān)注的問題。最近的研究顯示,這種數(shù)據(jù)損失可能會嚴(yán)重影響預(yù)先訓(xùn)練好的模型的性能。重要的是,在部署前,測試數(shù)據(jù)的損壞(分布)通常是未知的,有時也不可預(yù)測。
因此,調(diào)整預(yù)訓(xùn)練模型以適應(yīng)推理階段的測試數(shù)據(jù)分布是一個值得價值的新課題,即測試時領(lǐng)域適 (TTA)。此前,TTA 主要通過分布對齊 (TTAC++, TTT++),自監(jiān)督訓(xùn)練 (AdaContrast) 和自訓(xùn)練 (Conjugate PL) 來實現(xiàn),這些方法在多種視覺損壞測試數(shù)據(jù)中都帶來了顯著的穩(wěn)健提升。
現(xiàn)有的測試時領(lǐng)域適應(yīng)(TTA)方法通?;谝恍﹪?yán)格的測試數(shù)據(jù)假設(shè),如穩(wěn)定的類別分布、樣本服從獨立同分布采樣以及固定的領(lǐng)域偏移。這些假設(shè)啟發(fā)了許多研究者去探究真實世界中的測試數(shù)據(jù)流,如 CoTTA、NOTE、SAR 和 RoTTA 等。
最近,對真實世界的 TTA 研究,如 SAR(ICLR 2023)和 RoTTA(CVPR 2023)主要關(guān)注局部類別不平衡和連續(xù)的領(lǐng)域偏移對 TTA 帶來的挑戰(zhàn)。局部類別不平衡通常是由于測試數(shù)據(jù)并非獨立同分布采樣而產(chǎn)生的。直接不加區(qū)分的領(lǐng)域適應(yīng)將導(dǎo)致有偏置的分布估計。
最近有研究提出了指數(shù)式更新批歸一化統(tǒng)計量(RoTTA)或?qū)嵗壟袆e更新批歸一化統(tǒng)計量(NOTE)來解決這個挑戰(zhàn)。其研究目標(biāo)是超越局部類不平衡的挑戰(zhàn),考慮到測試數(shù)據(jù)的總體分布可能嚴(yán)重失衡,類的分布也可能隨著時間的推移而變化。在下圖 1 中可以看到更具挑戰(zhàn)性的場景示意圖。
由于在推理階段之前,測試數(shù)據(jù)中的類別流行率未知,而且模型可能會通過盲目的測試時間調(diào)整偏向于多數(shù)類別,這使得現(xiàn)有的 TTA 方法變得無效。根據(jù)經(jīng)驗觀察,對于依靠當(dāng)前批數(shù)據(jù)來估計全局統(tǒng)計量來更新歸一化層的方法來說,這個問題變得尤為突出(BN, PL, TENT, CoTTA 等)。
這主要是由于:
1.當(dāng)前批數(shù)據(jù)會受到局部類別不平衡的影響帶來有偏置的整體分布估計;
2.從全局類別不平衡的整個測試數(shù)據(jù)中估計出單一的全局分布,全局分布很容易偏向多數(shù)類,導(dǎo)致內(nèi)部協(xié)變量偏移。
為了避免有偏差的批歸一化(BN),該團隊提出了一種平衡的批歸一化層(Balanced Batch Normalization Layer),即對每個單獨類別的分布進行建模,并從類別分布中提取全局分布。平衡的批歸一化層允許在局部和全局類別不平衡的測試數(shù)據(jù)流下得到分布的類平衡估計。
隨著時間的推移,領(lǐng)域轉(zhuǎn)移在現(xiàn)實世界的測試數(shù)據(jù)中經(jīng)常發(fā)生,例如照明 / 天氣條件的逐漸變化。這給現(xiàn)有的 TTA 方法帶來了另一個挑戰(zhàn),TTA 模型可能由于過度適應(yīng)到領(lǐng)域 A 而當(dāng)從領(lǐng)域 A 切換到領(lǐng)域 B 時出現(xiàn)矛盾。
為了緩解過度適應(yīng)到某個短時領(lǐng)域,CoTTA 隨機還原參數(shù),EATA 用 fisher information 對參數(shù)進行正則化約束。盡管如此,這些方法仍然沒有明確解決測試數(shù)據(jù)領(lǐng)域中層出不窮的挑戰(zhàn)。
本文在兩分支自訓(xùn)練架構(gòu)的基礎(chǔ)上引入了一個錨定網(wǎng)絡(luò)(Anchor Network)組成三網(wǎng)絡(luò)自訓(xùn)練模型(Tri-Net Self-Training)。錨定網(wǎng)絡(luò)是一個凍結(jié)的源模型,但允許通過測試樣本調(diào)整批歸一化層中的統(tǒng)計量而非參數(shù)。并提出了一個錨定損失利用錨定網(wǎng)絡(luò)的輸出來正則化教師模型的輸出以避免網(wǎng)絡(luò)過度適應(yīng)到局部分布中。
最終模型結(jié)合了三網(wǎng)絡(luò)自訓(xùn)練模型和平衡的批歸一化層(TRI-net self-training with BalancEd normalization, TRIBE)在較為寬泛的的可調(diào)節(jié)學(xué)習(xí)率的范圍里表現(xiàn)出一致的優(yōu)越性能。在四個數(shù)據(jù)集和多種真實世界數(shù)據(jù)流下顯示了大幅性能提升,展示了獨一檔的穩(wěn)定性和魯棒性。
方法介紹
論文方法分為三部分:
- 介紹真實世界下的 TTA 協(xié)議;
- 平衡的批歸一化;
- 三網(wǎng)絡(luò)自訓(xùn)練模型。
真實世界下的 TTA 協(xié)議
作者采用了數(shù)學(xué)概率模型對真實世界下具有局部類別不平衡和全局類別不平衡的測試數(shù)據(jù)流,以及隨著時間變化的領(lǐng)域分布進行了建模。如下圖 2 所示。
平衡的批歸一化
為了糾正不平衡測試數(shù)據(jù)對 BN 統(tǒng)計量產(chǎn)生的估計偏置,作者提出了一個平衡批歸一化層,該層為每個語義類分別維護了一對統(tǒng)計量,表示為:
為了更新類別統(tǒng)計量,作者在偽標(biāo)簽預(yù)測的幫助下應(yīng)用了高效的迭代更新方法,如下所示:
通過偽標(biāo)簽對各個類別數(shù)據(jù)的采樣點進行單獨統(tǒng)計,并通過下式重新得到類別平衡下的整體分布統(tǒng)計量,以此來對齊用類別平衡的源數(shù)據(jù)學(xué)習(xí)好的特征空間。
在某些特殊情況下,作者發(fā)現(xiàn)當(dāng)類別數(shù)量較多或偽標(biāo)簽準(zhǔn)確率較低 (accuracy<0.5) 的情況下,以上的類別獨立的更新策略效果沒那么明顯。因此,他們進一步用超參數(shù) γ 來融合類別無關(guān)更新策略和類別獨立更新策略,如下式:
通過進一步分析和觀察,作者發(fā)現(xiàn)當(dāng) γ=1 時,整個更新策略就退化成了 RoTTA 中的 RobustBN 的更新策略,當(dāng) γ=0 時是純粹的類別獨立的更新策略,因此,當(dāng) γ 取值 0~1 時可以適應(yīng)到各種情況下。
網(wǎng)絡(luò)自訓(xùn)練模型
作者在現(xiàn)有的學(xué)生 - 教師模型的基礎(chǔ)上,添加了一個錨定網(wǎng)絡(luò)分支,并引入了錨定損失來約束教師網(wǎng)絡(luò)的預(yù)測分布。這種設(shè)計受到了 TTAC++ 的啟發(fā)。TTAC++ 指出在測試數(shù)據(jù)流上僅靠自我訓(xùn)練會容易導(dǎo)致確認偏置的積累,這個問題在本文中的真實世界中的測試數(shù)據(jù)流上更加嚴(yán)重。TTAC++ 采用了從源域收集到的統(tǒng)計信息實現(xiàn)領(lǐng)域?qū)R正則化,但對于 Fully TTA 設(shè)定來說,這個源域信息不可收集。
同時,作者也收獲了另一個啟示,無監(jiān)督領(lǐng)域?qū)R的成功是基于兩個領(lǐng)域分布相對高重疊率的假設(shè)。因此,作者僅調(diào)整了 BN 統(tǒng)計量的凍結(jié)源域模型來對教師模型進行正則化,避免教師模型的預(yù)測分布偏離源模型的預(yù)測分布太遠(這破壞了之前的兩者分布高重合率的經(jīng)驗觀測)。大量實驗證明,本文中的發(fā)現(xiàn)與創(chuàng)新是正確的且魯棒的。以下是錨定損失的表達式:
下圖展示了 TRIBE 網(wǎng)絡(luò)的框架圖:
實驗部分
論文作者在 4 個數(shù)據(jù)集上,以兩種真實世界 TTA 協(xié)議為基準(zhǔn),對 TRIBE 進行了驗證。兩種真實世界 TTA 協(xié)議分別是全局類分布固定的 GLI-TTA-F 和全局類分布不固定的 GLI-TTA-V。
上表展示了 CIFAR10-C 數(shù)據(jù)集兩種協(xié)議不同不平衡系數(shù)下的表現(xiàn),可以得到以下結(jié)論:
1.只有 LAME, TTAC, NOTE, RoTTA 和論文提出的 TRIBE 超過了 TEST 的基準(zhǔn)線,表明了真實測試流下更加魯棒的 TTA 方法的必要性。
2.全局類別不平衡對現(xiàn)有的 TTA 方法帶來了巨大挑戰(zhàn),如先前的 SOTA 方法 RoTTA 在 I.F.=1 時表現(xiàn)為錯誤率 25.20% 但在 I.F.=200 時錯誤率升到了 32.45%,相比之下,TRIBE 能穩(wěn)定地展示相對較好的性能。
3. TRIBE 的一致性具有絕對優(yōu)勢,超越了先前的所有方法,并在全局類別平衡的設(shè)定下 (I.F.=1) 超越先前 SOTA (TTAC) 約 7%,在更加困難的全局類別不平衡 (I.F.=200) 的設(shè)定下獲得了約 13% 的性能提升。
4.從 I.F.=10 到 I.F.=200,其他 TTA 方法隨著不平衡度增加,呈現(xiàn)性能下跌的趨勢。而 TRIBE 能維持較為穩(wěn)定的性能表現(xiàn)。這歸因于引入了平衡批歸一化層,更好地考慮了嚴(yán)重的類別不平衡和錨定損失,這避免了跨不同領(lǐng)域的過度適應(yīng)。
更多數(shù)據(jù)集的結(jié)果可查閱論文原文。
此外,表 4 展示了詳細的模塊化消融,有以下幾個觀測性結(jié)論:
1.僅將 BN 替換成平衡批歸一化層 (Balanced BN),不更新任何模型參數(shù),只通過 forward 更新 BN 統(tǒng)計量,就能帶來 10.24% (44.62 -> 34.28) 的性能提升,并超越了 Robust BN 的錯誤率 41.97%。
2.Anchored Loss 結(jié)合 Self-Training,無論是在之前 BN 結(jié)構(gòu)下還是最新的 Balanced BN 結(jié)構(gòu)下,都得到了性能的提升,并超越了 EMA Model 的正則化效果。
本文的其余部分和長達 9 頁的附錄最終呈現(xiàn)了 17 個詳細表格結(jié)果,從多個維度展示了 TRIBE 的穩(wěn)定性、魯棒性和優(yōu)越性。附錄中也含有對平衡批歸一化層的更加詳細的理論推導(dǎo)和解釋。
總結(jié)和展望
為應(yīng)對真實世界中 non-i.i.d. 測試數(shù)據(jù)流、全局類不平衡和持續(xù)的領(lǐng)域轉(zhuǎn)移等諸多挑戰(zhàn),研究團隊深入探索了如何改進測試時領(lǐng)域適應(yīng)算法的魯棒性。為了適應(yīng)不平衡的測試數(shù)據(jù),作者提出了一個平衡批歸一化層(Balanced Batchnorm Layer),以實現(xiàn)對統(tǒng)計量的無偏估計,進而提出了一種包含學(xué)生網(wǎng)絡(luò)、教師網(wǎng)絡(luò)和錨定網(wǎng)絡(luò)的三層網(wǎng)絡(luò)結(jié)構(gòu),以規(guī)范基于自我訓(xùn)練的 TTA。
但本文仍然存在不足和改進的空間,由于大量的實驗和出發(fā)點都基于分類任務(wù)和 BN 模塊,因此對于其他任務(wù)和基于 Transformer 模型的適配程度仍然未知。這些問題值得后續(xù)工作進一步研究和探索。