逐字生成非最優(yōu)?試試逐「塊」生成!Block Diffusion打通了自回歸與擴(kuò)散
去年初,OpenAI 的視頻生成模型 Sora 帶火了擴(kuò)散模型。
如今,擴(kuò)散模型被廣泛用于生成圖像和視頻,并在生成文本或生物序列等離散數(shù)據(jù)方面變得越來(lái)越有效。從技術(shù)上講,與自回歸模型相比,擴(kuò)散模型具有加速生成和提高模型輸出可控性的潛力。
目前,離散擴(kuò)散模型目前面臨至少三個(gè)限制。首先,在聊天系統(tǒng)等應(yīng)用中,模型必須生成任意長(zhǎng)度的輸出序列(例如對(duì)用戶問(wèn)題的回答)。但是,大多數(shù)最新的擴(kuò)散架構(gòu)僅能生成固定長(zhǎng)度的向量。其次,離散擴(kuò)散模型在生成過(guò)程中使用雙向上下文,因此無(wú)法使用 KV 緩存重用以前的計(jì)算,這會(huì)降低推理效率。第三,以困惑度等標(biāo)準(zhǔn)指標(biāo)衡量的離散擴(kuò)散模型,質(zhì)量落后于自回歸方法,進(jìn)一步限制了其適用性。
本文中,來(lái)自 Cornell Tech、斯坦福大學(xué)、Cohere 的研究者提出通過(guò)塊離散去噪擴(kuò)散語(yǔ)言模型(Block Discrete Denoising Diffusion Language Models,BD3-LMs)來(lái)解決以上限制,該模型在擴(kuò)散和自回歸模型之間進(jìn)行插值。
具體來(lái)講,塊擴(kuò)散模型(也是半自回歸模型)定義了離散隨機(jī)變量塊的自回歸概率分布,而給定先前塊的條件概率由離散去噪擴(kuò)散模型指定。
- 論文標(biāo)題:Block Diffusion: Interpolating Between Autoregressive and Diffusion Language Models
- 論文地址:https://arxiv.org/pdf/2503.09573
- 項(xiàng)目主頁(yè):https://m-arriola.com/bd3lms/
下圖為 Block Diffusion 與自回歸、擴(kuò)散模型的生成效果對(duì)比:
研究者表示,開(kāi)發(fā)有效的 BD3-LM 面臨以下兩個(gè)挑戰(zhàn):一是使用神經(jīng)網(wǎng)絡(luò)的一次標(biāo)準(zhǔn)前向傳遞無(wú)法有效地計(jì)算塊擴(kuò)散模型的訓(xùn)練目標(biāo),需要開(kāi)發(fā)專門(mén)的算法。二是擴(kuò)散目標(biāo)梯度的高方差阻礙了訓(xùn)練,導(dǎo)致 BD3-LM 即使在塊大小為 1 的情況下(當(dāng)兩個(gè)模型等效時(shí))也表現(xiàn)不佳。
因此,研究者推導(dǎo)出梯度方差的估計(jì)量,并證明它是自回歸和擴(kuò)散之間困惑度差距的關(guān)鍵因素。然后,他們提出了自定義噪聲過(guò)程,以實(shí)現(xiàn)最小化梯度方差并進(jìn)一步縮小困惑度差距。
實(shí)驗(yàn)部分,研究者在多個(gè)語(yǔ)言建?;鶞?zhǔn)上評(píng)估了 BD3-LM,并證明它們能夠生成任意長(zhǎng)度的序列,包括超出其訓(xùn)練上下文的長(zhǎng)度。此外,BD3-LM 在離散擴(kuò)散模型中實(shí)現(xiàn)了新的 SOTA 困惑度。與對(duì)嵌入進(jìn)行高斯擴(kuò)散的替代半自回歸方法相比,本文離散方法實(shí)現(xiàn)了易于處理的似然估計(jì),并在少一個(gè)數(shù)量級(jí)生成步驟的情況下,生成的樣本在困惑度方面得到了改進(jìn)。
論文一作 Marianne Arriola 發(fā)推稱,擴(kuò)散語(yǔ)言模型在并行文本生成領(lǐng)域正在崛起,但與自回歸模型相比,它們存在質(zhì)量、固定長(zhǎng)度限制和缺乏 KV 緩存等問(wèn)題。本文 Block Diffusion 將自回歸和擴(kuò)散模型結(jié)合了起來(lái),實(shí)現(xiàn)了兩全其美。
BD3-LMs 模型概覽
研究者結(jié)合建模范式,從自回歸模型中獲得更好的似然估計(jì)和靈活的長(zhǎng)度生成,并從擴(kuò)散模型中獲得了快速的并行生成效果。
塊擴(kuò)散似然
研究者提出了一個(gè)建??蚣埽摽蚣軐?duì) token 塊進(jìn)行自回歸建模,并在每個(gè)塊內(nèi)執(zhí)行擴(kuò)散操作。他們對(duì)長(zhǎng)度為 L′ 的 B 個(gè)塊進(jìn)行似然分解,如下所示:
每個(gè) pθ(x^b|x^<b) 都使用包含 L′個(gè) token 的塊上的離散擴(kuò)散 ELBO 進(jìn)行建模,并通過(guò)優(yōu)化以下似然邊界來(lái)獲得原則性學(xué)習(xí)目標(biāo) L_BD (x,θ):
研究者使用簡(jiǎn)單的離散擴(kuò)散參數(shù)化對(duì)每個(gè)塊的似然進(jìn)行建模,最終目標(biāo)是對(duì)交叉熵項(xiàng)進(jìn)行加權(quán)總和:
高效的訓(xùn)練與采樣算法
簡(jiǎn)單來(lái)說(shuō),研究者想要通過(guò)在一個(gè) loop 中應(yīng)用B 次來(lái)計(jì)算 logits。不過(guò),他們只需要兩次前向傳遞。第一次傳遞分別預(yù)計(jì)算完整序列 x 的鍵和值 K^1:B、V^1:B,在第二次前向傳遞中使用
同時(shí)計(jì)算所有塊的去噪預(yù)測(cè)。
為了從 BD3-LM 中采樣,研究者以先前采樣的塊為條件,一次生成一個(gè)塊。生成塊后,他們緩存其鍵和值,類似于 AR。同時(shí)在每個(gè)塊的 T 個(gè)采樣步下,使用任何擴(kuò)散采樣流程 SAMPLE中進(jìn)行采樣。來(lái)從條件分布 pθ
算法 1(塊擴(kuò)散訓(xùn)練)和算法 2(塊擴(kuò)散采樣)分別如下圖(左)和(右)所示。
BD3-LM 訓(xùn)練和采樣算法。
理解擴(kuò)散模型與自回歸模型之間的似然差距
案例研究:?jiǎn)?Token 生成
該研究中的塊擴(kuò)散參數(shù)化在期望上等同于自回歸負(fù)對(duì)數(shù)似然 (NLL),特別是在 L′=1 的極限情況下。令人驚訝的是,當(dāng)在 LM1B 數(shù)據(jù)集上訓(xùn)練兩種模型時(shí),研究發(fā)現(xiàn)塊擴(kuò)散模型 (L′=1) 與自回歸模型之間存在兩點(diǎn)困惑度差距。研究確定擴(kuò)散目標(biāo)的高訓(xùn)練方差是導(dǎo)致這一困惑度差距的原因。
在離散擴(kuò)散 ELBO 下進(jìn)行訓(xùn)練時(shí),存在高方差。
高方差訓(xùn)練導(dǎo)致的擴(kuò)散差距
直觀來(lái)說(shuō),如果采樣的掩碼率過(guò)低,重構(gòu) x 會(huì)變得容易,這不能提供有用的學(xué)習(xí)信號(hào)。如果掩碼全部?jī)?nèi)容,最優(yōu)的重構(gòu)就是數(shù)據(jù)分布中每個(gè)標(biāo)記的邊際概率,這很容易學(xué)習(xí),同樣也沒(méi)有用處。
研究需要找到能夠最小化擴(kuò)散目標(biāo)引起的訓(xùn)練方差,并進(jìn)一步減少困惑度差距的噪聲調(diào)度方案。
基于數(shù)據(jù)的低方差訓(xùn)練噪聲調(diào)度
為了避免導(dǎo)致高方差訓(xùn)練的掩碼率,研究者在「裁剪的』掩碼率下來(lái)訓(xùn)練 BD3-LMs。通過(guò)降低訓(xùn)練方差,研究者在均勻采樣的掩碼率評(píng)估下改善了似然度。
由于最佳掩碼率可能會(huì)根據(jù)塊大小 L′的不同而變化,他們?cè)谟?xùn)練期間自適應(yīng)地學(xué)習(xí) β,ω。在實(shí)踐中,研究者在每個(gè)驗(yàn)證步驟后(經(jīng)過(guò) 5K 次梯度更新)使用網(wǎng)格搜索來(lái)優(yōu)化。
在下文中,研究者展示了針對(duì)每個(gè)塊大小優(yōu)化噪聲調(diào)度可以減少損失估計(jì)器的方差,并與其他替代調(diào)度方案相比實(shí)現(xiàn)最佳困惑度。
實(shí)驗(yàn)結(jié)果
似然評(píng)估
BD3-LMs 在擴(kuò)散模型中實(shí)現(xiàn)了最先進(jìn)的似然水平。研究表明,通過(guò)調(diào)整塊長(zhǎng)度 L′,BD3-LMs 可以在擴(kuò)散和自回歸似然之間實(shí)現(xiàn)插值。
在 OWT 上測(cè)試針對(duì) 262B 標(biāo)記訓(xùn)練的模型的困惑度 (PPL; ↓)。
任意長(zhǎng)度序列生成
許多現(xiàn)有擴(kuò)散語(yǔ)言模型的一個(gè)主要缺點(diǎn)是,它們無(wú)法生成超過(guò)訓(xùn)練時(shí)選擇的輸出上下文長(zhǎng)度的完整文檔。例如,OpenWebText 包含最長(zhǎng)達(dá) 131K tokens 的文檔,而離散擴(kuò)散模型 SEDD(Lou 等人)僅限于生成 1024 tokens。研究表明,BD3-LMs 能夠通過(guò)解碼任意數(shù)量的塊來(lái)生成可變長(zhǎng)度的文檔。
從在 OWT 上訓(xùn)練的模型中抽樣 500 個(gè)文檔得出的生成長(zhǎng)度統(tǒng)計(jì)信息。
研究者評(píng)估了 BD3-LMs 在變長(zhǎng)序列上的生成質(zhì)量,使用相同數(shù)量的生成步驟(NFEs)比較了所有方法。他們用 GPT2-Large 模型測(cè)量生成序列的困惑度。結(jié)果表明,與之前所有的擴(kuò)散方法相比,BD3-LMs 實(shí)現(xiàn)了最佳的生成困惑度。
300 個(gè)可變長(zhǎng)度樣本的生成困惑度 (Gen. PPL;↓) 和功能評(píng)估次數(shù) (NFE;↓)。所有模型都在 OWT 上進(jìn)行訓(xùn)練,上下文長(zhǎng)度為 L = 1024,并使用核采樣。
對(duì)于 MDLM,研究者使用了其分塊解碼技術(shù)(該技術(shù)不同于 BD3-LMs 中的分塊擴(kuò)散訓(xùn)練)處理 L=2048 的序列。研究者還與 SSD-LM(Han 等人提出)進(jìn)行了比較,后者是一種替代性的分塊自回歸方法(也稱為半自回歸),它對(duì)詞嵌入執(zhí)行高斯擴(kuò)散,但無(wú)法進(jìn)行似然估計(jì)。該研究的離散方法使用比其他方法少一個(gè)數(shù)量級(jí)的生成步驟,產(chǎn)生了具有更好生成困惑度的樣本。
更多細(xì)節(jié)請(qǐng)參閱原論文。