圖像領(lǐng)域再次與LLM一拍即合!idea撞車OpenAI強(qiáng)化微調(diào),西湖大學(xué)發(fā)布圖像鏈CoT
OpenAI最近推出了在大語(yǔ)言模型LLM上的強(qiáng)化微調(diào)(Reinforcement Finetuning,ReFT),能夠讓模型利用CoT進(jìn)行多步推理之后,通過(guò)強(qiáng)化學(xué)習(xí)讓最終輸出符合人類偏好。
無(wú)獨(dú)有偶,齊國(guó)君教授領(lǐng)導(dǎo)的MAPLE實(shí)驗(yàn)室在OpenAI發(fā)布會(huì)一周前公布的工作中也發(fā)現(xiàn)了圖像生成領(lǐng)域的主打方法擴(kuò)散模型和流模型中也存在類似的過(guò)程:模型從高斯噪聲開(kāi)始的多步去噪過(guò)程也類似一個(gè)思維鏈,逐步「思考」怎樣生成一張高質(zhì)量圖像,是一種圖像生成領(lǐng)域的「圖像鏈CoT」。
與OpenAI不謀而和的是,機(jī)器學(xué)習(xí)與感知(MAPLE)實(shí)驗(yàn)室認(rèn)為強(qiáng)化學(xué)習(xí)微調(diào)方法同樣可以用于優(yōu)化多步去噪的圖像生成過(guò)程,論文指出利用與人類獎(jiǎng)勵(lì)對(duì)齊的強(qiáng)化學(xué)習(xí)監(jiān)督訓(xùn)練,能夠讓擴(kuò)散模型和流匹配模型自適應(yīng)地調(diào)整推理過(guò)程中噪聲強(qiáng)度,用更少的步數(shù)生成高質(zhì)量圖像內(nèi)容。
圖片
論文地址:https://arxiv.org/abs/2412.01243
研究背景
擴(kuò)散和流匹配模型是當(dāng)前主流的圖像生成模型,從標(biāo)準(zhǔn)高斯分布中采樣的噪聲逐步變換為一張高質(zhì)量圖像。在訓(xùn)練時(shí),這些模型會(huì)單獨(dú)監(jiān)督每一個(gè)去噪步驟,使其具備能恢復(fù)原始圖像的能力;而在實(shí)際推理時(shí),模型則會(huì)事先指定若干個(gè)不同的擴(kuò)散時(shí)間,然后在這些時(shí)間上依次執(zhí)行多步去噪過(guò)程。
這一過(guò)程存在兩個(gè)問(wèn)題:
1. 經(jīng)典的擴(kuò)散模型訓(xùn)練方法只能保證每一步去噪能盡可能恢復(fù)出原始圖像,不能保證整個(gè)去噪過(guò)程得到的圖像符合人類的偏好;
2. 經(jīng)典的擴(kuò)散模型所有的圖片都采用了同樣的去噪策略和步數(shù);而顯然不同復(fù)雜度的圖像對(duì)于人類來(lái)說(shuō)生成難度是不一樣的。
如下圖所示,當(dāng)輸入不同長(zhǎng)度的prompt的時(shí)候,對(duì)應(yīng)的生成任務(wù)難度自然有所區(qū)別。那些僅包含簡(jiǎn)單的單個(gè)主體前景的圖像較為簡(jiǎn)單,只需要少量幾步就能生成不錯(cuò)的效果,而帶有精細(xì)細(xì)節(jié)的圖像則需要更多步數(shù),即經(jīng)過(guò)強(qiáng)化微調(diào)訓(xùn)練后的圖像生成模型就能自適應(yīng)地推理模型去噪過(guò)程,用盡可能少的步數(shù)生成更高質(zhì)量的圖像。
值得注意的是,類似于LLM對(duì)思維鏈進(jìn)行的動(dòng)態(tài)優(yōu)化,對(duì)擴(kuò)散模型時(shí)間進(jìn)行優(yōu)化的時(shí)候也需要?jiǎng)討B(tài)地進(jìn)行,而非僅僅依據(jù)輸入的prompt;換言之,優(yōu)化過(guò)程需要根據(jù)推理過(guò)程生成的「圖像鏈」來(lái)動(dòng)態(tài)一步步預(yù)測(cè)圖像鏈下一步的最優(yōu)去噪時(shí)間,從而保證圖像的生成質(zhì)量滿足reward指標(biāo)。
方法
MAPLE實(shí)驗(yàn)室認(rèn)為,要想讓模型在推理時(shí)用更少的步數(shù)生成更高質(zhì)量的圖像結(jié)果,需要用強(qiáng)化微調(diào)技術(shù)對(duì)多步去噪過(guò)程進(jìn)行整體監(jiān)督訓(xùn)練。既然圖像生成過(guò)程同樣也類似于LLM中的CoT:模型通過(guò)中間的去噪步驟「思考」生成圖像的內(nèi)容,并在最后一個(gè)去噪步驟給出高質(zhì)量的結(jié)果,也可以通過(guò)利用獎(jiǎng)勵(lì)模型評(píng)價(jià)整個(gè)過(guò)程生成的圖像質(zhì)量,通過(guò)強(qiáng)化微調(diào)使模型的輸出更符合人類偏好。
圖片
OpenAI的O1通過(guò)在輸出最終結(jié)果之前生成額外的token讓LLM能進(jìn)行額外的思考和推理,模型所需要做的最基本的決策是生成下一個(gè)token;而擴(kuò)散和流匹配模型的「思考」過(guò)程則是在生成最終圖像前,在不同噪聲強(qiáng)度對(duì)應(yīng)的擴(kuò)散時(shí)間(diffusion time)執(zhí)行多個(gè)額外的去噪步驟。為此,模型需要知道額外的「思考」步驟應(yīng)該在反向擴(kuò)散過(guò)程推進(jìn)到哪一個(gè)diffusion time的時(shí)候進(jìn)行。
為了實(shí)現(xiàn)這一目的,在網(wǎng)絡(luò)中引入了一個(gè)即插即用的時(shí)間預(yù)測(cè)模塊(Time Prediction Module, TPM)。這一模塊會(huì)預(yù)測(cè)在當(dāng)前這一個(gè)去噪步驟執(zhí)行完畢之后,模型應(yīng)當(dāng)在哪一個(gè)diffusion time下進(jìn)行下一步去噪。
具體而言,該模塊會(huì)同時(shí)取出去噪網(wǎng)絡(luò)第一層和最后一層的圖像特征,預(yù)測(cè)下一個(gè)去噪步驟時(shí)的噪聲強(qiáng)度會(huì)下降多少。模型的輸出策略是一個(gè)參數(shù)化的beta分布。
由于單峰的Beta分布要求α>1且β>1,研究人員對(duì)輸出進(jìn)行了重參數(shù)化,使其預(yù)測(cè)兩個(gè)實(shí)數(shù)a和b,并通過(guò)如下公式確定對(duì)應(yīng)的Beta分布,并采樣下一步的擴(kuò)散時(shí)間。
圖片
圖片
在強(qiáng)化微調(diào)的訓(xùn)練過(guò)程中,模型會(huì)在每一步按輸出的Beta分布隨機(jī)采樣下一個(gè)擴(kuò)散時(shí)間,并在對(duì)應(yīng)時(shí)間執(zhí)行下一個(gè)去噪步驟。直到擴(kuò)散時(shí)間非常接近0時(shí),可以認(rèn)為此時(shí)模型已經(jīng)可以近乎得到了干凈圖像,便終止去噪過(guò)程并輸出最終圖像結(jié)果。
通過(guò)上述過(guò)程,即可采樣到用于強(qiáng)化微調(diào)訓(xùn)練的一個(gè)決策軌跡樣本。而在推理過(guò)程中,模型會(huì)在每一個(gè)去噪步驟輸出的Beta分布中直接采樣眾數(shù)作為下一步對(duì)應(yīng)的擴(kuò)散時(shí)間,以確保一個(gè)確定性的推理策略。
設(shè)計(jì)獎(jiǎng)勵(lì)函數(shù)時(shí),為了鼓勵(lì)模型用更少的步數(shù)生成高質(zhì)量圖像,在獎(jiǎng)勵(lì)中綜合考慮了生成圖像質(zhì)量和去噪步數(shù)這兩個(gè)因素,研究人員選用了與人類偏好對(duì)齊的圖像評(píng)分模型ImageReward(IR)用以評(píng)價(jià)圖像質(zhì)量,并將這一獎(jiǎng)勵(lì)隨步數(shù)衰減至之前的去噪結(jié)果,并取平均作為整個(gè)去噪過(guò)程的獎(jiǎng)勵(lì)。這樣,生成所用的步數(shù)越多,最終獎(jiǎng)勵(lì)就越低。模型會(huì)在保持圖像質(zhì)量的前提下,盡可能地減少生成步數(shù)。
圖片
將整個(gè)多步去噪過(guò)程當(dāng)作一個(gè)動(dòng)作進(jìn)行整體優(yōu)化,并采用了無(wú)需值模型的強(qiáng)化學(xué)習(xí)優(yōu)化算法RLOO [1]更新TPM模塊參數(shù),訓(xùn)練損失如下所示:
圖片
在這一公式中,s代表強(qiáng)化學(xué)習(xí)中的狀態(tài),在擴(kuò)散模型的強(qiáng)化微調(diào)中是輸入的文本提詞和初始噪聲;y代表決策動(dòng)作,也即模型采樣的擴(kuò)散時(shí)間;
代表決策器,即網(wǎng)絡(luò)中A是由獎(jiǎng)勵(lì)歸一化之后的優(yōu)勢(shì)函數(shù),采用LEAVE-One-Out策略,基于一個(gè)Batch內(nèi)的樣本間獎(jiǎng)勵(lì)的差值計(jì)算優(yōu)勢(shì)函數(shù)。
通過(guò)強(qiáng)化微調(diào)訓(xùn)練,模型能根據(jù)輸入圖像自適應(yīng)地調(diào)節(jié)擴(kuò)散時(shí)間的衰減速度,在面對(duì)不同的生成任務(wù)時(shí)推理不同數(shù)量的去噪步數(shù)。對(duì)于簡(jiǎn)單的生成任務(wù)(較短的文本提詞、生成圖像物體少),推理過(guò)程能夠很快生成高質(zhì)量的圖像,噪聲強(qiáng)度衰減較快,模型只需要思考較少的額外步數(shù),就能得到滿意的結(jié)果;對(duì)于復(fù)雜的生成任務(wù)(長(zhǎng)文本提詞,圖像結(jié)構(gòu)復(fù)雜)則需要在擴(kuò)散時(shí)間上密集地進(jìn)行多步思考,用一個(gè)較長(zhǎng)的圖像鏈COT來(lái)生成符合用戶要求的圖片。
圖片
通過(guò)調(diào)節(jié)不同的γ值,模型能在圖像生成質(zhì)量和去噪推理的步數(shù)之間取得更好的平衡,僅需要更少的平均步數(shù)就能達(dá)到與原模型相同的性能。
圖片
同時(shí),強(qiáng)化微調(diào)的訓(xùn)練效率也十分驚人。正如OpenAI最少僅僅用幾十個(gè)例子就能讓LLM學(xué)會(huì)在自定義領(lǐng)域中推理一樣,強(qiáng)化微調(diào)圖像生成模型對(duì)數(shù)據(jù)的需求也很少。不需要真實(shí)圖像,只需要文本提詞就可以訓(xùn)練,利用不到10,000條文本提詞就能取得不錯(cuò)的明顯的模型提升。
經(jīng)強(qiáng)化微調(diào)后,模型的圖像生成質(zhì)量也比原模型提高了很多。可以看出,在僅僅用了原模型一半生成步數(shù)的情況下,無(wú)論是圖C中的筆記本鍵盤,圖D中的球棒還是圖F中的遙控器,該模型生成的結(jié)果都比原模型更加自然。
圖片
針對(duì)Stable Diffusion 3、Flux-dev等一系列最先進(jìn)的開(kāi)源圖像生成模型進(jìn)行了強(qiáng)化微調(diào)訓(xùn)練,發(fā)現(xiàn)訓(xùn)練后的模型普遍能減少平均約50%的模型推理步數(shù),而圖像質(zhì)量評(píng)價(jià)指標(biāo)總體保持不變,這說(shuō)明對(duì)于圖像生成模型而言,強(qiáng)化微調(diào)訓(xùn)練是一種通用的后訓(xùn)練(Post Training)方法。
圖片
結(jié)論
這篇報(bào)告介紹了由MAPLE實(shí)驗(yàn)室提出的,一種擴(kuò)散和流匹配模型的強(qiáng)化微調(diào)方法。該方法將多步去噪的圖像生成過(guò)程看作圖像生成領(lǐng)域的COT過(guò)程,通過(guò)將整個(gè)去噪過(guò)程的最終輸出與人類偏好對(duì)齊,實(shí)現(xiàn)了用更少的推理步數(shù)生成更高質(zhì)量圖像。
在多個(gè)開(kāi)源圖像生成模型上的實(shí)驗(yàn)結(jié)果表明,這種強(qiáng)化微調(diào)方法能在保持圖像質(zhì)量的同時(shí)顯著減少約50%推理步數(shù),微調(diào)后模型生成的圖像在視覺(jué)效果上也更加自然??梢钥闯觯瑥?qiáng)化微調(diào)技術(shù)在圖像生成模型中仍有進(jìn)一步應(yīng)用和提升的潛力,值得進(jìn)一步挖掘。
參考資料: