CVPR 2024 | 分割一切模型SAM泛化能力差?域適應策略給解決了
引言
大語言模型(LLMs)的成功激發(fā)了計算機視覺領域探索分割基礎模型的興趣。這些基礎分割模型通常通過 Prompt Engineer 來進行 zero/few 圖像分割。其中,Segment Anything Model(SAM)是最先進的圖像分割基礎模型。
圖 SAM 在多個下游任務上表現不佳
但是最近的研究表明,SAM 在多種下游任務中并非具有很強的魯棒性與泛化性,例如在醫(yī)學圖像、偽裝物體、添加干擾的自然圖像等領域表現較差。這可能是由于訓練數據集與下游的測試數據集之間存在較大的域差異(Domain Shift)所致。因此,一個非常重要的問題是,如何設計域自適應方案,使 SAM 在面對現實世界和多樣化的下游任務中更加魯棒?
將預訓練好的 SAM 適應到下游任務主要面臨三個挑戰(zhàn):
- 首先,傳統(tǒng)的無監(jiān)督域自適應范式需要源數據集和目標數據集,由于隱私和計算成本較為不可行。
- 其次,對于域適應,更新所有權重通常性能更好,同時也受到了昂貴的內存成本的限制。
- 最后,SAM 可以針對不同種類、不同顆粒度的提示 Prompt,展現出多樣化的分割能力,因此當缺乏下游任務的提示信息時,無監(jiān)督適應將非常具有挑戰(zhàn)性。
圖 1 SAM 在大規(guī)模數據集上進行預訓練,但存在泛化性問題。我們使用弱監(jiān)督的方式在各種下游任務上對 SAM 進行自適應
為了應對上述挑戰(zhàn),我們提出了一種具有錨點正則化和低秩微調的弱監(jiān)督自訓練架構,以提高自適應的魯棒性和計算效率。
具體而言,我們首先采用無源域的自訓練策略,從而避免對源數據的依賴。自訓練產生偽標簽,用于監(jiān)督模型的更新,但是容易受到錯誤偽標簽的影響,我們引入凍結的 source model 作為錨定網絡,以規(guī)范模型的更新。
為了進一步減少更新完整模型權重的高計算成本,我們對編碼器應用低秩權重分解,并通過低秩快捷路徑進行反向傳播。
最后,為了進一步提高無源域自適應的效果,我們在目標域引入了弱監(jiān)督(weak supervise),例如稀疏的點注釋,以提供更強的域適應信息,同時這種弱監(jiān)督與 SAM 中的提示編碼器自然兼容。
借助弱監(jiān)督作為 Prompt,我們獲得了更局部、更明確的自訓練偽標簽。經過調整的模型在多個下游任務上表現出了更強的泛化能力。
我們總結本工作的貢獻如下:
1. 我們受到 SAM 在下游任務中泛化問題的啟發(fā),提出了一種與任務無關且無需源數據的解決方案,通過自訓練來適應 SAM。
2. 我們利用弱監(jiān)督,包括 box、point 等標簽,以提高自適應效果。這些弱監(jiān)督標簽與 SAM 的提示編碼器完全兼容。
3. 我們對 5 種類型的下游實例分割任務進行了大量實驗,證明了所提出的弱監(jiān)督自適應方法的有效性。
- 論文地址:https://arxiv.org/pdf/2312.03502.pdf
- 項目地址:https://github.com/Zhang-Haojie/WeSAM
- 論文標題:Improving the Generalization of Segmentation Foundation Model under Distribution Shift via Weakly Supervised Adaptation
方法
方法介紹分為四個部分:
- Segment Anything 模型
- 基于自訓練的自適應框架
- 弱監(jiān)督如何幫助實現有效的自訓練
- 低秩權重更新
1.Segment Anything Model
SAM 主要由三個組件構成:圖像編碼器(ImageEncoder)、提示編碼器(PromptEncoder)、和解碼器(MaskDecoder)。
圖像編碼器使用 MAE 進行預訓練,整個 SAM 在擁有 11 億標注的訓練集 SA-1B 上進一步進行微調,訓練時使用了 Focal loss 和 Dice loss 的組合。推理時,測試圖片 x 首先由圖像編碼器進行編碼,然后給定提示 Prompt,輕量級的解碼器將進行三個級別的預測。
2.Source-Free 域適應自訓練
圖 2 所提出的具有錨定網絡正則化和對比損失正則化的自訓練架構
針對未提供標記的目標數據集 DT={xi} 和預訓練的分割模型。我們采用了 student-teacher 架構進行自訓練。如圖 2 所示,我們維護三個編碼器網絡,即 anchor model、student model、teacher model,其中 student 和 teacher model 共享權重。
具體來說,對于每個樣本 xi,應用一個隨機的弱數據增強作為 anchor 和 teacher model 的輸入,應用一個隨機的強數據增強作為 student model 的輸入,三個編碼器網絡編碼產生三個特征圖。
在解碼器網絡中,給定一定數量 Np 的提示 prompt,例如 box、point 或 coarse mask,將推理出一組實例分割的 masks。
基于以上知識,我們下面詳細闡述用于自訓練的三組優(yōu)化目標。
1) Student-Teacher 自訓練
我們首先使用與訓練 SAM 時相同的損失函數作為自訓練優(yōu)化目標來更新 student/teacher model。自訓練廣泛應用于半監(jiān)督學習,最近還被證明了對無源域自適應非常有效。具體而言,我們使用 teacher model 產生的預測結果,作為偽標簽(Pseudo label),并使用 Focal loss 和 Dice loss 來監(jiān)督 student 的輸出。
2) Anchor 損失用于魯棒正則化
僅使用自訓練損失進行網絡訓練容易受到 teacher 網絡預測的錯誤偽標簽積累的影響,即所謂的確認偏差。觀察也表明,僅使用自訓練長時間迭代后性能會下降?,F有的無源域自適應方法通常采用額外的約束來防止自訓練的負面影響,例如對預測進行均勻分布。
我們通過 anchor 損失來進行正則化,如公式 3 所示,分別最小化了 anchor model 與 student/teacher model 之間的 Dice loss。凍結的 anchor model 作為從源域(source domain)繼承的知識,不鼓勵源模型和自訓練更新模型之間出現過大的偏差,可以防止模型崩潰。
3) 對比損失正則化編碼器特征空間
圖 3 兩個分支下的對比損失
以上兩個訓練目標是在解碼器的輸出空間中執(zhí)行的。實驗部分揭示出,更新編碼器網絡是適應 SAM 最有效的方法,因此有必要直接對從編碼器網絡輸出的特征應用正則化。具體如圖 3 所示,我們根據 anchor 和 teacher 分支中預測 mask 從特征圖中裁剪出每個實例的特征
。
我們進一步定義對比損失中的正負樣本對,正樣本對是由兩個分支中使用相同的 prompt 對應的實例特征構建,而負樣本對是由不同 prompt 對應的實例特征來構建的。最終的對比損失如下所示,其中
是溫度系數。
4) 總損失
我們將上述三個損失函數組合成最終的 Source-Free 自適應損失。
3. 自訓練的 Prompt 生成
SAM 分割需要 Prompt 輸入來指示出待分割的目標對象,但可能會存在顆粒度模糊的問題。Prompt 工程可以以完全自動化的方式實現,也可以通過人工交互實現。
1) 完全自動生成 Prompt
我們首先使用網格密集采樣點作為 prompt 輸入,通過 Anchor model 生成初始階段分割的 masks,剔除 IoU 和穩(wěn)定性得分低的 mask,然后進行非極大值抑制來獲得分割結果。接下來從最終的 masks 中產生一組固定的 prompts,作為所有三個分支的 prompt 輸入。因此,三個網絡分割輸出的 mask 長度相同,并且具有精確的一對一對應關系。
2) 弱監(jiān)督作為 Prompt
盡管可以通過在圖像上使用網格采樣獲得 prompts,并過濾掉質量低和重復的 mask 來進行自動分割。但這些分割質量相對較差,可能包含許多誤報預測,并且顆粒度不明確。由此產生的 prompt 質量參差不齊,使得自訓練效果較差。
因此,借鑒先前的弱監(jiān)督域自適應工作,我們提出使用三種弱監(jiān)督方式,包括邊界框 box、稀疏點標注 point 和粗分割多邊形 coarse mask。在 SAM 中,這些弱監(jiān)督方式與 prompt 輸入完美匹配,可以無縫集成弱監(jiān)督以適應 SAM。
4. 低秩權重更新
基礎模型龐大的編碼器網絡使得更新所有模型的權重變得異常困難。然而,許多現有研究表明,更新編碼器網絡權重是調整預訓練模型的有效方法。
為了能夠更加有效且低成本地更新編碼器網絡,我們選擇了一種計算友好的低秩更新方法。對于編碼器網絡中的每個權重 θ,我們使用低秩近似 ω = AB,并設定一個壓縮率 r。只有 A 和 B 通過反向傳播進行更新以減少內存占用。在推理階段,通過將低秩近似和原始權重組合來重構權重,即 θ = θ + AB。
實驗
在實驗中,我們提供了與最先進方法的詳細比較和定性結果。最后,我們分析了各個部分的有效性以及網絡的具體設計。
1. 數據集
在這項工作中,我們對五種不同類型的下游分割任務進行評估,其中一些與 SA-1B 存在明顯的分布偏移。數據集涵蓋了清晰的自然圖像、添加干擾的自然圖像、醫(yī)學圖像、偽裝物體和機器人圖像,總計 10 種。
數據劃分:每個下游數據集被劃分為互不重疊的訓練集和測試集。
表 1 中列出了每種類型下游任務所評估的數據集,以及訓練和測試數據集的劃分。
2. 實驗細節(jié)
Segment-Anything 模型:由于內存限制,我們采用 ViT-B 作為編碼器網絡。采用標準提示編碼器和 mask 解碼器。
Prompt 生成:訓練和評估階段的 Prompt 輸入均是由從實例分割 GT mask 計算而來,模擬人類交互作為弱監(jiān)督。
具體來說,我們從整個 GT mask 的最小邊界框中提取 box。Point 是通過在 GT mask 內隨機選擇 5 個正樣本點和 5 個 mask 外的負樣本點創(chuàng)建的。Coarse mask 是通過將多邊形擬合到 GT mask 來模擬的。
3. 實驗結果
表 2、3、4、5 分別是在添加干擾的自然圖像、清晰的自然圖像、醫(yī)學圖像、偽裝物體數據集上的測試結果,完整的實驗結果可以在論文中找到。實驗證明了我們的方案在幾乎所有的下游分割數據集上都優(yōu)于預訓練的 SAM 和最先進的域適應方案。
4. 可視化結果
部分可視化結果如圖 4 所示,更多的可視化結果可以在論文中找到。
圖 4 部分實例的可視化結果
5. 消融實驗和額外分析
我們在 COCO 數據集上分析了三個自訓練優(yōu)化目標各自的有效性,如表 7 所示。表 7 中,我們還分析了所提出方法在不使用任何弱監(jiān)督信息時進行自適應的效果。
我們分析了訓練和測試使用不同類別的 prompt 的性能差異,如表 8 所示。實驗表明我們的方案在 cross-prompt 條件下依然表現良好。
此外,我們還分析了優(yōu)化不同模塊,包括解碼器、LayerNorm 和不同的 finetune 方案以及他們的組合的實驗結果,實驗證明了 finetune 編碼器的 LoRA 方案效果最佳。
總結
盡管視覺基礎模型可以在分割任務上表現出色,但其在下游任務中仍會存在性能不佳的情況。我們研究了 Segment-Anything 模型在多個下游圖像分割任務中的泛化能力,并提出了一種基于錨點正則化和低秩微調的自訓練方法。該方法無需訪問源數據集、內存成本低、與弱監(jiān)督自然兼容,可以顯著提升自適應效果。經過廣泛的實驗驗證,結果表明我們提出的域適應方法可以顯著改善 SAM 在各種分布遷移下的泛化能力。
本文轉自機器之心 ,作者:機器之心
