理解并統(tǒng)一14種歸因算法,讓神經(jīng)網(wǎng)絡具有可解釋性
盡管 DNN 在各種實際應用中取得了廣泛的成功,但其過程通常被視為黑盒子,因為我們很難解釋 DNN 如何做出決定。缺乏可解釋性損害了 DNN 的可靠性,從而阻礙了它們在高風險任務中的廣泛應用,例如自動駕駛和 AI 醫(yī)療。因此,可解釋 DNN 引起了越來越多的關注。
作為解釋 DNN 的典型視角,歸因方法旨在計算每個輸入變量對網(wǎng)絡輸出的歸因 / 重要性 / 貢獻分數(shù)。例如,給定一個用于圖像分類的預訓練 DNN 和一個輸入圖像,每個輸入變量的屬性得分是指每個像素對分類置信度得分的數(shù)值影響。
盡管近年來研究者提出了許多歸因方法,但其中大多數(shù)都建立在不同的啟發(fā)式方法之上。目前還缺乏統(tǒng)一的理論視角來檢驗這些歸因方法的正確性,或者至少在數(shù)學上闡明其核心機制。
研究人員曾試圖統(tǒng)一不同的歸因方法,但這些研究只涵蓋了幾種方法。
本文中,我們提出了「統(tǒng)一解釋 14 種輸入單元重要性歸因算法的內在機理」。
論文地址:https://arxiv.org/pdf/2303.01506.pdf
其實無論是「12 種提升對抗遷移性的算法」,還是「14 種輸入單元重要性歸因算法」,都是工程性算法的重災區(qū)。在這兩大領域內,大部分算法都是經(jīng)驗性的,人們根據(jù)實驗經(jīng)驗或直覺認識,設計出一些似是而非的工程性算法。大部分研究沒有對 “究竟什么是輸入單元重要性” 做出嚴謹定義和理論論證,少數(shù)研究有一定的論證,但往往也很不完善。當然,“缺少嚴謹?shù)亩x和論證” 的問題充滿了整個人工智能領域,只是在這兩個方向上格外突出。
- 第一,在眾多經(jīng)驗性歸因算法充斥可解釋機器學習領域的環(huán)境下,我們希望證明 “所有 14 種歸因算法(解釋神經(jīng)網(wǎng)絡輸入單元重要性的算法)的內在機理,都可以表示為對神經(jīng)網(wǎng)絡所建模的交互效用的一種分配,不同歸因算法對應不同的交互效用分配比例”。這樣,雖然不同算法有著完全不同的設計著眼點(比如有些算法有提綱挈領的目標函數(shù),有些算法則是純粹的 pipeline),但是我們發(fā)現(xiàn)在數(shù)學上,這些算法都可以被我們納入到 “對交互效用的分配” 的敘事邏輯中來。
- 基于上面的交互效用分配框架,我們可以進一步為神經(jīng)網(wǎng)絡輸入單元重要性歸因算法提出三條評估準則,來衡量歸因算法所預測的輸入單元重要性值是否合理。
當然,我們的理論分析不只適用于 14 種歸因算法,理論上可以統(tǒng)一更多的類似研究。因為人力有限,這篇論文里我們僅僅討論 14 種算法。
研究的真正難點在于,不同的經(jīng)驗性歸因算法往往都是搭建在不同的直覺之上的,每篇論文都僅僅努力從各自的角度「自圓其說」,分別基于不同的直覺或角度來設計歸因算法,而缺少一套規(guī)范的數(shù)學語言來統(tǒng)一描述各種算法的本質。
算法回顧
在講數(shù)學以前,本文先從直覺層面簡單回顧之前的算法。
1. 基于梯度的歸因算法。這一類算法普遍認為,神經(jīng)網(wǎng)絡的輸出對每個輸入單元的梯度可以反映輸入單元的重要性。例如,Gradient*Input 算法將輸入單元的重要性建模為梯度與輸入單元值的逐元素乘積。考慮到梯度僅能反映輸入單元的局部重要性,Smooth Gradients 和 Integrated Gradients 算法將重要性建模為平均梯度與輸入單元值的逐元素乘積,其中這兩種方法中的平均梯度分別指輸入樣本鄰域內梯度的平均值或輸入樣本到基準點(baseline point)間線性插值點的梯度平均值。類似地,Grad-CAM 算法采用網(wǎng)絡輸出對每個 channel 中所有特征梯度的平均值,來計算重要性分數(shù)。進一步,Expected Gradients 算法認為,選擇單個基準點往往會導致有偏的歸因結果,從而提出將重要性建模為不同基準點下 Integrated Gradients 歸因結果的期望。
2. 基于逐層反向傳播的歸因算法。深度神經(jīng)網(wǎng)絡往往極為復雜,而每一層神經(jīng)網(wǎng)絡的結構相對簡單(比如深層特征通常是淺層特征的線性加和 + 非線性激活函數(shù)),便于分析淺層特征對深層特征的重要性。因此,這類算法通過估計中層特征的重要性,并將這些重要性逐層傳播直至輸入層,得到輸入單元的重要性。這一類算法包括 LRP-\epsilon, LRP-\alpha\beta, Deep Taylor, DeepLIFT Rescale, DeepLIFT RevealCancel, DeepShap 等。不同反向傳播算法間的根本區(qū)別在于,他們采用了不同的重要性逐層傳播規(guī)則。
3. 基于遮擋的歸因算法。這類算法根據(jù)遮擋某一輸入單元對模型輸出的影響,來推斷該輸入單元的重要性。例如,Occlusion-1(Occlusion-patch)算法將第 i 個像素(像素塊)的重要性建模為其它像素未被遮擋時,像素 i 未遮擋和遮擋兩種情況下的輸出改變量。Shapley value 算法則綜合考慮了其它像素的所有可能遮擋情況,并將重要性建模為不同遮擋情況下像素 i 對應輸出改變量的平均值。研究已證明,Shapley value 是唯一滿足 linearity, dummy, symmetry, efficiency 公理的歸因算法。
統(tǒng)一 14 種經(jīng)驗性歸因算法的內在機理
在深入研究多種經(jīng)驗性歸因算法后,我們不禁思考一個問題:在數(shù)學層面上,神經(jīng)網(wǎng)絡的歸因究竟在解決什么問題?在眾多經(jīng)驗性歸因算法的背后,是否蘊含著某種統(tǒng)一的數(shù)學建模與范式?為此,我們嘗試從歸因的定義出發(fā),著眼考慮上述問題。歸因,是指每一個輸入單元對神經(jīng)網(wǎng)絡輸出的重要性分數(shù) / 貢獻。那么,解決上述問題的關鍵在于,(1)在數(shù)學層面上建?!篙斎雴卧獙W(wǎng)絡輸出的影響機制」,(2)解釋眾多經(jīng)驗性歸因算法是如何利用該影響機制,來設計重要性歸因公式。
針對第一個關鍵點,我們研究發(fā)現(xiàn):每一個輸入單元往往通過兩種方式影響神經(jīng)網(wǎng)絡的輸出。一方面,某一個輸入單元無需依賴其他輸入單元,可獨立作用并影響網(wǎng)絡輸出,這類影響稱為 “獨立效應”。另一方面,一個輸入單元需要通過與其他輸入單元共同協(xié)作,形成某種模式,進而對網(wǎng)絡輸出產生影響,這類影響稱為 “交互效應”。我們理論證明了,神經(jīng)網(wǎng)絡的輸出可以嚴謹解構為不同輸入變量的獨立效應,以及不同集合內輸入變量間的交互效應。
其中,表示第 i 個輸入單元的獨立效應,
表示集合 S 內多個輸入單元間的交互效應。針對第二個關鍵點,我們探究發(fā)現(xiàn),所有 14 種現(xiàn)有經(jīng)驗性歸因算法的內在機理,都可以表示對上述獨立效用和交互效用的一種分配,而不同歸因算法按不同的比例來分配神經(jīng)網(wǎng)絡輸入單元的獨立效用和交互效用。具體地,令?
?表示第 i 個輸入單元的歸因分數(shù)。我們嚴格證明了,所有 14 種經(jīng)驗性歸因算法得到的
,都可以統(tǒng)一表示為下列數(shù)學范式(即獨立效用和交互效用的加權和):?
?
其中,反映了將第 j 個輸入單元的獨立效應分配給第 i 個輸入單元的比例,
表示將集合 S 內多個輸入單元間的交互效應分配給第 i 個輸入單元的比例。眾多歸因算法的 “根本區(qū)別” 在于,不同歸因算法對應著不同的分配比例
。
表 1 展示了十四種不同的歸因算法分別是如何對獨立效應與交互效應進行分配。
圖表 1. 十四種歸因算法均可以寫成獨立效應與交互效應加權和的數(shù)學范式。其中分別表示泰勒獨立效應和泰勒交互效應,滿足
,是對獨立效應
和
交互效的細化。
評價歸因算法可靠性的三大準則
在歸因解釋研究中,由于無從獲得 / 標注神經(jīng)網(wǎng)絡歸因解釋的真實值,人們無法從實證角度評價某一個歸因解釋算法的可靠性。“缺乏對歸因解釋算法可靠性的客觀評價標準” 這一根本缺陷,引發(fā)了學界對歸因解釋研究領域的廣泛批評與質疑。
而本研究中對歸因算法公共機理的揭示,使我們能在同一理論框架下,公平地評價和比較不同歸因算法的可靠性。具體地,我們提出了以下三條評估準則,以評價某一個歸因算法是否公平合理地分配獨立效應和交互效應。
(1)準則一:分配過程中涵蓋所有獨立效應和交互效應。當我們將神經(jīng)網(wǎng)絡輸出解構為獨立效應與交互效應后,可靠的歸因算法在分配過程中應盡可能涵蓋所有的獨立效應和交互效應。例如,對 I’m not happy 句子的歸因中,應涵蓋三個單詞 I’m, not, happy 的所有獨立效應,同時涵蓋 J (I’m, not), J (I’m, happy), J (not, happy), J (I’m, not, happy) 等所有可能的交互效應。
(2)準則二:避免將獨立效應和交互分配給無關的輸入單元。第 i 個輸入單元的獨立效應,只應分配給第 i 個輸入單元,而不應分配給其它輸入單元。類似地,集合 S 內輸入單元間的交互效應,只應分配給集合 S 內的輸入單元,而不應分配給集合 S 以外的輸入單元(未參與交互)。例如,not 和 happy 之間的交互效應,不應分配給單詞 I’m。
(3)準則三:完全分配。每個獨立效應(交互效應)應當完全分配給對應的輸入單元。換句話說,某一個獨立效應(交互效應)分配給所有對應輸入單元的歸因值,加起來應當恰好等于該獨立效應(交互效應)的值。例如,交互效應 J (not, happy) 會分配一部分效應(not, happy) 給單詞 not,同時分配一部分效應
(not, happy) 給單詞 happy。那么,分配比例應滿足
。
接著,我們采用這三條評估準則,評估了上述 14 種不同歸因算法(如表 2 所示)。我們發(fā)現(xiàn),Integrated Gradients, Expected Gradients, Shapley value, Deep Shap, DeepLIFT Rescale, DeepLIFT RevealCancel 這些算法滿足所有的可靠性準則。
表 2. 總結 14 種不同歸因算法是否滿足三條可靠性評估準則。
作者介紹
本文作者鄧輝琦,是中山大學應用數(shù)學專業(yè)的博士,博士期間曾在香港浸會大學和德州農工大學計算機系訪問學習,現(xiàn)于張拳石老師團隊進行博士后研究。研究方向主要為可信 / 可解釋機器學習,包括解釋深度神經(jīng)網(wǎng)絡的歸因重要性、解釋神經(jīng)網(wǎng)絡的表達能力等。
鄧輝琦前期做了很多工作。張老師只是在初期工作結束以后,幫她重新梳理了一遍理論,讓證明方式和體系更順暢一些。鄧輝琦畢業(yè)前論文不是很多,21 年末來張老師這邊以后,在博弈交互的體系下,一年多做了三個工作,包括(1)發(fā)現(xiàn)并理論解釋了神經(jīng)網(wǎng)絡普遍存在的表征瓶頸,即證明神經(jīng)網(wǎng)絡更不善于建模中等復雜度的交互表征。這一工作有幸被選為 ICLR 2022 oral 論文,審稿得分排名前五(得分 8 8 8 10)。(2)理論證明了貝葉斯網(wǎng)絡的概念表征趨勢,為解釋貝葉斯網(wǎng)絡的分類性能、泛化能力和對抗魯棒性提供了新的視角。(3)從理論層面上解釋了神經(jīng)網(wǎng)絡在訓練過程中對不同復雜度交互概念的學習能力。