超贊??!3D可視化工具透視神經(jīng)網(wǎng)絡(luò)內(nèi)部
哈嘍,大家好。
你有沒(méi)有想過(guò),我們編寫(xiě)的神經(jīng)網(wǎng)絡(luò),內(nèi)部究竟是什么樣子的?
加拿大蒙特利爾一家公司開(kāi)發(fā)一個(gè)3D可視化工具 —— Zetane Engine,幫助我們解決了這個(gè)問(wèn)題。
只要在Zetane Engine打開(kāi)一個(gè)深度學(xué)習(xí)模型,便可以看到網(wǎng)絡(luò)中任何一層,并顯示特征圖。
為了演示Zetane Engine?的用戶,我搭建了AlexNet?網(wǎng)絡(luò),在Fashion-MNIST數(shù)據(jù)集上訓(xùn)練了一個(gè) 10 個(gè)類(lèi)別的分類(lèi)器。
網(wǎng)絡(luò)架構(gòu)如下:
tf.keras.layers.Conv2D(filters=96, kernel_size=11, strides=4, input_shape=(224,224,3), activation='relu'),
tf.keras.layers.MaxPool2D(pool_size=3, strides=2),
tf.keras.layers.Conv2D(filters=256, kernel_size=5, padding='same', activation='relu'),
tf.keras.layers.MaxPool2D(pool_size=3, strides=2),
tf.keras.layers.Conv2D(filters=384, kernel_size=3, padding='same',activation='relu'),
tf.keras.layers.Conv2D(filters=384, kernel_size=3, padding='same',activation='relu'),
tf.keras.layers.Conv2D(filters=256, kernel_size=3, padding='same',activation='relu'),
tf.keras.layers.MaxPool2D(pool_size=3, strides=2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(4096, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(4096, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(10)
網(wǎng)絡(luò)使用Keras?搭建,AlexNet模型非常簡(jiǎn)單,包含5個(gè)卷積層和3個(gè)全連接層。
訓(xùn)練樣本如下:
樣本對(duì)應(yīng)的 10 類(lèi)別如下:
訓(xùn)練 50 個(gè) epoch,模型的準(zhǔn)確度是 92%,將模型保存為alexnet.h5?,用Zetane Engine打開(kāi)
可以看到AlexNet的網(wǎng)絡(luò)結(jié)構(gòu)。
我們輸入一張褲子圖片,查看第一個(gè)卷積層輸出的特征圖
第一個(gè)卷積層
圖片經(jīng)過(guò)第一個(gè)卷積層后,輸出的特征圖肉眼能明顯辨識(shí)出是褲子。
我們?cè)倏聪陆?jīng)過(guò)更深層的卷積網(wǎng)絡(luò)之后,會(huì)是什么樣子
深度卷積
中間第2、3層明顯可以看出是在提取邊緣特征,不用類(lèi)別的物體的邊緣特征是不同的,并且邊緣特征相比原圖表達(dá)能力更強(qiáng),相當(dāng)于是原圖更抽象一級(jí)的特征,不過(guò)這里還是可以看出來(lái)是褲子。
但到了第4、5層,特征更抽象了,肉眼已經(jīng)看不出是褲子了,當(dāng)然也說(shuō)明模型學(xué)習(xí)能力更強(qiáng)了。
簡(jiǎn)單總結(jié)下,神經(jīng)網(wǎng)絡(luò)從淺層到深層,學(xué)習(xí)的特征越來(lái)越抽象,學(xué)習(xí)能力也越來(lái)越強(qiáng)。
AlexNet網(wǎng)絡(luò)除了有卷積層,還有池化層,我們也可以看下特征經(jīng)過(guò)池化層的效果
顏色越明亮,代表權(quán)重越高。從上圖可以看到最大池化層能強(qiáng)化重要特征,發(fā)揮去噪、降維的作用。
另外,你可能會(huì)主要到網(wǎng)絡(luò)上每個(gè)節(jié)點(diǎn)的前后都有一些白色圓點(diǎn)組成的方塊。
左邊代表該節(jié)點(diǎn)輸入特征和權(quán)重,右邊代表輸出的特征。點(diǎn)擊它們可以看到不同視角的特征圖
三維視角
二維視角
標(biāo)注卷積結(jié)果的平面圖
卷積結(jié)果的平面圖
尤其對(duì)于網(wǎng)絡(luò)的最后一個(gè)節(jié)點(diǎn),它的輸出是預(yù)測(cè)結(jié)果
它輸出了長(zhǎng)度為 10 的特征向量,即:預(yù)測(cè)圖片屬于哪個(gè)類(lèi)別的權(quán)重??梢钥吹綑?quán)重最大的是類(lèi)別1?,類(lèi)別1?對(duì)應(yīng)的是褲子,所以模型的預(yù)測(cè)結(jié)果是正確的。