CV開啟大模型時(shí)代!谷歌發(fā)布史上最大ViT:220億參數(shù),視覺感知力直逼人類
Transformer無(wú)疑是促進(jìn)自然語(yǔ)言處理領(lǐng)域繁榮的最大功臣,也是GPT-4等大規(guī)模語(yǔ)言模型的基礎(chǔ)架構(gòu)。
不過相比語(yǔ)言模型動(dòng)輒成千上萬(wàn)億的參數(shù)量,計(jì)算機(jī)視覺領(lǐng)域吃到Transformer的紅利就沒那么多了,目前最大的視覺Transformer模型ViT-e的參數(shù)量還只有40億參數(shù)。
最近谷歌發(fā)布了一篇論文,研究人員提出了一種能夠高效且穩(wěn)定訓(xùn)練大規(guī)模Vision Transformers(ViT)模型的方法,成功將ViT的參數(shù)量提升到220億。
論文鏈接:https://arxiv.org/abs/2302.05442
為了實(shí)現(xiàn)模型的擴(kuò)展,ViT-22B結(jié)合了其他語(yǔ)言模型(如PaLM模型)的思路,使用 QK 歸一化改進(jìn)了訓(xùn)練穩(wěn)定性,提出了一種異步并行線性操作(asynchronous parallel linear operations)的新方法提升訓(xùn)練效率,并且能夠在硬件效率更高的Cloud TPU上進(jìn)行訓(xùn)練。
在對(duì)ViT-22B模型進(jìn)行實(shí)驗(yàn)以評(píng)估下游任務(wù)性能時(shí),ViT-22B也表現(xiàn)出類似大規(guī)模語(yǔ)言模型的能力,即隨著模型規(guī)模的擴(kuò)大,性能也在不斷提升。
ViT-22B 還可以應(yīng)用于PaLM-e中,與語(yǔ)言模型結(jié)合后的大模型可以顯著提升機(jī)器人任務(wù)的技術(shù)水平。
研究人員還進(jìn)一步觀察到規(guī)模帶來(lái)的其他優(yōu)勢(shì),包括更好地平衡公平性和性能,在形狀/紋理偏見方面與人類視覺感知的一致性,以及更好的穩(wěn)健性。
模型架構(gòu)
ViT-22B 是一個(gè)基于Transformer架構(gòu)的模型,和原版ViT架構(gòu)相比,研究人員主要做了三處修改以提升訓(xùn)練效率和訓(xùn)練穩(wěn)定性。
并行層(parallel layers)
ViT-22B并行執(zhí)行注意力塊和MLP塊,而在原版Transformer中為順序執(zhí)行。
PaLM模型的訓(xùn)練也采用了這種方法,可以將大模型的訓(xùn)練速度提高15%,并且性能沒有下降。
query/key (QK) normalization
在擴(kuò)展ViT的過程中,研究人員在80億參數(shù)量的模型中觀察到,在訓(xùn)練幾千步之后訓(xùn)練損失開始發(fā)散(divergence),主要是由于注意力logits的數(shù)值過大引起的不穩(wěn)定性,導(dǎo)致零熵的注意力權(quán)重(幾乎one-hot)。
為了解決這個(gè)問題,研究人員在點(diǎn)乘注意力計(jì)算之前對(duì)Query和Key使用LayerNorm
在80億參數(shù)模型上的實(shí)驗(yàn)結(jié)果如下圖所示,歸一化可以緩解發(fā)散問題。
刪除QKV投影和LayerNorms上的偏置項(xiàng)
和PaLM模型一樣,ViT-22B從QKV投影中刪除了偏置項(xiàng),并且在所有LayerNorms中都沒有偏置項(xiàng)(bias)和centering,使得硬件利用率提高了3%,并且質(zhì)量沒有下降。
不過與PaLM不同的是,ViT-22B對(duì)(內(nèi)部和外部)MLP稠密連接層使用了偏置項(xiàng),可以觀察到質(zhì)量得到了改善,并且速度也沒有下降。
ViT-22B的編碼器模塊中,嵌入層,包括抽取patches、線性投影和額外的位置嵌入都與原始ViT中使用的相同,并且使用多頭注意力pooling來(lái)聚合每個(gè)頭中的per-token表征。
ViT-22B的patch尺寸為14×14,圖像的分辨率為224×224(通過inception crop和隨機(jī)水平翻轉(zhuǎn)進(jìn)行預(yù)處理)。
異步并聯(lián)線性運(yùn)算(asynchronous parallel linear operations)
大規(guī)模的模型還需要分片(sharding),即將模型參數(shù)分布在不同的計(jì)算設(shè)備中,除此之外,研究人員還把激活(acctivations,輸入的中間表征)也進(jìn)行分片。
因?yàn)檩斎牒途仃嚤旧矶际欠植荚诟鞣N設(shè)備上的,即使是像矩陣乘法這樣簡(jiǎn)單的操作也需要特別小心。
研究人員開發(fā)了一種稱為異步并行線性運(yùn)算的方法,可以在矩陣乘法單元(在TPU 中占據(jù)絕大多數(shù)計(jì)算能力的單元)中計(jì)算時(shí),同時(shí)對(duì)設(shè)備之間的激活和權(quán)值進(jìn)行通信。
異步方法最小化了等待傳入通信的時(shí)間,從而提高了設(shè)備效率。
異步并行線性運(yùn)算的目標(biāo)是計(jì)算矩陣乘法 y = Ax,但矩陣 A 和激活 x 都分布在不同的設(shè)備上,需要通過跨設(shè)備的重疊通信和計(jì)算來(lái)實(shí)現(xiàn)這一點(diǎn)。矩陣 A 在設(shè)備之間進(jìn)行列分片(column-shard),每個(gè)矩陣包含一個(gè)連續(xù)的切片,每個(gè)塊表示為 Aij,更多細(xì)節(jié)請(qǐng)看原始論文。
實(shí)驗(yàn)結(jié)果
為了說(shuō)明ViT-22B學(xué)習(xí)到的表征非常豐富,研究人員使用LiT-tuning訓(xùn)練一個(gè)文本模型來(lái)生成一些表征用來(lái)對(duì)齊文本和圖像。
下面是用Parti 和 Imagen 生成的分布外(out-of-distribution)圖像得到的實(shí)驗(yàn)結(jié)果,可以看到ViT-22B的zero-shot圖像分類泛化能力非常強(qiáng),僅從web上爬取的自然圖像就能識(shí)別出沒見過的物體和場(chǎng)景。
論文中還討論了ViT-22B在視頻分類、深度估計(jì)和語(yǔ)義分割任務(wù)上的效果。
與人類目標(biāo)識(shí)別對(duì)齊
為了驗(yàn)證 ViT-22B 分類決策與人類分類決策的一致性,研究人員對(duì) ViT-22B 進(jìn)行了微調(diào),對(duì)分布外(OOD)數(shù)據(jù)集的不同分辨率進(jìn)行了微調(diào),其中人類比較數(shù)據(jù)可通過model-vs-human toolbox獲得。
該工具箱主要衡量三個(gè)關(guān)鍵指標(biāo): 模型如何處理失真(準(zhǔn)確性) ?人和模型的精度(精度差)有什么不同?人和模型的錯(cuò)誤模式(錯(cuò)誤一致性)有多相似?
形狀偏差評(píng)估(值越大代表更多的形狀偏差)。許多視覺模型具有低形狀/高紋理偏差,而在 ImageNet 上進(jìn)行微調(diào)的 ViT-22B具有迄今為止在 ML 模型中記錄的最高形狀偏差,更接近于人類形狀偏見
實(shí)驗(yàn)結(jié)果顯示,雖然并非所有的微調(diào)解決方案都表現(xiàn)得很好,但 ViT-22B 變體在所有三個(gè)指標(biāo)上都達(dá)到了新高。
此外,ViT-22B 模型在視覺模型中也有最高的形狀偏差記錄。這意味著他們主要使用目標(biāo)的形狀,而不是目標(biāo)的紋理來(lái)進(jìn)行分類決策,策略結(jié)果類似于人類的感知(其形狀偏差為96%)。
標(biāo)準(zhǔn)模型(例如,ResNet-50有20-30% 的形狀偏差)通常根據(jù)紋理來(lái)分類,而高形狀偏差的模型則傾向于關(guān)注形狀(下圖識(shí)別為貓),盡管人類和模型的感知之間仍然存在許多差異,但是 ViT-22B 顯示出與人類視覺對(duì)象識(shí)別更多的相似性。
貓還是大象?車還是鐘?鳥還是自行車?具有某個(gè)物體的形狀和另一個(gè)不同物體紋理的圖像,可用于測(cè)量形狀/紋理偏差
分布外(out-of-distribution)性能
測(cè)量 OOD 數(shù)據(jù)集的性能有助于評(píng)估模型泛化性。
在這個(gè)實(shí)驗(yàn)中,研究人員構(gòu)建了從 JFT 到 ImageNet 的標(biāo)簽映射,以及從 ImageNet 到不同的分布外數(shù)據(jù)集(如 ObjectNet)的標(biāo)簽映射。
對(duì)這些數(shù)據(jù)進(jìn)行預(yù)訓(xùn)練后的結(jié)果如下圖所示,然后在 ImageNet 上對(duì)模型進(jìn)行完全微調(diào)。
可以觀察到縮放 Vision Transformers 可以提高 OOD 性能: 即使 ImageNet 的精度達(dá)到飽和,也可以看到 ObjectNet 上從 ViT-e 換成 ViT-22B 模型可以顯著提升性能。
線性探測(cè)Linear Probe
線性探測(cè)是一種將單個(gè)線性層置于凍結(jié)模型之上的技術(shù),與完全微調(diào)相比,這種方法的訓(xùn)練成本更低,設(shè)置起來(lái)也更容易。
在 ImageNet 上訓(xùn)練的線性探測(cè)結(jié)果,在 ImageNet-Real,ImageNet-v2,ObjectNet,ImageNet-R 和 ImageNet-A 數(shù)據(jù)集上評(píng)估,提供高分辨率微調(diào) ViT-e/14作為參考
從結(jié)果中可以觀察到,ViT-22B 的線性探測(cè)性能接近于使用高分辨率圖像對(duì)較小模型進(jìn)行全面微調(diào)的最先進(jìn)水平,其中具有較高分辨率的訓(xùn)練通常要昂貴得多,但可以在許多任務(wù)上取得更好的結(jié)果。
蒸餾
利用蒸餾法,可以將較大模型的知識(shí)轉(zhuǎn)化為較小模型的知識(shí),可以提升成本更高、運(yùn)行速度更慢的大模型的運(yùn)行效率。
從實(shí)驗(yàn)結(jié)果中可以發(fā)現(xiàn),ViT-22B 的知識(shí)可以遷移到更小的模型,如 ViT-B/16和 ViT-L/16,并在同等模型尺寸下在ImageNet上刷新了性能記錄。
公平性與偏見
機(jī)器學(xué)習(xí)模型容易受到意想不到的不公平偏見的影響,例如找到錯(cuò)誤的相關(guān)性或者在各個(gè)子群體之間存在性能差距,研究人員發(fā)現(xiàn),擴(kuò)大模型規(guī)模有助于緩解這些問題。
首先,規(guī)模是一個(gè)有前景的權(quán)衡方式,即使模型經(jīng)過訓(xùn)練后再進(jìn)行后處理,將其人口平等(demographic parity)水平控制在規(guī)定的、可容忍的水平之下,性能也會(huì)隨著規(guī)模的增加而提高。
上圖: 去偏前 CelebA 中每個(gè)子組的精度。下圖: y 軸顯示了在這個(gè)例子中突出顯示的兩個(gè)特定亞組(女性和男性)的表現(xiàn)的絕對(duì)差異。與較小的 ViT 模型相比,ViT-22B 在性能的差距很小。
更重要的是,這不僅適用于以準(zhǔn)確性衡量性能的情況,而且適用于其他度量,例如校準(zhǔn),即對(duì)模型估計(jì)概率的真實(shí)性的統(tǒng)計(jì)測(cè)量,所有子群的分類隨著規(guī)模的增大而趨于改善,并且ViT-22B 降低了各子群之間的性能差距。
結(jié)論
研究人員提出了一個(gè)目前最大的視覺Transformer模型 ViT-22B,包含220億參數(shù)。
通過對(duì)原始模型架構(gòu)進(jìn)行微小但關(guān)鍵的修改后,實(shí)現(xiàn)了更高的硬件利用率和訓(xùn)練穩(wěn)定性,從而得到了一個(gè)在幾個(gè)基準(zhǔn)測(cè)試上提高了模型的上限性能。
使用凍結(jié)模型生成嵌入,只需要在頂部訓(xùn)練幾層,即可獲得很好的性能,并且評(píng)估結(jié)果進(jìn)一步表明,與現(xiàn)有模型相比,ViT-22B 在形狀和紋理偏差方面顯示出與人類視知覺更多的相似性,并且在公平性和穩(wěn)健性方面提供了優(yōu)勢(shì)。