譯者 | 朱先忠
審校 | 重樓
簡介
模型蒸餾是一種機器學(xué)習(xí)新技術(shù),其基本思想是讓較小的模型(學(xué)生)模仿較大的模型(老師)的行為。當(dāng)前,已經(jīng)存在幾種方法可以實現(xiàn)這一技術(shù)(將在下文中展開具體介紹),但其目標(biāo)都是在學(xué)生模型中獲得比從頭開始訓(xùn)練更好的泛化能力。
模型蒸餾示例:學(xué)生(較?。┠P褪褂谜麴s損失函數(shù)從教師模型中學(xué)習(xí),該函數(shù)使用“軟標(biāo)簽”和預(yù)測(使用OpenAI GPT4o生成的圖表)
一、為什么模型蒸餾很重要?
模型蒸餾是開發(fā)和部署大型語言模型(LLM)的關(guān)鍵技術(shù)。它解決了與這些模型的大小和復(fù)雜性相關(guān)的如下幾個挑戰(zhàn):
- 資源效率:大型模型(例如具有1750億個參數(shù)的GPT-3)需要大量計算資源進行訓(xùn)練和推理。這使得它們的部署和維護成本高昂。相反地,蒸餾可減小模型大小,從而降低內(nèi)存使用量并加快推理時間,這對于硬件功能有限的應(yīng)用程序尤其有益。
- 部署在邊緣設(shè)備上:許多應(yīng)用程序需要實時處理和低延遲,尤其是在智能手機或物聯(lián)網(wǎng)設(shè)備等邊緣設(shè)備上運行的應(yīng)用程序。精簡模型更加輕量,可以部署在此類設(shè)備上,從而無需依賴持續(xù)的云連接即可實現(xiàn)AI功能。
- 降低成本:由于能耗和專用硬件的需求,運行大型模型的成本很高。通過將模型蒸餾為較小的版本,公司或組織可以顯著降低運營費用,同時保持相當(dāng)?shù)男阅芩健?/span>
- 提高訓(xùn)練效率:通過蒸餾得到的較小模型需要更少的數(shù)據(jù)和計算能力來針對特定任務(wù)進行微調(diào)。這種效率加快了開發(fā)周期,使資源有限的研究人員和從業(yè)者更容易利用先進的人工智能模型。
為了說明模型蒸餾的影響,請考慮大型模型與其蒸餾模型之間的以下基準(zhǔn)比較:
表1:大型語言模型與精簡模型(較小)之間的示例比較。注意:這些數(shù)字僅供參考,實際性能指標(biāo)可能因?qū)嵤┖陀布?/span>
在此示例中,蒸餾模型實現(xiàn)了與GPT-3相當(dāng)?shù)臏?zhǔn)確率,同時顯著減少了參數(shù)數(shù)量和推理時間。這證明了蒸餾如何使AI模型在實際應(yīng)用中更加實用且更具成本效益。
到目前為止,我們已經(jīng)理解了為什么蒸餾如此重要。現(xiàn)在,讓我們更深入地了解模型蒸餾的細(xì)節(jié)。
二、什么是模型蒸餾?
想象一下,你正在向一位世界級專家(老師)學(xué)習(xí)一個復(fù)雜的主題,比如量子物理學(xué)。這位專家無所不知,但他們使用復(fù)雜的語言,需要很長時間才能解釋清楚?,F(xiàn)在再想象一下,另一個人——一位偉大的溝通者(學(xué)生)——向這位專家學(xué)習(xí),然后以一種更簡單、更快捷的方式教你相同的內(nèi)容,而不會丟失核心信息。這就是模型蒸餾背后的主要思想。
更正式地說,模型蒸餾是一個過程,其中訓(xùn)練一個較小、更高效的模型(稱為學(xué)生)來復(fù)制一個較大、更強大的模型(稱為老師)的行為。目標(biāo)是讓學(xué)生更快、更輕松,同時在相同的任務(wù)上仍然表現(xiàn)良好。
蒸餾類型
模型蒸餾并不局限于教學(xué)生最終的答案是什么。學(xué)生可以通過多種方式向老師學(xué)習(xí)。以下是三種主要類型:
1.基于Logit的蒸餾(軟標(biāo)簽):學(xué)生模型從老師的概率分布中學(xué)習(xí)
我們不只是對學(xué)生進行正確答案(硬標(biāo)簽)的訓(xùn)練,還讓它了解老師對每個答案的信心程度——這些被稱為軟標(biāo)簽或軟目標(biāo)。
為什么要使用軟目標(biāo)?
假設(shè)你正在訓(xùn)練一個模型來對動物進行分類:
- 輸入是一張狼的圖像。
- 老師輸出:
[Wolf: 1.0, Dog: 0.0, Cat: 0.0, Fox: 0.0]
- 學(xué)生不僅知道應(yīng)該說“狼”,還知道狗和狐貍有些相似。這種額外的細(xì)微差別有助于學(xué)生學(xué)習(xí)更好的決策界限。
相比之下,僅使用硬標(biāo)簽進行訓(xùn)練將會是這樣的:
[Wolf: 1.0, Dog: 0.0, Cat: 0.0, Fox: 0.0]
沒有細(xì)微差別=更難學(xué)習(xí)細(xì)微的差別。
學(xué)生學(xué)到了什么:
- 哪些類別可能或不可能
- 老師如何處理不確定性
為什么這種類型很有用:
- 軟目標(biāo)就像一個更平滑的訓(xùn)練信號,特別是對于難以學(xué)習(xí)的例子
- 鼓勵學(xué)生更好地概括,而不是僅僅依靠硬標(biāo)簽
這是最常用的方法,尤其是對于分類任務(wù)。
2.基于特征的蒸餾:學(xué)生模仿老師的中間層表征
學(xué)生被訓(xùn)練模仿老師的隱藏層激活,而不僅僅是輸出。
你可以將教師模型視為在內(nèi)部逐步解決復(fù)雜問題。在基于特征的蒸餾中,學(xué)生模型會嘗試復(fù)制教師解決問題的方式,而不僅僅是最終答案。
學(xué)生學(xué)到了什么:
- 內(nèi)部推理模式
- 輸入數(shù)據(jù)的分層表示
- 嵌入結(jié)構(gòu)和注意力圖(在Transformer中)
為什么這種類型很有用:
- 幫助學(xué)生學(xué)習(xí)更豐富的表現(xiàn)形式,尤其是在模型較小的情況下
- 可以提高對新任務(wù)的可轉(zhuǎn)移性
- 促進多模式或復(fù)雜模型的更好對齊
此類型用于TinyBERT和MobileBERT等蒸餾技術(shù)。
3.基于關(guān)系的蒸餾:捕獲多個實例之間的關(guān)系,而不僅僅是實例方面的知識
學(xué)生模型不僅從老師的輸出中學(xué)習(xí),還從老師的表征空間中不同數(shù)據(jù)樣本之間的關(guān)系中學(xué)習(xí)。
基于關(guān)系的蒸餾并不關(guān)注個別的預(yù)測,而是教導(dǎo)學(xué)生保留數(shù)據(jù)的結(jié)構(gòu)——例如,如果老師發(fā)現(xiàn)兩個句子相似,那么學(xué)生也應(yīng)該學(xué)會將它們視為相似。
學(xué)生學(xué)到了什么:
- 實例之間的相對距離和相似性
- 嵌入空間中的分組、聚類或其他結(jié)構(gòu)知識
為什么這種類型很有用:
- 在度量學(xué)習(xí)、對比學(xué)習(xí)或檢索任務(wù)中尤其有效
- 鼓勵學(xué)生學(xué)習(xí)與老師相同的“思維導(dǎo)圖”
- 使學(xué)生能夠適應(yīng)輸入分布的變化
此類型用于更高級或以研究為重點的蒸餾方法(例如,RKD——關(guān)系知識蒸餾)。
小結(jié)
三、如何進行模型蒸餾(附代碼示例)
第1步:加載預(yù)訓(xùn)練教師模型
使用Hugging Face的轉(zhuǎn)換器來加載大型模型(例如DistilBERT的老師:BERT)。
from transformers import AutoModelForSequenceClassification, AutoTokenizer
teacher_model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
teacher_model = AutoModelForSequenceClassification.from_pretrained(teacher_model_name)
第2步:從頭開始訓(xùn)練學(xué)生模型
初始化一個較小的模型,例如:
distilbert-base-uncased。
student_model_name = "distilbert-base-uncased"
student_model =
AutoModelForSequenceClassification.from_pretrained(student_model_name)
第3步:實現(xiàn)知識蒸餾損失
在老師和學(xué)生的預(yù)測之間使用KL散度。
import torch
import torch.nn.functional as F
def distillation_loss(student_logits, teacher_logits, labels, alpha=0.5, T=2.0):
hard_loss = F.cross_entropy(student_logits, labels)
soft_loss = F.kl_div(
F.log_softmax(student_logits / T, dim=1),
F.softmax(teacher_logits / T, dim=1),
reductinotallow="batchmean"
) * (T ** 2)
return alpha * soft_loss + (1 – alpha) * hard_loss
第4步:使用教師的軟目標(biāo)訓(xùn)練學(xué)生模型
optimizer = torch.optim.AdamW(student_model.parameters(), lr=5e-5)
for batch in train_dataloader:
input_ids, attention_mask, labels = batch
with torch.no_grad():
teacher_logits = teacher_model(input_ids, attention_mask=attention_mask).logits
student_logits = student_model(input_ids, attention_mask=attention_mask).logits
loss = distillation_loss(student_logits, teacher_logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
第5步:評估蒸餾后的模型
比較原始模型和蒸餾模型的準(zhǔn)確度和推理時間:
import time
def evaluate_model(model, dataloader):
model.eval()
correct, total = 0, 0
start_time = time.time()
with torch.no_grad():
for batch in dataloader:
input_ids, attention_mask, labels = batch
outputs = model(input_ids, attention_mask=attention_mask).logits
predictions = torch.argmax(outputs, dim=1)
correct += (predictions == labels).sum().item()
total += labels.size(0)
inference_time = time.time() - start_time
accuracy = correct / total
return accuracy, inference_time
teacher_acc, teacher_time = evaluate_model(teacher_model, test_dataloader)
student_acc, student_time = evaluate_model(student_model, test_dataloader)
print(f"Teacher Accuracy: {teacher_acc:.4f}, Inference Time: {teacher_time:.2f}s")
print(f"Student Accuracy: {student_acc:.4f}, Inference Time: {student_time:.2f}s")
結(jié)論
隨著大型語言模型不斷突破AI的極限,它們也帶來了現(xiàn)實的弊端:推理速度慢、能耗高、部署能力有限。模型蒸餾通過將大型模型的功能壓縮為更小、更快、更高效的版本,為這些挑戰(zhàn)提供了一種實用而優(yōu)雅的解決方案。
本文中,我們探索了模型蒸餾的緣由、內(nèi)容和方式——從通過軟標(biāo)簽學(xué)習(xí)到模仿內(nèi)部表示,甚至保留數(shù)據(jù)點之間的關(guān)系。無論你是為移動應(yīng)用程序、低延遲API還是邊緣設(shè)備構(gòu)建模型,蒸餾都是在不降低性能的情況下縮小模型的關(guān)鍵工具。
蒸餾技術(shù)最重要的貢獻在哪里?在于這種技術(shù)不僅適用于研究實驗室或科技巨頭。借助Hugging Face Transformers和PyTorch等開源工具,任何人都可以立即開始蒸餾模型。
蒸餾不僅僅是為了讓模型更小,而且還為了讓它們更智能、更快、更易于訪問。隨著人工智能從集中式數(shù)據(jù)中心轉(zhuǎn)移到日常設(shè)備和應(yīng)用程序,蒸餾只會變得越來越重要。
譯者介紹
朱先忠,51CTO社區(qū)編輯,51CTO專家博客、講師,濰坊一所高校計算機教師,自由編程界老兵一枚。
原文標(biāo)題:Understanding Model Distillation in Large Language Models (With Code Examples),作者:Edgar