達摩院開源半監(jiān)督學習框架Dash,刷新多項SOTA
一、研究背景
監(jiān)督學習(Supervised Learning)?
我們知道模型訓練的目的其實是學習一個預測函數,在數學上,這可以刻畫成一個學習從數據 (X) 到標注 (y) 的映射函數。監(jiān)督學習就是一種最常用的模型訓練方法,其效果的提升依賴于大量的且進行了很好標注的訓練數據,也就是所謂的大量帶標簽數據 ((X,y))。但是標注數據往往需要大量的人力物力等等,因此效果提升的同時也會帶來成本過高的問題。在實際應用中經常遇到的情況是有少量標注數據和大量未標注數據,由此引出的半監(jiān)督學習也越來越引起科學工作者的注意。
半監(jiān)督學習(Semi-Supervised Learning)?
半監(jiān)督學習同時對少量標注數據和大量未標注數據進行學習,其目的是借助無標簽數據來提高模型的精度。比如 self-training 就是一種很常見的半監(jiān)督學習方法,其具體流程是對于標注數據 (X, y) 學習數據從 X 到 y 的映射,同時利用學習得到的模型對未標注數據 X 預測出一個偽標簽,通過對偽標簽數據 (X,
)進一步進行監(jiān)督學習來幫助模型進行更好的收斂和精度提高。
核心解決問題?
現有的半監(jiān)督學習框架對無標簽數據的利用大致可以分為兩種,一是全部參與訓練,二是用一個固定的閾值卡出置信度較高的樣本進行訓練 (比如 FixMatch)。由于半監(jiān)督學習對未標注數據的利用依賴于當前模型預測的偽標簽,所以偽標簽的正確與否會給模型的訓練帶來較大的影響,好的預測結果有助于模型的收斂和對新的模式的學習,差的預測結果則會干擾模型的訓練。所以我們認為:不是所有的無標簽樣本都是必須的!
二、論文 & 代碼
- 論文鏈接:https://proceedings.mlr.press/v139/xu21e/xu21e.pdf
- 代碼地址:https://github.com/idstcv/Dash
- 技術應用:https://modelscope.cn/models/damo/cv_manual_face-liveness_flrgb/summary
這篇論文創(chuàng)新性地提出用動態(tài)閾值(dynamic threshold)的方式篩選無標簽樣本進行半監(jiān)督學習(semi-supervised learning,SSL)的方法,我們改造了半監(jiān)督學習的訓練框架,在訓練過程中對無標簽樣本的選擇策略進行了改進,通過動態(tài)變化的閾值來選擇更有效的無標簽樣本進行訓練。Dash 是一個通用策略,可以輕松與現有的半監(jiān)督學習方法集成。實驗方面,我們在 CIFAR-10, CIFAR-100, STL-10 和 SVHN 等標準數據集上充分驗證了其有效性。理論方面,論文從非凸優(yōu)化的角度證明了 Dash 算法的收斂性質。
三、方法
Fixmatch 訓練框架?
在引出我們的方法 Dash 之前,我們介紹一下 Google 提出的 FixMatch 算法,一種利用固定閾值選擇無標簽樣本的半監(jiān)督學習方法。FixMatch 訓練框架是之前的 SOTA 解決方案。整個學習框架的重點可以歸納為以下幾點:
1、對于無標簽數據經過弱數據增強(水平翻轉、偏移等)得到的樣本通過當前的模型得到預測值
2、對于無標簽數據經過強數據增強(RA or CTA)得到的樣本通過當前的模型得到預測值
3、把具有高置信度的弱數據增強的結果,通過 one hot 的方式形成偽標簽
,然后用
和 X 經過強數據增強得到的預測值
進行模型的訓練。
fixmatch 的優(yōu)點是用弱增強數據進行偽標簽的預測,增加了偽標簽預測的準確性,并在訓練過程中用固定的閾值 0.95(對應 loss 為 0.0513) 選取高置信度(閾值大于等于 0.95,也就是 loss 小于等于 0.0513)的預測樣本生成偽標簽,進一步穩(wěn)定了訓練過程。
Dash 訓練框架?
針對全部選擇偽標簽和用固定閾值選擇偽標簽的問題,我們創(chuàng)新性地提出用動態(tài)閾值來進行樣本篩選的策略。即動態(tài)閾值 是隨 t 衰減的
式中 C=1.0001,是有標簽數據在第一個 epoch 之后 loss 的平均值,我們選擇那些
的無標簽樣本參與梯度回傳。下圖展示了不同
值下的閾值
的變化曲線。可以看到參數
控制了閾值曲線的下降速率。
的變化曲線類似于模擬訓練模型時損失函數下降的趨勢。
下圖對比了訓練過程中的 FixMath 和 Dash 選擇的正確樣本數和錯誤樣本數隨訓練進行的變化情況(使用的數據集是 cifar100)。從圖中可以很清楚地看到,對比 FixMatch,Dash 可以選取更多正確 label 的樣本,同時選擇更少的錯誤 label 的樣本,從而最終有助于提高訓練模型的精度。
我們的算法可以總結為如下 Algorithm 1。Dash 是一個通用策略,可以輕松與現有的半監(jiān)督學習方法集成。為了方便,在本文的實驗中我們主要將 Dash 與 FixMatch 集成。更多理論證明詳見論文。
四、結果
我們在半監(jiān)督學習常用數據集:CIFAR-10,CIFAR-100,STL-10 和 SVHN 上進行了算法的驗證。結果分別如下:
可以看到我們的方法在多個實驗設置上都取得了比 SOTA 更好的結果,其中需要說明的是針對 CIFAR-100 400label 的實驗,ReMixMatch 用了 data align 的額外 trick 取得了更好的結果,在 Dash 中加入 data align 的 trick 之后可以取得 43.31% 的錯誤率,低于 ReMixMatch 44.28% 的錯誤率。
五、應用
實際面向任務域的模型研發(fā)過程中,該半監(jiān)督 Dash 框架經常會被應用到。接下來給大家介紹下我們研發(fā)的各個域上的開源免費模型,歡迎大家體驗、下載(大部分手機端即可體驗):
- ?https://modelscope.cn/models/damo/cv_resnet50_face-detection_retinaface/summary?
- ?https://modelscope.cn/models/damo/cv_resnet101_face-detection_cvpr22papermogface/summary?
- ?https://modelscope.cn/models/damo/cv_manual_face-detection_tinymog/summary?
- ?https://modelscope.cn/models/damo/cv_manual_face-detection_ulfd/summary?
- ?https://modelscope.cn/models/damo/cv_manual_face-detection_mtcnn/summary?
- ?https://modelscope.cn/models/damo/cv_resnet_face-recognition_facemask/summary?
- ?https://modelscope.cn/models/damo/cv_ir50_face-recognition_arcface/summary?
- ?https://modelscope.cn/models/damo/cv_manual_face-liveness_flir/summary?
- ?https://modelscope.cn/models/damo/cv_manual_face-liveness_flrgb/summary?
- ?https://modelscope.cn/models/damo/cv_manual_facial-landmark-confidence_flcm/summary?
- ?https://modelscope.cn/models/damo/cv_vgg19_facial-expression-recognition_fer/summary?
- ?https://modelscope.cn/models/damo/cv_resnet34_face-attribute-recognition_fairface/summary?