如何保存和恢復(fù)TensorFlow訓(xùn)練的模型
如果深層神經(jīng)網(wǎng)絡(luò)模型的復(fù)雜度非常高的話,那么訓(xùn)練它可能需要相當(dāng)長(zhǎng)的一段時(shí)間,當(dāng)然這也取決于你擁有的數(shù)據(jù)量,運(yùn)行模型的硬件等等。在大多數(shù)情況下,你需要通過保存文件來保障你試驗(yàn)的穩(wěn)定性,防止如果中斷(或一個(gè)錯(cuò)誤),你能夠繼續(xù)從沒有錯(cuò)誤的地方開始。
更重要的是,對(duì)于任何深度學(xué)習(xí)的框架,像TensorFlow,在成功的訓(xùn)練之后,你需要重新使用模型的學(xué)習(xí)參數(shù)來完成對(duì)新數(shù)據(jù)的預(yù)測(cè)。
在這篇文章中,我們來看一下如何保存和恢復(fù)TensorFlow模型,我們?cè)诖私榻B一些最有用的方法,并提供一些例子。
1. 首先我們將快速介紹TensorFlow模型
TensorFlow的主要功能是通過張量來傳遞其基本數(shù)據(jù)結(jié)構(gòu)類似于NumPy中的多維數(shù)組,而圖表則表示數(shù)據(jù)計(jì)算。它是一個(gè)符號(hào)庫,這意味著定義圖形和張量將僅創(chuàng)建一個(gè)模型,而獲取張量的具體值和操作將在會(huì)話(session)中執(zhí)行,會(huì)話(session)一種在圖中執(zhí)行建模操作的機(jī)制。會(huì)話關(guān)閉時(shí),張量的任何具體值都會(huì)丟失,這也是運(yùn)行會(huì)話后將模型保存到文件的另一個(gè)原因。
通過示例可以幫助我們更容易理解,所以讓我們?yōu)槎S數(shù)據(jù)的線性回歸創(chuàng)建一個(gè)簡(jiǎn)單的TensorFlow模型。
首先,我們將導(dǎo)入我們的庫:
- import tensorflow as tf
- import numpy as np
- import matplotlib.pyplot as plt
- %matplotlib inline
下一步是創(chuàng)建模型。我們將生成一個(gè)模型,它將以以下的形式估算二次函數(shù)的水平和垂直位移:
- y = (x - h) ^ 2 + v
其中h是水平和v是垂直的變化。
以下是如何生成模型的過程(有關(guān)詳細(xì)信息,請(qǐng)參閱代碼中的注釋):
- # Clear the current graph in each run, to avoid variable duplication
- tf.reset_default_graph()
- # Create placeholders for the x and y points
- X = tf.placeholder("float")
- Y = tf.placeholder("float")
- # Initialize the two parameters that need to be learned
- h_est = tf.Variable(0.0, name='hor_estimate')
- v_est = tf.Variable(0.0, name='ver_estimate')
- # y_est holds the estimated values on y-axis
- y_est = tf.square(X - h_est) + v_est
- # Define a cost function as the squared distance between Y and y_est
- cost = (tf.pow(Y - y_est, 2))
- # The training operation for minimizing the cost function. The
- # learning rate is 0.001
- trainop = tf.train.GradientDescentOptimizer(0.001).minimize(cost)
在創(chuàng)建模型的過程中,我們需要有一個(gè)在會(huì)話中運(yùn)行的模型,并且傳遞一些真實(shí)的數(shù)據(jù)。我們生成一些二次數(shù)據(jù)(Quadratic data),并給他們添加噪聲。
- # Use some values for the horizontal and vertical shift
- h = 1
- v = -2
- # Generate training data with noise
- x_train = np.linspace(-2,4,201)
- noise = np.random.randn(*x_train.shape) * 0.4
- y_train = (x_train - h) ** 2 + v + noise
- # Visualize the data
- plt.rcParams['figure.figsize'] = (10, 6)
- plt.scatter(x_train, y_train)
- plt.xlabel('x_train')
- plt.ylabel('y_train')
2. The Saver class
Saver類是TensorFlow庫提供的類,它是保存圖形結(jié)構(gòu)和變量的***方法。
(1) 保存模型
在以下幾行代碼中,我們定義一個(gè)Saver對(duì)象,并在train_graph()函數(shù)中,經(jīng)過100次迭代的方法最小化成本函數(shù)。然后,在每次迭代中以及優(yōu)化完成后,將模型保存到磁盤。每個(gè)保存在磁盤上創(chuàng)建二進(jìn)制文件被稱為“檢查點(diǎn)”。
- # Create a Saver object
- saver = tf.train.Saver()
- init = tf.global_variables_initializer()
- # Run a session. Go through 100 iterations to minimize the cost
- def train_graph():
- with tf.Session() as sess:
- sess.run(init)
- for i in range(100):
- for (x, y) in zip(x_train, y_train):
- # Feed actual data to the train operation
- sess.run(trainop, feed_dict={X: x, Y: y})
- # Create a checkpoint in every iteration
- saver.save(sess, 'model_iter', global_step=i)
- # Save the final model
- saver.save(sess, 'model_final')
- h_ = sess.run(h_est)
- v_ = sess.run(v_est)
- return h_, v_
現(xiàn)在讓我們用上述功能訓(xùn)練模型,并打印出訓(xùn)練的參數(shù)。
- result = train_graph()
- print("h_est = %.2f, v_est = %.2f" % result)
- $ python tf_save.py
- h_est = 1.01, v_est = -1.96
Okay,參數(shù)是非常準(zhǔn)確的。如果我們檢查我們的文件系統(tǒng),***4次迭代中保存有文件以及最終的模型。
保存模型時(shí),你會(huì)注意到需要4種類型的文件才能保存:
- “.meta”文件:包含圖形結(jié)構(gòu)。
- “.data”文件:包含變量的值。
- “.index”文件:標(biāo)識(shí)檢查點(diǎn)。
- “checkpoint”文件:具有最近檢查點(diǎn)列表的協(xié)議緩沖區(qū)。
圖1:檢查點(diǎn)文件保存到磁盤
調(diào)用tf.train.Saver()方法,如上所示,將所有變量保存到一個(gè)文件。通過將它們作為參數(shù),表情通過列表或dict傳遞來保存變量的子集,例如:tf.train.Saver({‘hor_estimate’: h_est})。
Saver構(gòu)造函數(shù)的一些其他有用的參數(shù),也可以控制整個(gè)過程,它們是:
- max_to_keep:最多保留的檢查點(diǎn)數(shù)。
- keep_checkpoint_every_n_hours:保存檢查點(diǎn)的時(shí)間間隔。如果你想要了解更多信息,請(qǐng)查看官方文檔的Saver類,它提供了其它有用的信息,你可以探索查看。
- Restoring Models
恢復(fù)TensorFlow模型時(shí)要做的***件事就是將圖形結(jié)構(gòu)從“.meta”文件加載到當(dāng)前圖形中。
- tf.reset_default_graph()
- imported_meta = tf.train.import_meta_graph("model_final.meta")
也可以使用以下命令探索當(dāng)前圖形tf.get_default_graph()。接著第二步是加載變量的值。提醒:值僅存在于會(huì)話(session)中。
- with tf.Session() as sess:
- imported_meta.restore(sess, tf.train.latest_checkpoint('./'))
- h_est2 = sess.run('hor_estimate:0')
- v_est2 = sess.run('ver_estimate:0')
- print("h_est: %.2f, v_est: %.2f" % (h_est2, v_est2))
- $ python tf_restore.py
- INFO:tensorflow:Restoring parameters from ./model_final
- h_est: 1.01, v_est: -1.96
如前面所提到的,這種方法只保存圖形結(jié)構(gòu)和變量,這意味著通過占位符“X”和“Y”輸入的訓(xùn)練數(shù)據(jù)不會(huì)被保存。
無論如何,在這個(gè)例子中,我們將使用我們定義的訓(xùn)練數(shù)據(jù)tf,并且可視化模型擬合。
- plt.scatter(x_train, y_train, label='train data')
- plt.plot(x_train, (x_train - h_est2) ** 2 + v_est2, color='red', label='model')
- plt.xlabel('x_train')
- plt.ylabel('y_train')
- plt.legend()
Saver這個(gè)類允許使用一個(gè)簡(jiǎn)單的方法來保存和恢復(fù)你的TensorFlow模型(圖形和變量)到/從文件,并保留你工作中的多個(gè)檢查點(diǎn),這可能是有用的,它可以幫助你的模型在訓(xùn)練過程中進(jìn)行微調(diào)。
4. SavedModel格式(Format)
在TensorFlow中保存和恢復(fù)模型的一種新方法是使用SavedModel,Builder和loader功能。這個(gè)方法實(shí)際上是Saver提供的更高級(jí)別的序列化,它更適合于商業(yè)目的。
雖然這種SavedModel方法似乎不被開發(fā)人員完全接受,但它的創(chuàng)作者指出:它顯然是未來。與Saver主要關(guān)注變量的類相比,SavedModel嘗試將一些有用的功能包含在一個(gè)包中,例如Signatures:允許保存具有一組輸入和輸出的圖形,Assets:包含初始化中使用的外部文件。
(1) 使用SavedModel Builder保存模型
接下來我們嘗試使用SavedModelBuilder類完成模型的保存。在我們的示例中,我們不使用任何符號(hào),但也足以說明該過程。
- tf.reset_default_graph()
- # Re-initialize our two variables
- h_est = tf.Variable(h_est2, name='hor_estimate2')
- v_est = tf.Variable(v_est2, name='ver_estimate2')
- # Create a builder
- builder = tf.saved_model.builder.SavedModelBuilder('./SavedModel/')
- # Add graph and variables to builder and save
- with tf.Session() as sess:
- sess.run(h_est.initializer)
- sess.run(v_est.initializer)
- builder.add_meta_graph_and_variables(sess,
- [tf.saved_model.tag_constants.TRAINING],
- signature_def_map=None,
- assets_collection=None)
- builder.save()
- $ python tf_saved_model_builder.py
- INFO:tensorflow:No assets to save.
- INFO:tensorflow:No assets to write.
- INFO:tensorflow:SavedModel written to: b'./SavedModel/saved_model.pb'
運(yùn)行此代碼時(shí),你會(huì)注意到我們的模型已保存到位于“./SavedModel/saved_model.pb”的文件中。
(2) 使用SavedModel Loader程序恢復(fù)模型
模型恢復(fù)使用tf.saved_model.loader,并且可以恢復(fù)會(huì)話范圍中保存的變量,符號(hào)。
在下面的例子中,我們將加載模型,并打印出我們的兩個(gè)系數(shù)(h_est和v_est)的數(shù)值。數(shù)值如預(yù)期的那樣,我們的模型已經(jīng)被成功地恢復(fù)了。
- with tf.Session() as sess:
- tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], './SavedModel/')
- h_est = sess.run('hor_estimate2:0')
- v_est = sess.run('ver_estimate2:0')
- print("h_est: %.2f, v_est: %.2f" % (h_est, v_est))
- $ python tf_saved_model_loader.py
- INFO:tensorflow:Restoring parameters from b'./SavedModel/variables/variables'
- h_est: 1.01, v_est: -1.96
5. 結(jié)論
如果你知道你的深度學(xué)習(xí)網(wǎng)絡(luò)的訓(xùn)練可能會(huì)花費(fèi)很長(zhǎng)時(shí)間,保存和恢復(fù)TensorFlow模型是非常有用的功能。該主題太廣泛,無法在一篇博客文章中詳細(xì)介紹。不管怎樣,在這篇文章中我們介紹了兩個(gè)工具:Saver和SavedModel builder/loader,并創(chuàng)建一個(gè)文件結(jié)構(gòu),使用簡(jiǎn)單的線性回歸來說明實(shí)例。希望這些能夠幫助到你訓(xùn)練出更好的神經(jīng)網(wǎng)絡(luò)模型。