自回歸解碼加速64倍,谷歌提出圖像合成新模型MaskGIT
生成式 transformer 在合成高保真和高分辨率圖像方面得到了快速普及。但迄今為止最好的生成式 transformer 模型仍是將圖像視為一系列 token,并按照光柵掃描順序(即逐行)解碼圖像。然而這種策略既不是最優(yōu)的,也不高效。
近日,來自谷歌研究院的研究者提出了一種使用雙向 transformer 解碼器的新型圖像合成模型 MaskGIT。在訓(xùn)練期間,MaskGIT 通過關(guān)注各個(gè)方向的 token 來學(xué)習(xí)預(yù)測隨機(jī)掩碼 token。在推理階段,模型首先同時(shí)生成圖像的所有 token,然后以上一次生成為條件迭代地細(xì)化圖像。實(shí)驗(yàn)表明,MaskGIT 在 ImageNet 數(shù)據(jù)集上顯著優(yōu)于 SOTA transformer 模型,并將自回歸解碼的速度提高了 64 倍。
論文地址:https://arxiv.org/abs/2202.04200
此外,該研究還表明 MaskGIT 可以輕松擴(kuò)展到各種圖像編輯任務(wù),例如修復(fù)、外推和圖像處理。
相關(guān)研究
先前的模型 VQVAE 提出分兩個(gè)階段在潛在空間中生成圖像。
第一個(gè)階段稱為 tokenization,其中嘗試將圖像壓縮到離散的潛在空間中,這一階段主要包含三個(gè)部分:
一個(gè)編碼器 E ,負(fù)責(zé)學(xué)習(xí)將圖像 x∈ tokenize 成潛在嵌入 E(x);一個(gè)用于最近鄰查找 codebook
,以將嵌入量化為視覺 token;一個(gè)解碼器 G,它根據(jù)視覺 token e 預(yù)測重建圖像
。
第二個(gè)階段首先使用深度自回歸模型預(yù)測視覺 token 的潛在先驗(yàn),然后使用第一階段的解碼器將 token 序列映射到圖像像素中。
這種兩階段范式是很有效的,因此幾種常用的方法都遵循了這種范式,例如 DALL-E、VQGAN。其中,VQGAN 在第一階段增加了對抗性損失和感知損失以提高圖像保真度。
MaskGIT
上述使用兩階段范式的方法由于仍然采用自回歸模型,因此第二階段的解碼時(shí)間與 token 序列長度成比例。而本研究的目標(biāo)是設(shè)計(jì)一種利用并行解碼和雙向生成的新圖像合成范式,遵循上述兩階段方案并改進(jìn)第二階段。第一階段采用與 VQGAN 模型相同的設(shè)置,并將潛在的改進(jìn)留給未來工作的 tokenization 步驟;對于第二階段,研究者提出通過掩碼視覺 token 建模(Masked Visual Token Modeling,MVTM 學(xué)習(xí)雙向 transformer。
訓(xùn)練中的 MVTM
該研究用表示將圖像輸入到 VQ 編碼器獲得的潛在 token,其中 N 是重構(gòu)后的 token 矩陣的長度,
是對應(yīng)的二進(jìn)制掩碼。在訓(xùn)練期間,該研究采樣 token 的子集,并用一個(gè)特殊的 [MASK] token 替代它們。如果 m_i=1,就用 [MASK] 取代 token y_i;如果 m_i=0,y_i 保留。
采樣過程由掩碼調(diào)度函數(shù)(mask scheduling function) 進(jìn)行參數(shù)化,然后按照如下步驟:
首先從 0 到 1 采樣一個(gè)比率,然后在 Y 中統(tǒng)一選擇 個(gè) token 來放置掩碼,其中 N 是長度。掩碼調(diào)度顯著影響了圖像的生成質(zhì)量。
迭代解碼
在自回歸解碼中,token 是根據(jù)先前生成的輸出順序生成的。這個(gè)過程是不可并行的,而圖像的 token 長度通常比語言長得多,因此速度非常慢。該研究提出了一種新型解碼方法,其中圖像中的所有 token 都是同時(shí)并行生成的,這基于 MTVM 的雙向自注意力。
理論上講,該模型能夠推斷出所有 token 并在單次傳遞中生成整個(gè)圖像,但訓(xùn)練任務(wù)的不一致給該研究帶來了挑戰(zhàn)。為了在推理時(shí)生成圖像,該研究從一個(gè)空白 canvas 開始,所有 token 都被掩碼,即。該研究提出的迭代解碼方法,每次迭代的算法運(yùn)行步驟如下:
1. 預(yù)測2. 采樣3. 掩碼調(diào)度4. 掩碼
掩碼設(shè)計(jì)
研究者發(fā)現(xiàn)圖像的生成質(zhì)量受到掩碼設(shè)計(jì)的顯著影響。該方法通過一個(gè)掩碼調(diào)度函數(shù)對掩碼過程進(jìn)行建模,該函數(shù)負(fù)責(zé)計(jì)算給定潛在 token 的掩碼比率。在推理期間,函數(shù)
用
的輸入代表解碼的進(jìn)度;在訓(xùn)練期間,該研究在 [0,1) 中隨機(jī)采樣一個(gè)比率 r 來模擬各種解碼場景。
實(shí)驗(yàn)
該研究從質(zhì)量、效率和靈活性方面對 MaskGIT 在圖像生成方面進(jìn)行了實(shí)驗(yàn)評估。
類條件圖像合成
該研究在 ImageNet 256 X 256 和 ImageNet 512 X 512 上評估了 MaskGIT 模型在類條件(class-conditional)圖像合成任務(wù)上的性能,主要結(jié)果如下表 1 所示。
質(zhì)量。在 ImageNet 256 X 256 上,不使用任何特殊的采樣策略,MaskGIT 在 FID 和 IS 方面都顯著優(yōu)于 VQGAN。
速度。該研究通過評估每個(gè)模型生成樣本所需的步驟數(shù)(前向傳遞)來評估模型速度。如表 1 所示,在所有基于非 GAN 的模型中,MaskGIT 在兩種分辨率上所需的步驟最少。
為了進(jìn)一步證實(shí) MaskGIT 和自回歸模型之間的速度差異,該研究對 MaskGIT 和 VQGAN 的解碼過程進(jìn)行了運(yùn)行時(shí)比較。如下圖 4 所示,MaskGIT 將 VQGAN 顯著加速了 30-64 倍,隨著圖像分辨率(以及輸入 token 長度)的增加,加速變得更加明顯。
多樣性。除了樣本質(zhì)量外,該研究還將分類準(zhǔn)確率得分 (CAS) 和 Precision/Recall 作為評估樣本多樣性的兩個(gè)指標(biāo)。與 BigGAN 的樣本相比,MaskGIT 的樣本更加多樣化,具有更多種光照、姿態(tài)、規(guī)模和語境,如下圖 5 所示。
圖像編輯應(yīng)用
該研究展示了 MaskGIT 在三個(gè)圖像編輯任務(wù)上的直接應(yīng)用:類條件圖像編輯、圖像修復(fù)和圖像擴(kuò)展(outpainting)。如果將任務(wù)看作對初始二進(jìn)制掩碼 M MaskGIT 在其迭代解碼中使用約束,那么這三個(gè)任務(wù)幾乎都可以輕松地轉(zhuǎn)換為 MaskGIT 可以處理的任務(wù)。
該研究表明,無需修改架構(gòu)或任何特定于任務(wù)的訓(xùn)練,MaskGIT 就能夠在所有三個(gè)應(yīng)用程序上產(chǎn)生非常優(yōu)秀的結(jié)果。此外,MaskGIT 在圖像修復(fù)和擴(kuò)展方面獲得了與專用模型相當(dāng)?shù)男阅堋?/p>
在類條件圖像編輯任務(wù)上,該研究定義了一個(gè)新的類條件圖像編輯任務(wù)來展示 MaskGIT 的靈活性。模型在給定類的邊界框內(nèi)重新生成特定內(nèi)容,同時(shí)保留語境,即框外的內(nèi)容。由于違背了預(yù)測順序,因此自回歸方法是不可行的。
然而,對于 MaskGIT,如果將邊界框區(qū)域視為迭代解碼算法的初始掩碼的輸入,這個(gè)問題就迎刃而解了。下圖 6 給出了一些示例結(jié)果。
表 2 比較了幾種方法的定量結(jié)果。MaskGIT 在 FID 和 IS 中均以顯著優(yōu)勢擊敗 DeepFill 和 HiFill,同時(shí)獲得接近 SOTA 修復(fù)方法 CoModGAN 的分?jǐn)?shù)。
如下圖 7 所示,MaskGIT 還能夠在給定相同輸入和不同種子的情況下合成不同的結(jié)果。
消融實(shí)驗(yàn)
為了驗(yàn)證新設(shè)計(jì)的效用,該研究在 ImageNet 256×256 的默認(rèn)設(shè)置上進(jìn)行了消融實(shí)驗(yàn)。MaskGIT 的一個(gè)關(guān)鍵設(shè)計(jì)是用于訓(xùn)練和迭代解碼的掩碼調(diào)度函數(shù),實(shí)驗(yàn)結(jié)果如下表 3 和圖 8 所示。
值得注意的是,如圖 8 所示,在相同的設(shè)置下,更多的迭代不一定更好:隨著迭代次數(shù) T 的增加,除了對數(shù)函數(shù)在整個(gè)過程中都表現(xiàn)不佳以外,其他所有函數(shù)都達(dá)到了一個(gè)「sweet spot」位置,即模型的性能在再次惡化之前達(dá)到峰值。