Alex Graves新作貝葉斯流網(wǎng)絡(luò),解決離散數(shù)據(jù)生成問(wèn)題,滿論文都是數(shù)學(xué)公式
近來(lái),大規(guī)模神經(jīng)網(wǎng)絡(luò)徹底改變了生成式模型,使模型具有前所未有的捕捉許多變量之間復(fù)雜關(guān)系的能力,例如建立高分辨率圖像中所有像素的聯(lián)合模型。
大多數(shù)神經(jīng)網(wǎng)絡(luò)(包括自回歸模型、基于流的模型、深度 VAE 和擴(kuò)散模型)表達(dá)能力的關(guān)鍵在于,它們編碼的聯(lián)合分布被分解為一系列步驟,從而避免了「維數(shù)災(zāi)難(curse of dimensionality)」。也就是說(shuō),它們將難題分解成多個(gè)簡(jiǎn)單問(wèn)題來(lái)解決。
自回歸網(wǎng)絡(luò)目前是語(yǔ)言建模領(lǐng)域的 SOTA 方法,并且通常在自然排序的離散數(shù)據(jù)上表現(xiàn)良好。然而,事實(shí)證明自回歸網(wǎng)絡(luò)在圖像生成等領(lǐng)域效果較差,因?yàn)檫@些領(lǐng)域的數(shù)據(jù)是連續(xù)的,并且變量之間不存在自然順序。自回歸模型還有一個(gè)缺點(diǎn)是,生成樣本需要與數(shù)據(jù)中變量一樣多的網(wǎng)絡(luò)更新。擴(kuò)散模型是一種應(yīng)用于圖像生成的有效替代框架,但傳輸過(guò)程會(huì)變得更加復(fù)雜。
然而,當(dāng)數(shù)據(jù)是離散的,擴(kuò)散模型的性能仍不及自回歸模型。最近,機(jī)器學(xué)習(xí)領(lǐng)域知名研究者、神經(jīng)圖靈機(jī)(NTM)提出者和可微神經(jīng)計(jì)算機(jī)的創(chuàng)造者之一 Alex Graves 以第一作者的身份發(fā)表了一篇新論文,提出了一種新型生成模型 —— 貝葉斯流網(wǎng)絡(luò)(Bayesian Flow Networks,BFN)。與擴(kuò)散模型不同的是,BFN 對(duì)數(shù)據(jù)分布的參數(shù)進(jìn)行操作,而不是對(duì)數(shù)據(jù)本身的噪聲版本進(jìn)行操作。這確保了生成過(guò)程是完全連續(xù)且可微的,即使數(shù)據(jù)是離散的。
論文地址:https://arxiv.org/abs/2308.07037
論文一作 Alex Graves,他是圖靈獎(jiǎng)得主 Geoffrey Hinton 的學(xué)生。
BFN 方法會(huì)根據(jù)噪聲數(shù)據(jù)樣本使用貝葉斯推斷修改一組獨(dú)立分布的參數(shù),然后將其作為輸入傳遞給神經(jīng)網(wǎng)絡(luò),該神經(jīng)網(wǎng)絡(luò)會(huì)輸出一個(gè)相互依賴(lài)的分布,然后從簡(jiǎn)單的先驗(yàn)開(kāi)始并迭代更新上述兩個(gè)分布,產(chǎn)生一種類(lèi)似于擴(kuò)散模型逆過(guò)程的生成過(guò)程,但 BFN 在概念上更簡(jiǎn)單,因?yàn)椴恍枰跋蜻^(guò)程。
BFN 的整體概覽如下圖 1 所示。在每一步中,消息發(fā)送者(Sender)Alice 都會(huì)向消息接收者(Receiver)Bob 發(fā)送一條消息,包含關(guān)于數(shù)據(jù)的一些信息。
其中,Bob 會(huì)嘗試猜測(cè)消息是什么:他猜測(cè)得越好,傳輸消息所需的比特?cái)?shù)就越少。收到消息后,Bob 使用剛剛獲得的信息來(lái)改進(jìn)對(duì)下一條消息的猜測(cè)。
重復(fù)該過(guò)程,每一步的預(yù)測(cè)都會(huì)得到改進(jìn)。傳輸成本之和是完整文本序列的負(fù)對(duì)數(shù)概率,通過(guò)最大似然訓(xùn)練進(jìn)行損失函數(shù)最小化。這也是 Alice 使用算術(shù)編碼將片段傳輸給 Bob 所需的最小位數(shù)。因此,用最大似然擬合自回歸模型與訓(xùn)練數(shù)據(jù)壓縮之間存在直接的對(duì)應(yīng)關(guān)系。
上述傳輸過(guò)程定義了一個(gè) n 步損失函數(shù),通過(guò)將 n 擴(kuò)展到∞,就能推廣到連續(xù)時(shí)間。連續(xù)時(shí)間損失函數(shù)在數(shù)學(xué)上比離散時(shí)間損失函數(shù)更簡(jiǎn)單、易于計(jì)算。經(jīng)過(guò)連續(xù)時(shí)間損失訓(xùn)練的 BFN 可以在推斷和采樣期間運(yùn)行任意數(shù)量的離散步驟,并且性能隨著步驟數(shù)量的增加而提升。
總的來(lái)說(shuō),BFN 結(jié)合了貝葉斯推斷和深度學(xué)習(xí)的優(yōu)勢(shì),前者為單個(gè)變量提供了一種極佳的數(shù)學(xué)方法,后者則擅長(zhǎng)整合多個(gè)相關(guān)變量的信息。
LSTM 提出者和奠基者 Sepp Hochreiter 表示:「貝葉斯流網(wǎng)絡(luò) (BFN) 作為擴(kuò)散模型的替代者,它更新的兩個(gè)分布過(guò)程可看作是一個(gè)生成過(guò)程,就像沒(méi)有前向傳遞的擴(kuò)散模型一樣。實(shí)驗(yàn)顯示,在 text8 字符級(jí)語(yǔ)言建模上優(yōu)于離散擴(kuò)散。」
論文作者之一 Rupesh Kumar Srivastava 表示,「這項(xiàng)研究使得我們可以通過(guò)選擇合適的分布,輕松地將 BFN 框架適應(yīng)于連續(xù)和離散數(shù)據(jù),并且在 MNIST、CIFAR-10 和 text8 任務(wù)上得到了很好的結(jié)果?!?/span>
貝葉斯流網(wǎng)絡(luò)
接下來(lái)我們介紹一下貝葉斯流網(wǎng)絡(luò)(Bayesian Flow Networks,BFN)的基本數(shù)學(xué)形式。本節(jié)都是公式推導(dǎo),大家可以參考原論文了解更詳細(xì)的信息。
輸入分布和 Sender 分布:給定 D 維數(shù)據(jù),
為因式輸入分布
的參數(shù),則輸入分布公式如下:
經(jīng)過(guò)一系列變換后,得到 Sender 分布公式:
輸出分布數(shù)據(jù)傳輸過(guò)程中,輸入?yún)?shù) θ 與過(guò)程時(shí)間 t 一起作為輸入傳遞給神經(jīng)網(wǎng)絡(luò) Ψ,然后網(wǎng)絡(luò)輸出一個(gè)向量,得到輸出分布:
與輸入分布不同,輸出分布可以利用上下文信息,例如圖像中的周?chē)袼鼗蛭谋局械南嚓P(guān)單詞。
Receiver 分布給定 Sender 分布和輸出分布, Receiver 分布可以表述為:
由上式可得,Receiver 分布有兩個(gè)不確定來(lái)源,即 Sender 分布和輸出分布。
貝葉斯更新
對(duì)于給定的參數(shù) θ,參數(shù)更新的方式如下所示,其中 y 為 Sender 樣本, α 為準(zhǔn)確率:
得到貝葉斯更新分布:
本文認(rèn)為,從某種意義上講,準(zhǔn)確率 α 是可以相加的,從而得到總的貝葉斯更新分布公式:
通過(guò)執(zhí)行無(wú)限多的傳輸步驟,貝葉斯更新過(guò)程可以推廣到連續(xù)時(shí)間。假設(shè) t ∈ [0, 1] 為處理時(shí)間,α(t) > 0 為時(shí)間 t 的準(zhǔn)確率,得到準(zhǔn)確率時(shí)間表:
貝葉斯流分布
給定先驗(yàn)參數(shù) θ_0、貝葉斯更新分布以及準(zhǔn)確率時(shí)間表 β(t), 貝葉斯流分布可以表示為
損失函數(shù)
損失函數(shù)定義為如下方式:
其中,
L (x) 可以推導(dǎo)為變分自編碼器(VAE)的損失函數(shù),經(jīng)過(guò)一系列變化,損失函數(shù)表述為:
根據(jù)損失函數(shù)(16),該研究又推導(dǎo)出了離散損失:
以及連續(xù)時(shí)間損失:
實(shí)驗(yàn)
該研究在以下生成基準(zhǔn)上評(píng)估了 BFN 網(wǎng)絡(luò),包括 CIFAR-10(32×32 8 位彩色圖像)、動(dòng)態(tài)二值化 MNIST(28×28 手寫(xiě)數(shù)字的二值化圖像)以及 text8(長(zhǎng)度 256 個(gè)字符序列,大小為 27 個(gè)字母)。
動(dòng)態(tài)二值化 MNIST
從表 1 可以看出,BFN 在沒(méi)有數(shù)據(jù)增強(qiáng)的情況下達(dá)到該任務(wù)最好的性能。
下圖為 MNIST 損失曲線:表明對(duì)于二進(jìn)制數(shù)據(jù),準(zhǔn)確率時(shí)間表不是最優(yōu)的。
CIFAR-10
該研究在 CIFAR-10 上進(jìn)行了兩組生成建模實(shí)驗(yàn),一組 bit-depth 為 8 ,對(duì)應(yīng)于顏色通道有 256 個(gè)離散 bin,另一組 bit-depth 為 4 ,對(duì)應(yīng)于顏色通道為 16 個(gè) bin。
表 3 顯示,對(duì)于 16 bins,離散損失比連續(xù)損失提供了更好的性能,并且訓(xùn)練時(shí)間也快得多。這一結(jié)果對(duì)應(yīng)了這樣一個(gè)假設(shè),即 bin 相對(duì)較低時(shí),使用離散損失進(jìn)行訓(xùn)練是最有益的。此外,對(duì)于 16 和 256 個(gè) bin,當(dāng)步數(shù) n 較低(例如 10 或 25)時(shí),離散訓(xùn)練會(huì)給出更好的結(jié)果。然而,在 256 個(gè) bin 上,連續(xù)損失比離散損失具有更好的性能。
圖 15 顯示,使用 16 個(gè) bin 進(jìn)行離散訓(xùn)練比使用 256 個(gè) bin 進(jìn)行離散訓(xùn)練可提供更好的樣本質(zhì)量。
TEXT8
表 4 顯示,BFN 在 text8 測(cè)試集上產(chǎn)生了 1.41 BPC,這比其他文獻(xiàn)中發(fā)現(xiàn)的所有離散擴(kuò)散模型都要好,并且接近最佳模型 MAC(1.40 BPC)。
表 5 顯示,對(duì)于步數(shù) n 的減少,BFN 的性能還是相當(dāng)穩(wěn)健的,只需 100 步即可達(dá)到 1.43 BPC。通過(guò)離散時(shí)間損失訓(xùn)練可能會(huì)改善這個(gè)結(jié)果。