您如何判斷是否用足夠的數(shù)據(jù)訓(xùn)練了模型?
譯文【51CTO.com快譯】深度神經(jīng)網(wǎng)絡(luò)(DNN)需要大量訓(xùn)練數(shù)據(jù),即使微調(diào)模型也需要大量訓(xùn)練數(shù)據(jù)。那么您如何知道是否已用了足夠的數(shù)據(jù)?如果是計(jì)算機(jī)視覺(CV)模型,您始終可以查看測(cè)試錯(cuò)誤。但是如果微調(diào)BERT或GPT之類的大型transformer模型,又該如何?
- 評(píng)估模型的最佳度量指標(biāo)是什么?
- 您如何確信已用足夠的數(shù)據(jù)訓(xùn)練了模型?
- 您的客戶如何確信?
WeightWatcher可助您一臂之力。
- pip install weightwatcher
WeightWatcher是一種開源診斷工具,用于評(píng)估(預(yù))訓(xùn)練和微調(diào)的深度神經(jīng)網(wǎng)絡(luò)的性能。它基于研究深度學(xué)習(xí)為何有效的前沿成果。最近,它已在《自然》雜志上刊登。
本文介紹如何使用WeightWatcher來確定您的DNN模型是否用足夠的數(shù)據(jù)加以訓(xùn)練。
我們?cè)诒疚闹锌紤]GPT vs GPT2這個(gè)例子。GPT是一種NLP Transformer模型,由OpenAI開發(fā),用于生成假文本。最初開發(fā)時(shí),OpenAI發(fā)布了GPT模型,該模型專門用小的數(shù)據(jù)集進(jìn)行訓(xùn)練,因此無法生成假文本。后來,他們認(rèn)識(shí)到假文本是個(gè)好生意,于是發(fā)布了GPT2,GPT2就像GPT 一樣,但用足夠的數(shù)據(jù)來加以訓(xùn)練,確保有用。
我們可以將WeightWatcher運(yùn)用于GPT和GPT2,比較結(jié)果;我們將看到WeightWatcher log spectral norm和 alpha(冪律)這兩個(gè)度量指標(biāo)可以立即告訴我們GPT模型出了岔子。這在論文的圖6中顯示;
圖 6
我們?cè)谶@里將詳細(xì)介紹如何針對(duì)WeightWatcher冪律(PL)alpha度量指標(biāo)執(zhí)行此操作,并解釋如何解讀這些圖。
建議在Jupiter筆記本或Google Colab中運(yùn)行這些計(jì)算。(作為參考,您還可以查看用于在論文中創(chuàng)建圖解的實(shí)際筆記本,然而這里使用的是舊版本的WeightWatcher)。
出于本文需要,我們?cè)赪eightWatcher github代碼存儲(chǔ)庫中提供了切實(shí)有效的筆記本。
WeightWatcher了解基本的Huggingface模型。的確,WeightWatcher支持以下:
- TF2.0/Keras
- pyTorch 1.x
- HuggingFace
- 很快會(huì)支持ONNX(當(dāng)前主干中)
目前,我們支持Dense層和Conv2D層。即將支持更多層。針對(duì)我們的NLP Transformer模型,我們只需要支持Dense層。
首先,我們需要GPT和GPT2 pyTorch模型。我們將使用流行的HuggingFace transformers軟件包。
- !pip install transformers
其次,我們需要導(dǎo)入pyTorch和weightwatcher
- Import torch
- Import weightwatcher as ww
我們還需要pandas庫和matplotlib庫來幫助我們解讀weightwatcher度量指標(biāo)。在Jupyter筆記本中,這看起來像:
- import pandas as pd
- import matplotlib
- import matplotlib.pyplot as plt
- %matplotlib inline
我們現(xiàn)在導(dǎo)入transformers軟件包和2個(gè)模型類
- import transformers
- from transformers import OpenAIGPTModel,GPT2Model
我們要獲取2個(gè)預(yù)訓(xùn)練的模型,并運(yùn)行model.eval()
- gpt_model = OpenAIGPTModel.from_pretrained('openai-gpt')
- gpt_model.eval();
- gpt2_model = GPT2Model.from_pretrained('gpt2')
- gpt2_model.eval();
想使用WeightWatcher分析我們的GPT模型,只需創(chuàng)建一個(gè)watcher實(shí)例,然后運(yùn)行 watcher.analyze()。這將返回的Pandas數(shù)據(jù)幀,附有每一層的度量指標(biāo)。
- watcher = ww.WeightWatcher(model=gpt_model)
- gpt_details = watcher.analyze()
細(xì)節(jié)數(shù)據(jù)幀報(bào)告可用于分析模型性能的質(zhì)量度量指標(biāo)——無需訪問測(cè)試數(shù)據(jù)或訓(xùn)練數(shù)據(jù)。最重要的度量指標(biāo)是冪律度量指標(biāo)。WeightWatcher報(bào)告每一層 的
。GPT模型有近50層,因此將所有層 alpha作為直方圖(使用pandas API)一次性檢查顯得很方便。
- gpt_details.alpha.plot.hist(bins=100, color='red', alpha=0.5, density=True, label='gpt')
- plt.xlabel(r"alpha $(\alpha)$ PL exponent")
- plt.legend()
圖2
從這個(gè)直方圖中,我們可以立即看到模型的兩個(gè)問題
•有幾個(gè)異常值是,表明幾個(gè)層訓(xùn)練欠佳。
所以對(duì)GPT一無所知,也從未見過測(cè)試訓(xùn)練或訓(xùn)練數(shù)據(jù),WeightWatcher告訴我們這個(gè)模型永遠(yuǎn)不該進(jìn)入生產(chǎn)環(huán)境。
現(xiàn)在不妨看看GPT2,它有相同的架構(gòu),但使用更多更好的數(shù)據(jù)加以訓(xùn)練。我們?cè)俅问褂弥付ǖ哪P蛣?chuàng)建一個(gè)watcher實(shí)例,然后運(yùn)行 watcher.analyze()
- watcher = ww.WeightWatcher(model=gpt2_model)
- gpt2_details = watcher.analyze()
現(xiàn)在不妨比較GPT和GPT2的冪律alpha度量指標(biāo)。我們就創(chuàng)建2個(gè)直方圖,每個(gè)模型1個(gè)直方圖,并疊加這2個(gè)圖。
- gpt_details.alpha.plot.hist(bins=100, color='red', alpha=0.5, density=True, label='gpt')
- gpt2_details.alpha.plot.hist(bins=100, color='green', density=True, label='gpt2')
- plt.xlabel(r"alpha $(\alpha)$ PL exponent")
- plt.legend()
GPT的層alpha顯示紅色,GPT2的層alpha顯示綠色,直方圖差異很大。對(duì)于GPT2,峰值$alpha\sim 3.5&bg=ffffff$,更重要的是沒有異常值$latex \alpha>6&bg=ffffff$。Alpha越小越好,GPT2模型比GPT好得多,原因在于它用更多更好的數(shù)據(jù)加以訓(xùn)練。
圖3
WeightWatcher 有許多功能可以幫助您評(píng)估模型。它可以做這樣的事情:
- 幫助您決定是否用足夠的數(shù)據(jù)對(duì)其進(jìn)行了訓(xùn)練(如圖所示)
- 檢測(cè)過度訓(xùn)練的潛在層
- 用于獲取提前停止的標(biāo)準(zhǔn)(當(dāng)您無法查看測(cè)試數(shù)據(jù)時(shí))
- 針對(duì)不同的模型和超參數(shù),預(yù)測(cè)測(cè)試精度方面的趨勢(shì)
等等
不妨試一下。如果它對(duì)您有用,請(qǐng)告訴我。
原文標(biāo)題:How to Tell if You Have Trained Your Model with Enough Data,作者:Charles Martin
【51CTO譯稿,合作站點(diǎn)轉(zhuǎn)載請(qǐng)注明原文譯者和出處為51CTO.com】