自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

如何判斷LSTM模型中的過(guò)擬合與欠擬合

開(kāi)發(fā) 開(kāi)發(fā)工具
判斷長(zhǎng)短期記憶模型在序列預(yù)測(cè)問(wèn)題上是否表現(xiàn)良好可能是一件困難的事。也許你會(huì)得到一個(gè)不錯(cuò)的模型技術(shù)得分,但了解模型是較好的擬合,還是欠擬合/過(guò)擬合,以及模型在不同的配置條件下能否實(shí)現(xiàn)更好的性能是非常重要的。

在本教程中,你將發(fā)現(xiàn)如何診斷 LSTM 模型在序列預(yù)測(cè)問(wèn)題上的擬合度。完成教程之后,你將了解:

  • 如何收集 LSTM 模型的訓(xùn)練歷史并為其畫(huà)圖。
  • 如何判別一個(gè)欠擬合、較好擬合和過(guò)擬合的模型。
  • 如何通過(guò)平均多次模型運(yùn)行來(lái)開(kāi)發(fā)更魯棒的診斷方法。

讓我們開(kāi)始吧。

1. Keras 中的訓(xùn)練歷史

你可以通過(guò)回顧模型的性能隨時(shí)間的變化來(lái)更多地了解模型行為。

LSTM 模型通過(guò)調(diào)用 fit() 函數(shù)進(jìn)行訓(xùn)練。這個(gè)函數(shù)會(huì)返回一個(gè)叫作 history 的變量,該變量包含損失函數(shù)的軌跡,以及在模型編譯過(guò)程中被標(biāo)記出來(lái)的任何一個(gè)度量指標(biāo)。這些得分會(huì)在每一個(gè) epoch 的***被記錄下來(lái)。

  1. ...  
  2. history = model.fit ( ... ) 

例如,如果你的模型被編譯用來(lái)優(yōu)化 log loss(binary_crossentropy),并且要在每一個(gè) epoch 中衡量準(zhǔn)確率,那么,log loss 和準(zhǔn)確率將會(huì)在每一個(gè)訓(xùn)練 epoch 的歷史記錄中被計(jì)算出,并記錄下來(lái)。

每一個(gè)得分都可以通過(guò)由調(diào)用 fit() 得到的歷史記錄中的一個(gè) key 進(jìn)行訪(fǎng)問(wèn)。默認(rèn)情況下,擬合模型時(shí)優(yōu)化過(guò)的損失函數(shù)為「loss」,準(zhǔn)確率為「acc」。

  1. model.com pile ( loss='binary_crossentropy'optimizer='adam'metrics= [ 'accuracy' ] ) 
  2.  
  3. history = model.fit ( X, Y, epochs=100 ) 
  4.  
  5. print ( history.history [ 'loss' ] ) 
  6.  
  7. print ( history.history [ 'acc' ] ) 

Keras 還允許在擬合模型時(shí)指定獨(dú)立的驗(yàn)證數(shù)據(jù)集,該數(shù)據(jù)集也可以使用同樣的損失函數(shù)和度量指標(biāo)進(jìn)行評(píng)估。

該功能可以通過(guò)在 fit() 中設(shè)置 validation_split 參數(shù)來(lái)啟用,以將訓(xùn)練數(shù)據(jù)分割出一部分作為驗(yàn)證數(shù)據(jù)集。

  1. history = model.fit ( X, Y, epochs=100validation_split=0.33 ) 

該功能也可以通過(guò)設(shè)置 validation_data 參數(shù),并向其傳遞 X 和 Y 數(shù)據(jù)集元組來(lái)執(zhí)行。

  1. history = model.fit ( X, Y, epochs=100validation_data= ( valX, valY ) ) 

在驗(yàn)證數(shù)據(jù)集上計(jì)算得到的度量指標(biāo)會(huì)使用相同的命名,只是會(huì)附加一個(gè)「val_」前綴。

  1. ... 
  2. model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy']) 
  3. history =model.fit(X,Y,epochs=100,validation_split=0.33) 
  4. print(history.history['loss']) 
  5. print(history.history['acc']) 
  6. print(history.history['val_loss']) 
  7. print(history.history['val_acc']) 

2. 診斷圖

