華人博士生首次嘗試用兩個Transformer構(gòu)建一個GAN
最近,CV 研究者對 transformer 產(chǎn)生了極大的興趣并取得了不少突破。這表明,transformer 有可能成為計算機(jī)視覺任務(wù)(如分類、檢測和分割)的強(qiáng)大通用模型。
我們都很好奇:在計算機(jī)視覺領(lǐng)域,transformer 還能走多遠(yuǎn)?對于更加困難的視覺任務(wù),比如生成對抗網(wǎng)絡(luò) (GAN),transformer 表現(xiàn)又如何?
在這種好奇心的驅(qū)使下,德州大學(xué)奧斯汀分校的 Yifan Jiang、Zhangyang Wang,IBM Research 的 Shiyu Chang 等研究者進(jìn)行了第一次試驗性研究,構(gòu)建了一個只使用純 transformer 架構(gòu)、完全沒有卷積的 GAN,并將其命名為 TransGAN。與其它基于 transformer 的視覺模型相比,僅使用 transformer 構(gòu)建 GAN 似乎更具挑戰(zhàn)性,這是因為與分類等任務(wù)相比,真實圖像生成的門檻更高,而且 GAN 訓(xùn)練本身具有較高的不穩(wěn)定性。

- 論文鏈接:https://arxiv.org/pdf/2102.07074.pdf
- 代碼鏈接:https://github.com/VITA-Group/TransGAN
從結(jié)構(gòu)上來看,TransGAN 包括兩個部分:一個是內(nèi)存友好的基于 transformer 的生成器,該生成器可以逐步提高特征分辨率,同時降低嵌入維數(shù);另一個是基于 transformer 的 patch 級判別器。
研究者還發(fā)現(xiàn),TransGAN 顯著受益于數(shù)據(jù)增強(qiáng)(超過標(biāo)準(zhǔn)的 GAN)、生成器的多任務(wù)協(xié)同訓(xùn)練策略和強(qiáng)調(diào)自然圖像鄰域平滑的局部初始化自注意力。這些發(fā)現(xiàn)表明,TransGAN 可以有效地擴(kuò)展至更大的模型和具有更高分辨率的圖像數(shù)據(jù)集。
實驗結(jié)果表明,與當(dāng)前基于卷積骨干的 SOTA GAN 相比,表現(xiàn)最佳的 TransGAN 實現(xiàn)了極具競爭力的性能。具體來說,TransGAN 在 STL-10 上的 IS 評分為 10.10,F(xiàn)ID 為 25.32,實現(xiàn)了新的 SOTA。
該研究表明,對于卷積骨干以及許多專用模塊的依賴可能不是 GAN 所必需的,純 transformer 有足夠的能力生成圖像。
在該論文的相關(guān)討論中,有讀者調(diào)侃道,「attention is really becoming『all you need』.」

不過,也有部分研究者表達(dá)了自己的擔(dān)憂:在 transformer 席卷整個社區(qū)的大背景下,勢單力薄的小實驗室要怎么活下去?

如果 transformer 真的成為社區(qū)「剛需」,如何提升這類架構(gòu)的計算效率將成為一個棘手的研究問題。
基于純 Transformer 的 GAN
作為基礎(chǔ)塊的 Transformer 編碼器
研究者選擇將 Transformer 編碼器(Vaswani 等人,2017)作為基礎(chǔ)塊,并盡量進(jìn)行最小程度的改變。編碼器由兩個部件組成,第一個部件由一個多頭自注意力模塊構(gòu)造而成,第二個部件是具有 GELU 非線性的前饋 MLP(multiple-layer perceptron,多層感知器)。此外,研究者在兩個部件之前均應(yīng)用了層歸一化(Ba 等人,2016)。兩個部件也都使用了殘差連接。
內(nèi)存友好的生成器
NLP 中的 Transformer 將每個詞作為輸入(Devlin 等人,2018)。但是,如果以類似的方法通過堆疊 Transformer 編碼器來逐像素地生成圖像,則低分辨率圖像(如 32×32)也可能導(dǎo)致長序列(1024)以及更高昂的自注意力開銷。
所以,為了避免過高的開銷,研究者受到了基于 CNN 的 GAN 中常見設(shè)計理念的啟發(fā),在多個階段迭代地提升分辨率(Denton 等人,2015;Karras 等人,2017)。他們的策略是逐步增加輸入序列,并降低嵌入維數(shù)。
如下圖 1 左所示,研究者提出了包含多個階段的內(nèi)存友好、基于 Transformer 的生成器:

每個階段堆疊了數(shù)個編碼器塊(默認(rèn)為 5、2 和 2)。通過分段式設(shè)計,研究者逐步增加特征圖分辨率,直到其達(dá)到目標(biāo)分辨率 H_T×W_T。具體來說,該生成器以隨機(jī)噪聲作為其輸入,并通過一個 MLP 將隨機(jī)噪聲傳遞給長度為 H×W×C 的向量。該向量又變形為分辨率為 H×W 的特征圖(默認(rèn) H=W=8),每個點都是 C 維嵌入。然后,該特征圖被視為長度為 64 的 C 維 token 序列,并與可學(xué)得的位置編碼相結(jié)合。
與 BERT(Devlin 等人,2018)類似,該研究提出的 Transformer 編碼器以嵌入 token 作為輸入,并遞歸地計算每個 token 之間的匹配。為了合成分辨率更高的圖像,研究者在每個階段之后插入了一個由 reshaping 和 pixelshuffle 模塊組成的上采樣模塊。
具體操作上,上采樣模塊首先將 1D 序列的 token 嵌入變形為 2D 特征圖

