COLM 2024:一種新的深度學習架構——Monotone Deep Boltzmann Machines 原創(chuàng)
在深度學習的世界里,Boltzmann機器是一種很有趣的模型,通過概率來理解數(shù)據(jù)。想象一下,我們有很多變量,它們之間的關系就像一張復雜的網(wǎng)。Boltzmann機器就是試圖描述這些變量之間的概率關系。它有不同的版本,比如深Boltzmann機器(DBM)和受限Boltzmann機器(RBM)。
RBM是一種比較常用的形式,它避免了模型同一層內(nèi)的連接,這樣可以使用更高效的基于塊的近似推理方法。但是,我們不禁要問,除了這種限制,還有沒有其他的限制方式也能讓推理變得高效呢?這就引出了我們今天要介紹的新模型——單調(diào)深度Boltzmann機器(Monotone Deep Boltzmann Machines,mDBM)。
一、背景知識
1.平衡模型及其收斂性
首先要提到的是深度平衡模型(DEQ)。它是由Bai等人提出的,就像一個神奇的公式,可以模擬一個無限深度的網(wǎng)絡。后來,Winston和Kolter又提出了一種參數(shù)化的DEQ(monDEQ),它可以保證收斂到一個獨特的固定點。這就像給模型找到了一個穩(wěn)定的“家”,讓它不會亂跑。
2.馬爾可夫隨機場(MRF)及其變體
MRF是一種基于能量的模型,Boltzmann機器就是它的一種形式。其中,RBM是比較成功的變體,它通過特定的能量函數(shù)定義來避免層內(nèi)連接。但是我們的mDBM不一樣,它允許層內(nèi)連接,更加靈活和強大。
3.并行和收斂的平均場
平均場更新通常使用坐標上升算法在局部收斂。但是有很多研究都在嘗試并行化更新。比如Kr?henbühl和Koltun提出的方法,還有Baqué等人的方法。我們的mDBM在這方面也有自己的創(chuàng)新,它可以保證在并行更新時收斂到一個全局最優(yōu)的平均場固定點。
圖1:不同玻爾茲曼機的神經(jīng)網(wǎng)絡拓撲結構。一般情況是一個完全圖(紅色虛線是玻爾茲曼機中有但受限玻爾茲曼機中沒有的邊的子集)。
二、單調(diào)深Boltzmann機器(mDBM)的奧秘
1.單調(diào)參數(shù)化
- 我們是如何讓mDBM變得獨特的呢?首先是參數(shù)化。我們通過一種特殊的方式來定義模型中的成對勢,使得它們滿足一定的單調(diào)性條件。就像給模型的各個部分設定了規(guī)則,讓它們按照我們想要的方式“行動”。
- 具體來說,我們定義了一些矩陣和運算,比如通過對矩陣A的處理來得到成對勢Φ。這樣的參數(shù)化既保證了Φ矩陣的空心性,又保證了單調(diào)性。
2.平均場推理作為單調(diào)DEQ
- 接下來,我們把平均場推理和單調(diào)DEQ聯(lián)系起來。平均場推理是為了近似條件分布,我們發(fā)現(xiàn),在一定條件下,這個平均場固定點可以看作是一個類似DEQ的固定點。這就像找到了兩個不同領域之間的橋梁,讓我們可以更好地理解和處理模型。
- 而且,我們還證明了在特定條件下,對于任何輸入,都存在一個獨特的、全局最優(yōu)的平均場分布的固定點。
3.實際建模考慮
- 在實際中,當我們用mDBM來建模時,需要考慮很多細節(jié)。比如變量可能代表深度學習架構中的隱藏單元,我們不能直接表示矩陣A,而是要找到一種方法來計算與A相關的乘法。
- 我們通常會把隱藏單元分成不同的集合,通過卷積層等方式來參數(shù)化A。這樣可以處理各種復雜的網(wǎng)絡結構,比如卷積、全連接層和跳躍連接等。
4.高效并行求解平均場固定點
- 雖然我們保證了單調(diào)性,但是簡單的迭代不一定能收斂到我們想要的解。所以我們需要使用阻尼迭代。
- 這個阻尼迭代可以在網(wǎng)絡的所有變量上并行進行,不需要像傳統(tǒng)的平均場推理那樣使用坐標下降方法。但是計算這個阻尼迭代中的近鄰算子并不容易,我們通過一系列的定理和方法來解決這個問題,包括找到一種數(shù)值穩(wěn)定的計算方法。
圖2:一種可能的深度卷積玻爾茲曼機的示意圖,其中單調(diào)性結構仍然可以被強制實施
三、訓練mDBM的技巧
1.損失函數(shù)的選擇
- 在訓練mDBM時,我們要選擇合適的方法。概率模型通常通過近似似然最大化來訓練,但是對于我們的mDBM,直接使用平均場近似來訓練參數(shù)可能不是最好的方法。
- 我們發(fā)現(xiàn),基于邊際的損失函數(shù)是一個更好的選擇。就像給模型一個明確的目標,讓它朝著正確的方向?qū)W習。
2.具體訓練過程
- 給定一個樣本,我們首先要解決平均場推理問題,找到隱藏狀態(tài)的估計值。然后我們可以計算預測值和真實值之間的損失函數(shù),通過這個損失函數(shù)來更新模型的參數(shù)。
- 在計算梯度時,我們要注意一些細節(jié),比如通過隱函數(shù)定理來處理一些復雜的導數(shù)關系。而且,由于單調(diào)性約束,我們可能還需要對輸出邊際進行一些處理,比如使用一個可學習的溫度參數(shù)來調(diào)整。
四、實驗評估:mDBM的實力展示
1.在MNIST數(shù)據(jù)集上的表現(xiàn)
- 我們在MNIST數(shù)據(jù)集上測試了mDBM。我們進行了聯(lián)合像素插補和分類任務,隨機掩蓋一部分像素,然后讓模型預測缺失的像素和圖像的類別。
- 結果顯示,mDBM的測試分類準確率達到了92.95%,而傳統(tǒng)的深RBM只有64.23%。而且從像素插補的效果來看,mDBM也遠遠優(yōu)于RBM。我們還比較了不同比例像素被掩蓋時的情況,mDBM在各種情況下都表現(xiàn)出了優(yōu)勢。
- 我們還測試了mDBM在一些特殊任務上的表現(xiàn),比如隨機掩蓋14X14的補丁,mDBM也能很好地收斂和預測。
2.在CIFAR - 10數(shù)據(jù)集上的表現(xiàn)
- 在CIFAR - 10數(shù)據(jù)集上,我們同樣進行了圖像像素插補和標簽預測任務。當50%的像素被觀察時,mDBM模型獲得了58%的測試準確率,并且能夠有效地插補缺失的像素。
- 與深RBM相比,mDBM在插補誤差等方面也表現(xiàn)出了優(yōu)勢。
3.與其他推理方法的比較
- 我們還比較了mDBM的平均場推理方法與其他一些方法,比如Kr?henbühl和Koltun以及Baqué等人提出的方法。
- 實驗結果顯示,我們的方法在收斂速度上更快,而且能夠保證收斂到真正的平均場固定點,而其他方法可能存在不收斂或者收斂到錯誤的固定點的問題。
圖3:使用mDBM和深度RBM對CIFAR - 10進行像素填充
圖4:僅將上半部分展示給模型時,使用mDBM對MNIST和CIFAR10進行像素填充。左:觀測圖像;中:填充結果;右:原始圖像。
五、未來方向:探索更多可能
1.單調(diào)性條件的改進
目前定理3.1中的單調(diào)性條件只是充分條件,不是必要條件。如果我們能改進這個條件,可能會讓我們的模型更加靈活和強大。
2.模型的單調(diào)性調(diào)整
我們可以思考是否可以使用一個負的參數(shù)m來讓模型“有界非單調(diào)”,同時還能保持良好的收斂性質(zhì)。
3.聯(lián)合概率建模
我們的模型目前只學習條件概率,是否可以讓它更高效地學習聯(lián)合概率呢?這是一個值得探索的方向。
4.模型的擴展和優(yōu)化
雖然我們已經(jīng)有了一個比較高效的實現(xiàn),但是與一些常見的非線性函數(shù)相比,還是比較慢。我們需要找到一種方法來更高效地擴展mDBM。
5.其他概率模型與DEQ框架的聯(lián)系
我們可以探索更多的概率模型是否也可以在DEQ框架內(nèi)表達,就像發(fā)現(xiàn)更多的寶藏等待我們?nèi)ネ诰颉?/p>
六、結論
在這篇文章中,我們介紹了單調(diào)深Boltzmann機器(mDBM)。它是一種很有潛力的深度學習模型,通過獨特的參數(shù)化和推理方法,在處理圖像數(shù)據(jù)等任務上表現(xiàn)出了很好的性能。我們還討論了它的訓練方法和未來的發(fā)展方向。希望這個新的模型能在深度學習的領域中開辟出一片新的天地,讓我們更好地理解和處理數(shù)據(jù)。
本文轉(zhuǎn)載自公眾號AIGC最前線 作者:實習小畢
原文鏈接:??https://mp.weixin.qq.com/s/RF6TeDxJIA0YXRCqz2QX-g??
