填補(bǔ)領(lǐng)域空白!TerDiT:首次探索大規(guī)模DiT模型量化問(wèn)題
論文鏈接:https://arxiv.org/pdf/2405.14854
項(xiàng)目鏈接:https://github.com/Lucky-Lance/TerDiT
最近在大規(guī)模預(yù)訓(xùn)練的文本到圖像擴(kuò)散模型方面的發(fā)展顯著提高了高保真圖像的生成能力,特別是基于transformer架構(gòu)的擴(kuò)散模型(DiTs)的出現(xiàn)。在這些擴(kuò)散模型中,擴(kuò)散transformer展示了卓越的圖像生成能力,降低了FID分?jǐn)?shù)并提高了可擴(kuò)展性。然而,由于其龐大的參數(shù)數(shù)量,部署大規(guī)模的DiT模型可能會(huì)非常昂貴。盡管現(xiàn)有研究已經(jīng)探索了擴(kuò)散模型的高效部署技術(shù),如模型量化,但關(guān)于基于DiT模型的研究仍然很少。為了解決這一研究空白,本文提出了TerDiT,一種面向量化感知訓(xùn)練(QAT)和高效部署的基于
transformer的三值化擴(kuò)散模型方案。本文專注于DiT網(wǎng)絡(luò)的三值化化,并將模型規(guī)模從600M擴(kuò)展到4.2B。本文的工作為大規(guī)模DiT模型的高效部署策略探索做出了貢獻(xiàn),證明了從頭訓(xùn)練極低比特的擴(kuò)散transformer模型的可行性,同時(shí)在圖像生成能力上保持與全精度模型相當(dāng)?shù)母?jìng)爭(zhēng)力。
介紹
大規(guī)模預(yù)訓(xùn)練的文本到圖像擴(kuò)散模型的進(jìn)步已經(jīng)成功生成了復(fù)雜且高度保真于輸入條件的圖像。值得注意的是,基于transformer架構(gòu)的擴(kuò)散模型(DiTs)的出現(xiàn)代表了該研究領(lǐng)域的一個(gè)重要進(jìn)展。與其他擴(kuò)散模型相比,擴(kuò)散transformer展示了在更高計(jì)算量下實(shí)現(xiàn)更低FID分?jǐn)?shù)的能力。最新的研究突出了擴(kuò)散transformer架構(gòu)在圖像生成方面的顯著能力,例如在Stable Diffusion 3方法中展示的成果,以及在視頻生成方面的出色表現(xiàn),如Sora所展示的工作。
鑒于擴(kuò)散transformer模型的出色表現(xiàn),研究人員現(xiàn)在越來(lái)越多地深入研究這些視覺(jué)模型的擴(kuò)展規(guī)律,這與大語(yǔ)言模型相似。例如,Stable Diffusion 3提供了一系列參數(shù)規(guī)模的訓(xùn)練DiT模型,從8億到80億。此外,有研究人員推測(cè)Sora可能擁有大約30億個(gè)參數(shù)。由于這些模型擁有龐大的參數(shù)數(shù)量,部署這些DiT模型往往成本高昂,尤其是在某些終端設(shè)備上。
為了應(yīng)對(duì)部署難題,最近已經(jīng)有一些關(guān)于高效部署擴(kuò)散模型的研究工作,其中大多數(shù)集中在模型量化方面。然而,據(jù)本文所知,目前的研究仍然存在兩個(gè)主要缺陷。首先,雖然量化基于 U-Net 的擴(kuò)散模型已經(jīng)受到了很多關(guān)注,但對(duì)于基于transformer的擴(kuò)散模型的量化方法探索仍然非常有限。其次,目前文獻(xiàn)中的大多數(shù)主流方法主要依賴于后訓(xùn)練量化 (PTQ) 技術(shù)來(lái)進(jìn)行模型量化,這會(huì)導(dǎo)致不可接受的性能下降,特別是在極低比特寬度(例如2比特和1比特)下。然而,神經(jīng)網(wǎng)絡(luò)的極低比特量化非常重要,因?yàn)樗梢燥@著減少部署所需的計(jì)算資源,尤其對(duì)于具有巨大參數(shù)規(guī)模的模型。在本文的研究過(guò)程中,本文發(fā)現(xiàn)目前還沒(méi)有研究考慮 DiT 模型的極低比特量化。
為了解決這些缺陷,本文提出利用量化感知訓(xùn)練(QAT)技術(shù)對(duì)大規(guī)模 DiT 模型進(jìn)行極低比特量化。在大規(guī)模模型領(lǐng)域,低比特 QAT 方法已在大語(yǔ)言模型(LLM)領(lǐng)域進(jìn)行了討論。最近的研究表明,從頭開(kāi)始訓(xùn)練具有極低比特參數(shù)(例如二進(jìn)制和三進(jìn)制)的大語(yǔ)言模型,也可以達(dá)到與全精度模型相當(dāng)?shù)母?jìng)爭(zhēng)性能。這表明大規(guī)模模型中仍然存在顯著的精度冗余,并且暗示了 QAT 方案對(duì)于大規(guī)模 DiT 模型的可行性。
在本文中,本文主要關(guān)注三值權(quán)重網(wǎng)絡(luò),并提供了 TerDiT,這是本文所知的首個(gè)用于 DiT 的量化方案。本文的方法實(shí)現(xiàn)了三值擴(kuò)散transformer模型的量化感知訓(xùn)練(僅限權(quán)重)和高效部署。與 LLM 和 CNN 中線性層的簡(jiǎn)單量化不同,本文發(fā)現(xiàn)直接對(duì) DiT 模塊中的 adaLN 模塊進(jìn)行權(quán)重三值化,會(huì)導(dǎo)致歸一化層中的大尺寸尺度和偏移值(由于權(quán)重量化和梯度近似),這與全精度模型相比,導(dǎo)致收斂速度較慢和模型性能較差。因此,本文提出了一種 adaLN 的變體,通過(guò)在 adaLN 模塊的三值線性層之后應(yīng)用 RMS Norm,有效地緩解了這一訓(xùn)練問(wèn)題。
通過(guò)這種修改,本文將三值 DiT 模型的參數(shù)規(guī)模從 600M(DiT-XL/2的規(guī)模)擴(kuò)展到 4.2B(Large-DiT-4.2B的規(guī)模),發(fā)現(xiàn)具有更多參數(shù)的模型能夠收斂到更好的結(jié)果。本文進(jìn)一步采用現(xiàn)有的2-bit CUDA 內(nèi)核來(lái)部署訓(xùn)練后的三值 DiT 模型,使模型checkpoint 大小減少了十倍以上,推理內(nèi)存消耗減少了約六倍,同時(shí)實(shí)現(xiàn)了具有競(jìng)爭(zhēng)力的生成質(zhì)量。主要貢獻(xiàn)總結(jié)如下:
- 受低位 LLMs 量化感知訓(xùn)練方案的啟發(fā),本文研究了針對(duì)三值 DiT 模型的 QAT 方法,并引入了 DiT 特定的改進(jìn)以獲得更好的訓(xùn)練效果,這在 DiT 文獻(xiàn)中尚未被探索。
- 本文將三值DiT模型的參數(shù)規(guī)模從600M擴(kuò)展到4.2B,并基于現(xiàn)有的2-bit CUDA內(nèi)核在GPU上部署了訓(xùn)練后的三值DiT模型,使得4.2B DiT模型的推理內(nèi)存消耗小于3GB。
- 與全精度模型在ImageNet基準(zhǔn)測(cè)試(圖像生成)中的對(duì)比評(píng)估結(jié)果展示了本文提出的TerDiT方案的有效性。
本文的研究是首次嘗試探索DiT模型的量化問(wèn)題。本文專注于量化感知訓(xùn)練和大規(guī)模三值DiT模型的高效部署,為未來(lái)研究在極低比特精度下部署DiT模型提供了寶貴的見(jiàn)解。
相關(guān)工作
擴(kuò)散模型。 近年來(lái),擴(kuò)散模型因其生成高質(zhì)量圖像的能力和多種應(yīng)用潛力而受到了廣泛關(guān)注。擴(kuò)散模型的概念最早由提出,該研究提出了一種通過(guò)學(xué)習(xí)逆向擴(kuò)散過(guò)程的生成模型。這項(xiàng)工作為該領(lǐng)域的后續(xù)研究奠定了基礎(chǔ)。[1]進(jìn)一步擴(kuò)展了這一思想,提出了去噪擴(kuò)散概率模型(DDPMs),這類模型已成為圖像生成任務(wù)中的熱門(mén)選擇。DDPMs被應(yīng)用于廣泛的領(lǐng)域,包括無(wú)條件圖像生成、圖像修復(fù)和圖像超分辨率。此外,擴(kuò)散模型還被用于文本到圖像的合成,如DALL-E模型和Imagen模型所展示的那樣,這些模型展示了擴(kuò)散模型從文本描述中生成高度逼真和多樣化圖像的能力。進(jìn)一步地,擴(kuò)散模型還被擴(kuò)展到其他模態(tài),如音頻合成和視頻生成,展示了其在多模態(tài)應(yīng)用中的多樣性和潛力。
擴(kuò)散模型的量化。 近年來(lái),為了提高擴(kuò)散模型的效率,研究人員對(duì)擴(kuò)散模型的量化進(jìn)行了研究。后訓(xùn)練量化(PTQ)方法,如[9, 11, 13, 14, 15]中提出的方法,在量化時(shí)間和數(shù)據(jù)使用方面具有優(yōu)勢(shì)。然而,當(dāng)這些方法應(yīng)用于低比特設(shè)置時(shí),通常會(huì)導(dǎo)致性能不佳。為了解決這個(gè)問(wèn)題,[31]提出了將量化感知低秩適配器(QALoRA)與PTQ方法相結(jié)合,從而提高評(píng)估結(jié)果。作為PTQ的替代方案,量化感知訓(xùn)練(QAT)方法專門(mén)用于低比特?cái)U(kuò)散模型的量化。盡管這些QAT方法有效,但目前僅限于小規(guī)模的基于U-Net的擴(kuò)散模型,這揭示了在大規(guī)模DiT模型上應(yīng)用QAT的研究空白。進(jìn)一步探索適用于大規(guī)模DiT模型的極低比特寬度QAT技術(shù),可能會(huì)帶來(lái)更大的效率提升,并在資源受限的環(huán)境中有效部署擴(kuò)散模型。
三值權(quán)重網(wǎng)絡(luò)。 三值權(quán)重網(wǎng)絡(luò)作為一種內(nèi)存高效和計(jì)算高效的網(wǎng)絡(luò)結(jié)構(gòu),已經(jīng)引起了廣泛關(guān)注,其在推理內(nèi)存使用上的顯著減少潛力尤為突出。在專用硬件的支持下,三值權(quán)重網(wǎng)絡(luò)還可以提供顯著的計(jì)算加速。在量化方法中,三值權(quán)重網(wǎng)絡(luò)備受關(guān)注,主要有兩種方法:僅權(quán)重量化和權(quán)重-激活量化。在僅權(quán)重量化中,如[35]所述,僅對(duì)權(quán)重進(jìn)行三值量化。而權(quán)重-激活量化,如[36, 37]所述,則對(duì)權(quán)重和激活值同時(shí)進(jìn)行三值量化。近期研究表明,三值權(quán)重網(wǎng)絡(luò)在訓(xùn)練大語(yǔ)言模型方面具有可行性,并且其結(jié)果可與全精度模型相媲美。基于這些進(jìn)展,本文的工作首次引入了針對(duì)三值DiT模型的量化感知訓(xùn)練和高效部署方案。通過(guò)在DiT模型中利用三值量化的優(yōu)勢(shì),本文旨在推動(dòng)效率的極限,并使強(qiáng)大的擴(kuò)散模型在資源受限的環(huán)境中得以部署,從而為實(shí)際應(yīng)用開(kāi)辟新的可能性。
TerDiT
TerDiT,這是一個(gè)用于進(jìn)行僅權(quán)重量化感知訓(xùn)練和高效部署大規(guī)模三值DiT模型的框架。本文首先簡(jiǎn)要回顧擴(kuò)散transformer(DiT)模型。然后,基于之前開(kāi)源的Large-DiT,闡述了量化函數(shù)和量化感知訓(xùn)練方案,并進(jìn)行特定于QAT的模型結(jié)構(gòu)改進(jìn)以優(yōu)化網(wǎng)絡(luò)訓(xùn)練,并介紹了三值部署方案。
擴(kuò)散transformer模型
擴(kuò)散transformer(Diffusion Transformer)。擴(kuò)散transformer(DiT)是一種架構(gòu),它用操作潛在patches的transformer替代了擴(kuò)散模型中常用的U-Net骨干結(jié)構(gòu)。類似于下圖2(C)中展示的視覺(jué)transformer(ViT)架構(gòu),DiT首先將空間輸入劃分為一系列tokens,然后通過(guò)一系列transformer塊(下圖2(B))進(jìn)行去噪處理。為了處理額外的條件信息(例如噪聲時(shí)間步t、類別標(biāo)簽l、自然語(yǔ)言輸入),DiT利用自適應(yīng)歸一化模塊(adaLNZero)將這些額外的條件輸入插入到transformer塊中。在最后一個(gè)transformer塊之后,應(yīng)用標(biāo)準(zhǔn)線性解碼器來(lái)預(yù)測(cè)最終的噪聲和協(xié)方差。DiT模型的訓(xùn)練方式與基于U-Net的擴(kuò)散模型相同。
DiT中的AdaLN模塊。DiT與傳統(tǒng)ViT的主要區(qū)別在于需要注入條件信息以進(jìn)行圖像生成。DiT在每個(gè)transformer塊中使用零初始化的自適應(yīng)層歸一化(adaLN-Zero)模塊,如上圖2(B)紅色部分所示,該模塊根據(jù)輸入條件c計(jì)算維度級(jí)的縮放和偏移值。
AdaLN 是 DiT 模型中的一個(gè)重要組件,其效果已被證明優(yōu)于交叉注意力和上下文條件方法。在 DiT 架構(gòu)中,AdaLN 模塊集成了一個(gè)包含大量參數(shù)的 MLP 層,占模型總參數(shù)的約 10% 到 20%。在 TerDiT 的訓(xùn)練過(guò)程中,本文觀察到直接對(duì)該模塊進(jìn)行權(quán)重三值化會(huì)導(dǎo)致不理想的訓(xùn)練結(jié)果。
模型量化
如上文所示,理解DiT模型的擴(kuò)展規(guī)律越來(lái)越受到關(guān)注,這對(duì)于開(kāi)發(fā)和優(yōu)化大語(yǔ)言模型(LLM)至關(guān)重要。在最近的探索中,Large-DiT成功地將模型參數(shù)從600M擴(kuò)展到7B,結(jié)合了LLaMA和DiT的方法。結(jié)果表明,參數(shù)擴(kuò)展可以潛在地提升模型性能,并加快標(biāo)簽條件的ImageNet生成任務(wù)的收斂速度。受此啟發(fā),本文提出進(jìn)一步研究DiT模型的三值化,這可以緩解部署大規(guī)模DiT模型相關(guān)的挑戰(zhàn)。在本小節(jié)中,本文介紹量化函數(shù)和量化感知訓(xùn)練方案。
TerDiT 是一種僅對(duì)權(quán)重進(jìn)行量化的方案,本文不對(duì)激活值進(jìn)行量化。
量化感知訓(xùn)練方案。 基于上述設(shè)計(jì)的量化函數(shù),本文從頭開(kāi)始訓(xùn)練一個(gè) DiT 模型,利用直接傳遞估計(jì)器(STE),允許梯度通過(guò)不可微分的網(wǎng)絡(luò)組件傳播。在整個(gè)訓(xùn)練過(guò)程中,本文保留網(wǎng)絡(luò)的全精度參數(shù)。對(duì)于每一步訓(xùn)練,通過(guò)前向傳播中的三值量化函數(shù)從全精度參數(shù)計(jì)算出三值權(quán)重,并在反向傳播中將三值權(quán)重的梯度直接應(yīng)用于全精度參數(shù)進(jìn)行參數(shù)更新。
然而,本文發(fā)現(xiàn)收斂速度非常慢。即使經(jīng)過(guò)多次訓(xùn)練迭代,損失值也無(wú)法降低到合理范圍。本文認(rèn)為這個(gè)問(wèn)題可能源于三值線性層通常會(huì)導(dǎo)致較大的激活值,并提出在接下來(lái)的小節(jié)中通過(guò)針對(duì) QAT(量化感知訓(xùn)練)特定的模型結(jié)構(gòu)改進(jìn)來(lái)解決這個(gè)問(wèn)題。
QAT特定模型結(jié)構(gòu)改進(jìn)
通過(guò)對(duì)三值線性層的輸出應(yīng)用層歸一化,可以緩解由三值線性權(quán)重帶來(lái)的大激活值問(wèn)題。本文在三值線性層之后添加了一個(gè)RMS歸一化層(類似于LLaMA),并獲得了激活值分布(如下圖3左側(cè)所示)。在這種情況下,激活值在通過(guò)歸一化層后被縮放到一個(gè)合理范圍,從而導(dǎo)致更穩(wěn)定的訓(xùn)練行為。這一觀察結(jié)果也與[17]中的結(jié)論一致,其中在每個(gè)量化線性層的激活量化之前應(yīng)用了層歸一化函數(shù)。
RMS歸一化的AdaLN模塊。 基于上述見(jiàn)解,本文分析了DiT模型以改進(jìn)QAT特定的模型結(jié)構(gòu)。在標(biāo)準(zhǔn)的ViT Transformer塊中,層歸一化應(yīng)用于每個(gè)自注意力層和前饋層。DiT塊中的自注意力層和前饋層也是如此,這有助于適當(dāng)?shù)乜s放激活值范圍。然而,由于上文中介紹的AdaLN模塊的存在,DiT塊與傳統(tǒng)的Transformer塊有所不同。值得注意的是,這個(gè)模塊沒(méi)有應(yīng)用層歸一化。在全精度訓(xùn)練的情況下,缺乏層歸一化并不會(huì)產(chǎn)生顯著影響。然而,對(duì)于三值DiT網(wǎng)絡(luò)來(lái)說(shuō),其缺失可能會(huì)導(dǎo)致adaLN(歸一化)模塊中的維度尺度和偏移值過(guò)大,從而對(duì)模型訓(xùn)練產(chǎn)生不良影響。為了解決這個(gè)問(wèn)題,本文在每個(gè)三值DiT塊的AdaLN模塊的MLP層之后引入了RMS歸一化:
最終的TerDiT模型結(jié)構(gòu)如上圖2(A)所示。這個(gè)小的修改可以帶來(lái)更快的收斂速度和更低的訓(xùn)練損失,從而在定量和定性評(píng)估中取得更好的結(jié)果。為了更好地展示這一效果,在原文附錄中分析了模型訓(xùn)練后引入或不引入RMS Norm的實(shí)際激活分布。
部署方案
在訓(xùn)練了DiT模型之后,本文發(fā)現(xiàn)目前沒(méi)有有效的開(kāi)源三值網(wǎng)絡(luò)部署解決方案。在這種情況下,本文使用2位實(shí)現(xiàn)來(lái)部署訓(xùn)練好的網(wǎng)絡(luò)。具體來(lái)說(shuō),本文使用文獻(xiàn)[44]提供的??pack_2bit_u8()?
??函數(shù),將三值線性權(quán)重打包成??int8?
??值(4個(gè)三值數(shù)打包成一個(gè)??int8?
??數(shù))。在DiT模型的推斷過(guò)程中,本文即時(shí)調(diào)用相應(yīng)的??unpack_2bit_u8()?
?函數(shù),將打包的2位數(shù)字恢復(fù)為浮點(diǎn)數(shù)值,然后進(jìn)行后續(xù)計(jì)算。添加解包操作會(huì)減慢推斷過(guò)程,但本文相信,隨著對(duì)模型三值化研究的深入,將會(huì)有更多硬件支持來(lái)加速推斷過(guò)程。
實(shí)驗(yàn)
在本節(jié)中,本文進(jìn)行了一系列實(shí)驗(yàn)來(lái)評(píng)估本文提出的TerDiT。本文展示了主要的評(píng)估結(jié)果,進(jìn)行了部署效率比較,并說(shuō)明了RMS Normalized adaLN模塊的有效性。本文的DiT實(shí)現(xiàn)基于開(kāi)源代碼Large-DiT-ImageNet4。本文分別對(duì)具有600M(DiT-XL/2的大?。┖?.2B(Large-DiT-4.2B的大?。﹨?shù)的三值DiT模型進(jìn)行了實(shí)驗(yàn)。
主要評(píng)價(jià)結(jié)果
本文在本小節(jié)中提供了TerDiT模型的定量和定性評(píng)估結(jié)果。據(jù)本文所知,目前尚無(wú)關(guān)于擴(kuò)散transformer模型量化的已發(fā)表工作,因此本文主要在本小節(jié)中將其與具有代表性的全精度擴(kuò)散模型進(jìn)行比較。
關(guān)于TerDiT基線的備注。 據(jù)本文所知,目前仍沒(méi)有研究DiT模型量化的工作。除了在本小節(jié)中與全精度模型進(jìn)行比較外,本文還在其他小節(jié)中建立了一些基線進(jìn)行比較。對(duì)于QAT基線,本文直接訓(xùn)練了在Sec. 4.3中的adaLN模塊中沒(méi)有RMS Norm的三值DiT模型。為了與現(xiàn)有的PTQ方法進(jìn)行比較,本文對(duì)預(yù)訓(xùn)練模型進(jìn)行了4位權(quán)重量化,使用與TerDiT相同的一組參數(shù),結(jié)果發(fā)現(xiàn)它們無(wú)法生成可視的圖像。
實(shí)驗(yàn)設(shè)置。 按照原始DiT論文的評(píng)估設(shè)置,本文在ImageNet數(shù)據(jù)集上訓(xùn)練了600M和4.2B的三值DiT模型。由于計(jì)算資源的限制,本文在256×256分辨率下訓(xùn)練和評(píng)估模型,但本文認(rèn)為評(píng)估結(jié)果已經(jīng)具有很強(qiáng)的代表性。本文將TerDiT與一系列全精度擴(kuò)散模型進(jìn)行比較,并報(bào)告FID、sFID、Inception Score、Precision和Recall(50k生成圖像),參考[48]。本文還提供了訓(xùn)練階段的總圖像數(shù)量(百萬(wàn)),如[23]所示,以進(jìn)一步了解不同生成模型的收斂速度。
定量結(jié)果分析。 評(píng)估結(jié)果列在下表1中。TerDiT是針對(duì)DiT模型的QAT方案,因此在所有全精度模型中,本文特別關(guān)注DiT-XL/2(675M)和Large-DiT-4.2B。在沒(méi)有分類器自由指導(dǎo)的情況下,TerDiT-4.2B在測(cè)試結(jié)果上與DiT-XL/2非常相似(使用的訓(xùn)練圖像數(shù)量要少得多)。在有分類器自由指導(dǎo)(cfg=1.5)的情況下,TerDiT-4.2B-G的表現(xiàn)優(yōu)于LDM-G,同時(shí)與兩個(gè)全精度DiT結(jié)構(gòu)模型相比僅帶來(lái)了非常輕微的性能下降。此外,TerDiT-4.2B-G的評(píng)估結(jié)果優(yōu)于TerDiT-600M-G,這表明參數(shù)更多的模型在量化后會(huì)帶來(lái)更小的性能下降。
為了直觀地展示TerDiT的有效性,本文在下圖4中展示了一些定性比較結(jié)果,涉及TerDiT-4.2B、DiT-XL/2和Large-DiT4.2B。從視覺(jué)感知的角度來(lái)看,TerDiT生成的圖像與全精度模型生成的圖像之間沒(méi)有顯著差異。
部署效率對(duì)比
部署效率的提升是本文提出TerDiT方案的動(dòng)機(jī)。在本小節(jié)中,本文對(duì)TerDiT-600M/4.2B、DiT-XL/2和Large-DiT-4.2B進(jìn)行了比較,以討論TerDiT在實(shí)際部署中所能帶來(lái)的效率提升。下表2展示了四種DiT模型的checkpoint 大小。本文還記錄了在單個(gè)A100-80G GPU上,總擴(kuò)散采樣循環(huán)(步數(shù)=250)的內(nèi)存使用情況和推理時(shí)間。
從表格中可以看出,TerDiT大大減少了checkpoint 大小和內(nèi)存使用。4.2B三值化DiT模型的checkpoint 大小和內(nèi)存使用顯著小于Large-DiT-4.2B,甚至比DiT-XL/2還要小。這為在終端設(shè)備(如手機(jī))上部署模型帶來(lái)了顯著優(yōu)勢(shì)。盡管由于需要解包操作,本文觀察到推理速度較慢,但本文相信,隨著未來(lái)硬件支持的提升,三值化權(quán)重網(wǎng)絡(luò)的計(jì)算優(yōu)勢(shì)將會(huì)得到充分展示。
RMS歸一化AdaLN模塊的討論
TerDiT對(duì)DiT模型結(jié)構(gòu)的主要修改是在adaLN模塊中的MLP之后增加了RMS Norm。在這一部分,本文與基線三值化模型進(jìn)行比較,以展示RMS Norm對(duì)訓(xùn)練過(guò)程和訓(xùn)練結(jié)果的影響。
實(shí)驗(yàn)設(shè)置。 本文在ImageNet數(shù)據(jù)集上以256×256分辨率訓(xùn)練具有600M和4.2B參數(shù)的三值化DiT模型。對(duì)于每種參數(shù)規(guī)模,本文訓(xùn)練了兩個(gè)模型,一個(gè)在adaLN模塊中使用了RMS Norm,另一個(gè)則沒(méi)有(本文的基線模型)。本文記錄了訓(xùn)練過(guò)程中的損失曲線,并每100k訓(xùn)練步測(cè)量一次FID-50k分?jǐn)?shù)(cfg=1.5)。
結(jié)果分析: 訓(xùn)練損失和評(píng)估得分分別顯示在下圖5和下圖6中。如圖所示,使用RMS Normalized adaLN模塊進(jìn)行訓(xùn)練將導(dǎo)致更快的收斂速度和更低的FID分?jǐn)?shù)。另一個(gè)觀察結(jié)果是,參數(shù)更多的模型相比參數(shù)較少的模型能實(shí)現(xiàn)更快且更好的訓(xùn)練效果。這在一定程度上也反映了三值化DiT模型的擴(kuò)展規(guī)律。
討論和未來(lái)展望
本文在成功的大語(yǔ)言模型低比特訓(xùn)練方法的基礎(chǔ)上,提出了針對(duì)大規(guī)模三值化DiT模型的量化感知訓(xùn)練(QAT)和高效部署方法。在ImageNet數(shù)據(jù)集(256×256)上的競(jìng)爭(zhēng)性評(píng)估結(jié)果證明了從頭開(kāi)始訓(xùn)練大型三值化DiT模型的可行性,同時(shí)實(shí)現(xiàn)了與全精度模型相當(dāng)?shù)慕Y(jié)果。據(jù)本文所知,這是首個(gè)關(guān)于DiT模型量化的研究。
雖然本文認(rèn)為這項(xiàng)工作為DiT模型的低比特量化提供了有價(jià)值的見(jiàn)解,但它仍然存在一些局限性。首先,訓(xùn)練三值化DiT比全精度網(wǎng)絡(luò)更不穩(wěn)定且耗時(shí)。在本文中,盡管本文討論了通過(guò)添加歸一化方法來(lái)使訓(xùn)練更穩(wěn)定,但相較于訓(xùn)練全精度網(wǎng)絡(luò)(如Large-DiT-4.2B),它仍然更耗時(shí),這將在更廣泛的背景下導(dǎo)致模型訓(xùn)練期間二氧化碳排放量的增加。
其次,由于計(jì)算資源的限制,本文沒(méi)有進(jìn)行ImageNet 512×512實(shí)驗(yàn),也沒(méi)有進(jìn)行文本到圖像生成任務(wù)的實(shí)驗(yàn)。然而,本文相信ImageNet 256×256基準(zhǔn)上的評(píng)估結(jié)果已經(jīng)相當(dāng)具有代表性。剩余的任務(wù)將留待本文未來(lái)的工作中進(jìn)行。本文希望本文的工作可以減少圖像生成模型的部署需求,并能激勵(lì)社區(qū)在未來(lái)加入本文,共同促進(jìn)這一研究領(lǐng)域的更廣泛發(fā)展。
本文轉(zhuǎn)自AI生成未來(lái) ,作者:Xudong Lu等
