如何判斷LSTM模型中的過(guò)擬合與欠擬合
在本教程中,你將發(fā)現(xiàn)如何診斷 LSTM 模型在序列預(yù)測(cè)問(wèn)題上的擬合度。完成教程之后,你將了解:
- 如何收集 LSTM 模型的訓(xùn)練歷史并為其畫(huà)圖。
- 如何判別一個(gè)欠擬合、較好擬合和過(guò)擬合的模型。
- 如何通過(guò)平均多次模型運(yùn)行來(lái)開(kāi)發(fā)更魯棒的診斷方法。
讓我們開(kāi)始吧。
1. Keras 中的訓(xùn)練歷史
你可以通過(guò)回顧模型的性能隨時(shí)間的變化來(lái)更多地了解模型行為。
LSTM 模型通過(guò)調(diào)用 fit() 函數(shù)進(jìn)行訓(xùn)練。這個(gè)函數(shù)會(huì)返回一個(gè)叫作 history 的變量,該變量包含損失函數(shù)的軌跡,以及在模型編譯過(guò)程中被標(biāo)記出來(lái)的任何一個(gè)度量指標(biāo)。這些得分會(huì)在每一個(gè) epoch 的***被記錄下來(lái)。
- ...
- history = model.fit ( ... )
例如,如果你的模型被編譯用來(lái)優(yōu)化 log loss(binary_crossentropy),并且要在每一個(gè) epoch 中衡量準(zhǔn)確率,那么,log loss 和準(zhǔn)確率將會(huì)在每一個(gè)訓(xùn)練 epoch 的歷史記錄中被計(jì)算出,并記錄下來(lái)。
每一個(gè)得分都可以通過(guò)由調(diào)用 fit() 得到的歷史記錄中的一個(gè) key 進(jìn)行訪(fǎng)問(wèn)。默認(rèn)情況下,擬合模型時(shí)優(yōu)化過(guò)的損失函數(shù)為「loss」,準(zhǔn)確率為「acc」。
- model.com pile ( loss='binary_crossentropy', optimizer='adam', metrics= [ 'accuracy' ] )
- history = model.fit ( X, Y, epochs=100 )
- print ( history.history [ 'loss' ] )
- print ( history.history [ 'acc' ] )
Keras 還允許在擬合模型時(shí)指定獨(dú)立的驗(yàn)證數(shù)據(jù)集,該數(shù)據(jù)集也可以使用同樣的損失函數(shù)和度量指標(biāo)進(jìn)行評(píng)估。
該功能可以通過(guò)在 fit() 中設(shè)置 validation_split 參數(shù)來(lái)啟用,以將訓(xùn)練數(shù)據(jù)分割出一部分作為驗(yàn)證數(shù)據(jù)集。
- history = model.fit ( X, Y, epochs=100, validation_split=0.33 )
該功能也可以通過(guò)設(shè)置 validation_data 參數(shù),并向其傳遞 X 和 Y 數(shù)據(jù)集元組來(lái)執(zhí)行。
- history = model.fit ( X, Y, epochs=100, validation_data= ( valX, valY ) )
在驗(yàn)證數(shù)據(jù)集上計(jì)算得到的度量指標(biāo)會(huì)使用相同的命名,只是會(huì)附加一個(gè)「val_」前綴。
- ...
- model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])
- history =model.fit(X,Y,epochs=100,validation_split=0.33)
- print(history.history['loss'])
- print(history.history['acc'])
- print(history.history['val_loss'])
- print(history.history['val_acc'])
2. 診斷圖
LSTM 模型的訓(xùn)練歷史可用于診斷模型行為。你可以使用 Matplotlib 庫(kù)來(lái)進(jìn)行性能的可視化,你可以將訓(xùn)練損失和測(cè)試損失都畫(huà)出來(lái)以作比較,如下所示:
- frommatplotlib importpyplot
- ...
- history =model.fit(X,Y,epochs=100,validation_data=(valX,valY))
- pyplot.plot(history.history['loss'])
- pyplot.plot(history.history['val_loss'])
- pyplot.title('model train vs validation loss')
- pyplot.ylabel('loss')
- pyplot.xlabel('epoch')
- pyplot.legend(['train','validation'],loc='upper right')
- pyplot.show()
創(chuàng)建并檢查這些圖有助于啟發(fā)你找到新的有可能優(yōu)化模型性能的配置。
接下來(lái),我們來(lái)看一些例子。我們將從損失最小化的角度考慮在訓(xùn)練集和驗(yàn)證集上的建模技巧。
3. 欠擬合實(shí)例
欠擬合模型就是在訓(xùn)練集上表現(xiàn)良好而在測(cè)試集上性能較差的模型。
這個(gè)可以通過(guò)以下情況來(lái)診斷:訓(xùn)練的損失曲線(xiàn)低于驗(yàn)證的損失曲線(xiàn),并且驗(yàn)證集中的損失函數(shù)表現(xiàn)出了有可能被優(yōu)化的趨勢(shì)。
下面是一個(gè)人為設(shè)計(jì)的小的欠擬合 LSTM 模型。
- fromkeras.models importSequential
- fromkeras.layers importDense
- fromkeras.layers importLSTM
- frommatplotlib importpyplot
- fromnumpy importarray
- # return training data
- defget_train():
- seq =[[0.0,0.1],[0.1,0.2],[0.2,0.3],[0.3,0.4],[0.4,0.5]]
- seq =array(seq)
- X,y =seq[:,0],seq[:,1]
- XX =X.reshape((len(X),1,1))
- returnX,y
- # return validation data
- defget_val():
- seq =[[0.5,0.6],[0.6,0.7],[0.7,0.8],[0.8,0.9],[0.9,1.0]]
- seq =array(seq)
- X,y =seq[:,0],seq[:,1]
- XX =X.reshape((len(X),1,1))
- returnX,y
- # define model
- model.add(LSTM(10,input_shape=(1,1)))
- model.add(Dense(1,activation='linear'))
- # compile model
- model.compile(loss='mse',optimizer='adam')
- # fit model
- X,y =get_train()
- valX,valY =get_val()
- history =model.fit(X,y,epochs=100,validation_data=(valX,valY),shuffle=False)
- # plot train and validation loss
- pyplot.plot(history.history['loss'])
- pyplot.plot(history.history['val_loss'])
- pyplot.title('model train vs validation loss')
- pyplot.ylabel('loss')
- pyplot.xlabel('epoch')
- pyplot.legend(['train','validation'],loc='upper right')
- pyplot.show()
運(yùn)行這個(gè)實(shí)例會(huì)產(chǎn)生一個(gè)訓(xùn)練損失和驗(yàn)證損失圖,該圖顯示欠擬合模型特點(diǎn)。在這個(gè)案例中,模型性能可能隨著訓(xùn)練 epoch 的增加而有所改善。
欠擬合模型的診斷圖
另外,如果模型在訓(xùn)練集上的性能比驗(yàn)證集上的性能好,并且模型性能曲線(xiàn)已經(jīng)平穩(wěn)了,那么這個(gè)模型也可能欠擬合。下面就是一個(gè)缺乏足夠的記憶單元的欠擬合模型的例子。
- fromkeras.models importSequential
- fromkeras.layers importDense
- fromkeras.layers importLSTM
- frommatplotlib importpyplot
- fromnumpy importarray
- # return training data
- defget_train():
- seq =[[0.0,0.1],[0.1,0.2],[0.2,0.3],[0.3,0.4],[0.4,0.5]]
- seq =array(seq)
- X,y =seq[:,0],seq[:,1]
- XX =X.reshape((5,1,1))
- returnX,y
- # return validation data
- defget_val():
- seq =[[0.5,0.6],[0.6,0.7],[0.7,0.8],[0.8,0.9],[0.9,1.0]]
- seq =array(seq)
- X,y =seq[:,0],seq[:,1]
- XX =X.reshape((len(X),1,1))
- returnX,y
- # define model
- model.add(LSTM(1,input_shape=(1,1)))
- model.add(Dense(1,activation='linear'))
- # compile model
- model.compile(loss='mae',optimizer='sgd')
- # fit model
- X,y =get_train()
- valX,valY =get_val()
- history =model.fit(X,y,epochs=300,validation_data=(valX,valY),shuffle=False)
- # plot train and validation loss
- pyplot.plot(history.history['loss'])
- pyplot.plot(history.history['val_loss'])
- pyplot.title('model train vs validation loss')
- pyplot.ylabel('loss')
- pyplot.xlabel('epoch')
- pyplot.legend(['train','validation'],loc='upper right')
- pyplot.show()
運(yùn)行這個(gè)實(shí)例會(huì)展示出一個(gè)存儲(chǔ)不足的欠擬合模型的特點(diǎn)。
在這個(gè)案例中,模型的性能也許會(huì)隨著模型的容量增加而得到改善,例如隱藏層中記憶單元的數(shù)目或者隱藏層的數(shù)目增加。
欠擬合模型的狀態(tài)診斷線(xiàn)圖
4. 良好擬合實(shí)例
良好擬合的模型就是模型的性能在訓(xùn)練集和驗(yàn)證集上都比較好。
這可以通過(guò)訓(xùn)練損失和驗(yàn)證損失都下降并且穩(wěn)定在同一個(gè)點(diǎn)進(jìn)行診斷。
下面的小例子描述的就是一個(gè)良好擬合的 LSTM 模型。
- fromkeras.models importSequential
- fromkeras.layers importDense
- fromkeras.layers importLSTM
- frommatplotlib importpyplot
- fromnumpy importarray
- # return training data
- defget_train():
- seq =[[0.0,0.1],[0.1,0.2],[0.2,0.3],[0.3,0.4],[0.4,0.5]]
- seq =array(seq)
- X,y =seq[:,0],seq[:,1]
- XX =X.reshape((5,1,1))
- returnX,y
- # return validation data
- defget_val():
- seq =[[0.5,0.6],[0.6,0.7],[0.7,0.8],[0.8,0.9],[0.9,1.0]]
- seq =array(seq)
- X,y =seq[:,0],seq[:,1]
- XX =X.reshape((len(X),1,1))
- returnX,y
- # define model
- model.add(LSTM(10,input_shape=(1,1)))
- model.add(Dense(1,activation='linear'))
- # compile model
- model.compile(loss='mse',optimizer='adam')
- # fit model
- X,y =get_train()
- valX,valY =get_val()
- history =model.fit(X,y,epochs=800,validation_data=(valX,valY),shuffle=False)
- # plot train and validation loss
- pyplot.plot(history.history['loss'])
- pyplot.plot(history.history['val_loss'])
- pyplot.title('model train vs validation loss')
- pyplot.ylabel('loss')
- pyplot.xlabel('epoch')
- pyplot.legend(['train','validation'],loc='upper right')
- pyplot.show()
運(yùn)行這個(gè)實(shí)例可以創(chuàng)建一個(gè)線(xiàn)圖,圖中訓(xùn)練損失和驗(yàn)證損失出現(xiàn)重合。
理想情況下,我們都希望模型盡可能是這樣,盡管面對(duì)大量數(shù)據(jù)的挑戰(zhàn),這似乎不太可能。
良好擬合模型的診斷線(xiàn)圖
5. 過(guò)擬合實(shí)例
過(guò)擬合模型即在訓(xùn)練集上性能良好且在某一點(diǎn)后持續(xù)增長(zhǎng),而在驗(yàn)證集上的性能到達(dá)某一點(diǎn)然后開(kāi)始下降的模型。
這可以通過(guò)線(xiàn)圖來(lái)診斷,圖中訓(xùn)練損失持續(xù)下降,驗(yàn)證損失下降到拐點(diǎn)開(kāi)始上升。
下面這個(gè)實(shí)例就是一個(gè)過(guò)擬合 LSTM 模型。
- fromkeras.models importSequential
- fromkeras.layers importDense
- fromkeras.layers importLSTM
- frommatplotlib importpyplot
- fromnumpy importarray
- # return training data
- defget_train():
- seq =[[0.0,0.1],[0.1,0.2],[0.2,0.3],[0.3,0.4],[0.4,0.5]]
- seq =array(seq)
- X,y =seq[:,0],seq[:,1]
- XX =X.reshape((5,1,1))
- returnX,y
- # return validation data
- defget_val():
- seq =[[0.5,0.6],[0.6,0.7],[0.7,0.8],[0.8,0.9],[0.9,1.0]]
- seq =array(seq)
- X,y =seq[:,0],seq[:,1]
- XX =X.reshape((len(X),1,1))
- returnX,y
- # define model
- model.add(LSTM(10,input_shape=(1,1)))
- model.add(Dense(1,activation='linear'))
- # compile model
- model.compile(loss='mse',optimizer='adam')
- # fit model
- X,y =get_train()
- valX,valY =get_val()
- history =model.fit(X,y,epochs=1200,validation_data=(valX,valY),shuffle=False)
- # plot train and validation loss
- pyplot.plot(history.history['loss'][500:])
- pyplot.plot(history.history['val_loss'][500:])
- pyplot.title('model train vs validation loss')
- pyplot.ylabel('loss')
- pyplot.xlabel('epoch')
- pyplot.legend(['train','validation'],loc='upper right')
- pyplot.show()
運(yùn)行這個(gè)實(shí)例會(huì)創(chuàng)建一個(gè)展示過(guò)擬合模型在驗(yàn)證集中出現(xiàn)拐點(diǎn)的曲線(xiàn)圖。
這也許是進(jìn)行太多訓(xùn)練 epoch 的信號(hào)。
在這個(gè)案例中,模型會(huì)在拐點(diǎn)處停止訓(xùn)練。另外,訓(xùn)練樣本的數(shù)目可能會(huì)增加。
過(guò)擬合模型的診斷線(xiàn)圖
6. 多次運(yùn)行實(shí)例
LSTM 是隨機(jī)的,這意味著每次運(yùn)行時(shí)都會(huì)得到一個(gè)不同的診斷圖。
多次重復(fù)診斷運(yùn)行很有用(如 5、10、30)。每次運(yùn)行的訓(xùn)練軌跡和驗(yàn)證軌跡都可以被繪制出來(lái),以更魯棒的方式記錄模型隨著時(shí)間的行為軌跡。
以下實(shí)例多次運(yùn)行同樣的實(shí)驗(yàn),然后繪制每次運(yùn)行的訓(xùn)練損失和驗(yàn)證損失軌跡。
- fromkeras.models importSequential
- fromkeras.layers importDense
- fromkeras.layers importLSTM
- frommatplotlib importpyplot
- fromnumpy importarray
- frompandas importDataFrame
- # return training data
- defget_train():
- seq =[[0.0,0.1],[0.1,0.2],[0.2,0.3],[0.3,0.4],[0.4,0.5]]
- seq =array(seq)
- X,y =seq[:,0],seq[:,1]
- XX =X.reshape((5,1,1))
- returnX,y
- # return validation data
- defget_val():
- seq =[[0.5,0.6],[0.6,0.7],[0.7,0.8],[0.8,0.9],[0.9,1.0]]
- seq =array(seq)
- X,y =seq[:,0],seq[:,1]
- XX =X.reshape((len(X),1,1))
- returnX,y
- # collect data across multiple repeats
- train =DataFrame()
- val =DataFrame()
- fori inrange(5):
- # define model
- model.add(LSTM(10,input_shape=(1,1)))
- model.add(Dense(1,activation='linear'))
- # compile model
- model.compile(loss='mse',optimizer='adam')
- X,y =get_train()
- valX,valY =get_val()
- # fit model
- history =model.fit(X,y,epochs=300,validation_data=(valX,valY),shuffle=False)
- # story history
- train[str(i)]=history.history['loss']
- val[str(i)]=history.history['val_loss']
- # plot train and validation loss across multiple runs
- pyplot.plot(train,color='blue',label='train')
- pyplot.plot(val,color='orange',label='validation')
- pyplot.title('model train vs validation loss')
- pyplot.ylabel('loss')
- pyplot.xlabel('epoch')
- pyplot.show()
從下圖中,我們可以在 5 次運(yùn)行中看到欠擬合模型的通常趨勢(shì),該案例強(qiáng)有力地證明增加訓(xùn)練 epoch 次數(shù)的有效性。
模型多次運(yùn)行的診斷線(xiàn)圖
擴(kuò)展閱讀
如果你想更深入地了解這方面的內(nèi)容,這一部分提供了更豐富的資源。
- Keras 的歷史回調(diào) API(History Callback Keras API,https://keras.io/callbacks/#history)
- 維基百科中關(guān)于機(jī)器學(xué)習(xí)的學(xué)習(xí)曲線(xiàn)(Learning Curve in Machine Learning on Wikipedia,https://en.wikipedia.org/wiki/Learning_curve#In_machine_learning)
- 維基百科上關(guān)于過(guò)擬合的描述(Overfitting on Wikipedia,https://en.wikipedia.org/wiki/Overfitting)
原文:https://machinelearningmastery.com/diagnose-overfitting-underfitting-lstm-models/
【本文是51CTO專(zhuān)欄機(jī)構(gòu)“機(jī)器之心”的原創(chuàng)譯文,微信公眾號(hào)“機(jī)器之心( id: almosthuman2014)”】