模態(tài)編碼器 | CLIP改進(jìn)之SigLIP,采用sigmoid損失的圖文預(yù)訓(xùn)練
DeepMind對(duì)CLIP改進(jìn)的一篇工作--SigLIP,發(fā)表在2023CVPR。
簡(jiǎn)單看下研究動(dòng)機(jī):傳統(tǒng)的對(duì)比學(xué)習(xí)方法如CLIP等依賴于 softmax 歸一化,這需要一個(gè)全局視角來(lái)計(jì)算成對(duì)相似度,從而限制了批處理大小的擴(kuò)展能力,并且在小批處理大小下表現(xiàn)不佳。因此本文提出了一個(gè)簡(jiǎn)單的成對(duì) Sigmoid 損失函數(shù)用于語(yǔ)言-圖像預(yù)訓(xùn)練(SigLIP)。
01、方法介紹
用于語(yǔ)言圖像預(yù)訓(xùn)練的Softmax損失
CLIP等經(jīng)典對(duì)比學(xué)習(xí)方法通常使用基于 softmax 的損失函數(shù)來(lái)對(duì)齊圖像和文本的表示。具體來(lái)說(shuō),給定一個(gè)包含圖像-文本對(duì)的小批量,對(duì)比學(xué)習(xí)的目標(biāo)是使匹配對(duì)
的嵌入對(duì)齊,同時(shí)將不匹配對(duì)
的嵌入推遠(yuǎn)。通常通過(guò)訓(xùn)練一個(gè)圖像模型f(?) 和一個(gè)文本模型g(?)實(shí)現(xiàn)。損失函數(shù)如下:
其中:
由于 softmax 損失的不對(duì)稱性,規(guī)范化分別在圖像和文本上獨(dú)立進(jìn)行。這種損失函數(shù)需要計(jì)算所有成對(duì)相似度的全局歸一化因子,在大規(guī)模批處理中會(huì)導(dǎo)致內(nèi)存和計(jì)算開銷。
用于語(yǔ)言圖像預(yù)訓(xùn)練的sigmoid損失
為了解決 softmax 損失的局限性,本文提出了一種基于 Sigmoid 的損失函數(shù)。Sigmoid 損失函數(shù)獨(dú)立處理每個(gè)圖像-文本對(duì),避免了全局歸一化的需要。具體來(lái)說(shuō),Sigmoid 損失函數(shù)將學(xué)習(xí)問(wèn)題轉(zhuǎn)化為標(biāo)準(zhǔn)的二分類問(wèn)題,其中匹配對(duì)標(biāo)記為正樣本,所有其他對(duì)
的嵌入推遠(yuǎn),標(biāo)記為負(fù)樣本。Sigmoid 損失函數(shù)的形式如下:
其中:是標(biāo)簽,
如果是匹配對(duì),則
,否則
。
是 Sigmoid 函數(shù),t和b是可學(xué)習(xí)的溫度和偏置參數(shù),初始化為t=log10 和b=?10。
初始化時(shí),由于負(fù)樣本數(shù)量遠(yuǎn)多于正樣本,損失函數(shù)會(huì)被負(fù)樣本主導(dǎo),導(dǎo)致初始優(yōu)化步驟較大。為了緩解這一問(wèn)題,引入了一個(gè)額外的可學(xué)習(xí)偏置項(xiàng)b,以確保訓(xùn)練從接近先驗(yàn)的地方開始,避免需要巨大的矯正步驟。
高效“分塊”實(shí)現(xiàn)
在對(duì)比訓(xùn)練中,通常利用數(shù)據(jù)并行性來(lái)加速計(jì)算。當(dāng)數(shù)據(jù)被分割到D個(gè)設(shè)備上時(shí),計(jì)算損失函數(shù)需要收集所有嵌入向量,這通常涉及到昂貴的全聚集操作,并且需要在內(nèi)存中構(gòu)建一個(gè)大規(guī)模的成對(duì)相似性矩陣,這在內(nèi)存消耗上是非常密集的。
- Sigmoid損失的優(yōu)勢(shì):sigmoid損失函數(shù)不需要計(jì)算全局歸一化因子,這使得它特別適合于內(nèi)存高效、快速且數(shù)值穩(wěn)定的實(shí)現(xiàn)。這種方法可以減少數(shù)據(jù)并行處理中的挑戰(zhàn)。
- 分塊實(shí)現(xiàn)方法:作者提出了一種“分塊”方法,避免了全聚集操作,并且減少了內(nèi)存使用。具體來(lái)說(shuō),每個(gè)設(shè)備的批量大小表示為
,損失函數(shù)被重新表述為:
D為設(shè)備數(shù),是sigmoid損失函數(shù)
02、計(jì)算過(guò)程
計(jì)算過(guò)程簡(jiǎn)單來(lái)說(shuō),首先計(jì)算與正對(duì)對(duì)應(yīng)的損失分量,然后對(duì)設(shè)備間的表示進(jìn)行排列,使得每個(gè)設(shè)備從其鄰近設(shè)備中獲取負(fù)樣本。每個(gè)設(shè)備獨(dú)立地計(jì)算其本地批量b的損失,并將所有設(shè)備的損失簡(jiǎn)單地求和。
著重介紹下SigLIP是如何進(jìn)行分布式訓(xùn)練的,假設(shè)全局的batch size為B,一共有個(gè)GPU,那么每個(gè)GPU上的batch size為
,可以將Sigmoid 損失函數(shù)表示公式的損失拆解為上述公式所示,在Fig 1展示了整個(gè)過(guò)程的示意圖,在初始化階段,以第一個(gè)GPU為例子,其所包含的樣本為:
此時(shí)GPU 1可以完成一次上述公式中C的計(jì)算,然后,交換GPU 1和GPU 2的文本編碼器特征向量,即:
此時(shí)GPU 2完成一次上述公式中的B計(jì)算,以此類推,直到GPU 1遍歷完所有樣本為止,其他GPU也是如此操作的,最終把所有卡上的損失聚集即可,也就是A計(jì)算。這個(gè)輪流交換不同GPU之間數(shù)據(jù)的操作,可以稱之為排列。不難發(fā)現(xiàn),整個(gè)過(guò)程的通信成本來(lái)自于排列,一共需要D?1次聚集操作即可完成一次排列。
而在CLIP中需要對(duì)圖文的編碼器特征都進(jìn)行聚集,因此需要2次全聚集
操作。如果全聚集
采用ring方法的話,那么一個(gè)全聚集
就是D?1次聚集
操作。由此得出一個(gè)SigLIP和CLIP的性能復(fù)雜度對(duì)比:
這種方法顯著減少了內(nèi)存消耗,從減少到
,并且由于b通常是固定的,通過(guò)增加加速器的數(shù)量來(lái)擴(kuò)展|B|。這種方法使得在相對(duì)較少的設(shè)備上訓(xùn)練超過(guò)一百萬(wàn)個(gè)批量大小的模型成為可能。
03、實(shí)驗(yàn)結(jié)果
batch_size影響
實(shí)驗(yàn)結(jié)果顯示,在batch size小于32k的時(shí)候,采用sigmoid的SigLIP的性能都會(huì)優(yōu)于采用softmax的CLIP。在batch size足夠大的時(shí)候,CLIP能夠追上,甚至超越SigLIP的表現(xiàn),但是最佳性能仍然是SigLIP@32k情況下得到,從實(shí)際應(yīng)用來(lái)看,采用SigLIP能用更少的資源更快的訓(xùn)練速度得到更好的性能。從SigLiT的實(shí)驗(yàn)來(lái)看,隨著batch size的尺度放大性能將會(huì)在32k batch size的情況下達(dá)到飽和,同時(shí)能觀察到SigLiT在不同batch size下持續(xù)優(yōu)于LiT。繼續(xù)提高batch size將不能帶來(lái)更好的收益,甚至?xí)屑?xì)微的性能下降。
訓(xùn)練時(shí)間影響
另外,作者在SigLiT的設(shè)置下,訓(xùn)練了更長(zhǎng)時(shí)間(也即是見了更多數(shù)據(jù)量),結(jié)果顯示,在超大batch size,如262k的情況下,對(duì)比較小batch size(如8k)提供更長(zhǎng)的訓(xùn)練時(shí)間的確能觀察到性能的較大提升。并且也能觀察到在超大batch size下,采用sigmoid和采用softmax的區(qū)別很小,但是在較小batch size(如8k)的情況下,差距明顯。
因此,在資源受限的情況下,采用SigLIP更劃算,所需資源少而且性能更強(qiáng)。同時(shí),這個(gè)實(shí)驗(yàn)也說(shuō)明了,超大的batch size并不意味著訓(xùn)練得更快,反而還需要更長(zhǎng)的訓(xùn)練時(shí)間。
多語(yǔ)言數(shù)據(jù)集表現(xiàn)
在多語(yǔ)言數(shù)據(jù)集上,性能同樣在32k batch size上達(dá)到了飽和
有限資源場(chǎng)景
在僅有四塊 TPU-v4 芯片的情況下使用 SigLiT 模型的訓(xùn)練效果如下,在四個(gè)芯片上訓(xùn)練一天的結(jié)果顯示,模型在ImageNet上的零樣本分類準(zhǔn)確率達(dá)到了79.7%,這在資源有限的情況下非常有競(jìng)爭(zhēng)力。使用ViT-g/14模型作為視覺塔和大型文本塔,可以在兩天內(nèi)以20k的批量大小進(jìn)行107k步的訓(xùn)練,進(jìn)一步將零樣本分類準(zhǔn)確率提高到84.5%。
另外還探討了在資源有限的情況下,特別是只有少量TPUv4芯片時(shí),如何有效地訓(xùn)練SigLIP模型。結(jié)果顯示,使用16個(gè)TPUv4芯片,以16k的批量大小訓(xùn)練了2.4B個(gè)樣本,僅靠微調(diào)并不能立即提高模型的性能。這與之前的研究一致,即微調(diào)圖像模型會(huì)降低視覺表示的質(zhì)量。然而,通過(guò)凍結(jié)預(yù)訓(xùn)練權(quán)重的權(quán)重衰減,可以獲得更好的結(jié)果。
兩個(gè)微調(diào)經(jīng)驗(yàn)
微調(diào)圖像模型會(huì)降低視覺表示的質(zhì)量。然而,通過(guò)凍結(jié)預(yù)訓(xùn)練權(quán)重的weight-decay,可以獲得更好的結(jié)果
隨著批量大小的增加,訓(xùn)練過(guò)程變得越來(lái)越不穩(wěn)定。實(shí)驗(yàn)發(fā)現(xiàn)降低Adam和AdaFactor中的β2參數(shù)值(從默認(rèn)的0.999降低到0.95)可以穩(wěn)定訓(xùn)練過(guò)程。
sigmoid損失函數(shù)中正負(fù)樣本比例的影響
對(duì)于sigmoid來(lái)說(shuō),它的loss是以pair為粒度計(jì)算的,positive和negative非常不平衡。以batch_size=16k為例,positive和negative的比率約為1:16k
文中嘗試了4種掩碼策略來(lái)保留最難的負(fù)樣本(損失最高)或最易的負(fù)樣本(損失最低)
- 隨機(jī)(Random):隨機(jī)選擇負(fù)對(duì)進(jìn)行掩碼。
- 難(Hard):保留最難的負(fù)對(duì)。
- 易(Easy):保留最易的負(fù)對(duì)。
- 難+匹配總對(duì)數(shù)(Hard + matching total pairs seen):在固定數(shù)量的步驟中掩碼樣本會(huì)減少訓(xùn)練期間看到的總對(duì)數(shù)。因此,在匹配對(duì)的設(shè)置中,作者增加了掩碼比例的步數(shù),以保持看到的對(duì)數(shù)不變。
不做matched pair的情況下,用3種mask方式均會(huì)造成精度下降。影響程度:easy>random>hard。
結(jié)果顯示,隨機(jī)移除負(fù)樣本以重新平衡會(huì)惡化性能。保留最易的樣本根本不起作用,而保留最難的負(fù)樣本幾乎可以保持質(zhì)量。通過(guò)保留最難的負(fù)樣本,模型的性能得到了維持,這表明在負(fù)樣本中,較難的樣本對(duì)學(xué)習(xí)更為重要。
另外,還觀察了學(xué)習(xí)到的偏置項(xiàng)的最終值以及正負(fù)樣本的平均logit值,發(fā)現(xiàn)當(dāng)正負(fù)樣本的imbalance減弱時(shí),learnable bias和pair的logit都在上升,說(shuō)明了預(yù)設(shè)的learnable bias起到了積極的作用。
數(shù)據(jù)噪聲對(duì)模型魯棒性的影響
通過(guò)以下五種方法“污染”訓(xùn)練噪聲:
- Image:以概率p將圖文對(duì)的圖片用均勻噪聲替換;
- Text:以概率p將圖文對(duì)的文本token序列用隨機(jī)采樣的等長(zhǎng)token序列替換;
- Batch alignment: 隨機(jī)將batch中的p%的sample的圖文pair進(jìn)行shuffle;
- Image & text: 同時(shí)進(jìn)行1.和2.
- Image, text & batch: 同時(shí)進(jìn)行3和4
從結(jié)果可見,sigmoid loss在“污染”數(shù)據(jù)的性能更好。
04、總結(jié)
總的來(lái)說(shuō),SigLIP是一個(gè)很棒的工作,作者采用sigmoid損失去取代對(duì)比學(xué)習(xí)中的softmax函數(shù),以更小的資源開銷帶來(lái)了更好的模型表現(xiàn),目前也被很多多模態(tài)大模型所采用,作為視覺端的編碼器。