,然后采用 pixelshuffle 模塊對 2D 特征圖的分辨率進(jìn)行上采樣處理,并下采樣嵌入維數(shù),最終得到輸出

。然后,2D 特征圖 X’_0 再次變形為嵌入 token 的 1D 序列,其中 token 數(shù)為 4HW,嵌入維數(shù)為 C/4。所以,在每個階段,分辨率(H, W)提升到兩倍,同時嵌入維數(shù) C 減少至輸入的四分之一。這一權(quán)衡(trade-off)策略緩和了內(nèi)存和計算量需求的激增。
研究者在多個階段重復(fù)上述流程,直到分辨率達(dá)到(H_T , W_T )。然后,他們將嵌入維數(shù)投影到 3,并得到 RGB 圖像。

用于判別器的 tokenized 輸入
與那些需要準(zhǔn)確合成每個像素的生成器不同,該研究提出的判別器只需要分辨真假圖像即可。這使得研究者可以在語義上將輸入圖像 tokenize 為更粗糙的 patch level(Dosovitskiy 等人,2020)。
如上圖 1 右所示,判別器以圖像的 patch 作為輸入。研究者將輸入圖像

分解為 8 × 8 個 patch,其中每個 patch 可被視為一個「詞」。然后,8 × 8 個 patch 通過一個線性 flatten 層轉(zhuǎn)化為 token 嵌入的 1D 序列,其中 token 數(shù) N = 8 × 8 = 64,嵌入維數(shù)為 C。再之后,研究者在 1D 序列的開頭添加了可學(xué)得位置編碼和一個 [cls] token。在通過 Transformer 編碼器后,分類 head 只使用 [cls] token 來輸出真假預(yù)測。
實驗
CIFAR-10 上的結(jié)果
研究者在 CIFAR-10 數(shù)據(jù)集上對比了 TransGAN 和近來基于卷積的 GAN 的研究,結(jié)果如下表 5 所示:

如上表 5 所示,TransGAN 優(yōu)于 AutoGAN (Gong 等人,2019) ,在 IS 評分方面也優(yōu)于許多競爭者,如 SN-GAN (Miyato 等人, 2018)、improving MMDGAN (Wang 等人,2018a)、MGAN (Hoang 等人,2018)。TransGAN 僅次于 Progressive GAN 和 StyleGAN v2。
對比 FID 結(jié)果,研究發(fā)現(xiàn),TransGAN 甚至優(yōu)于 Progressive GAN,而略低于 StyleGANv2 (Karras 等人,2020b)。在 CIFAR-10 上生成的可視化示例如下圖 4 所示:

STL-10 上的結(jié)果
研究者將 TransGAN 應(yīng)用于另一個流行的 48×48 分辨率的基準(zhǔn) STL-10。為了適應(yīng)目標(biāo)分辨率,該研究將第一階段的輸入特征圖從(8×8)=64 增加到(12×12)=144,然后將提出的 TransGAN-XL 與自動搜索的 ConvNets 和手工制作的 ConvNets 進(jìn)行了比較,結(jié)果下表 6 所示:

與 CIFAR-10 上的結(jié)果不同,該研究發(fā)現(xiàn),TransGAN 優(yōu)于所有當(dāng)前的模型,并在 IS 和 FID 得分方面達(dá)到新的 SOTA 性能。
高分辨率生成
由于 TransGAN 在標(biāo)準(zhǔn)基準(zhǔn) CIFAR-10 和 STL-10 上取得不錯的性能,研究者將 TransGAN 用于更具挑戰(zhàn)性的數(shù)據(jù)集 CelebA 64 × 64,結(jié)果如下表 10 所示:

TransGAN-XL 的 FID 評分為 12.23,這表明 TransGAN-XL 可適用于高分辨率任務(wù)??梢暬Y(jié)果如圖 4 所示。
局限性
雖然 TransGAN 已經(jīng)取得了不錯的成績,但與最好的手工設(shè)計的 GAN 相比,它還有很大的改進(jìn)空間。在論文的最后,作者指出了以下幾個具體的改進(jìn)方向:
- 對 G 和 D 進(jìn)行更加復(fù)雜的 tokenize 操作,如利用一些語義分組 (Wu et al., 2020)。
- 使用代理任務(wù)(pretext task)預(yù)訓(xùn)練 Transformer,這樣可能會改進(jìn)該研究中現(xiàn)有的 MT-CT。
- 更加強(qiáng)大的注意力形式,如 (Zhu 等人,2020)。
- 更有效的自注意力形式 (Wang 等人,2020;Choromanski 等人,2020),這不僅有助于提升模型效率,還能節(jié)省內(nèi)存開銷,從而有助于生成分辨率更高的圖像。
作者簡介
本文一作 Yifan Jiang 是德州大學(xué)奧斯汀分校電子與計算機(jī)工程系的一年級博士生(此前在德克薩斯 A&M 大學(xué)學(xué)習(xí)過一年),本科畢業(yè)于華中科技大學(xué),研究興趣集中在計算機(jī)視覺、深度學(xué)習(xí)等方向。目前,Yifan Jiang 主要從事神經(jīng)架構(gòu)搜索、視頻理解和高級表征學(xué)習(xí)領(lǐng)域的研究,師從德州大學(xué)奧斯汀分校電子與計算機(jī)工程系助理教授 Zhangyang Wang。
在本科期間,Yifan Jiang 曾在字節(jié)跳動 AI Lab 實習(xí)。今年夏天,他將進(jìn)入 Google Research 實習(xí)。
一作主頁:https://yifanjiang.net/