LLM核心損失函數(shù)深度剖析——KL散度與交叉熵?fù)p失
在深度學(xué)習(xí)和機(jī)器學(xué)習(xí)領(lǐng)域,損失函數(shù)是模型優(yōu)化的核心工具之一。它不僅決定了模型的訓(xùn)練方向,還直接影響模型的性能和泛化能力。隨著大語(yǔ)言模型(LLM)的興起,對(duì)損失函數(shù)的理解和應(yīng)用變得更加重要。本文將深入探討兩種常用的損失函數(shù)——KL散度和交叉熵?fù)p失,并分析它們?cè)趯?shí)際應(yīng)用中的區(qū)別和聯(lián)系。
一、KL散度
KL散度(Kullback-Leibler Divergence)用于衡量?jī)蓚€(gè)概率分布之間的相似性,常用于知識(shí)蒸餾(Knowledge Distillation)和對(duì)抗訓(xùn)練(Adversarial Training)等任務(wù)。其公式為:
應(yīng)用場(chǎng)景
- 知識(shí)蒸餾:在知識(shí)蒸餾中,KL散度損失用于衡量學(xué)生模型的輸出分布與教師模型的輸出分布之間的差異。通過(guò)最小化這種差異,學(xué)生模型可以學(xué)習(xí)到教師模型的“知識(shí)”,從而提高性能。
- 對(duì)抗訓(xùn)練:KL散度損失可以用于衡量模型在對(duì)抗樣本上的輸出分布與真實(shí)分布之間的差異,從而增強(qiáng)模型的魯棒性。
物理意義
KL散度(Kullback-Leibler Divergence)的物理意義可以從信息論和統(tǒng)計(jì)學(xué)的角度來(lái)理解,它是一種衡量?jī)蓚€(gè)概率分布之間差異的工具,具有重要的理論和實(shí)際應(yīng)用價(jià)值。
- 信息論角度
KL散度最初來(lái)源于信息論,用于衡量?jī)蓚€(gè)概率分布之間的“信息差距”。具體來(lái)說(shuō),它量化了當(dāng)我們用一個(gè)概率分布Q來(lái)近似另一個(gè)概率分布P時(shí),所導(dǎo)致的額外信息損失。這種信息損失可以理解為編碼數(shù)據(jù)時(shí)所需的“額外比特?cái)?shù)”,即使用Q來(lái)編碼P數(shù)據(jù)時(shí)的效率損失。
從熵的角度來(lái)看,KL散度可以表示為真實(shí)分布P的熵與P和Q之間的交叉熵的差值。因此,KL散度實(shí)際上衡量了使用Q而非P所引入的額外不確定性。 - 非對(duì)稱性和非負(fù)性
KL散度具有兩個(gè)關(guān)鍵性質(zhì):
非負(fù)性:KL散度始終大于等于零,表示兩個(gè)分布之間的差異不會(huì)產(chǎn)生“負(fù)的信息損失”。
非對(duì)稱性:KL散度是不對(duì)稱的,即。這意味著選擇不同的分布作為真實(shí)分布和近似分布會(huì)導(dǎo)致不同的結(jié)果。
KL散度定義為使用Q來(lái)編碼P時(shí)的額外信息量怎么解讀
- 信息量與編碼
在信息論中,信息量通常與編碼長(zhǎng)度相關(guān)。對(duì)于一個(gè)概率分布P,如果某個(gè)事件x發(fā)生的概率為P(x),那么該事件的信息量可以表示為?log?P(x)。這意味著概率越小的事件,其信息量越大,因?yàn)樗鼈兏俺鋈艘饬稀薄?/span> - 使用Q編碼P
當(dāng)我們使用另一個(gè)概率分布Q來(lái)編碼P時(shí),事件x的編碼長(zhǎng)度將基于Q(x)而不是P(x)。因此,事件x的編碼長(zhǎng)度變?yōu)?logQ(x)。 - 額外信息量
使用Q來(lái)編碼P時(shí)的額外信息量,就是基于Q的編碼長(zhǎng)度與基于P的編碼長(zhǎng)度之間的差值。對(duì)于所有可能的事件x,這個(gè)差值的期望值就是KL散度:
這個(gè)公式可以解釋為:對(duì)于每個(gè)事件x,我們計(jì)算使用Q編碼x時(shí)比使用P編碼x多出的信息量,然后根據(jù)P的概率分布對(duì)所有事件求和。
4. 直觀理解
- 如果P(x)和Q(x)非常接近,那么
接近零,表示使用Q編碼P時(shí)的額外信息量很小。
- 如果P(x)和Q(x)差距很大,那么
的絕對(duì)值很大,表示使用Q編碼P時(shí)的額外信息量很大。
二、交叉熵?fù)p失
交叉熵?fù)p失函數(shù)(Cross-Entropy Loss Function)主要用于衡量模型預(yù)測(cè)的概率分布與真實(shí)標(biāo)簽之間的差異。
其中,p表示真實(shí)標(biāo)簽,q表示模型預(yù)測(cè)的標(biāo)簽,N表示樣本數(shù)量。該公式可以看作是一個(gè)基于概率分布的比較方式,即將真實(shí)標(biāo)簽看做一個(gè)概率分布,將模型預(yù)測(cè)的標(biāo)簽也看做一個(gè)概率分布,然后計(jì)算它們之間的交叉熵。
應(yīng)用場(chǎng)景
- 分類問(wèn)題:在分類問(wèn)題中,它通常用于衡量模型的預(yù)測(cè)分布與實(shí)際標(biāo)簽分布之間的差異。
- 下一個(gè)單詞預(yù)測(cè):在語(yǔ)言模型中通過(guò)最小化模型預(yù)測(cè)的概率分布與真實(shí)單詞的概率分布之間的差異,用于下一個(gè)單詞預(yù)測(cè)任務(wù)。
物理意義
交叉熵?fù)p失函數(shù)本質(zhì)上是衡量?jī)蓚€(gè)概率分布之間的差異,這種差異反映了信息的“不確定性”或“信息量”。
- 信息量與不確定性
在信息論中,熵(Entropy)是衡量信息不確定性的一個(gè)重要概念。熵越高,表示信息的不確定性越大;熵越低,表示信息的不確定性越小。例如,一個(gè)均勻分布的隨機(jī)變量(如拋硬幣)具有較高的熵,因?yàn)樗嗟牟淮_定性;而一個(gè)確定性事件(如拋一枚兩面都是正面的硬幣)的熵為零,因?yàn)樗鼪](méi)有任何不確定性。
交叉熵?fù)p失函數(shù)的核心是交叉熵(Cross-Entropy),它衡量的是模型預(yù)測(cè)的概率分布與真實(shí)分布之間的信息量。具體來(lái)說(shuō),交叉熵?fù)p失反映了模型預(yù)測(cè)分布對(duì)真實(shí)分布的“驚訝程度”或“不確定性”。
- 如果模型的預(yù)測(cè)分布與真實(shí)分布完全一致,交叉熵?fù)p失會(huì)達(dá)到最小值。
- 如果模型的預(yù)測(cè)分布與真實(shí)分布相差很大,交叉熵?fù)p失會(huì)很大,表示模型對(duì)真實(shí)結(jié)果感到“非常驚訝”。
- 信息編碼與傳輸
從信息編碼的角度來(lái)看,交叉熵?fù)p失也可以理解為一種“編碼代價(jià)”。假設(shè)我們用模型預(yù)測(cè)的概率分布來(lái)編碼真實(shí)數(shù)據(jù),交叉熵?fù)p失表示了這種編碼所需的“平均比特?cái)?shù)”。
真實(shí)分布P表示數(shù)據(jù)的真實(shí)生成過(guò)程,模型預(yù)測(cè)的分布Q表示模型對(duì)數(shù)據(jù)生成過(guò)程的估計(jì)。如果Q與P非常接近,那么用Q來(lái)編碼P所需的信息量就會(huì)很少(即交叉熵?fù)p失很?。?。反之,如果Q與P差距很大,編碼所需的比特?cái)?shù)就會(huì)很多(即交叉熵?fù)p失很大)。 - 數(shù)學(xué)上的直觀理解
假設(shè)我們有一個(gè)二分類問(wèn)題,真實(shí)標(biāo)簽為y∈{0,1},模型預(yù)測(cè)為正類的概率為p。交叉熵?fù)p失可以表示為:
從這個(gè)公式可以看出,交叉熵?fù)p失懲罰了模型對(duì)真實(shí)結(jié)果的“不確定性”(即p遠(yuǎn)離真實(shí)標(biāo)簽)。當(dāng)模型預(yù)測(cè)越準(zhǔn)確時(shí),損失越小,這與信息論中“減少不確定性”的目標(biāo)一致。
- 如果真實(shí)標(biāo)簽y=1,損失為?log(p)。此時(shí),p越接近 1(即預(yù)測(cè)越準(zhǔn)確),損失越小。
- 如果真實(shí)標(biāo)簽y=0,損失為?log(1?p)。此時(shí),p越接近 0(即預(yù)測(cè)越準(zhǔn)確),損失越小。
分類問(wèn)題為什么用交叉熵?fù)p失函數(shù)不用均方誤差(MSE)
交叉熵?fù)p失函數(shù)通常在分類問(wèn)題中使用,而均方誤差(MSE)損失函數(shù)通常用于回歸問(wèn)題。這是因?yàn)榉诸悊?wèn)題和回歸問(wèn)題具有不同的特點(diǎn)和需求。
分類問(wèn)題的目標(biāo)是將輸入樣本分到不同的類別中,輸出為類別的概率分布。交叉熵?fù)p失函數(shù)可以度量?jī)蓚€(gè)概率分布之間的差異,使得模型更好地?cái)M合真實(shí)的類別分布。它對(duì)概率的細(xì)微差異更敏感,可以更好地區(qū)分不同的類別。此外,交叉熵?fù)p失函數(shù)在梯度計(jì)算時(shí)具有較好的數(shù)學(xué)性質(zhì),有助于更穩(wěn)定地進(jìn)行模型優(yōu)化。
相比之下,均方誤差(MSE)損失函數(shù)更適用于回歸問(wèn)題,其中目標(biāo)是預(yù)測(cè)連續(xù)數(shù)值而不是類別。MSE損失函數(shù)度量預(yù)測(cè)值與真實(shí)值之間的差異的平方,適用于連續(xù)數(shù)值的回歸問(wèn)題。在分類問(wèn)題中使用MSE損失函數(shù)可能不太合適,因?yàn)樗鼘?duì)概率的微小差異不夠敏感,而且在分類問(wèn)題中通常需要使用激活函數(shù)(如sigmoid或softmax)將輸出映射到概率空間,使得MSE的數(shù)學(xué)性質(zhì)不再適用。
綜上所述,交叉熵?fù)p失函數(shù)更適合分類問(wèn)題,而MSE損失函數(shù)更適合回歸問(wèn)題。
CrossEntropyLoss基于pytorch的源代碼實(shí)現(xiàn)
CrossEntropyLoss
的實(shí)現(xiàn)原理CrossEntropyLoss
的計(jì)算過(guò)程可以分為兩步:log_softmax:對(duì)模型的輸出應(yīng)用log_softmax,將輸出轉(zhuǎn)換為對(duì)數(shù)概率分布。
NLLLoss:計(jì)算負(fù)對(duì)數(shù)似然損失,即對(duì)真實(shí)標(biāo)簽對(duì)應(yīng)的對(duì)數(shù)概率取負(fù)值。
數(shù)學(xué)公式可以表示為:
其中,是真實(shí)標(biāo)簽(通常是 one-hot 編碼),
是模型預(yù)測(cè)的類別概率。
- PyTorch 源代碼中的實(shí)現(xiàn)
# PyTorch 的 CrossEntropyLoss 實(shí)現(xiàn)
class CrossEntropyLoss(torch.nn.Module):
def __init__(self, weight=None, size_average=None, ignore_index=-100, reductinotallow='mean'):
super(CrossEntropyLoss, self).__init__()
self.weight = weight
self.ignore_index = ignore_index
self.reduction = reduction
def forward(self, input, target):
return F.cross_entropy(input, target, weight=self.weight, ignore_index=self.ignore_index, reductinotallow=self.reduction)
在torch.nn.functional.cross_entropy中,實(shí)際調(diào)用了log_softmax和nll_loss:
def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100, reductinotallow='mean'):
return F.nll_loss(F.log_softmax(input, 1), target, weight=weight, size_average=size_average, ignore_index=ignore_index, reductinotallow=reduction)
三、KL散度與交叉熵的區(qū)別
定義上的區(qū)別
- KL散度:KL散度是衡量?jī)蓚€(gè)概率分布P和Q之間的差異的度量。它定義為使用Q來(lái)編碼P時(shí)的額外信息量,即P和Q的交叉熵與P的熵之差。
- 交叉熵:交叉熵是衡量使用一個(gè)概率分布Q來(lái)編碼另一個(gè)概率分布P時(shí)所需的平均信息量。
應(yīng)用上的區(qū)別
- KL散度:KL散度在深度學(xué)習(xí)中常用于衡量?jī)蓚€(gè)概率分布之間的差異,如在變分推斷、生成模型(如VAE)、強(qiáng)化學(xué)習(xí)等領(lǐng)域。它也用于檢測(cè)數(shù)據(jù)分布的漂移。
- 交叉熵:交叉熵在深度學(xué)習(xí)中主要用作損失函數(shù),特別是在分類任務(wù)中。它用于衡量模型預(yù)測(cè)的概率分布與真實(shí)標(biāo)簽的概率分布之間的差異,從而指導(dǎo)模型的優(yōu)化。
KL散度與交叉熵的關(guān)系
交叉熵可以表示為KL散度與真實(shí)分布的熵之和:
其中H(P)是真實(shí)分布P的熵。因此,最小化交叉熵等價(jià)于最小化KL散度,因?yàn)檎鎸?shí)分布的熵是固定的。
四、其他
多任務(wù)學(xué)習(xí)各loss差異過(guò)大怎樣處理
多任務(wù)學(xué)習(xí)中,如果各任務(wù)的損失差異過(guò)大,可以通過(guò)動(dòng)態(tài)調(diào)整損失權(quán)重、使用任務(wù)特定的損失函數(shù)、改變模型架構(gòu)或引入正則化等方法來(lái)處理。目標(biāo)是平衡各任務(wù)的貢獻(xiàn),以便更好地訓(xùn)練模型。
如果softmax的e次方超過(guò)float的值了怎么辦
- 可以使用數(shù)值穩(wěn)定性技巧來(lái)避免溢出。具體來(lái)說(shuō),可以在計(jì)算
之前從每個(gè)
中減去 x 中的最大值
:
- 使用對(duì)數(shù)概率來(lái)避免溢出
對(duì)數(shù)函數(shù)log(x)具有將乘法運(yùn)算轉(zhuǎn)換為加法運(yùn)算的性質(zhì):
這使得在處理非常小或非常大的概率值時(shí),可以避免直接相乘導(dǎo)致的數(shù)值下溢(underflow)或上溢(overflow)。
在計(jì)算交叉熵?fù)p失時(shí),可以使用log_softmax
,公式為: