大模型訓(xùn)練核心算法之——反向傳播算法 原創(chuàng)
“ 反向傳播是大模型訓(xùn)練的核心,沒(méi)有反向傳播就沒(méi)有大模型”
了解過(guò)大模型技術(shù)的人應(yīng)該都知道,大模型有幾個(gè)核心模塊;對(duì)應(yīng)的也有幾個(gè)核心技術(shù)點(diǎn),比如訓(xùn)練數(shù)據(jù)的準(zhǔn)備,機(jī)器學(xué)習(xí)(神經(jīng)網(wǎng)絡(luò))模型的設(shè)計(jì),損失函數(shù)的設(shè)計(jì),反向傳播算法等。
而今天討論的就是反向傳播算法,其可以說(shuō)是模型訓(xùn)練的核心模塊,沒(méi)有反向傳播模型訓(xùn)練就無(wú)從談起。
那么,反向傳播算法是怎么實(shí)現(xiàn)的呢?其技術(shù)原理是什么?有哪些注意點(diǎn)?
反向傳播算法的實(shí)現(xiàn)
介紹
反向傳播是深度神經(jīng)網(wǎng)絡(luò)訓(xùn)練的核心算法,旨在通過(guò)計(jì)算和傳播梯度來(lái)優(yōu)化模型參數(shù);以下是從原理,實(shí)現(xiàn)和技術(shù)細(xì)節(jié)等多個(gè)方面對(duì)反向傳播進(jìn)行介紹。
原理
反向傳播算法的核心是鏈?zhǔn)椒▌t,目的是通過(guò)計(jì)算損失函數(shù)對(duì)模型參數(shù)的梯度來(lái)優(yōu)化模型。具體來(lái)說(shuō):
鏈?zhǔn)椒▌t:反向傳播利用鏈?zhǔn)椒▌t將損失函數(shù)對(duì)模型輸出的梯度逐層傳播到網(wǎng)絡(luò)中的每個(gè)參數(shù)。鏈?zhǔn)椒▌t的核心思想是:
如果一個(gè)函數(shù) zz 是由兩個(gè)函數(shù) ff 和 gg 組合而成,即 z=f(g(x))z=f(g(x)),那么 zz 對(duì) xx 的導(dǎo)數(shù)可以表示為 dzdx=dzdg?dgdxdxdz=dgdz?dxdg
梯度下降:計(jì)算出的梯度用來(lái)調(diào)整模型參數(shù),以減少損失函數(shù)值;參數(shù)更新的步驟通常是基于梯度下降算法
實(shí)現(xiàn)步驟
前向傳播
在反向傳播之前,首先要進(jìn)行前向傳播以計(jì)算預(yù)測(cè)值和損失差:
輸入數(shù)據(jù):將數(shù)據(jù)傳人網(wǎng)絡(luò)的輸入層
計(jì)算每層的輸出:
對(duì)于沒(méi)一層計(jì)算加權(quán)和并加上偏執(zhí)
應(yīng)用激活函數(shù)得到該層的輸出
計(jì)算損失:用損失函數(shù)(如均方差,交叉熵等)計(jì)算預(yù)測(cè)值與實(shí)際標(biāo)簽之間的差距
計(jì)算損失對(duì)輸出的梯度
損失函數(shù)對(duì)輸出的梯度:計(jì)算損失函數(shù)對(duì)網(wǎng)絡(luò)輸出的偏導(dǎo)數(shù),這一過(guò)程取決于損失函數(shù)的類型
反向傳播梯度
輸出層到倒數(shù)第二層
計(jì)算輸出層的梯度(損失對(duì)輸出的梯度),并通過(guò)鏈?zhǔn)椒▌t計(jì)算每一層的梯度
對(duì)于每層 ll,計(jì)算:
- 激活函數(shù)的導(dǎo)數(shù)。
- 損失函數(shù)對(duì)每個(gè)神經(jīng)元的梯度。
- 權(quán)重和偏置的梯度
從倒數(shù)第二層到第一層:
- 繼續(xù)向前一層傳播梯度。
- 更新每層的權(quán)重和偏置。
更新參數(shù)
使用計(jì)算得到的梯度來(lái)更新權(quán)重和偏置:
WL:=WL?η??L?WL
bL:=bL?η??L?bL
其中,η是學(xué)習(xí)率,?L?WL和 ?L?bL
技術(shù)細(xì)節(jié)
激活函數(shù)和其導(dǎo)數(shù)
常見(jiàn)激活函數(shù):
Sigmoid:σ(x)=11+e?x
ReLU:ReLU(x)=max?(0,x)
Tanh:Tanh(x)=ex?e?xex+e?x
激活函數(shù)的導(dǎo)數(shù):
Sigmoid:σ′(x)=σ(x)?(1?σ(x))
ReLU:ReLU′(x)={1 if x>00 if x≤0
Tanh:Tanh′(x)=1?Tanh2(x)
梯度計(jì)算
權(quán)重梯度:對(duì)于每個(gè)權(quán)重 WW,梯度為:
?W/?L=δ?Aprev
其中 δ是當(dāng)前層的誤差項(xiàng),aprev是前一層的激活值。
- 偏置梯度:對(duì)于每個(gè)偏置 bb,梯度為:
?L?b=δ
參數(shù)更新
學(xué)習(xí)率:決定了每次更新的步長(zhǎng),通常使用較小的學(xué)習(xí)率,以確保穩(wěn)定的收斂
優(yōu)化算法:除了標(biāo)準(zhǔn)的梯度下降,還可以使用動(dòng)量,RMSprop、Adam 等優(yōu)化算法來(lái)提高訓(xùn)練效率與效果
正則化
L1/L2正則化:通過(guò)在損失函數(shù)中加入權(quán)重的L1和L2范數(shù)來(lái)防止過(guò)擬合
Dropout:在訓(xùn)練過(guò)程中隨機(jī)忽略一些神經(jīng)元,以防止網(wǎng)絡(luò)對(duì)訓(xùn)練數(shù)據(jù)的過(guò)擬合
數(shù)值穩(wěn)定性
梯度消失:在深層網(wǎng)絡(luò)中,梯度可能會(huì)變得非常小,導(dǎo)致學(xué)習(xí)過(guò)程緩慢或停滯??梢允褂肦eLU激活函數(shù)或歸一化技術(shù)(如批量歸一化)來(lái)緩解
梯度爆炸:梯度值變得非常大,可能導(dǎo)致訓(xùn)練不穩(wěn)定,可以使用梯度裁剪來(lái)限制梯度大小
實(shí)際應(yīng)用
框架支持:現(xiàn)代深度學(xué)習(xí)框架(如TensorFlow,PyTorch等)提供了自動(dòng)微分功能,簡(jiǎn)化了反向傳播的實(shí)現(xiàn)和梯度計(jì)算
并行計(jì)算:使用GPU加速前向傳播和反向傳播的計(jì)算,提高訓(xùn)練效率
總結(jié)
反向傳播算法通過(guò)計(jì)算損失函數(shù)對(duì)網(wǎng)絡(luò)參數(shù)的梯度,利用鏈?zhǔn)椒▌t將梯度從輸出層逐層傳播到輸入層,從而更新網(wǎng)絡(luò)的權(quán)重與偏執(zhí);其核心在于計(jì)算梯度并利用優(yōu)化算法進(jìn)行參數(shù)更新;掌握反向傳播的原理和技術(shù)細(xì)節(jié)對(duì)于訓(xùn)練神經(jīng)網(wǎng)絡(luò)非常重要。
本文轉(zhuǎn)載自公眾號(hào)AI探索時(shí)代 作者:DFires
原文鏈接:?????https://mp.weixin.qq.com/s/s8owdnwSZ4ggKiNfkCxDQg???
