如何在開放世界進(jìn)行測試段訓(xùn)練?基于動態(tài)原型擴(kuò)展的自訓(xùn)練方法
提高模型泛化能力是推動基于視覺的感知方法落地的重要基礎(chǔ),測試段訓(xùn)練和適應(yīng)(Test-Time Training/Adaptation)通過在測試段調(diào)整模型參數(shù)權(quán)重,將模型泛化至未知的目標(biāo)域數(shù)據(jù)分布段。現(xiàn)有 TTT/TTA 方法通常著眼于在閉環(huán)世界的目標(biāo)域數(shù)據(jù)下提高測試段訓(xùn)練性能。
可是,在諸多應(yīng)用場景中,目標(biāo)域容易受到強(qiáng)域外數(shù)據(jù) (Strong OOD) 數(shù)據(jù)的污染,例如不相關(guān)的語義類別數(shù)據(jù)。在該場景又可稱為開放世界測試段訓(xùn)練 (OWTTT),在該場景下,現(xiàn)有 TTT/TTA 通常將強(qiáng)域外數(shù)據(jù)強(qiáng)行分類至已知類別,從而最終干擾對如收到噪聲干擾圖像的弱域外數(shù)據(jù)(Weak OOD)的分辨能力。
近日,來自華南理工大學(xué)和 A*STAR 團(tuán)隊首次提出開放世界測試段訓(xùn)練的設(shè)定,并推出了針對開放世界測試段訓(xùn)練的方法。
- 論文:https://arxiv.org/abs/2308.09942
- 代碼:https://github.com/Yushu-Li/OWTTT
本文首先提出了一種自適應(yīng)閾值的強(qiáng)域外數(shù)據(jù)樣本過濾方法,提高了自訓(xùn)練 TTT 方法的在開放世界的魯棒性。該方法進(jìn)一步提出了一種基于動態(tài)擴(kuò)展原型來表征強(qiáng)域外樣本的方法,以改進(jìn)弱 / 強(qiáng)域外數(shù)據(jù)分離效果。最后,通過分布對齊來約束自訓(xùn)練。
本文的方法在 5 個不同的 OWTTT 基準(zhǔn)上實(shí)現(xiàn)了最優(yōu)的性能表現(xiàn),并為 TTT 的后續(xù)研究探索面向更加魯棒 TTT 方法的提供了新方向。研究已作為 Oral 論文被 ICCV 2023 接收。
引言
測試段訓(xùn)練(TTT)可以僅在推理階段訪問目標(biāo)域數(shù)據(jù),并對分布偏移的測試數(shù)據(jù)進(jìn)行即時推理。TTT 的成功已經(jīng)在許多人工選擇的合成損壞目標(biāo)域數(shù)據(jù)上得到證明。然而,現(xiàn)有的 TTT 方法的能力邊界尚未得到充分探索。
為促進(jìn)開放場景下的 TTT 應(yīng)用,研究的重點(diǎn)已轉(zhuǎn)移到調(diào)查 TTT 方法可能失敗的場景。人們在更現(xiàn)實(shí)的開放世界環(huán)境下開發(fā)穩(wěn)定和強(qiáng)大的 TTT 方法已經(jīng)做出了許多努力。而在本文工作中,我們深入研究了一個很常見但被忽略的開放世界場景,其中目標(biāo)域可能包含從顯著不同的環(huán)境中提取的測試數(shù)據(jù)分布,例如與源域不同的語義類別,或者只是隨機(jī)噪聲。
我們將上述測試數(shù)據(jù)稱為強(qiáng)分布外數(shù)據(jù)(strong OOD)。而在本工作中被稱為弱 OOD 數(shù)據(jù)則是分布偏移的測試數(shù)據(jù),例如常見的合成損壞。因此,現(xiàn)有工作缺乏對這種現(xiàn)實(shí)環(huán)境的研究促使我們探索提高開放世界測試段訓(xùn)練(OWTTT)的魯棒性,其中測試數(shù)據(jù)被強(qiáng) OOD 樣本污染。
圖 1 :現(xiàn)有的 TTT 方法在 OWTTT 設(shè)定下的評估結(jié)果
如圖 1 所示,我們首先對現(xiàn)有的 TTT 方法在 OWTTT 設(shè)定下進(jìn)行評估,發(fā)現(xiàn)通過自訓(xùn)練和分布對齊的 TTT 方法都會受到強(qiáng) OOD 樣本的影響。這些結(jié)果表明,應(yīng)用現(xiàn)有的 TTT 技術(shù)無法在開放世界中實(shí)現(xiàn)安全的測試時訓(xùn)練。我們將它們的失敗歸因于以下兩個原因。
- 基于自訓(xùn)練的 TTT 很難處理強(qiáng) OOD 樣本,因?yàn)樗仨殞y試樣本分配給已知的類別。盡管可以通過應(yīng)用半監(jiān)督學(xué)習(xí)中采用的閾值來過濾掉一些低置信度樣本,但仍然不能保證濾除所有強(qiáng) OOD 樣本。
- 當(dāng)計算強(qiáng) OOD 樣本來估計目標(biāo)域分布時,基于分布對齊的方法將會受到影響。全局分布對齊 [1] 和類別分布對齊 [2] 都可能受到影響,并導(dǎo)致特征分布對齊不準(zhǔn)確。
考慮到現(xiàn)有 TTT 方法失敗的潛在原因,我們提出了兩種技術(shù)相結(jié)合來提高自訓(xùn)練框架下開放世界 TTT 的魯棒性。
首先,我們在自訓(xùn)練的變體上構(gòu)建 TTT 的基線,即在目標(biāo)域中以源域原型作為聚類中心進(jìn)行聚類。為了減輕自訓(xùn)練受到錯誤偽標(biāo)簽的強(qiáng) OOD 的影響,我們設(shè)計了一種無超參數(shù)的方法來拒絕強(qiáng) OOD 樣本。
為了進(jìn)一步分離弱 OOD 樣本和強(qiáng) OOD 樣本的特征,我們允許原型池通過選擇孤立的強(qiáng) OOD 樣本擴(kuò)展。因此,自訓(xùn)練將允許強(qiáng) OOD 樣本圍繞新擴(kuò)展的強(qiáng) OOD 原型形成緊密的聚類。這將有利于源域和目標(biāo)域之間的分布對齊。我們進(jìn)一步提出通過全局分布對齊來規(guī)范自我訓(xùn)練,以降低確認(rèn)偏差的風(fēng)險。
最后,為了綜合開放世界的 TTT 場景,我們采用 CIFAR10-C、CIFAR100-C、ImageNet-C、VisDA-C、ImageNet-R、Tiny-ImageNet、MNIST 和 SVHN 數(shù)據(jù)集,并通過利用一個數(shù)據(jù)集為弱 OOD,其他為強(qiáng) OOD 建立基準(zhǔn)數(shù)據(jù)集。我們將此基準(zhǔn)稱為開放世界測試段訓(xùn)練基準(zhǔn),并希望這能鼓勵未來更多的工作關(guān)注更現(xiàn)實(shí)場景中測試段訓(xùn)練的穩(wěn)健性。
方法
論文分了四個部分來介紹所提出的方法。
1)概述開放世界下測試段訓(xùn)練任務(wù)的設(shè)定。
2)介紹了如何通過原型聚類實(shí)現(xiàn) TTT 以及如何擴(kuò)展原型以進(jìn)行開放世界測試時訓(xùn)練。
3)介紹了如何利用目標(biāo)域數(shù)據(jù)進(jìn)行動態(tài)原型擴(kuò)展。
4)引入分布對齊與原型聚類相結(jié)合,以實(shí)現(xiàn)強(qiáng)大的開放世界測試時訓(xùn)練。
圖 2 :方法概覽圖
任務(wù)設(shè)定
TTT 的目的是使源域預(yù)訓(xùn)練模型適應(yīng)目標(biāo)域,其中目標(biāo)域可能會相對于源域有分布遷移。在標(biāo)準(zhǔn)的封閉世界 TTT 中,源域和目標(biāo)域的標(biāo)簽空間是相同的。然而在開放世界 TTT 中,目標(biāo)域的標(biāo)簽空間包含源域的目標(biāo)空間,也就是說目標(biāo)域具有未見過的新語義類別。
為了避免 TTT 定義之間的混淆,我們采用 TTAC [2] 中提出的順序測試時間訓(xùn)練(sTTT)協(xié)議進(jìn)行評估。在 sTTT 協(xié)議下,測試樣本被順序測試,并在觀察到小批量測試樣本后進(jìn)行模型更新。對到達(dá)時間戳 t 的任何測試樣本的預(yù)測不會受到到達(dá) t+k(其 k 大于 0)的任何測試樣本的影響。
原型聚類
受到域適應(yīng)任務(wù)中使用聚類的工作啟發(fā) [3,4],我們將測試段訓(xùn)練視為發(fā)現(xiàn)目標(biāo)域數(shù)據(jù)中的簇結(jié)構(gòu)。通過將代表性原型識別為聚類中心,在目標(biāo)域中識別聚類結(jié)構(gòu),并鼓勵測試樣本嵌入到其中一個原型附近。原型聚類的目標(biāo)定義為最小化樣本與聚類中心余弦相似度的負(fù)對數(shù)似然損失,如下式所示。
我們開發(fā)了一種無超參數(shù)的方法來濾除強(qiáng) OOD 樣本,以避免調(diào)整模型權(quán)重的負(fù)面影響。具體來說,我們?yōu)槊總€測試樣本定義一個強(qiáng) OOD 分?jǐn)?shù) os 作為與源域原型的最高相似度,如下式所示。
圖 3 離群值呈雙峰分布
我們觀察到離群值服從雙峰分布,如圖 3 所示。因此,我們沒有指定固定閾值,而是將最佳閾值定義為分離兩種分布的的最佳值。具體來說,問題可以表述為將離群值分為兩個簇,最佳閾值將最小化中的簇內(nèi)方差。優(yōu)化下式可以通過以 0.01 的步長窮舉搜索從 0 到 1 的所有可能閾值來有效實(shí)現(xiàn)。
動態(tài)原型擴(kuò)展
擴(kuò)展強(qiáng) OOD 原型池需要同時考慮源域和強(qiáng) OOD 原型來評估測試樣本。為了從數(shù)據(jù)中動態(tài)估計簇的數(shù)量,之前的研究了類似的問題。確定性硬聚類算法 DP-means [5] 是通過測量數(shù)據(jù)點(diǎn)到已知聚類中心的距離而開發(fā)的,當(dāng)距離高于閾值時將初始化一個新聚類。DP-means 被證明相當(dāng)于優(yōu)化 K-means 目標(biāo),但對簇的數(shù)量有額外的懲罰,為動態(tài)原型擴(kuò)展提供了一個可行的解決方案。
為了減輕估計額外超參數(shù)的難度,我們首先定義一個測試樣本,其具有擴(kuò)展的強(qiáng) OOD 分?jǐn)?shù)作為與現(xiàn)有源域原型和強(qiáng) OOD 原型的最近距離,如下式。因此,測試高于此閾值的樣本將建立一個新的原型。為了避免添加附近的測試樣本,我們增量地重復(fù)此原型擴(kuò)展過程。
隨著其他強(qiáng) OOD 原型的確定,我們定義了用于測試樣本的原型聚類損失,并考慮了兩個因素。首先,分類為已知類的測試樣本應(yīng)該嵌入到更靠近原型的位置并遠(yuǎn)離其他原型,這定義了 K 類分類任務(wù)。其次,被分類為強(qiáng) OOD 原型的測試樣本應(yīng)該遠(yuǎn)離任何源域原型,這定義了 K+1 類分類任務(wù)。考慮到這些目標(biāo),我們將原型聚類損失定義為下式。
分布對齊約束
眾所周知,自訓(xùn)練容易受到錯誤偽標(biāo)簽的影響。目標(biāo)域由 OOD 樣本組成時,情況會更加惡化。為了降低失敗的風(fēng)險,我們進(jìn)一步將分布對齊 [1] 作為自我訓(xùn)練的正則化,如下式。
實(shí)驗(yàn)
我們在 5 個不同的 OWTTT 基準(zhǔn)數(shù)據(jù)集中進(jìn)行測試,包括人工合成的損壞數(shù)據(jù)集和風(fēng)格變化的數(shù)據(jù)集。實(shí)驗(yàn)主要使用了三個評價指標(biāo):弱 OOD 分類準(zhǔn)確率 ACCS、強(qiáng) OOD 分類準(zhǔn)確率 ACCN 和二者的調(diào)和平均數(shù) ACCH。
表 1 不同方法在 Cifar10-C 數(shù)據(jù)集的表現(xiàn)
表 2 不同方法在 Cifar100-C 數(shù)據(jù)集的表現(xiàn)
表 3 不同方法在 ImageNet-C 數(shù)據(jù)集的表現(xiàn)
表 4 不同方法在 ImageNet-R 數(shù)據(jù)集的表現(xiàn)
表 5 不同方法在 VisDA-C 數(shù)據(jù)集的表現(xiàn)
如上表所示,我們的方法在幾乎所有數(shù)據(jù)集上相較于目前最優(yōu)秀的方法都有比較大的提升,可以有效地識別強(qiáng) OOD 樣本,并減小其對弱 OOD 樣本分類的影響。我們的方法可以在開放世界的場景下實(shí)現(xiàn)更加魯棒的 TTT。
總結(jié)
本文首次提出了開放世界測試段訓(xùn)練(OWTTT)的問題和設(shè)定,指出現(xiàn)有的方法在處理含有和源域樣本有語義偏移的強(qiáng) OOD 樣本的目標(biāo)域數(shù)據(jù)時時會遇到困難,并提出一個基于動態(tài)原型擴(kuò)展的自訓(xùn)練的方法解決上述問題。我們希望這項(xiàng)工作能夠?yàn)?TTT 的后續(xù)研究探索面向更加魯棒的 TTT 方法提供新方向。