如何對(duì)機(jī)器學(xué)習(xí)代碼進(jìn)行單元測(cè)試?
目前,關(guān)于神經(jīng)網(wǎng)絡(luò)代碼,并沒(méi)有一個(gè)特別完善的單元測(cè)試的在線(xiàn)教程。甚至像 OpenAI 這樣的站點(diǎn),也只能靠 盯著每一行看來(lái)思考哪里錯(cuò)了來(lái)尋找 bug。很明顯,大多數(shù)人沒(méi)有那樣的時(shí)間,并且也討厭這么做。所以希望這篇教程能幫助你開(kāi)始穩(wěn)健的測(cè)試系統(tǒng)。
首先來(lái)看一個(gè)簡(jiǎn)單的例子,嘗試找出以下代碼的 bug。
看出來(lái)了嗎?網(wǎng)絡(luò)并沒(méi)有實(shí)際融合(stacking)。寫(xiě)這段代碼時(shí),只是復(fù)制、粘貼了 slim.conv2d(…) 這行,修改了核(kernel)大小,忘記修改實(shí)際的輸入。
這個(gè)實(shí)際上是作者一周前剛剛碰到的狀況,很尷尬,但是也是重要的一個(gè)教訓(xùn)!這些 bug 很難發(fā)現(xiàn),有以下原因。
- 這些代碼不會(huì)崩潰,不會(huì)拋出異常,甚至不會(huì)變慢。
- 這個(gè)網(wǎng)絡(luò)仍然能訓(xùn)練,并且損失(loss)也會(huì)下降。
- 運(yùn)行多個(gè)小時(shí)后,值回歸到很差的結(jié)果,讓人抓耳撓腮不知如何修復(fù)。
只有最終的驗(yàn)證錯(cuò)誤這一條線(xiàn)索情況下,必須回顧整個(gè)網(wǎng)絡(luò)架構(gòu)才能找到問(wèn)題所在。很明顯,你需要需要一個(gè)更好的處理方式。
比起在運(yùn)行了很多天的訓(xùn)練后才發(fā)現(xiàn),我們?nèi)绾翁崆邦A(yù)防呢?這里可以明顯注意到,層(layers)的值并沒(méi)有到達(dá)函數(shù)外的任何張量(tensors)。在有損失和優(yōu)化器情況下,如果這些張量從未被優(yōu)化,它們會(huì)保持默認(rèn)值。
因此,只需要比較值在訓(xùn)練步驟前后有沒(méi)有發(fā)生變化,我們就可以發(fā)現(xiàn)這種情況。
哇。只需要短短 15 行不到的代碼,就能保證至少所有創(chuàng)建的變量都被訓(xùn)練到了。
這個(gè)測(cè)試,簡(jiǎn)單但是卻很有用?,F(xiàn)在問(wèn)題修復(fù)了,讓我們來(lái)嘗試添加批量標(biāo)準(zhǔn)化??茨隳芊裼醚劬闯?bug 來(lái)。
發(fā)現(xiàn)了嗎?這個(gè) bug 很巧妙。在 tensorflow 中,batch_norm 的 is_training 默認(rèn)值是 False,所以在訓(xùn)練過(guò)程中添加這行代碼,會(huì)導(dǎo)致輸入無(wú)法標(biāo)準(zhǔn)化!幸虧,我們剛剛添加的那個(gè)單元測(cè)試會(huì)立即捕捉到這個(gè)問(wèn)題?。? 天前,它剛剛幫助我捕捉到這個(gè)問(wèn)題。)
讓我們看另外一個(gè)例子。這是我從 reddit 帖子中看來(lái)的。我們不會(huì)太深入原帖,簡(jiǎn)單的說(shuō),發(fā)帖的人想要?jiǎng)?chuàng)建一個(gè)分類(lèi)器,輸出的范圍在 0 到 1 之間??纯茨隳芊窨闯瞿睦锊粚?duì)。
發(fā)現(xiàn)問(wèn)題了嗎?這個(gè)問(wèn)題很難發(fā)現(xiàn),結(jié)果非常難以理解。簡(jiǎn)單的說(shuō),因?yàn)轭A(yù)測(cè)只有單個(gè)輸出值,應(yīng)用了 softmax 交叉熵函數(shù)后,損失就會(huì)永遠(yuǎn)是 0 了。
最簡(jiǎn)單的發(fā)現(xiàn)這個(gè)問(wèn)題的測(cè)試方式,就是保證損失永遠(yuǎn)不等于 0。
我們***個(gè)實(shí)現(xiàn)的測(cè)試,也能發(fā)現(xiàn)這種錯(cuò)誤,但是要反向檢查:保證只訓(xùn)練需要訓(xùn)練的變量。就生成式對(duì)抗網(wǎng)絡(luò)(GAN)為例,一個(gè)常見(jiàn)的 bug 就是在優(yōu)化過(guò)程中不小心忘記設(shè)置需要訓(xùn)練哪個(gè)變量。這樣的代碼隨處可見(jiàn)。
這段代碼***的問(wèn)題是,優(yōu)化器默認(rèn)會(huì)優(yōu)化所有的變量。在像生成式對(duì)抗網(wǎng)絡(luò)這樣高級(jí)的結(jié)構(gòu)中,這意味著遙遙無(wú)期的訓(xùn)練時(shí)間。然而只需要一個(gè)簡(jiǎn)單測(cè)試,就可以檢查到這種錯(cuò)誤:
也可以對(duì)判定模型(discriminator)寫(xiě)一個(gè)同類(lèi)型的測(cè)試。同樣的測(cè)試,也可以應(yīng)用來(lái)加強(qiáng)大量其他的學(xué)習(xí)算法。很多演員評(píng)判家(actor-critic)模型,有不同的網(wǎng)絡(luò)需要用不同的損失來(lái)優(yōu)化。
這里列出一些作者推薦的測(cè)試模式。
- 確保輸入的確定性。如果發(fā)現(xiàn)一個(gè)詭異的失敗測(cè)試,但是卻再也無(wú)法重現(xiàn),將會(huì)是很糟糕的事情。在特別需要隨機(jī)輸入的場(chǎng)景下,確保用了同一個(gè)隨機(jī)數(shù)種子。這樣出現(xiàn)了失敗后,可以再次以同樣的輸入重現(xiàn)它。
- 確保測(cè)試很精簡(jiǎn)。不要用同一個(gè)單元測(cè)試檢查回歸訓(xùn)練和檢查一個(gè)驗(yàn)證集合。這樣做只是浪費(fèi)時(shí)間。
- 確保每次測(cè)試時(shí)都重置了圖。
作為總結(jié),這些黑盒算法仍然有大量方法來(lái)測(cè)試!花一個(gè)小時(shí)寫(xiě)一個(gè)簡(jiǎn)單的測(cè)試,可以節(jié)約成天的重新運(yùn)行時(shí)間,并且大大提升你的研究能力。天才的想法,永遠(yuǎn)不要因?yàn)橐粋€(gè)充滿(mǎn) bug 的實(shí)現(xiàn)而無(wú)法成為現(xiàn)實(shí)。
這篇文章列出的測(cè)試遠(yuǎn)遠(yuǎn)沒(méi)有完備,但是是一個(gè)很好的起步!如果你發(fā)現(xiàn)有其他的建議或者某種特定類(lèi)型的測(cè)試,請(qǐng)?jiān)?twitter 上給我消息!我很樂(lè)意寫(xiě)這篇文章的續(xù)集。
文章中所有的觀(guān)點(diǎn),僅代表作者的個(gè)人經(jīng)驗(yàn),并沒(méi)有 Google 的支持、贊助。
查看英文原文
https://medium.com/@keeper6928/how-to-unit-test-machine-learning-code-57cf6fd81765