如何讓等變神經(jīng)網(wǎng)絡(luò)可解釋性更強(qiáng)?試試將它分解成「簡單表示」
神經(jīng)網(wǎng)絡(luò)是一種靈活且強(qiáng)大的函數(shù)近似方法。而許多應(yīng)用都需要學(xué)習(xí)一個相對于某種對稱性不變或等變的函數(shù)。圖像識別便是一個典型示例 —— 當(dāng)圖像發(fā)生平移時,情況不會發(fā)生變化。等變神經(jīng)網(wǎng)絡(luò)(equivariant neural network)可為學(xué)習(xí)這些不變或等變函數(shù)提供一個靈活的框架。
而要研究等變神經(jīng)網(wǎng)絡(luò),可使用表示論(representation theory)這種數(shù)學(xué)工具。(請注意,「表示」這一數(shù)學(xué)概念不同于機(jī)器學(xué)習(xí)領(lǐng)域中的「表征」的典型含義。本論文僅使用該術(shù)語的數(shù)學(xué)意義。)
近日,Joel Gibson、Daniel Tubbenhauer 和 Geordie Williamson 三位研究者對等變神經(jīng)網(wǎng)絡(luò)進(jìn)行了探索,并研究了分段線性表示論在其中的作用。
- 論文標(biāo)題:Equivariant neural networks and piecewise linear representation theory
- 論文地址:https://arxiv.org/pdf/2408.00949
在表示論中,簡單表示(simple representation)是指該理論的不可約簡的原子。在解決問題時,表示論的一個主要策略是將該問題分解成簡單表示,然后分別基于這些基本片段研究該問題。但對等變神經(jīng)網(wǎng)絡(luò)而言,這一策略并不奏效:它們的非線性性質(zhì)允許簡單表示之間發(fā)生互動,而線性世界無法做到這一點(diǎn)。
但是,該團(tuán)隊又論證表明:將等變神經(jīng)網(wǎng)絡(luò)的層分解成簡單表示依然能帶來好處。然后很自然地,他們又進(jìn)一步研究了簡單表示之間的分段線性映射和分段線性表示論。具體來說,這種分解成簡單表示的過程能為神經(jīng)網(wǎng)絡(luò)的層構(gòu)建一個新的基礎(chǔ),這是對傅立葉變換的泛化。
該團(tuán)隊表示:「我們希望這種新基礎(chǔ)能為理解和解讀等變神經(jīng)網(wǎng)絡(luò)提供一個有用的工具?!?/span>
該論文證明了什么?
在介紹該論文的主要結(jié)果之前,我們先來看一個簡單卻非平凡的示例。
以一個小型的簡單神經(jīng)網(wǎng)絡(luò)為例:
其中每個節(jié)點(diǎn)都是 ? 的一個副本,每個箭頭都標(biāo)記了一個權(quán)重 w,并且層之間的每個線性映射的結(jié)果都由一個非線性激活函數(shù) ?? 組成,然后再進(jìn)入下一層。
為了構(gòu)建等變神經(jīng)網(wǎng)絡(luò),可將 ? 和 w 替換成具有更多對稱性的更復(fù)雜對象。比如可以這樣替換:
其可被描述為:
不過,要想在計算機(jī)上真正實(shí)現(xiàn)這個結(jié)構(gòu),卻根本不可能,但這里先忽略這一點(diǎn)。
現(xiàn)在暫時假設(shè)函數(shù)是周期性的,周期為 2π。當(dāng)用傅里葉級數(shù)展開神經(jīng)網(wǎng)絡(luò)時,我們很自然就會問發(fā)生了什么。在傅里葉理論中,卷積算子會在傅里葉基中變成對角。因此,為了理解信號流過上述神經(jīng)網(wǎng)絡(luò)的方式,還需要理解激活函數(shù)在基頻上的工作方式。
一個基本卻關(guān)鍵的觀察是:??(sin (x)) 的傅里葉級數(shù)僅涉及較高共振頻率的項:
(這里展示了當(dāng) ?? 是 ReLU 時,??(sin (x)) 的前幾個傅里葉級數(shù)項。)這與我們撥動吉他琴弦時發(fā)生的情況非常相似:一個音符具有與所彈奏音符相對應(yīng)的基頻,以及更高的頻率(泛音,類似于上面底部的三張圖片),它們結(jié)合在一起形成了吉他獨(dú)特的音色。
該團(tuán)隊的研究表明:一般情況下,在等變神經(jīng)網(wǎng)絡(luò)中,信息會從更低共振頻率流向更高共振頻率,但反之則不然:
這對等變神經(jīng)網(wǎng)絡(luò)有兩個具體影響:
- 等變神經(jīng)網(wǎng)絡(luò)的大部分復(fù)雜性都出現(xiàn)在高頻區(qū),
- 如果想學(xué)習(xí)一個低頻函數(shù),那么可以忽略神經(jīng)網(wǎng)絡(luò)中與高頻相對應(yīng)的大部分。
舉個例子,如果使用典型的流式示意圖(稱為交互圖 /interaction graph)表示,一個基于(8 階循環(huán)群)構(gòu)建的等變神經(jīng)網(wǎng)絡(luò)是這樣的:
其中的節(jié)點(diǎn)是 C_8 的簡單表示,節(jié)點(diǎn)中的值表示生成器的動作。在此圖中,「低頻」簡單表示位于頂部,信息從低頻流向高頻。這意味著在大型網(wǎng)絡(luò)中,高頻將占據(jù)主導(dǎo)地位。
主要貢獻(xiàn)
該團(tuán)隊做出了一些重要的理論貢獻(xiàn),主要包括:
- 他們指出將等變神經(jīng)網(wǎng)絡(luò)分解成簡單表示是有意義且有用的。
- 他們論證表明等變神經(jīng)網(wǎng)絡(luò)必須通過置換表示構(gòu)建。
- 他們證明分段線性(但并非線性)的等變映射的存在受控于類似于伽羅瓦理論的正規(guī)子群。
- 他們計算了一些示例,展示了理論的豐富性,即使在循環(huán)群等「簡單」示例中也是如此。
等變神經(jīng)網(wǎng)絡(luò)和分段線性表示
該團(tuán)隊在論文中首先簡要介紹了表示論和神經(jīng)網(wǎng)絡(luò)的基礎(chǔ)知識,這里受限于篇幅,我們略過不表,詳見原論文。我們僅重點(diǎn)介紹有關(guān)等變神經(jīng)網(wǎng)絡(luò)和分段線性表示的研究成果。
等變神經(jīng)網(wǎng)絡(luò):一個示例
這篇論文的出發(fā)點(diǎn)是:學(xué)習(xí)關(guān)于某種對稱性的等變映射是有用的。舉些例子:
- 圖像識別結(jié)果通常不會隨平移變化,比如識別圖像中的「冰淇淋」時與冰淇淋所在的位置無關(guān);
- 文本轉(zhuǎn)語音時,「冰淇淋」這個詞不管在文本中的什么位置,都應(yīng)該生成一樣的音頻;
- 工程學(xué)和應(yīng)用數(shù)學(xué)領(lǐng)域的許多問題都需要分析點(diǎn)云。這里,人們感興趣的通常是對點(diǎn)云集合的質(zhì)量評估,而與順序無關(guān)。換句話說,這樣的問題不會隨點(diǎn)的排列順序變化而變化。因此,這里的學(xué)習(xí)問題在對稱群下是不變的。
為了解釋構(gòu)建等變神經(jīng)網(wǎng)絡(luò)的方式,該團(tuán)隊使用了一個基于卷積神經(jīng)網(wǎng)絡(luò)的簡單示例,其要處理一張帶周期性的圖像。
這里,這張周期性圖像可表示成一個 n × n 的網(wǎng)格,其中每個點(diǎn)都是一個實(shí)數(shù)。如果設(shè)定 n=10,再將這些實(shí)數(shù)表示成灰度值,則可得到如下所示的圖像:
我們可以在這張圖上下左右進(jìn)行重復(fù),使之具有周期性,也就相當(dāng)于這張圖在一個環(huán)面上。令 C_n = ?/n? 為 n 階循環(huán)群,C^2_n = C_n × C_n。用數(shù)學(xué)術(shù)語來說,一張周期性圖像是從群 C^2_n 到 ? 的映射的 ? 向量空間的一個元素:。在這個周期性圖像的模型中,V 是一個「C^2_n 表示」。事實(shí)上,給定 (a, b) ∈ C^2_n 和 ?? ∈ V,可通過移動坐標(biāo)得到一張新的周期性圖像:
- ((a, b)?f)(x, y) = f (x + a, y + b)
也就是說,平移周期性圖像會得到新的周期性圖像,例如:
得到等變神經(jīng)網(wǎng)絡(luò)的一個關(guān)鍵觀察是:從 V 到 V 的所有線性映射的 ? 向量空間的維度為 n^4,而所有 C^2_n 表示線性映射的 ? 向量空間的維度為 n^2。
下面來看一個 C^2_n 等變映射。對于,可通過一個卷積型公式得到 C^2_n 等變映射 V → V:
舉個例子,如果令 c = 1/4 ((1, 0) + (0, 1) + (?1, 0) + (0, ?1))。則 c??? 是周期性圖像且其像素 (a, b) 處的值是其相鄰像素 (a+1, b)、(a, b+1)、(a?1, b) 和 (a, b?1) 的值的平均值。用圖像表示即為:
更一般地,不同 c 的卷積可對應(yīng)圖像處理中廣泛使用的各種映射。
現(xiàn)在,就可以定義這種情況下的 C^2_n 等變神經(jīng)網(wǎng)絡(luò)了。其結(jié)構(gòu)如下:
其中每個箭頭都是一個卷積。此外,W 通常是 ? 或 V。上圖是一張卷積神經(jīng)網(wǎng)絡(luò)的(經(jīng)過簡化的)圖像,而該網(wǎng)絡(luò)在機(jī)器學(xué)習(xí)領(lǐng)域具有重要地位。對于該網(wǎng)絡(luò)的構(gòu)建方式,值得注意的主要概念是:
- 此神經(jīng)網(wǎng)絡(luò)的結(jié)構(gòu)會迫使得到的映射 V → W 為等變映射。
- 所有權(quán)重的空間比傳統(tǒng)的(全連接)神經(jīng)網(wǎng)絡(luò)小得多。在實(shí)踐中,這意味著等變神經(jīng)網(wǎng)絡(luò)所能處理的樣本比「原始」神經(jīng)網(wǎng)絡(luò)所能處理的大得多。(這一現(xiàn)象也被機(jī)器學(xué)習(xí)研究者稱為權(quán)重共享。)
該團(tuán)隊還指出上圖隱式地包含了激活圖,而他們最喜歡的選擇是 ReLU。這意味著神經(jīng)網(wǎng)絡(luò)的組成成分實(shí)際上是分段線性映射。因此,為了將上述的第二個主要觀察(通過將問題分解成簡單表示來簡化問題)用于等變神經(jīng)網(wǎng)絡(luò),很自然就需要研究分段線性表示論。
等變神經(jīng)網(wǎng)絡(luò)
下面將給出等變神經(jīng)網(wǎng)絡(luò)的定義。該定義基于前述示例。
令 G 為一個有限群。Fun (X, ?) 是有限群 G 的置換表示(permutation representation)。
定義:等變神經(jīng)網(wǎng)絡(luò)是一種神經(jīng)網(wǎng)絡(luò),其每一層都是置換表示的直接和,且所有線性映射都是 G 等變映射。如圖所示:
(這里,綠色、藍(lán)色和紅色點(diǎn)分別表示輸入、隱藏層和輸出層,perm 表示一個置換表示,它們并不一定相等。和普通的原始神經(jīng)網(wǎng)絡(luò)一樣,這里也假設(shè)始終會有一個固定的激活函數(shù),其會在每個隱藏層中被逐個應(yīng)用到分量上。)
最后舉個例子,這是一個基于點(diǎn)云的等變神經(jīng)網(wǎng)絡(luò),而點(diǎn)云是指 ?^d 中 n 個不可區(qū)分的點(diǎn)構(gòu)成的集合。這里 n 和 d 為自然數(shù)。在這種情況下,有限群 G 便為 S_n,即在 n 個字母上的對稱群,并且其輸入層由 (?^d)^n = (?^n)^d 給定,而我們可以將其看作是 d 個置換模塊 Fun ({1, ..., n}, ?) 的副本。如果將 Fun ({1, ..., n}, ?) 寫成 n,則可將典型的等變神經(jīng)網(wǎng)絡(luò)表示成:
(這里 d=3 且有 2 層隱藏層。)這里的線性映射應(yīng)當(dāng)是 S_n 等變映射,而我們可以基于下述引理很快確定出可能的映射。
引理:對于有限 G 集合 X 和 Y,有,其中 Fun_G (X × Y, ?) 表示 G 不變函數(shù) X×Y →?。
根據(jù)該引理,,并且 G = S_n 有兩條由對角及其補(bǔ)集(complement)給出的軌道。因此,存在一個二維的等變映射空間 n→n,并且這與 n 無關(guān)。(在機(jī)器學(xué)習(xí)領(lǐng)域,這種形式的 S_n 的等變神經(jīng)網(wǎng)絡(luò)也被稱為深度網(wǎng)絡(luò)。)
為了更詳細(xì)地理解等變神經(jīng)網(wǎng)絡(luò)以及相關(guān)的分段線性表示論的定義、證明和分析,請參閱原論文。