如何理解模型的蒸餾和量化
在LLM領域內,經(jīng)常會聽到兩個名詞:蒸餾和量化。這代表了LLM兩種不同的技術,它們之間有什么區(qū)別呢?本次我們就來詳細聊一下。
一、模型蒸餾
1.1 什么是模型蒸餾
模型蒸餾是一種知識遷移技術,通過將一個大規(guī)模、預訓練的教師模型(Teacher Model)所蘊含的知識傳遞給一個規(guī)模較小的學生模型(Student Model),使學生模型在性能上接近教師模型,同時顯著降低計算資源消耗。
以一種更為通俗的方式來解釋:
假設你有一個特別聰明的學霸朋友(大模型),他考試能考100分,但做題速度慢(計算量大),沒法幫你考場作弊。
于是你想:能不能讓學霸把他的“解題思路”教給你,讓你變成一個小號的學霸(小模型),做題又快又準?
這就是模型蒸餾的思想。
1.2 蒸餾的核心原理
學霸的“秘密武器”不是答案本身,而是他的“思考過程”!
- 普通訓練:老師(訓練數(shù)據(jù))直接告訴你答案(標簽),比如“這張圖是貓”。
- 蒸餾訓練:學霸(大模型)不僅告訴你答案,還告訴你:“這張圖80%像貓,15%像豹子,5%像狗”(軟標簽),因為貓和豹子都有毛茸茸的特征。
小模型通過學霸的“思考細節(jié)”,能學得更深,甚至發(fā)現(xiàn)學霸自己都沒總結出的規(guī)律。
1.3 蒸餾的工作原理
- 教師模型訓練:首先訓練一個性能強大的教師模型,該模型通常具有復雜的結構和大量的參數(shù)。
教師模型就是常規(guī)訓練的LLM,比如GPT4。 - 生成軟標簽:教師模型對訓練數(shù)據(jù)進行預測,生成軟標簽(概率分布),這些軟標簽包含了教師模型對各類別的置信度信息。
本質來說就是通過softmax將預測結果轉化為概率分布,表示模型預測每個類別的可能性。 - 學生模型訓練:學生模型使用教師模型生成的軟標簽進行訓練,同時也可以結合真實標簽進行聯(lián)合訓練。通過優(yōu)化損失函數(shù)(KL散度),使學生模型的輸出盡可能接近教師模型的輸出。
注:Kullback-Leibler (KL) 散度,也稱為相對熵,是衡量一個概率分布與第二個參考概率分布之間差異程度的指標。 簡單來說,它衡量的是兩個概率分布有多么不同。 - 微調:在蒸餾完成后,進一步微調學生模型以提高其性能表現(xiàn)
1.4 舉個例子
比如有這樣一個任務:需要識別不同動物的圖片。
- 學霸(大模型):看到一張貓的圖片,輸出概率:貓(95%)、豹子(4%)、狗(1%)。
- 普通小模型:只知道正確答案是“貓”,拼命記貓的特征,但遇到豹子可能認錯。
- 蒸餾后的小模型:學霸告訴它:“重點看耳朵形狀和花紋,貓和豹子有點像,但豹子花紋更復雜”。于是小模型學會區(qū)分細微差別,準確率更高!
1.5 為什么蒸餾有效?
通過硬標簽向軟標簽的轉換,讓笨徒弟(小模型)偷師學霸(大模型)的“內功心法”,而不是只抄答案。
- 硬標簽(正確答案):只告訴小模型“是貓”,就像只背答案,不懂原理。
- 軟標簽(概率分布):告訴小模型“貓、豹子、狗的相似點”,就像學霸教你舉一反三。
- 防止學死記硬背:小模型不會過度依賴訓練數(shù)據(jù)中的偶然特征(減少過擬合)。
1.6 模型蒸餾的具體實現(xiàn)
1.6.1 準備教師模型和學生模型
教師模型:通常是一個預訓練好的復雜模型(如ResNet-50、BERT等)。
學生模型:結構更簡單的小模型(如MobileNet、TinyBERT等),參數(shù)少但需要與教師模型兼容。
1.6.2 定義損失函數(shù)
蒸餾損失(Distillation Loss):學生模型模仿教師模型的輸出分布。
可以使用KL散度或交叉熵衡量兩者的輸出差異。
學生損失(Student Loss):學生模型預測結果與真實標簽的交叉熵。
總損失:加權結合兩種損失:
1.6.3 訓練過程
- 溫度參數(shù):軟化輸出分布,通常取2~5。訓練完成推理時設置為1。
- 數(shù)據(jù)選擇:使用教師模型生成軟標簽的數(shù)據(jù)(可以是訓練集或額外數(shù)據(jù))。
- 優(yōu)化器:選擇Adam、SGD等,學習率通常低于普通訓練(例如0.001)。
- 訓練細節(jié):
- 先固定教師模型,僅訓練學生模型。
- 可以逐步調整溫度參數(shù)或損失權重。
import torch
import torch.nn as nn
import torch.optim as optim
# 定義教師模型和學生模型
teacher_model = ... # 預訓練好的復雜模型
student_model = ... # 待訓練的小模型
# 定義損失函數(shù)
criterion_hard = nn.CrossEntropyLoss() # 學生損失(硬標簽)
criterion_soft = nn.KLDivLoss(reductinotallow='batchmean') # 蒸餾損失(軟標簽)
# 溫度參數(shù)和權重
temperature = 5
alpha = 0.7
# 優(yōu)化器
optimizer = optim.Adam(student_model.parameters(), lr=1e-4)
# 訓練循環(huán)
for inputs, labels in dataloader:
# 教師模型推理(不計算梯度)
with torch.no_grad():
teacher_logits = teacher_model(inputs)
# 學生模型推理
student_logits = student_model(inputs)
# 計算損失
loss_student = criterion_hard(student_logits, labels)
# 軟化教師和學生輸出
soft_teacher = torch.softmax(teacher_logits / temperature, dim=-1)
soft_student = torch.log_softmax(student_logits / temperature, dim=-1)
loss_distill = criterion_soft(soft_student, soft_teacher) * (temperature**2)
# 總損失
total_loss = alpha * loss_distill + (1 - alpha) * loss_student
# 反向傳播
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
二、模型量化
2.1 什么是模型量化
模型量化(Model Quantization)是一種通過降低模型參數(shù)的數(shù)值精度(如將32位浮點數(shù)轉換為8位整數(shù))來壓縮模型大小、提升推理速度并降低功耗的技術。
舉個具體例子:
假設模型記住了一群人的體重:
- 原版:[55.3kg, 61.7kg, 48.9kg](精確到小數(shù)點)
- 量化版:[55kg, 62kg, 49kg](四舍五入取整)
誤差就像體重秤的±0.5kg,不影響判斷「是否超重」
2.2 為什么要模型量化?
1、體積暴減
- 原模型像裝滿礦泉水瓶的箱子(500MB)
- 量化后像壓扁的易拉罐(125MB)
2、速度起飛
- 原來用大象運貨(FP32計算)
- 現(xiàn)在換快遞小車(INT8計算)
NVIDIA顯卡上推理速度提升2-4倍
3、省電耐耗
- 原本手機跑模型像開空調(耗電快)
- 量化后像開電風扇(省電60%)
2.3 如何進行模型量化?
1、劃定范圍
- 找出最輕48.9kg和最重61.7kg
- 就像量身高要站在標尺前
2、標刻度
- 把48.9-61.7kg映射到0-100的整數(shù)
- 公式:量化值 = round( (原值 - 最小值) / 步長 )
- 步長 = (61.7-48.9)/100 = 0.128
3、壓縮存儲
- 55.3kg → (55.3-48.9)/0.128 ≈ 50 → 存為整數(shù)50
- 使用時還原:50×0.128+48.9 ≈ 55.3kg
- 誤差控制:就像買菜抹零,5.2元算5元,差2毛不影響做菜
2.4 常用量化方式
1、事后減肥法(訓練后量化)
- 適用場景:模型已經(jīng)訓練好,直接壓縮
- 操作:像用榨汁機把水果變成果汁(保持營養(yǎng)但損失纖維)
import torch
# 準備模型(插入量化模塊)
model.eval() # 確保模型處于評估模式
model.qconfig = torch.quantization.default_qconfig # 設置默認量化配置
quantized_model = torch.quantization.prepare(model) # 插入觀察器
# 收集校準數(shù)據(jù)
for data, _ in calibration_data:
quantized_model(data.to('cpu')) # 在 CPU 上運行,避免對模型結構的影響
quantized_model = torch.quantization.convert(quantized_model) # 轉換為量化模型
- 優(yōu)點:快!5分鐘搞定
- 缺點:可能損失關鍵精度
2、健康瘦身法(量化感知訓練)
- 適用場景:訓練時就控制模型「體重」
- 操作:像健身教練全程監(jiān)督,邊訓練邊控制飲食
# PyTorch示例(訓練時插偽量化節(jié)點)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm')
model = torch.ao.quantization.prepare_qat(model)
# 正常訓練...
model = torch.ao.quantization.convert(model)
- 優(yōu)點:精度更高(像保留肌肉的減肥)
- 缺點:要重新訓練(耗時久)
3、混合套餐法(混合精度量化)
- 核心思想:重要部分用高精度,次要部分用低精度
例如:
人臉識別:眼睛區(qū)域用FP16,背景用INT8
語音識別:關鍵詞用16bit,靜音段用4bit
雖然說量化后模型不如原模型精度效果好,但是推理性能的提升相較性能損失在可控范圍內,性價比上量化是更優(yōu)的。