LSTM 模型的訓(xùn)練歷史可用于診斷模型行為。你可以使用 Matplotlib 庫(kù)來(lái)進(jìn)行性能的可視化,你可以將訓(xùn)練損失和測(cè)試損失都畫(huà)出來(lái)以作比較,如下所示:

  1. frommatplotlib importpyplot 
  2. ... 
  3. history =model.fit(X,Y,epochs=100,validation_data=(valX,valY)) 
  4. pyplot.plot(history.history['loss']) 
  5. pyplot.plot(history.history['val_loss']) 
  6. pyplot.title('model train vs validation loss') 
  7. pyplot.ylabel('loss') 
  8. pyplot.xlabel('epoch') 
  9. pyplot.legend(['train','validation'],loc='upper right'
  10. pyplot.show() 

創(chuàng)建并檢查這些圖有助于啟發(fā)你找到新的有可能優(yōu)化模型性能的配置。

接下來(lái),我們來(lái)看一些例子。我們將從損失最小化的角度考慮在訓(xùn)練集和驗(yàn)證集上的建模技巧。

3. 欠擬合實(shí)例

欠擬合模型就是在訓(xùn)練集上表現(xiàn)良好而在測(cè)試集上性能較差的模型。

這個(gè)可以通過(guò)以下情況來(lái)診斷:訓(xùn)練的損失曲線(xiàn)低于驗(yàn)證的損失曲線(xiàn),并且驗(yàn)證集中的損失函數(shù)表現(xiàn)出了有可能被優(yōu)化的趨勢(shì)。

下面是一個(gè)人為設(shè)計(jì)的小的欠擬合 LSTM 模型。

  1. fromkeras.models importSequential 
  2. fromkeras.layers importDense 
  3. fromkeras.layers importLSTM 
  4. frommatplotlib importpyplot 
  5. fromnumpy importarray 
  6. # return training data 
  7. defget_train(): 
  8. seq =[[0.0,0.1],[0.1,0.2],[0.2,0.3],[0.3,0.4],[0.4,0.5]] 
  9. seq =array(seq) 
  10. X,y =seq[:,0],seq[:,1] 
  11. XX =X.reshape((len(X),1,1)) 
  12. returnX,y 
  13. # return validation data 
  14. defget_val(): 
  15. seq =[[0.5,0.6],[0.6,0.7],[0.7,0.8],[0.8,0.9],[0.9,1.0]] 
  16. seq =array(seq) 
  17. X,y =seq[:,0],seq[:,1] 
  18. XX =X.reshape((len(X),1,1)) 
  19. returnX,y 
  20. # define model 
  21. model.add(LSTM(10,input_shape=(1,1))) 
  22. model.add(Dense(1,activation='linear')) 
  23. # compile model 
  24. model.compile(loss='mse',optimizer='adam'
  25. # fit model 
  26. X,y =get_train() 
  27. valX,valY =get_val() 
  28. history =model.fit(X,y,epochs=100,validation_data=(valX,valY),shuffle=False
  29. # plot train and validation loss 
  30. pyplot.plot(history.history['loss']) 
  31. pyplot.plot(history.history['val_loss']) 
  32. pyplot.title('model train vs validation loss') 
  33. pyplot.ylabel('loss') 
  34. pyplot.xlabel('epoch') 
  35. pyplot.legend(['train','validation'],loc='upper right'
  36. pyplot.show() 

運(yùn)行這個(gè)實(shí)例會(huì)產(chǎn)生一個(gè)訓(xùn)練損失和驗(yàn)證損失圖,該圖顯示欠擬合模型特點(diǎn)。在這個(gè)案例中,模型性能可能隨著訓(xùn)練 epoch 的增加而有所改善。

欠擬合模型的診斷圖

欠擬合模型的診斷圖

另外,如果模型在訓(xùn)練集上的性能比驗(yàn)證集上的性能好,并且模型性能曲線(xiàn)已經(jīng)平穩(wěn)了,那么這個(gè)模型也可能欠擬合。下面就是一個(gè)缺乏足夠的記憶單元的欠擬合模型的例子。

  1. fromkeras.models importSequential 
  2. fromkeras.layers importDense 
  3. fromkeras.layers importLSTM 
  4. frommatplotlib importpyplot 
  5. fromnumpy importarray 
  6. # return training data 
  7. defget_train(): 
  8. seq =[[0.0,0.1],[0.1,0.2],[0.2,0.3],[0.3,0.4],[0.4,0.5]] 
  9. seq =array(seq) 
  10. X,y =seq[:,0],seq[:,1] 
  11. XX =X.reshape((5,1,1)) 
  12. returnX,y 
  13. # return validation data 
  14. defget_val(): 
  15. seq =[[0.5,0.6],[0.6,0.7],[0.7,0.8],[0.8,0.9],[0.9,1.0]] 
  16. seq =array(seq) 
  17. X,y =seq[:,0],seq[:,1] 
  18. XX =X.reshape((len(X),1,1)) 
  19. returnX,y 
  20. # define model 
  21. model.add(LSTM(1,input_shape=(1,1))) 
  22. model.add(Dense(1,activation='linear')) 
  23. # compile model 
  24. model.compile(loss='mae',optimizer='sgd'
  25. # fit model 
  26. X,y =get_train() 
  27. valX,valY =get_val() 
  28. history =model.fit(X,y,epochs=300,validation_data=(valX,valY),shuffle=False
  29. # plot train and validation loss 
  30. pyplot.plot(history.history['loss']) 
  31. pyplot.plot(history.history['val_loss']) 
  32. pyplot.title('model train vs validation loss') 
  33. pyplot.ylabel('loss') 
  34. pyplot.xlabel('epoch') 
  35. pyplot.legend(['train','validation'],loc='upper right'
  36. pyplot.show() 

運(yùn)行這個(gè)實(shí)例會(huì)展示出一個(gè)存儲(chǔ)不足的欠擬合模型的特點(diǎn)。

在這個(gè)案例中,模型的性能也許會(huì)隨著模型的容量增加而得到改善,例如隱藏層中記憶單元的數(shù)目或者隱藏層的數(shù)目增加。

欠擬合模型的狀態(tài)診斷線(xiàn)圖

欠擬合模型的狀態(tài)診斷線(xiàn)圖

4. 良好擬合實(shí)例

良好擬合的模型就是模型的性能在訓(xùn)練集和驗(yàn)證集上都比較好。

這可以通過(guò)訓(xùn)練損失和驗(yàn)證損失都下降并且穩(wěn)定在同一個(gè)點(diǎn)進(jìn)行診斷。

下面的小例子描述的就是一個(gè)良好擬合的 LSTM 模型。

  1. fromkeras.models importSequential 
  2. fromkeras.layers importDense 
  3. fromkeras.layers importLSTM 
  4. frommatplotlib importpyplot 
  5. fromnumpy importarray 
  6. # return training data 
  7. defget_train(): 
  8. seq =[[0.0,0.1],[0.1,0.2],[0.2,0.3],[0.3,0.4],[0.4,0.5]] 
  9. seq =array(seq) 
  10. X,y =seq[:,0],seq[:,1] 
  11. XX =X.reshape((5,1,1)) 
  12. returnX,y 
  13. # return validation data 
  14. defget_val(): 
  15. seq =[[0.5,0.6],[0.6,0.7],[0.7,0.8],[0.8,0.9],[0.9,1.0]] 
  16. seq =array(seq) 
  17. X,y =seq[:,0],seq[:,1] 
  18. XX =X.reshape((len(X),1,1)) 
  19. returnX,y 
  20. # define model 
  21. model.add(LSTM(10,input_shape=(1,1))) 
  22. model.add(Dense(1,activation='linear')) 
  23. # compile model 
  24. model.compile(loss='mse',optimizer='adam'
  25. # fit model 
  26. X,y =get_train() 
  27. valX,valY =get_val() 
  28. history =model.fit(X,y,epochs=800,validation_data=(valX,valY),shuffle=False
  29. # plot train and validation loss 
  30. pyplot.plot(history.history['loss']) 
  31. pyplot.plot(history.history['val_loss']) 
  32. pyplot.title('model train vs validation loss') 
  33. pyplot.ylabel('loss') 
  34. pyplot.xlabel('epoch') 
  35. pyplot.legend(['train','validation'],loc='upper right'
  36. pyplot.show() 

運(yùn)行這個(gè)實(shí)例可以創(chuàng)建一個(gè)線(xiàn)圖,圖中訓(xùn)練損失和驗(yàn)證損失出現(xiàn)重合。

理想情況下,我們都希望模型盡可能是這樣,盡管面對(duì)大量數(shù)據(jù)的挑戰(zhàn),這似乎不太可能。

良好擬合模型的診斷線(xiàn)圖

良好擬合模型的診斷線(xiàn)圖

5. 過(guò)擬合實(shí)例

過(guò)擬合模型即在訓(xùn)練集上性能良好且在某一點(diǎn)后持續(xù)增長(zhǎng),而在驗(yàn)證集上的性能到達(dá)某一點(diǎn)然后開(kāi)始下降的模型。

這可以通過(guò)線(xiàn)圖來(lái)診斷,圖中訓(xùn)練損失持續(xù)下降,驗(yàn)證損失下降到拐點(diǎn)開(kāi)始上升。

下面這個(gè)實(shí)例就是一個(gè)過(guò)擬合 LSTM 模型。

  1. fromkeras.models importSequential 
  2. fromkeras.layers importDense 
  3. fromkeras.layers importLSTM 
  4. frommatplotlib importpyplot 
  5. fromnumpy importarray 
  6. # return training data 
  7. defget_train(): 
  8. seq =[[0.0,0.1],[0.1,0.2],[0.2,0.3],[0.3,0.4],[0.4,0.5]] 
  9. seq =array(seq) 
  10. X,y =seq[:,0],seq[:,1] 
  11. XX =X.reshape((5,1,1)) 
  12. returnX,y 
  13. # return validation data 
  14. defget_val(): 
  15. seq =[[0.5,0.6],[0.6,0.7],[0.7,0.8],[0.8,0.9],[0.9,1.0]] 
  16. seq =array(seq) 
  17. X,y =seq[:,0],seq[:,1] 
  18. XX =X.reshape((len(X),1,1)) 
  19. returnX,y 
  20. # define model 
  21. model.add(LSTM(10,input_shape=(1,1))) 
  22. model.add(Dense(1,activation='linear')) 
  23. # compile model 
  24. model.compile(loss='mse',optimizer='adam'
  25. # fit model 
  26. X,y =get_train() 
  27. valX,valY =get_val() 
  28. history =model.fit(X,y,epochs=1200,validation_data=(valX,valY),shuffle=False
  29. # plot train and validation loss 
  30. pyplot.plot(history.history['loss'][500:]) 
  31. pyplot.plot(history.history['val_loss'][500:]) 
  32. pyplot.title('model train vs validation loss') 
  33. pyplot.ylabel('loss') 
  34. pyplot.xlabel('epoch') 
  35. pyplot.legend(['train','validation'],loc='upper right'
  36. pyplot.show() 

運(yùn)行這個(gè)實(shí)例會(huì)創(chuàng)建一個(gè)展示過(guò)擬合模型在驗(yàn)證集中出現(xiàn)拐點(diǎn)的曲線(xiàn)圖。

這也許是進(jìn)行太多訓(xùn)練 epoch 的信號(hào)。

在這個(gè)案例中,模型會(huì)在拐點(diǎn)處停止訓(xùn)練。另外,訓(xùn)練樣本的數(shù)目可能會(huì)增加。

過(guò)擬合模型的診斷線(xiàn)圖

過(guò)擬合模型的診斷線(xiàn)圖

6. 多次運(yùn)行實(shí)例

LSTM 是隨機(jī)的,這意味著每次運(yùn)行時(shí)都會(huì)得到一個(gè)不同的診斷圖。

多次重復(fù)診斷運(yùn)行很有用(如 5、10、30)。每次運(yùn)行的訓(xùn)練軌跡和驗(yàn)證軌跡都可以被繪制出來(lái),以更魯棒的方式記錄模型隨著時(shí)間的行為軌跡。

以下實(shí)例多次運(yùn)行同樣的實(shí)驗(yàn),然后繪制每次運(yùn)行的訓(xùn)練損失和驗(yàn)證損失軌跡。

  1. fromkeras.models importSequential 
  2. fromkeras.layers importDense 
  3. fromkeras.layers importLSTM 
  4. frommatplotlib importpyplot 
  5. fromnumpy importarray 
  6. frompandas importDataFrame 
  7. # return training data 
  8. defget_train(): 
  9. seq =[[0.0,0.1],[0.1,0.2],[0.2,0.3],[0.3,0.4],[0.4,0.5]] 
  10. seq =array(seq) 
  11. X,y =seq[:,0],seq[:,1] 
  12. XX =X.reshape((5,1,1)) 
  13. returnX,y 
  14. # return validation data 
  15. defget_val(): 
  16. seq =[[0.5,0.6],[0.6,0.7],[0.7,0.8],[0.8,0.9],[0.9,1.0]] 
  17. seq =array(seq) 
  18. X,y =seq[:,0],seq[:,1] 
  19. XX =X.reshape((len(X),1,1)) 
  20. returnX,y 
  21. # collect data across multiple repeats 
  22. train =DataFrame() 
  23. val =DataFrame() 
  24. fori inrange(5): 
  25. # define model 
  26. model.add(LSTM(10,input_shape=(1,1))) 
  27. model.add(Dense(1,activation='linear')) 
  28. # compile model 
  29. model.compile(loss='mse',optimizer='adam'
  30. X,y =get_train() 
  31. valX,valY =get_val() 
  32. # fit model 
  33. history =model.fit(X,y,epochs=300,validation_data=(valX,valY),shuffle=False
  34. # story history 
  35. train[str(i)]=history.history['loss'] 
  36. val[str(i)]=history.history['val_loss'] 
  37. # plot train and validation loss across multiple runs 
  38. pyplot.plot(train,color='blue',label='train'
  39. pyplot.plot(val,color='orange',label='validation'
  40. pyplot.title('model train vs validation loss') 
  41. pyplot.ylabel('loss') 
  42. pyplot.xlabel('epoch') 
  43. pyplot.show() 

從下圖中,我們可以在 5 次運(yùn)行中看到欠擬合模型的通常趨勢(shì),該案例強(qiáng)有力地證明增加訓(xùn)練 epoch 次數(shù)的有效性。

模型多次運(yùn)行的診斷線(xiàn)圖

模型多次運(yùn)行的診斷線(xiàn)圖

擴(kuò)展閱讀

如果你想更深入地了解這方面的內(nèi)容,這一部分提供了更豐富的資源。

  • Keras 的歷史回調(diào) API(History Callback Keras API,https://keras.io/callbacks/#history)
  • 維基百科中關(guān)于機(jī)器學(xué)習(xí)的學(xué)習(xí)曲線(xiàn)(Learning Curve in Machine Learning on Wikipedia,https://en.wikipedia.org/wiki/Learning_curve#In_machine_learning)
  • 維基百科上關(guān)于過(guò)擬合的描述(Overfitting on Wikipedia,https://en.wikipedia.org/wiki/Overfitting)

原文:https://machinelearningmastery.com/diagnose-overfitting-underfitting-lstm-models/

【本文是51CTO專(zhuān)欄機(jī)構(gòu)“機(jī)器之心”的原創(chuàng)譯文,微信公眾號(hào)“機(jī)器之心( id: almosthuman2014)”】

 

戳這里,看該作者更多好文

責(zé)任編輯:趙寧寧 來(lái)源: 51CTO專(zhuān)欄
相關(guān)推薦

2024-04-29 14:54:36

機(jī)器學(xué)習(xí)過(guò)擬合模型人工智能

2022-08-10 15:56:40

機(jī)器學(xué)習(xí)算法深度學(xué)習(xí)

2023-03-06 14:12:47

深度學(xué)習(xí)

2020-06-05 08:38:39

python散點(diǎn)圖擬合

2021-01-20 15:30:25

模型人工智能深度學(xué)習(xí)

2019-12-20 09:15:48

神經(jīng)網(wǎng)絡(luò)數(shù)據(jù)圖形

2020-07-14 10:40:49

Keras權(quán)重約束神經(jīng)網(wǎng)絡(luò)

2022-09-25 23:19:01

機(jī)器學(xué)習(xí)決策樹(shù)Python

2025-03-07 08:50:00

AI生成技術(shù)

2024-08-27 15:33:57

2025-01-20 09:21:00

2023-10-30 10:29:50

C++最小二乘法

2023-04-10 15:48:41

代碼計(jì)算

2025-01-09 09:29:57

2020-06-05 18:57:41

BiLSTMCRF算法

2018-07-03 09:12:23

深度學(xué)習(xí)正則化Python

2022-05-30 15:44:33

模型訓(xùn)練GAN

2024-12-26 00:34:47

2025-03-21 13:25:14

點(diǎn)贊
收藏

51CTO技術(shù)棧公眾號(hào)