Keras vs PyTorch,哪一個(gè)更適合做深度學(xué)習(xí)?
如何選擇工具對深度學(xué)習(xí)初學(xué)者是個(gè)難題。本文作者以 Keras 和 Pytorch 庫為例,提供了解決該問題的思路。
當(dāng)你決定學(xué)習(xí)深度學(xué)習(xí)時(shí),有一個(gè)問題會一直存在——學(xué)習(xí)哪種工具?
深度學(xué)習(xí)有很多框架和庫。這篇文章對兩個(gè)流行庫 Keras 和 Pytorch 進(jìn)行了對比,因?yàn)槎叨己苋菀咨鲜郑鯇W(xué)者能夠輕松掌握。
那么到底應(yīng)該選哪一個(gè)呢?本文分享了一個(gè)解決思路。
做出合適選擇的最佳方法是對每個(gè)框架的代碼樣式有一個(gè)概覽。開發(fā)任何解決方案時(shí)首先也是最重要的事就是開發(fā)工具。你必須在開始一項(xiàng)工程之前設(shè)置好開發(fā)工具。一旦開始,就不能一直換工具了,否則會影響你的開發(fā)效率。
作為初學(xué)者,你應(yīng)該多嘗試不同的工具,找到最適合你的那一個(gè)。但是當(dāng)你認(rèn)真開發(fā)一個(gè)項(xiàng)目時(shí),這些事應(yīng)該提前計(jì)劃好。
每天都會有新的框架和工具投入市場,而最好的工具能夠在定制和抽象之間做好平衡。工具應(yīng)該和你的思考方式和代碼樣式同步。因此要想找到適合自己的工具,首先你要多嘗試不同的工具。
我們同時(shí)用 Keras 和 PyTorch 訓(xùn)練一個(gè)簡單的模型。如果你是深度學(xué)習(xí)初學(xué)者,對有些概念無法完全理解,不要擔(dān)心。從現(xiàn)在開始,專注于這兩個(gè)框架的代碼樣式,盡量去想象哪個(gè)最適合你,使用哪個(gè)工具你最舒服,也最容易適應(yīng)。
這兩個(gè)工具最大的區(qū)別在于:PyTorch 默認(rèn)為 eager 模式,而 Keras 基于 TensorFlow 和其他框架運(yùn)行(現(xiàn)在主要是 TensorFlow),其默認(rèn)模式為圖模式。最新版本的 TensorFlow 也提供類似 PyTorch 的 eager 模式,但是速度較慢。
如果你熟悉 NumPy,你可以將 PyTorch 視為有 GPU 支持的 NumPy。此外,現(xiàn)在有多個(gè)具備高級 API(如 Keras)且以 PyTorch 為后端框架的庫,如 Fastai、Lightning、Ignite 等。如果你對它們感興趣,那你選擇 PyTorch 的理由就多了一個(gè)。
在不同的框架里有不同的模型實(shí)現(xiàn)方法。讓我們看一下這兩種框架里的簡單實(shí)現(xiàn)。本文提供了 Google Colab 鏈接。打開鏈接,試驗(yàn)代碼。這可以幫助你找到最適合自己的框架。
我不會給出太多細(xì)節(jié),因?yàn)樵诖耍覀兊哪繕?biāo)是看一下代碼結(jié)構(gòu),簡單熟悉一下框架的樣式。
Keras 中的模型實(shí)現(xiàn)
以下示例是數(shù)字識別的實(shí)現(xiàn)。代碼很容易理解。你需要打開 colab,試驗(yàn)代碼,至少自己運(yùn)行一遍。
Keras 自帶一些樣本數(shù)據(jù)集,如 MNIST 手寫數(shù)字?jǐn)?shù)據(jù)集。以上代碼可以加載這些數(shù)據(jù),數(shù)據(jù)集圖像是 NumPy 數(shù)組格式。Keras 還做了一點(diǎn)圖像預(yù)處理,使數(shù)據(jù)適用于模型。
以上代碼展示了模型。在 Keras(TensorFlow)上,我們首先需要定義要使用的東西,然后立刻運(yùn)行。在 Keras 中,我們無法隨時(shí)隨地進(jìn)行試驗(yàn),不過 PyTorch 可以。
以上的代碼用于訓(xùn)練和評估模型。我們可以使用 save() 函數(shù)來保存模型,以便后續(xù)用 load_model() 函數(shù)加載模型。predict() 函數(shù)則用來獲取模型在測試數(shù)據(jù)上的輸出。
現(xiàn)在我們概覽了 Keras 基本模型實(shí)現(xiàn)過程,現(xiàn)在來看 PyTorch。
PyTorch 中的模型實(shí)現(xiàn)
研究人員大多使用 PyTorch,因?yàn)樗容^靈活,代碼樣式也是試驗(yàn)性的。你可以在 PyTorch 中調(diào)整任何事,并控制全部,但控制也伴隨著責(zé)任。
在 PyTorch 里進(jìn)行試驗(yàn)是很容易的。因?yàn)槟悴恍枰榷x好每一件事再運(yùn)行。我們能夠輕松測試每一步。因此,在 PyTorch 中 debug 要比在 Keras 中容易一些。
接下來,我們來看簡單的數(shù)字識別模型實(shí)現(xiàn)。
以上代碼導(dǎo)入了必需的庫,并定義了一些變量。n_epochs、momentum 等變量都是必須設(shè)置的超參數(shù)。此處不討論細(xì)節(jié),我們的目的是理解代碼的結(jié)構(gòu)。
以上代碼旨在聲明用于加載訓(xùn)練所用批量數(shù)據(jù)的數(shù)據(jù)加載器。下載數(shù)據(jù)有很多種方式,不受框架限制。如果你剛開始學(xué)習(xí)深度學(xué)習(xí),以上代碼可能看起來比較復(fù)雜。
在此,我們定義了模型。這是一種創(chuàng)建網(wǎng)絡(luò)的通用方法。我們擴(kuò)展了 nn.Module,在前向傳遞中調(diào)用 forward() 函數(shù)。
PyTorch 的實(shí)現(xiàn)比較直接,且能夠根據(jù)需要進(jìn)行修改。
以上代碼段定義了訓(xùn)練和測試函數(shù)。在 Keras 中,我們需要調(diào)用 fit() 函數(shù)把這些事自動做完。但是在 PyTorch 中,我們必須手動執(zhí)行這些步驟。像 Fastai 這樣的高級 API 庫會簡化它,訓(xùn)練所需的代碼也更少。
最后,保存和加載模型,以進(jìn)行二次訓(xùn)練或預(yù)測。這部分沒有太多差別。PyTorch 模型通常有 pt 或 pth 擴(kuò)展。
關(guān)于框架選擇的建議
學(xué)會一種模型并理解其概念后,再轉(zhuǎn)向另一種模型,并不是件難事,只是需要一些時(shí)間。本文作者給出的建議是兩個(gè)都學(xué),但是不需要兩個(gè)都深入地學(xué)。
你應(yīng)該從一個(gè)開始,然后在該框架中實(shí)現(xiàn)模型,同時(shí)也應(yīng)當(dāng)掌握另一個(gè)框架的知識。這有助于你閱讀別人用另一個(gè)框架寫的代碼。永遠(yuǎn)不要被框架限制住。
先從適合自己的框架開始,然后嘗試學(xué)習(xí)另一個(gè)。如果你發(fā)現(xiàn)另一個(gè)用起來更合適,那么轉(zhuǎn)換成另一個(gè)。因?yàn)?PyTorch 和 Keras 的大多數(shù)核心概念是類似的,二者之間的轉(zhuǎn)換非常容易。
Colab 鏈接:
- PyTorch:https://colab.research.google.com/drive/1irYr0byhK6XZrImiY4nt9wX0fRp3c9mx?usp=sharing
- Keras:https://colab.research.google.com/drive/1QH6VOY_uOqZ6wjxP0K8anBAXmI0AwQCm?usp=sharing
【本文是51CTO專欄機(jī)構(gòu)“機(jī)器之心”的原創(chuàng)譯文,微信公眾號“機(jī)器之心( id: almosthuman2014)”】