無(wú)限生成視頻,還能規(guī)劃決策,擴(kuò)散強(qiáng)制整合下一token預(yù)測(cè)與全序列擴(kuò)散
近日,MIT CSAIL 的一個(gè)研究團(tuán)隊(duì)(一作為 MIT 在讀博士陳博遠(yuǎn))成功地將全序列擴(kuò)散模型與下一 token 模型的強(qiáng)大能力統(tǒng)合到了一起,提出了一種訓(xùn)練和采樣范式:Diffusion Forcing(DF)。
- 論文標(biāo)題:Diffusion Forcing:Next-token Prediction Meets Full-Sequence Diffusion
- 論文地址:https://arxiv.org/pdf/2407.01392
- 項(xiàng)目網(wǎng)站:https://boyuan.space/diffusion-forcing
- 代碼地址:https://github.com/buoyancy99/diffusion-forcing?
如下所示,擴(kuò)散強(qiáng)制在一致性和穩(wěn)定性方面都明顯勝過(guò)全序列擴(kuò)散和教師強(qiáng)制這兩種方法。
在該框架中,每個(gè) token 都關(guān)聯(lián)了一個(gè)隨機(jī)的、獨(dú)立的噪聲水平,并且可使用一種共享的下一 token 預(yù)測(cè)模型或下幾 token 預(yù)測(cè)模型根據(jù)任意的、獨(dú)立的、每 token 的方案對(duì) token 進(jìn)行去噪。
該方法的研究靈感來(lái)自這一觀察:對(duì) token 加噪聲的過(guò)程就是一種形式的部分掩碼過(guò)程 —— 零噪聲就意味著未對(duì) token 加掩碼,而完整噪聲則是完全掩蔽 token。因此,DF 可強(qiáng)迫模型學(xué)習(xí)去除任何可變有噪聲 token 集合的掩碼(圖 2)。
與此同時(shí),通過(guò)將預(yù)測(cè)方法參數(shù)化為多個(gè)下一 token 預(yù)測(cè)模型的組合,該系統(tǒng)可以靈活地生成不同長(zhǎng)度的序列,并以組合方式泛化到新的軌跡(圖 1)。
該團(tuán)隊(duì)將用于序列生成的 DF 實(shí)現(xiàn)成了因果擴(kuò)散強(qiáng)制(Causal Diffusion Forcing/CDF),其中未來(lái) token 通過(guò)一個(gè)因果架構(gòu)依賴于過(guò)去 token。他們訓(xùn)練該模型一次性去噪序列的所有 token(其中每個(gè) token 都有獨(dú)立的噪聲水平)。
在采樣期間,CDF 會(huì)將一個(gè)高斯噪聲幀序列逐漸地去噪成潔凈的樣本,其中不同幀在每個(gè)去噪步驟可能會(huì)有不同的噪聲水平。類似于下一 token 預(yù)測(cè)模型,CDF 可以生成長(zhǎng)度可變的序列;不同于下一 token 預(yù)測(cè),CDF 的表現(xiàn)非常穩(wěn)定 —— 不管是預(yù)測(cè)接下來(lái)的一個(gè) token,還是未來(lái)的數(shù)千 token,甚至是連續(xù) token。
此外,類似于全序列擴(kuò)散,它也可接收引導(dǎo),從而實(shí)現(xiàn)高獎(jiǎng)勵(lì)生成。通過(guò)協(xié)同利用因果關(guān)系、靈活的范圍和可變?cè)肼曊{(diào)度,CDF 能實(shí)現(xiàn)一項(xiàng)新功能:蒙特卡洛樹(shù)引導(dǎo)(MCTG)。相比于非因果全序列擴(kuò)散模型,MCTG 能極大提升高獎(jiǎng)勵(lì)生成的采樣率。圖 1 給出了這些能力的概況。
Diffusion Forcing(擴(kuò)散強(qiáng)制)
1、將加噪過(guò)程視為部分掩碼
首先,我們可以將任意 token 集合(不管是否為序列)視為一個(gè)通過(guò) t 索引的有序集合。那么,使用教師強(qiáng)制(teacher forcing)訓(xùn)練下一 token 預(yù)測(cè)便可被解釋成掩蔽掉時(shí)間 t 的每個(gè) token x_t 基于過(guò)去 x_{1:t?1} 預(yù)測(cè)它們。
對(duì)于序列,可將這種操作描述成:沿時(shí)間軸執(zhí)行掩碼。我們可以將全序列前向擴(kuò)散(即逐漸向數(shù)據(jù)
添加噪聲的過(guò)程)看作一種部分掩碼(partial masking),這可被稱為「沿噪聲軸執(zhí)行掩碼)。
事實(shí)上,在 K 步加噪之后,
(大概)就是白噪聲了,不再有任何有關(guān)原數(shù)據(jù)的信息。
如圖 2 所示,該團(tuán)隊(duì)建立了一個(gè)統(tǒng)一視角來(lái)看待沿這兩個(gè)軸的掩碼。
2、擴(kuò)散強(qiáng)制:不同 token 的噪聲水平不同
擴(kuò)散強(qiáng)制(DF)框架可用于訓(xùn)練和采樣任意序列長(zhǎng)度的有噪聲 token
,其中的關(guān)鍵在于每個(gè) token 的噪聲水平 k_t 會(huì)隨時(shí)間步驟而變化。
這篇論文關(guān)注的重點(diǎn)是時(shí)間序列數(shù)據(jù),因此他們通過(guò)一種因果架構(gòu)實(shí)例化了 DF,并由此得到了因果擴(kuò)散強(qiáng)制(CDF)。簡(jiǎn)單來(lái)說(shuō),這是使用基礎(chǔ)循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)獲得的一種最小實(shí)現(xiàn)。
權(quán)重為 θ 的 RNN 維護(hù)著獲悉過(guò)去 token 影響的隱藏狀態(tài) z_t,其會(huì)通過(guò)一個(gè)循環(huán)層根據(jù)動(dòng)態(tài)
而演化。當(dāng)獲得輸入噪聲觀察
時(shí),就以馬爾可夫方式更新該隱藏狀態(tài)。
當(dāng) k_t=0 時(shí),這就是貝葉斯過(guò)濾中的后驗(yàn)更新;而當(dāng) k_t= K(純?cè)肼?、無(wú)信息)時(shí),這就等價(jià)于建模貝葉斯過(guò)濾中的「后驗(yàn)分布」p_θ(z_t | z_{t?1})。
給定隱藏狀態(tài) z_t,觀察模型 p_θ(x_t^0 | z_t) 的目標(biāo)是預(yù)測(cè) x_t;這個(gè)單元的輸入 - 輸出行為與標(biāo)準(zhǔn)的條件擴(kuò)散模型一樣:以條件變量 z_{t?1} 和有噪聲 token 為輸入,預(yù)測(cè)無(wú)噪聲的 x_t=x_t^0,并由此間接地通過(guò)仿射重新參數(shù)化預(yù)測(cè)噪聲 ε^{k_t}。因此,我們就可以直接使用經(jīng)典的擴(kuò)散目標(biāo)來(lái)訓(xùn)練(因果)擴(kuò)散強(qiáng)制。根據(jù)噪聲預(yù)測(cè)結(jié)果 ε_(tái)θ,可以對(duì)上述單元進(jìn)行參數(shù)化。然后,通過(guò)最小化以下?lián)p失來(lái)找到參數(shù) θ:
算法 1 給出了偽代碼。重點(diǎn)在于,該損失捕獲了貝葉斯過(guò)濾和條件擴(kuò)散的關(guān)鍵元素。該團(tuán)隊(duì)也進(jìn)一步重新推斷了用于擴(kuò)散強(qiáng)制的擴(kuò)散模型訓(xùn)練中的常用技術(shù),詳見(jiàn)原論文的附錄部分。他們也得出了一個(gè)非正式的定理。
定理 3.1(非正式)。擴(kuò)散強(qiáng)制訓(xùn)練流程(算法 1)是在期望對(duì)數(shù)似然
上優(yōu)化證據(jù)下限(ELBO)的重新加權(quán),其中期望值會(huì)在噪聲水平上平均,而
是根據(jù)前向過(guò)程加噪。此外,在適當(dāng)條件下,優(yōu)化 (3.1) 式還可以同時(shí)最大化所有噪聲水平序列的似然下限。
擴(kuò)散強(qiáng)制采樣和所得到的能力
算法 2 描述了采樣過(guò)程,其定義是:在二維的 M × T 網(wǎng)格 K ∈ [K]^{M×T} 上指定噪聲調(diào)度;其中列對(duì)應(yīng)于時(shí)間步驟 t,m 索引的行則決定了噪聲水平。
為了生成長(zhǎng)度為 T 的整個(gè)序列,先將 token x_{1:T} 初始化為白噪聲,對(duì)應(yīng)于噪聲水平 k = K。然后沿著網(wǎng)格逐行向下迭代,并從左到右逐列去噪,直到噪聲水平達(dá)到 K。到最后一行 m = 0 時(shí),token 的噪聲已清理干凈,即噪聲水平為 K_{0,t} ≡ 0。
這個(gè)采樣范式會(huì)帶來(lái)如下新能力:
- 讓自回歸生成變得穩(wěn)定
- 保持未來(lái)的不確定
- 長(zhǎng)期引導(dǎo)能力
將擴(kuò)散強(qiáng)制用于靈活的序列決策
擴(kuò)散強(qiáng)制的新能力也帶來(lái)了新的可能性。該團(tuán)隊(duì)基于此為序列決策(SDM)設(shè)計(jì)了一種全新框架,并且將其成功應(yīng)用到了機(jī)器人和自主智能體領(lǐng)域。
首先,定義一個(gè)馬爾可夫決策過(guò)程,該過(guò)程具有動(dòng)態(tài) p (s_{t+1}|s_t, a_t)、觀察 p (o_t|s_t) 和獎(jiǎng)勵(lì) p (r_t|s_t, a_t)。這里的目標(biāo)是訓(xùn)練一個(gè)策略 π(a_t|o_{1:t}),使得軌跡
的預(yù)期累積獎(jiǎng)勵(lì)最大化。這里分配 token x_t = [a_t, r_t, o_{t+1}]。一條軌跡就是一個(gè)序列 x_{1:T},其長(zhǎng)度可能是可變的;訓(xùn)練方式則如算法 1 所示。
在執(zhí)行過(guò)程的每一步 t,都有一個(gè)隱藏狀態(tài) z_{t-1} 總結(jié)過(guò)去的無(wú)噪聲 token x_{1:t-1}?;谶@個(gè)隱藏狀態(tài),根據(jù)算法 2 采樣一個(gè)規(guī)劃
,其中
包含預(yù)測(cè)的動(dòng)作、獎(jiǎng)勵(lì)和觀察。H 是一個(gè)前向觀察窗口,類似于模型預(yù)測(cè)控制中的未來(lái)預(yù)測(cè)。在采用了規(guī)劃的動(dòng)作之后,環(huán)境會(huì)得到一個(gè)獎(jiǎng)勵(lì)和下一個(gè)觀察,從而得到下一個(gè) token。其中隱藏狀態(tài)可以根據(jù)后驗(yàn) p_θ(z_t|z_{t?1}, x_t, 0) 獲得更新。
該框架既可以作為策略,也可以作為規(guī)劃器,其優(yōu)勢(shì)包括:
- 具有靈活的規(guī)劃范圍
- 可實(shí)現(xiàn)靈活的獎(jiǎng)勵(lì)引導(dǎo)
- 能實(shí)現(xiàn)蒙特卡洛樹(shù)引導(dǎo)(MCTG),從而實(shí)現(xiàn)未來(lái)不確定性
實(shí)驗(yàn)
該團(tuán)隊(duì)評(píng)估了擴(kuò)散強(qiáng)制作為生成序列模型的優(yōu)勢(shì),其中涉及視頻和時(shí)間序列預(yù)測(cè)、規(guī)劃和模仿學(xué)習(xí)等多種應(yīng)用。
視頻預(yù)測(cè):一致且穩(wěn)定的序列生成和無(wú)限展開(kāi)
針對(duì)視頻生成式建模任務(wù),他們基于 Minecraft 游戲視頻和 DMLab 導(dǎo)航為因果擴(kuò)散強(qiáng)制訓(xùn)練了一個(gè)卷積 RNN 實(shí)現(xiàn)。
圖 3 展示了擴(kuò)散強(qiáng)制與基準(zhǔn)的定性結(jié)果。
可以看到,擴(kuò)散強(qiáng)制能穩(wěn)定地展開(kāi),甚至能超過(guò)其訓(xùn)練范圍;而教師強(qiáng)制和全序列擴(kuò)散基準(zhǔn)會(huì)很快發(fā)散。
擴(kuò)散規(guī)劃:MCTG、因果不確定性、靈活的范圍控制
擴(kuò)散強(qiáng)制的能力能為決策帶來(lái)獨(dú)有的好處。該團(tuán)隊(duì)使用一種標(biāo)準(zhǔn)的離線強(qiáng)化學(xué)習(xí)框架 D4RL 評(píng)估了新提出的決策框架。
表 1 給出了定性和定量的評(píng)估結(jié)果??梢钥吹?,擴(kuò)散強(qiáng)制在全部 6 個(gè)環(huán)境中都優(yōu)于 Diffuser 和所有基準(zhǔn)。
可控的序列組合生成
該團(tuán)隊(duì)發(fā)現(xiàn),僅需修改采樣方案,就可以靈活地組合訓(xùn)練時(shí)間觀察到的序列的子序列。
他們使用一個(gè) 2D 軌跡數(shù)據(jù)集進(jìn)行了實(shí)驗(yàn):在一個(gè)方形平面上,所有軌跡都是始于一角并最終到達(dá)對(duì)角,形成一種十字形。
如上圖 1 所示,當(dāng)不需要組合行為時(shí),可讓 DF 保持完整記憶,復(fù)制十字形的分布。當(dāng)需要組合時(shí),可讓模型使用 MPC 無(wú)記憶地生成更短的規(guī)劃,從而實(shí)現(xiàn)對(duì)這個(gè)十字形的子軌跡的縫合,得到 V 形軌跡。
機(jī)器人:長(zhǎng)范圍模仿學(xué)習(xí)和穩(wěn)健的視覺(jué)運(yùn)動(dòng)控制
擴(kuò)散強(qiáng)制也為真實(shí)機(jī)器人的視覺(jué)運(yùn)動(dòng)控制帶來(lái)了新的機(jī)會(huì)。
模仿學(xué)習(xí)是一種常用的機(jī)器人操控技術(shù),即學(xué)習(xí)專家演示的觀察到動(dòng)作的映射。但是,缺乏記憶往往會(huì)讓模仿學(xué)習(xí)難以完成長(zhǎng)范圍的任務(wù)。DF 不僅能緩解這個(gè)短板,還能讓模仿學(xué)習(xí)更穩(wěn)健。
使用記憶進(jìn)行模仿學(xué)習(xí)。通過(guò)遙控 Franka 機(jī)器人,該團(tuán)隊(duì)收集了一個(gè)視頻和動(dòng)作數(shù)據(jù)集。如圖 4 所示,任務(wù)就是利用第三個(gè)位置交換蘋(píng)果和橘子的位置。水果的初始位置是隨機(jī)的,因此可能的目標(biāo)狀態(tài)有兩個(gè)。
此外,當(dāng)?shù)谌齻€(gè)位置有一個(gè)水果時(shí),就無(wú)法通過(guò)當(dāng)前觀察推斷出所需結(jié)果 —— 策略必須記住初始配置才能決定移動(dòng)哪個(gè)水果。不同于常用的行為克隆方法,DF 可以自然地將記憶整合進(jìn)自己的隱藏狀態(tài)中。結(jié)果發(fā)現(xiàn),DF 能實(shí)現(xiàn) 80% 的成功率,而擴(kuò)散策略(當(dāng)前最佳的無(wú)記憶模仿學(xué)習(xí)算法)卻失敗了。
此外,DF 還能更穩(wěn)健地應(yīng)對(duì)噪聲并助益機(jī)器人預(yù)訓(xùn)練。
時(shí)間序列預(yù)測(cè):擴(kuò)散強(qiáng)制是一種優(yōu)秀的通用序列模型
對(duì)于多變量時(shí)間序列預(yù)測(cè)任務(wù),該團(tuán)隊(duì)的研究表明 DF 足以與之前的擴(kuò)散模型和基于 Transformer 的模型媲美。
更多技術(shù)細(xì)節(jié)和實(shí)驗(yàn)結(jié)果請(qǐng)參閱原論文。
本文轉(zhuǎn)自機(jī)器之心 ,作者:機(jī)器之心
