訓練的神經網絡不工作?一文帶你跨過這37個坑
近日,Slav Ivanov 在 Medium 上發(fā)表了一篇題為《37 Reasons why your Neural Network is not working》的文章,從四個方面(數據集、數據歸一化/增強、實現(xiàn)、訓練),對自己長久以來的神經網絡調試經驗做了 37 條總結,并穿插了不少出色的個人想法和思考,希望能幫助你跨過神經網絡訓練中的 37 個大坑。機器之心對該文進行了編譯。
神經網絡已經持續(xù)訓練了 12 個小時。它看起來很好:梯度在變化,損失也在下降。但是預測結果出來了:全部都是零值,全部都是背景,什么也檢測不到。我質問我的計算機:「我做錯了什么?」,它卻無法回答。
如果你的模型正在輸出垃圾(比如預測所有輸出的平均值,或者它的精確度真的很低),那么你從哪里開始檢查呢?
無法訓練神經網絡的原因有很多,因此通過總結諸多調試,作者發(fā)現(xiàn)有一些檢查是經常做的。這張列表匯總了作者的經驗以及***的想法,希望也對讀者有所幫助。
一、數據集問題
1. 檢查你的輸入數據
檢查饋送到網絡的輸入數據是否正確。例如,我不止一次混淆了圖像的寬度和高度。有時,我錯誤地令輸入數據全部為零,或者一遍遍地使用同一批數據執(zhí)行梯度下降。因此打印/顯示若干批量的輸入和目標輸出,并確保它們正確。
2. 嘗試隨機輸入
嘗試傳遞隨機數而不是真實數據,看看錯誤的產生方式是否相同。如果是,說明在某些時候你的網絡把數據轉化為了垃圾。試著逐層調試,并查看出錯的地方。
3. 檢查數據加載器
你的數據也許很好,但是讀取輸入數據到網絡的代碼可能有問題,所以我們應該在所有操作之前打印***層的輸入并進行檢查。
4. 確保輸入與輸出相關聯(lián)
檢查少許輸入樣本是否有正確的標簽,同樣也確保 shuffling 輸入樣本同樣對輸出標簽有效。
5. 輸入與輸出之間的關系是否太隨機?
相較于隨機的部分(可以認為股票價格也是這種情況),輸入與輸出之間的非隨機部分也許太小,即輸入與輸出的關聯(lián)度太低。沒有一個統(tǒng)一的方法來檢測它,因為這要看數據的性質。
6. 數據集中是否有太多的噪音?
我曾經遇到過這種情況,當我從一個食品網站抓取一個圖像數據集時,錯誤標簽太多以至于網絡無法學習。手動檢查一些輸入樣本并查看標簽是否大致正確。
7. Shuffle 數據集
如果你的數據集沒有被 shuffle,并且有特定的序列(按標簽排序),這可能給學習帶來不利影響。你可以 shuffle 數據集來避免它,并確保輸入和標簽都被重新排列。
8. 減少類別失衡
一張類別 B 圖像和 1000 張類別 A 圖像?如果是這種情況,那么你也許需要平衡你的損失函數或者嘗試其他解決類別失衡的方法。
9. 你有足夠的訓練實例嗎?
如果你在從頭開始訓練一個網絡(即不是調試),你很可能需要大量數據。對于圖像分類,每個類別你需要 1000 張圖像甚至更多。
10. 確保你采用的批量數據不是單一標簽
這可能發(fā)生在排序數據集中(即前 10000 個樣本屬于同一個分類)??赏ㄟ^ shuffle 數據集輕松修復。
11. 縮減批量大小
巨大的批量大小會降低模型的泛化能力(參閱:https://arxiv.org/abs/1609.04836)
二、數據歸一化/增強
12. 歸一化特征
你的輸入已經歸一化到零均值和單位方差了嗎?
13. 你是否應用了過量的數據增強?
數據增強有正則化效果(regularizing effect)。過量的數據增強,加上其它形式的正則化(權重 L2,中途退出效應等)可能會導致網絡欠擬合(underfit)。
14. 檢查你的預訓練模型的預處理過程
如果你正在使用一個已經預訓練過的模型,確保你現(xiàn)在正在使用的歸一化和預處理與之前訓練模型時的情況相同。例如,一個圖像像素應該在 [0, 1],[-1, 1] 或 [0, 255] 的范圍內嗎?
15. 檢查訓練、驗證、測試集的預處理
CS231n 指出了一個常見的陷阱:「任何預處理數據(例如數據均值)必須只在訓練數據上進行計算,然后再應用到驗證、測試數據中。例如計算均值,然后在整個數據集的每個圖像中都減去它,再把數據分發(fā)進訓練、驗證、測試集中,這是一個典型的錯誤?!勾送?,要在每一個樣本或批量(batch)中檢查不同的預處理。
三、實現(xiàn)的問題
16. 試著解決某一問題的更簡易的版本。
這將會有助于找到問題的根源究竟在哪里。例如,如果目標輸出是一個物體類別和坐標,那就試著把預測結果僅限制在物體類別當中(嘗試去掉坐標)。
17.「碰巧」尋找正確的損失
還是來源于 CS231n 的技巧:用小參數進行初始化,不使用正則化。例如,如果我們有 10 個類別,「碰巧」就意味著我們將會在 10% 的時間里得到正確類別,Softmax 損失是正確類別的負 log 概率: -ln(0.1) = 2.302。然后,試著增加正則化的強度,這樣應該會增加損失。
18. 檢查你的損失函數
如果你執(zhí)行的是你自己的損失函數,那么就要檢查錯誤,并且添加單元測試。通常情況下,損失可能會有些不正確,并且損害網絡的性能表現(xiàn)。
19. 核實損失輸入
如果你正在使用的是框架提供的損失函數,那么要確保你傳遞給它的東西是它所期望的。例如,在 PyTorch 中,我會混淆 NLLLoss 和 CrossEntropyLoss,因為一個需要 softmax 輸入,而另一個不需要。
20. 調整損失權重
如果你的損失由幾個更小的損失函數組成,那么確保它們每一個的相應幅值都是正確的。這可能會涉及到測試損失權重的不同組合。
21. 監(jiān)控其它指標
有時損失并不是衡量你的網絡是否被正確訓練的***預測器。如果可以的話,使用其它指標來幫助你,比如精度。
22. 測試任意的自定義層
你自己在網絡中實現(xiàn)過任意層嗎?檢查并且復核以確保它們的運行符合預期。
23. 檢查「冷凍」層或變量
檢查你是否無意中阻止了一些層或變量的梯度更新,這些層或變量本來應該是可學的。
24. 擴大網絡規(guī)模
可能你的網絡的表現(xiàn)力不足以采集目標函數。試著加入更多的層,或在全連層中增加更多的隱藏單元。
25. 檢查隱維度誤差
如果你的輸入看上去像(k,H,W)= (64, 64, 64),那么很容易錯過與錯誤維度相關的誤差。給輸入維度使用一些「奇怪」的數值(例如,每一個維度使用不同的質數),并且檢查它們是如何通過網絡傳播的。
26. 探索梯度檢查(Gradient checking)
如果你手動實現(xiàn)梯度下降,梯度檢查會確保你的反向傳播(backpropagation)能像預期中一樣工作。
四、訓練問題
27. 一個真正小的數據集
過擬合數據的一個小子集,并確保其工作。例如,僅使用 1 或 2 個實例訓練,并查看你的網絡是否學習了區(qū)分它們。然后再訓練每個分類的更多實例。
28. 檢查權重初始化
如果不確定,請使用 Xavier 或 He 初始化。同樣,初始化也許會給你帶來壞的局部最小值,因此嘗試不同的初始化,看看是否有效。
29. 改變你的超參數
或許你正在使用一個很糟糕的超參數集。如果可行,嘗試一下網格搜索。
30. 減少正則化
太多的正則化可致使網絡嚴重地欠擬合。減少正則化,比如 dropout、批規(guī)范、權重/偏差 L2 正則化等。在優(yōu)秀課程《編程人員的深度學習實戰(zhàn)》(http://course.fast.ai)中,Jeremy Howard 建議首先解決欠擬合。這意味著你充分地過擬合數據,并且只有在那時處理過擬合。
31. 給它一些時間
也許你的網絡需要更多的時間來訓練,在它能做出有意義的預測之前。如果你的損失在穩(wěn)步下降,那就再多訓練一會兒。
32. 從訓練模式轉換為測試模式
一些框架的層很像批規(guī)范、Dropout,而其他的層在訓練和測試時表現(xiàn)并不同。轉換到適當的模式有助于網絡更好地預測。
33. 可視化訓練
- 監(jiān)督每一層的激活值、權重和更新。確保它們的大小匹配。例如,參數更新的大小(權重和偏差)應該是 1-e3。
- 考慮可視化庫,比如 Tensorboard 和 Crayon。緊要時你也可以打印權重/偏差/激活值。
- 尋找平均值遠大于 0 的層激活。嘗試批規(guī)范或者 ELUs。
Deeplearning4j 指出了權重和偏差柱狀圖中的期望值:對于權重,一些時間之后這些柱狀圖應該有一個近似高斯的(正常)分布。對于偏差,這些柱狀圖通常會從 0 開始,并經常以近似高斯(這種情況的一個例外是 LSTM)結束。留意那些向 +/- ***發(fā)散的參數。留意那些變得很大的偏差。這有時可能發(fā)生在分類的輸出層,如果類別的分布不均勻。
- 檢查層更新,它們應該有一個高斯分布。
34. 嘗試不同的優(yōu)化器
優(yōu)化器的選擇不應當妨礙網絡的訓練,除非你選擇了一個特別糟糕的參數。但是,為任務選擇一個合適的優(yōu)化器非常有助于在最短的時間內獲得最多的訓練。描述你正在使用的算法的論文應當指定優(yōu)化器;如果沒有,我傾向于選擇 Adam 或者帶有動量的樸素 SGD。
35. 梯度爆炸、梯度消失
- 檢查隱蔽層的***情況,過大的值可能代表梯度爆炸。這時,梯度截斷(Gradient clipping)可能會有所幫助。
- 檢查隱蔽層的激活值。Deeplearning4j 中有一個很好的指導方針:「一個好的激活值標準差大約在 0.5 到 2.0 之間。明顯超過這一范圍可能就代表著激活值消失或爆炸?!?/li>
36. 增加、減少學習速率
- 低學習速率將會導致你的模型收斂很慢;
- 高學習速率將會在開始階段減少你的損失,但是可能會導致你很難找到一個好的解決方案。
- 試著把你當前的學習速率乘以 0.1 或 10。
37. 克服 NaNs
據我所知,在訓練 RNNs 時得到 NaN(Non-a-Number)是一個很大的問題。一些解決它的方法:
- 減小學習速率,尤其是如果你在前 100 次迭代中就得到了 NaNs。
- NaNs 的出現(xiàn)可能是由于用零作了除數,或用零或負數作了自然對數。
- Russell Stewart 對如何處理 NaNs 很有心得(http://russellsstewart.com/notes/0.html)。
- 嘗試逐層評估你的網絡,這樣就會看見 NaNs 到底出現(xiàn)在了哪里。
原文:https://medium.com/@slavivanov/4020854bd607
【本文是51CTO專欄機構“機器之心”的原創(chuàng)譯文,微信公眾號“機器之心( id: almosthuman2014)”】