如何正確定義測(cè)試階段訓(xùn)練?順序推理和域適應(yīng)聚類方法
域適應(yīng)是解決遷移學(xué)習(xí)的重要方法,當(dāng)前域適應(yīng)當(dāng)法依賴原域和目標(biāo)域數(shù)據(jù)進(jìn)行同步訓(xùn)練。當(dāng)源域數(shù)據(jù)不可得,同時(shí)目標(biāo)域數(shù)據(jù)不完全可見時(shí),測(cè)試階段訓(xùn)練(Test- Time Training)成為新的域適應(yīng)方法。當(dāng)前針對(duì) Test-Time Training(TTT)的研究廣泛利用了自監(jiān)督學(xué)習(xí)、對(duì)比學(xué)習(xí)、自訓(xùn)練等方法,然而,如何定義真實(shí)環(huán)境下的 TTT 卻被經(jīng)常忽略,以至于不同方法間缺乏可比性。
近日,華南理工、A*STAR 團(tuán)隊(duì)和鵬城實(shí)驗(yàn)室聯(lián)合提出了針對(duì) TTT 問題的系統(tǒng)性分類準(zhǔn)則,通過區(qū)分方法是否具備順序推理能力(Sequential Inference)和是否需要修改源域訓(xùn)練目標(biāo),對(duì)當(dāng)前方法做了詳細(xì)分類。同時(shí),提出了基于目標(biāo)域數(shù)據(jù)定錨聚類(Anchored Clustering)的方法,在多種 TTT 分類下取得了最高的分類準(zhǔn)確率,本文對(duì) TTT 的后續(xù)研究指明了正確的方向,避免了實(shí)驗(yàn)設(shè)置混淆帶來的結(jié)果不可比問題。研究論文已被 NeurIPS 2022 接收。
- 論文:https://arxiv.org/abs/2206.02721
- 代碼:https://github.com/Gorilla-Lab-SCUT/TTAC
一、引言
深度學(xué)習(xí)的成功主要?dú)w功于大量的標(biāo)注數(shù)據(jù)和訓(xùn)練集與測(cè)試集獨(dú)立同分布的假設(shè)。在一般情況下,需要在合成數(shù)據(jù)上訓(xùn)練,然后在真實(shí)數(shù)據(jù)上測(cè)試時(shí),以上假設(shè)就沒辦法滿足,這也被稱為域偏移。為了緩解這個(gè)問題,域適應(yīng) (Domain Adaptation, DA) 誕生了?,F(xiàn)有的 DA 工作要么需要在訓(xùn)練期間訪問源域和目標(biāo)域的數(shù)據(jù),要么同時(shí)在多個(gè)域進(jìn)行訓(xùn)練。前者需要模型在做適應(yīng) (Adaptation) 訓(xùn)練期間總是能訪問到源域數(shù)據(jù),而后者需要更加昂貴的計(jì)算量。為了降低對(duì)源域數(shù)據(jù)的依賴,由于隱私問題或者存儲(chǔ)開銷不能訪問源域數(shù)據(jù),無需源域數(shù)據(jù)的域適應(yīng) (Source-Free Domain Adaptation, SFDA) 解決無法訪問源域數(shù)據(jù)的域適應(yīng)問題。作者發(fā)現(xiàn) SFDA 需要在整個(gè)目標(biāo)數(shù)據(jù)集上訓(xùn)練多個(gè)輪次才能達(dá)到收斂,在面對(duì)流式數(shù)據(jù)需要及時(shí)做出推斷預(yù)測(cè)的時(shí)候 SFDA 無法解決此類問題。這種面對(duì)流式數(shù)據(jù)需要及時(shí)適應(yīng)并做出推斷預(yù)測(cè)的更現(xiàn)實(shí)的設(shè)定,被稱為測(cè)試時(shí)訓(xùn)練 (Test-Time Training, TTT) 或測(cè)試時(shí)適應(yīng)(Test-Time Adaptation, TTA)。
作者注意到在社區(qū)里對(duì) TTT 的定義存在混亂從而導(dǎo)致比較的不公平。論文以兩個(gè)關(guān)鍵的因素對(duì)現(xiàn)有的 TTT 方法進(jìn)行分類:
- 對(duì)于數(shù)據(jù)是流式出現(xiàn)的并需要對(duì)當(dāng)前出現(xiàn)的數(shù)據(jù)作出及時(shí)預(yù)測(cè)的,稱之為單輪適應(yīng)協(xié)議(One-Pass Adaptation);對(duì)于其他不符合以上設(shè)定的稱為多輪適應(yīng)協(xié)議(Multi-Pass Adaptation),模型可能需要在整個(gè)測(cè)試集上進(jìn)行多輪次的更新后,再進(jìn)行從頭到尾的推斷預(yù)測(cè)。
- 根據(jù)是否需要修改源域的訓(xùn)練損失方程,比如引入額外的自監(jiān)督分支以達(dá)到更有效的 TTT。
這篇論文的目標(biāo)是解決最現(xiàn)實(shí)和最具挑戰(zhàn)性的 TTT 協(xié)議,即單輪適應(yīng)并無需修改訓(xùn)練損失方程。這個(gè)設(shè)定類似于 TENT[1]提出的 TTA,但不限于使用來自源域的輕量級(jí)信息,如特征的統(tǒng)計(jì)量。鑒于 TTT 在測(cè)試時(shí)高效適應(yīng)的目標(biāo),該假設(shè)在計(jì)算上是高效的,并大大提高了 TTT 的性能。作者將這個(gè)新的 TTT 協(xié)議命名為順序測(cè)試時(shí)訓(xùn)練(sequential Test Time Training, sTTT)。
除了以上對(duì)不同 TTT 方法的分類外,論文還提出了兩個(gè)技術(shù)讓 sTTT 更加有效和準(zhǔn)確:
- 論文提出了測(cè)試時(shí)錨定聚類 (Test-Time Anchored Clustering, TTAC) 方法。
- 為了降低錯(cuò)誤偽標(biāo)簽對(duì)聚類更新的影響,論文根據(jù)網(wǎng)絡(luò)對(duì)樣本的預(yù)測(cè)穩(wěn)定性和自信度對(duì)偽標(biāo)簽進(jìn)行過濾。
二、方法介紹
論文分了四部分來闡述所提出的方法,分別是 1)介紹測(cè)試時(shí)訓(xùn)練 (TTT) 的錨定聚類模塊,如圖 1 中的 Anchored Clustering 部分;2)介紹用于過濾偽標(biāo)簽的一些策略,如圖 1 中的 Pseudo Label Filter 部分;3)不同于 TTT++[2]中的使用 L2 距離來衡量兩個(gè)分布的距離,作者使用了 KL 散度來度量兩個(gè)全局特征分布間的距離;4)介紹在測(cè)試時(shí)訓(xùn)練 (TTT) 過程的特征統(tǒng)計(jì)量的有效更新迭代方法。最后第五小節(jié)給出了整個(gè)算法的過程代碼。
第一部分 在錨定聚類里,作者首先使用混合高斯對(duì)目標(biāo)域的特征進(jìn)行建模,其中每個(gè)高斯分量代表一個(gè)被發(fā)現(xiàn)的聚類。然后,作者使用源域中每個(gè)類別的分布作為目標(biāo)域分布的錨點(diǎn)來進(jìn)行匹配。通過這種方式,測(cè)試數(shù)據(jù)特征可以同時(shí)形成集群,并且集群與源域類別相關(guān)聯(lián),從而達(dá)到了對(duì)目標(biāo)域的推廣。概述來說就是,將源域和目標(biāo)域的特征分別根據(jù)類別信息建模成:
然后通過 KL 散度度量兩個(gè)混合高斯分布的距離,并通過減少 KL 散度來達(dá)到兩個(gè)域特征的匹配。可是,在兩個(gè)混合高斯分布上直接求解 KL 散度并沒有閉式解,這導(dǎo)致了無法使用有效的梯度優(yōu)化方法。在這篇論文中,作者在源域和目標(biāo)域中分配相同數(shù)量的集群,每個(gè)目標(biāo)域集群被分配給一個(gè)源域集群,這樣就可以將整個(gè)混合高斯的 KL 散度求解變成了各對(duì)高斯之間的 KL 散度之和。如下式:
上式的閉式解形式為:
在公式 2 中,源域集群的參數(shù)可以線下收集完,而且由于只用到了輕量化統(tǒng)計(jì)數(shù)據(jù),所以不會(huì)導(dǎo)致隱私泄漏問題且只使用了少量的計(jì)算和存儲(chǔ)開銷。對(duì)于目標(biāo)域的變量,涉及到了偽標(biāo)簽的使用,作者為此設(shè)計(jì)了一套有效的且輕量的偽標(biāo)簽過濾策略。
第二部分 偽標(biāo)簽過濾的策略主要分為兩部分:
1)時(shí)序上一致性預(yù)測(cè)的過濾:
2)根據(jù)后驗(yàn)概率的過濾:
最后,使用過濾后的樣本來求解目標(biāo)域集群的統(tǒng)計(jì)量:
第三部分 由于在錨定聚類中,部分被濾除的樣本并沒有參與目標(biāo)域的估計(jì)。作者還對(duì)所有測(cè)試樣本進(jìn)行全局特征對(duì)齊,類似錨定聚類中對(duì)集群的做法,這里將所有樣本看作一個(gè)整體的集群,在源域和目標(biāo)域分別定義
然后再次以最小化 KL 散度為目標(biāo)對(duì)齊全局特征分布:
第四部分 以上三部分都在介紹一些域?qū)R的手段,但在 TTT 過程中,想要估計(jì)一個(gè)目標(biāo)域的分布是不簡單的,因?yàn)槲覀儫o法觀測(cè)整個(gè)目標(biāo)域的數(shù)據(jù)。在前沿的工作中,TTT++[2]使用了一個(gè)特征隊(duì)列來存儲(chǔ)過去的部分樣本,來計(jì)算一個(gè)局部分布來估計(jì)整體分布。但這樣不但帶來了內(nèi)存開銷還導(dǎo)致了精度與內(nèi)存之間的 trade off。在這篇論文中,作者提出了迭代更新統(tǒng)計(jì)量的方式來緩解內(nèi)存開銷。具體的迭代更新式子如下:
總的來說,整個(gè)算法如下算法 1 所示:
三、實(shí)驗(yàn)結(jié)果
正如引言部分所說,這篇論文中作者非常注重不同 TTT 策略下的不同方法的公平比較。作者將所有 TTT 方法根據(jù)以下兩個(gè)關(guān)鍵因素來分類:1)是否單輪適應(yīng)協(xié)議 (One-Pass Adaptation) 和 2)修改源域的訓(xùn)練損失方程,分別記為 Y/N 表示需要或不需要修改源域訓(xùn)練方程,O/M 表示單輪適應(yīng)或多輪適應(yīng)。除此之外,作者在 6 個(gè)基準(zhǔn)的數(shù)據(jù)集上進(jìn)行了充分的對(duì)比實(shí)驗(yàn)和一些進(jìn)一步的分析。
如表一所示,TTT++[2]同時(shí)出現(xiàn)在了 N-O 和 Y-O 的協(xié)議下,是因?yàn)?TTT++[2]擁有一個(gè)額外的自監(jiān)督分支,我們?cè)?N-O 協(xié)議下將不添加自監(jiān)督分支的損失,而在 Y-O 下可以正常使用此分子的損失。TTAC 在 Y-O 下也是使用了跟 TTT++[2]一樣的自監(jiān)督分支。從表中可以看到,在所有的 TTT 協(xié)議下所有數(shù)據(jù)集下,TTAC 均取得到最優(yōu)的結(jié)果;在 CIFAR10-C 和 CIFAR100-C 數(shù)據(jù)集上,TTAC 都取得了 3% 以上的提升。從表 2 - 表 5 分別是 ImageNet-C、CIFAR10.1、VisDA 上的數(shù)據(jù),TTAC 均取到了最優(yōu)的結(jié)果。
此外,作者在多個(gè) TTT 協(xié)議下同時(shí)做了嚴(yán)格的消融實(shí)驗(yàn),清晰地看出了每個(gè)部件的作用,如表 6 所示。首先從 L2 Dist 和 KLD 的對(duì)比中,可以看出使用 KL 散度來衡量兩個(gè)分布具有更優(yōu)的效果;其次,發(fā)現(xiàn)如果單單使用 Anchored Clustering 或單獨(dú)使用偽標(biāo)簽監(jiān)督提升只有 14%,但如果結(jié)合了 Anchored Cluster 和 Pseudo Label Filter 就可以看到性能顯著提高 29.15% -> 11.33%。這也可以看出每個(gè)部件的必要性和有效的結(jié)合。
最后,作者在正文的尾部從五個(gè)維度對(duì) TTAC 展開了充分的分析,分別是 sTTT (N-O)下的累計(jì)表現(xiàn)、TTAC 特征的 TSNE 可視化、源域無關(guān)的 TTT 分析、測(cè)試樣本隊(duì)列和更新輪次的分析、以 wall-clock 時(shí)間度量計(jì)算開銷。還有更多有趣的證明和分析會(huì)展示在文章的附錄中。
四、總結(jié)
本文只是粗糙地介紹了 TTAC 這篇工作的貢獻(xiàn)點(diǎn):對(duì)已有 TTT 方法的分類比較、提出的方法、以及各個(gè) TTT 協(xié)議分類下的實(shí)驗(yàn)。論文和附錄中會(huì)有更加詳細(xì)的討論和分析。我們希望這項(xiàng)工作能夠?yàn)?TTT 方法提供一個(gè)公平的基準(zhǔn),未來的研究應(yīng)該在各自的協(xié)議內(nèi)進(jìn)行比較。