幾行代碼穩(wěn)定UNet ! 中山大學(xué)等提出ScaleLong擴(kuò)散模型:從質(zhì)疑Scaling到成為Scaling
在標(biāo)準(zhǔn)的UNet結(jié)構(gòu)中,long skip connection上的scaling系數(shù)一般為1。
然而,在一些著名的擴(kuò)散模型工作中,比如Imagen, Score-based generative model,以及SR3等等,它們都設(shè)置了,并發(fā)現(xiàn)這樣的設(shè)置可以有效加速擴(kuò)散模型的訓(xùn)練。
質(zhì)疑Scaling然而,Imagen等模型對skip connection的Scaling操作在原論文中并沒有具體的分析,只是說這樣設(shè)置有助于加速擴(kuò)散模型的訓(xùn)練。
首先,這種經(jīng)驗上的展示,讓我們并搞不清楚到底這種設(shè)置發(fā)揮了什么作用?
另外,我們也不清楚是否只能設(shè)置,還是說可以使用其他的常數(shù)?
不同位置的skip connection的「地位」一樣嗎,為什么使用一樣的常數(shù)?
對此,作者有非常多的問號……
圖片
理解Scaling
一般來說,和ResNet以及Transformer結(jié)構(gòu)相比,UNet在實際使用中「深度」并不深,不太容易出現(xiàn)其他「深」神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)常見的梯度消失等優(yōu)化問題。
另外,由于UNet結(jié)構(gòu)的特殊性,淺層的特征通過long skip connection與深層的位置相連接,從而進(jìn)一步避免了梯度消失等問題。
那么反過來想,這樣的結(jié)構(gòu)如果稍不注意,會不會導(dǎo)致梯度過猛、參數(shù)(特征)由于更新導(dǎo)致震蕩的問題?
圖片
通過對擴(kuò)散模型任務(wù)在訓(xùn)練過程中特征和參數(shù)的可視化,可以發(fā)現(xiàn),確實存在不穩(wěn)定現(xiàn)象。
參數(shù)(特征)的不穩(wěn)定,影響了梯度,接著又反過來影響參數(shù)更新。最終這個過程對性能有較大的不良干擾的風(fēng)險。因此需要想辦法去控制這種不穩(wěn)定性。
進(jìn)一步的,對于擴(kuò)散模型。UNet的輸入是一個帶噪圖像,如果要求模型能從中準(zhǔn)確預(yù)測出加入的噪聲,這需要模型對輸入有很強(qiáng)的抵御額外擾動的魯棒性。
論文:https://arxiv.org/abs/2310.13545
代碼:https://github.com/sail-sg/ScaleLong
研究人員發(fā)現(xiàn)上述這些問題,可以在Long skip connection上進(jìn)行Scaling來進(jìn)行統(tǒng)一地緩解。
從定理3.1來看,中間層特征的震蕩范圍(上下界的寬度)正相關(guān)于scaling系數(shù)的平方和。適當(dāng)?shù)膕caling系數(shù)有助于緩解特征不穩(wěn)定。
不過需要注意的是,如果直接讓scaling系數(shù)設(shè)置為0,確實最佳地緩解了震蕩。(手動狗頭)
但是UNet退化為無skip的情況的話,不穩(wěn)定問題是解決了,但是表征能力也沒了。這是模型穩(wěn)定性和表征能力的trade-off。
圖片
類似地,從參數(shù)梯度的角度。定理3.3也揭示了scaling系數(shù)對梯度量級的控制。
圖片
進(jìn)一步地,定理3.4還揭示了long skip connection上的scaling還可以影響模型對輸入擾動的魯棒上界,提升擴(kuò)散模型對輸入擾動的穩(wěn)定性。
成為Scaling
通過上述的分析,我們清楚了Long skip connection上進(jìn)行scaling對穩(wěn)定模型訓(xùn)練的重要性,也適用于上述的分析。
接下來,我們將分析怎么樣的scaling可以有更好的性能,畢竟上述分析只能說明scaling有好處,但不能確定怎么樣的scaling最好或者較好。
一種簡單的方式是為long skip connection引入可學(xué)習(xí)的模塊來自適應(yīng)地調(diào)整scaling,這種方法稱為Learnable Scaling (LS) Method。我們采用類似SENet的結(jié)構(gòu),即如下所示(此處考慮的是代碼整理得非常好的U-ViT結(jié)構(gòu),贊?。?/span>
圖片
從本文的結(jié)果來看,LS確實可以有效地穩(wěn)定擴(kuò)散模型的訓(xùn)練!進(jìn)一步地,我們嘗試可視化LS中學(xué)習(xí)到的系數(shù)。
如下圖所示,我們會發(fā)現(xiàn)這些系數(shù)呈現(xiàn)出一種指數(shù)下降的趨勢(注意這里第一個long skip connection是指連接UNet首尾兩端的connection),且第一個系數(shù)幾乎接近于1,這個現(xiàn)象也很amazing!
圖片
基于這一系列觀察(更多的細(xì)節(jié)請查閱論文),我們進(jìn)一步提出了Constant Scaling (CS) Method,即無需可學(xué)習(xí)參數(shù)的:
CS策略和最初的使用的scaling操作一樣無需額外參數(shù),從而幾乎沒有太多的額外計算消耗。
雖然CS在大多數(shù)時候沒有LS在穩(wěn)定訓(xùn)練上表現(xiàn)好,不過對于已有的策略來說,還是值得一試。
上述CS和LS的實現(xiàn)均非常簡潔,僅僅需要若干行代碼即可。針對各(hua)式(li)各(hu)樣(shao)的UNet結(jié)構(gòu)可能需要對齊一下特征維度。(手動狗頭+1)
最近,一些后續(xù)工作,比如FreeU、SCEdit等工作也揭示了skip connection上scaling的重要性,歡迎大家試用和推廣。