被DeepSeek帶火的知識蒸餾詳解!
今天來詳細了解DeepSeek中提到的知識蒸餾技術,主要內容來自三巨頭之一Geoffrey Hinton的一篇經典工作:https://arxiv.org/pdf/1503.02531。
主要從背景、定義、原理、代碼復現等幾個方面來介紹:
1、背景介紹
訓練與部署的不一致性
在機器學習和深度學習領域,訓練模型和部署模型通常存在顯著差異。訓練階段,為了追求最佳性能,我們通常會使用復雜的模型架構和大量的計算資源,從海量且高度冗余的數據集中提取有用信息。例如,一些最先進的模型可能包含數十億甚至上百億的參數,或者通過多個模型集成來進一步提升性能。然而,這些龐大的模型在實際部署時面臨諸多問題:
- 推斷速度慢大模型在處理數據時需要更多的時間來完成計算,這在需要實時響應的場景中是不可接受的。
- 資源要求高大模型需要大量的內存和顯存來存儲模型參數和中間計算結果,這使得它們難以部署在資源受限的設備上,如移動設備或嵌入式系統(tǒng)。
因此,在部署階段,我們對模型的延遲和計算資源有著嚴格的限制。這就引出了模型壓縮的需求——在盡量不損失性能的前提下,減少模型的參數量,使其更適合實際應用環(huán)境。
模型壓縮與知識蒸餾
模型壓縮用于解決訓練階段與部署階段之間的不一致性,特別是在模型規(guī)模與實際應用需求之間的矛盾,在盡量不損失模型性能的前提下,減少模型的參數量和計算復雜度,使其更適合在資源受限的環(huán)境中部署。
知識蒸餾(Knowledge Distillation)是其中一種非常有效的模型壓縮技術。
2、什么是知識蒸餾?
知識蒸餾是一種模型壓縮技術,通過訓練一個小而高效的學生模型來模仿一個預訓練的大且復雜的教師模型(或一組模型)的行為。這種訓練設置通常被稱為“教師-學生”模式,其中大型模型作為教師,小型模型作為學生。教師模型的知識通過最小化損失函數傳遞給學生模型,目標是匹配教師模型預測的類概率分布。
知識蒸餾的核心思想是將一個復雜且性能強大的“教師”模型的知識遷移到一個更小、更輕量的“學生”模型中。通過這種方式,學生模型可以在保持較小參數量的同時,盡可能地繼承教師模型的性能。
該方法最早由Bucila等人在2006年提出(https://www.cs.cornell.edu/~caruana/compression.kdd06.pdf),并在2015年由Hinton等人推廣,成為知識蒸餾領域的奠基之作。Hinton的工作引入了帶溫度參數的softmax函數,進一步增強了知識轉移的有效性。
在 Geoffrey Hinton 等人的論文《Distilling the Knowledge in a Neural Network》中,作者通過昆蟲的幼蟲和成蟲形態(tài)來比喻機器學習中的訓練階段和部署階段。這種比喻強調了不同階段對于模型的不同需求:就像昆蟲的幼蟲形態(tài)專注于從環(huán)境中吸收能量和養(yǎng)分,而其成蟲形態(tài)則優(yōu)化于移動和繁殖;同樣地,在機器學習中,訓練階段需要處理大量數據以提取結構信息,而部署階段則更關注于實時性和計算資源的有效利用。
教師與學生模型
教師模型:指的是那些龐大、復雜且可能計算成本高昂的模型或模型集合(ensemble)。這些模型雖然在訓練階段能夠提供優(yōu)秀的性能,但由于其復雜性,不適合直接用于實際部署。教師模型通常經過充分訓練,可以很好地泛化到新數據上。
學生模型:相對較小且計算效率更高的模型,旨在模仿教師模型的表現同時保持低延遲和高效能,便于大規(guī)模部署。學生模型通過“知識蒸餾”過程從教師模型那里獲得知識,從而能夠在資源受限的環(huán)境下運行。
知識的理解
關于“知識”的理解存在一個常見的誤解,即將其簡單等同于模型的權重參數。然而,Hinton 等人提倡一種更為抽象的觀點,即知識應被視為從輸入向量到輸出向量的學習映射關系。這不僅包括了對正確答案的概率預測,還涵蓋了對不正確答案之間細微差異的理解。例如,在圖像分類任務中,即使某個類別的概率非常小,它與其他錯誤類別相比仍然可能存在顯著差異,這些差異反映了模型如何泛化的關鍵信息。
蒸餾技術
知識蒸餾是一種技術,它允許我們將教師模型中的知識轉移到學生模型中。這不僅僅是復制權重的過程,而是涉及到使用教師模型生成的軟目標(soft targets)作為指導,幫助學生模型學習到相似的泛化能力。通過調整 softmax 層的溫度參數,可以使這些軟目標更加平滑,從而讓學生模型能夠捕捉到更多有用的信息,并減少過擬合的風險。
3、知識蒸餾的工作原理
Soft Target vs Hard Target
Hard Targets是最常見的訓練標簽形式,通常指的是每個訓練樣本對應的一個確切的類別標簽。例如,在分類問題中,如果任務是對圖像進行數字識別(如MNIST數據集),那么hard target就是一個具體的數字(比如1)。在這種情況下,模型被訓練來最大化正確類別的概率,而其他所有類別的概率則應盡可能地小。這通常通過最小化交叉熵損失函數來實現,該函數懲罰了模型對正確標簽預測的概率不足,并且不考慮錯誤標簽之間的相對概率。
相比之下,Soft Targets提供了一個更加細致的概率分布,不僅包含了正確的類別,還包括了模型認為可能相關的其他類別的概率。這意味著除了給出最有可能的類別之外,soft targets還提供了關于模型對于哪些類別可能是正確的、以及這些類別之間如何相互關聯的信息。這種類型的標簽可以通過一個已經訓練好的教師模型生成,它為每個輸入產生一個概率分布而不是單一的類別標簽。
基于soft targets訓練模型相較于hard targets訓練模型,尤其是在資源受限的情況下有效地轉移復雜模型的知識,具有以下幾個顯著的優(yōu)勢和意義:
- 傳遞更豐富的信息:Soft targets不僅包含正確類別的概率,還包括了對其他類別的相對概率估計。這意味著模型可以學習到不同類別之間的相似性和差異性,使得即使是對于錯誤的類別預測,也能反映出它們之間的細微差別,從而提供比單一正確答案更多的信息。例如,在圖像識別任務中,一個特定圖像可能與多個類別有一定的相似性,而這些信息在hard target中是丟失的。
- 減少梯度方差:當使用較高的溫度參數T時,softmax輸出的概率分布變得更加平滑,這意味著每個樣本提供的信息量相對于hard target來說增加了,同時減少了梯度估計中的方差。這對于小數據集上的訓練尤其有益,因為它可以防止過擬合并幫助模型更好地泛化到未見過的數據。
- 提高泛化能力:相比于hard targets僅給出一個確切的類別標簽,通過模仿教師模型(通常是大型或集成模型)生成的soft targets,學生模型能夠學習到教師模型如何進行泛化的細節(jié)。這有助于模型更好地學習數據中的潛在結構和模式,特別是在處理復雜或模糊邊界的問題時。對于那些難以明確區(qū)分的類別,soft targets能夠指導模型認識到哪些錯誤類別之間更加接近,從而改進其泛化能力。
- 加速收斂和降低過擬合風險:由于soft targets通常比hard targets擁有更高的熵,它們能提供更多的信息,并且減少梯度估計的方差。這對于小規(guī)模數據集特別有用,因為它可以幫助防止模型過度擬合訓練數據中的噪聲。論文中提到的一個實驗顯示,使用soft targets的學生模型即使只用3%的數據進行訓練,也能夠幾乎恢復全量數據所能提供的信息,同時不需要早期停止來防止過擬合。這表明soft targets作為一種有效的正則化方法,能夠幫助模型更好地泛化。
- 增強模型的魯棒性:通過利用soft targets進行訓練,模型可以學習到輸入數據的內在分布特性,而不僅僅是表面特征。這意味著模型對于輸入數據的小變化(如輕微的圖像變換)會更加穩(wěn)健,因為它們已經學會了識別那些對最終分類決策影響較小的變化。
帶溫度的Softmax
傳統(tǒng)softmax函數傾向于產生極端的概率分布,導致非正確類別的概率接近于零,這限制了其對模型訓練的幫助。
如圖所示,當輸入一張馬的圖片時,對于未調整溫度(默認為1)的 Softmax 輸出,正標簽的概率接近 1,而負標簽的概率接近 0。這種尖銳的分布對學生模型不夠友好,因為它只提供了關于正確答案的信息,而忽略了錯誤答案的信息。即驢比汽車更像馬,識別為驢的概率應該大于識別為汽車的概率。
帶溫度的Softmax是一種調整softmax函數輸出的方法,通過引入一個額外參數——溫度(temperature, T),來控制輸出概率分布的平滑度。這個概念在知識蒸餾中尤為重要,因為它能夠影響教師模型如何向學生模型傳遞知識。
- 當T=1時,該公式退化為標準的softmax。
- 當T>1時,輸出的概率分布變得更平滑,即不同類別的概率差異減小,這使得即使是不太可能的類別也會分配到一定的概率值。
- 當T<1時,結果是相反的,輸出的概率分布變得更加尖銳,增加了最有可能類別的概率,同時進一步降低了其他類別的概率。
關于T的取值,原文說是“不要太大,也不要太小”,如果T的值太小,就會導致原本概率值比較小的類別除以后還是比較小,就獲取不到數據中有效的信息。如果T的值太大,負標簽帶的信息會被放大,就會引入噪聲,也就是將一些沒有用信息給放的很大。這個實際還是調參吧,記得之前看到過文章說一般是20以內。當需要考慮負標簽之間的關系時,可以采用較大的溫度。例如,在自然語言處理任務中,模型可能需要學習到“貓”和“狗”之間的相似性,而不僅僅是它們的硬標簽。在這種情況下,較大的溫度可以使模型更好地捕捉到這些關系。反之,如果為了消除負標簽中噪聲的影響,可以采用較小的溫度。
蒸餾過程
準備階段:
- 一個已經預訓練好的、泛化能力強的教師網絡。
- 構造好數據集,這個數據集可以使用用來訓練教師網絡的數據集,也可以專門準備一個數據集來做蒸餾的過程。
- 搭建好學生網絡。
蒸餾過程:
- 輸入數據通過教師網絡: 輸入數據首先通過教師網絡得到logits(即softmax層之前的輸出)。使用較高的溫度t對這些logits應用softmax函數,生成soft labels。
- 輸入數據通過學生網絡: 同樣的輸入數據也通過學生網絡得到自己的logits。 這些logits分別通過相同的溫度t和標準溫度T=1的softmax函數處理,得到soft predictions和hard predictions。Soft predictions用于與教師網絡的soft labels比較,而hard predictions則用于與真實標簽比較。
- 計算損失:
Distillation Loss:使用交叉熵損失函數計算教師網絡的soft labels和學生網絡的soft predictions之間的差異。這一步幫助學生網絡學習到教師網絡對于不同類別間細微差別的理解。
Student Loss:同樣使用交叉熵損失函數計算學生網絡的hard predictions與真實標簽之間的差異。這確保了學生網絡也能直接從數據中學習。
最終損失:這兩個損失(distillation loss和student loss)通常會加權求和形成最終的損失函數。權重的選擇可以根據具體任務調整,但一般情況下,distillation loss的權重相對較小,以避免過度依賴教師網絡的預測。
其中,α是介于0和1之間的一個系數,用來平衡兩個損失項的重要性。作者發(fā)現學生網絡的損失的權重小一點比較好,比如可以嘗試 0.3、0.4。
最后作者說,由于的梯度大約是
的梯度的
,因此在
前乘上
可以保證兩個損失部分的梯度量貢獻基本一致。
4、代碼復現
使用mnist公開數據集簡單復現了上述流程,知識蒸餾主要代碼如下:
def train_kd_student(epoch):
student_kd.train()
optimizer = torch.optim.SGD(student_kd.parameters(), lr=lr, momentum=momentum)
print(f'\nTraining KD Student Epoch: {epoch}')
train_loss = 0
correct = 0
total = 0
with tqdm(train_loader, desc=f"Training KD Student Epoch {epoch}", total=len(train_loader)) as pbar:
for batch_idx, (inputs, targets) in enumerate(pbar):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
logits_student = student_kd(inputs)
with torch.no_grad():
logits_teacher = teacher_net(inputs)
ce_loss = nn.CrossEntropyLoss()(logits_student, targets)
kd_loss = loss(logits_student, logits_teacher, temperature=T)
total_loss = ALPHA * ce_loss + BETA * kd_loss
total_loss.backward()
optimizer.step()
train_loss += total_loss.item()
_, predicted = logits_student.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
# 記錄訓練損失
writers['student_kd'].add_scalar('Training Loss', ce_loss.item(),
epoch * len(train_loader) + batch_idx)
writers['student_kd'].add_scalar('Training Loss', kd_loss.item(),
epoch * len(train_loader) + batch_idx)
writers['student_kd'].add_scalar('Training Loss', total_loss.item(),
epoch * len(train_loader) + batch_idx)
pbar.set_postfix(loss=train_loss/(batch_idx+1), acc=f"{100.*correct/total:.1f}%")
# 記錄訓練準確率
acc = 100. * correct / total
writers['student_kd'].add_scalar('Training Accuracy', acc, epoch)
學生模型和教師模型代碼如下:
class LeNet(nn.Module):
def __init__(self, num_classes=10):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(in_features=16 * 4 * 4, out_features=120)
self.fc2 = nn.Linear(in_features=120, out_features=84)
self.fc3 = nn.Linear(in_features=84, out_features=num_classes)
def forward(self, x):
x = self.maxpool(F.relu(self.conv1(x)))
x = self.maxpool(F.relu(self.conv2(x)))
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
class LeNetHalfChannel(nn.Module):
def __init__(self, num_classes=10):
super(LeNetHalfChannel, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=5)
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(in_features=3 * 12 * 12, out_features=num_classes)
def forward(self, x):
x = self.maxpool(F.relu(self.conv1(x)))
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
return x
# 初始化模型
teacher_net = LeNet().to(device=device)
student_plain = LeNetHalfChannel().to(device=device) # 單獨訓練的學生
student_kd = LeNetHalfChannel().to(device=device) # 知識蒸餾的學生
teacher_net.load_state_dict(torch.load('./model/model.pt'))
對比僅使用學生模型、僅使用教師模型和使用教師模型蒸餾學生模型的模型性能,結果如下:
Teacher Model Accuracy: 97.99%
Plain Student Best Accuracy: 85.77%
KD Student Best Accuracy: 96.70%
可以看到,使用知識蒸餾得到的模型性能十分接近教師模型,遠超僅訓練學生模型的性能。
完整工程代碼地址:https://github.com/jinbo0906/awesome-model-compression/blob/main/knowledge%20distillation/kd.ipynb
5、DeepSeek-R1的蒸餾方法
DeepSeek-R1 的蒸餾過程基于其自身生成的合成推理數據。由DeepSeek-R1 模型生成的 800,000 個數據樣本,對較小的基礎模型(例如 Qwen 和 Llama 系列)僅進行監(jiān)督微調,從而將大型模型的推理能力高效地遷移到小型模型中。
從這里可以了解到,在大模型時代,蒸餾不僅僅完全通過學習教師模型的軟標簽實現,也可以通過學習教師模型的輸出結構化數據,從而學習到教師模型中一些強大的性能。這種方式對算力的需求大大減少,推動未來AI能力的普惠化。
這種方式降低了對模型訓練方法的要求,但是對數據質量的要求則大大增加,從這里也可以看到,DeepSeek-R1模型生成數據的質量非常高,可能遠遠超過了人工標注,在DeepSeek-R1的論文中也提到模型的self-evolution能力,可能是未來非常值得關注的一個研究方向。
6、總結
知識蒸餾不僅能夠有效地降低模型的復雜度和計算成本,還能保持較高的模型性能。通過對教師模型知識的巧妙遷移,學生模型能夠在資源受限的環(huán)境中展現出色的表現。未來知識蒸餾將在更多領域發(fā)揮重要作用。