字節(jié)豆包大模型團(tuán)隊(duì)突破殘差連接局限!預(yù)訓(xùn)練收斂最快加速80%
自從 ResNet 提出后,殘差連接已成為深度學(xué)習(xí)模型的基礎(chǔ)組成部分。其主要作用是 —— 緩解梯度消失問題,使得網(wǎng)絡(luò)的訓(xùn)練更加穩(wěn)定。
但是,現(xiàn)有殘差連接變體在梯度消失和表示崩潰之間存在一種 “蹺蹺板式” 的權(quán)衡,無法同時(shí)解決。
為此,字節(jié)豆包大模型 Foundation 團(tuán)隊(duì)于近日提出超連接(Hyper-Connections),針對(duì)上述 “蹺蹺板式” 困境,實(shí)現(xiàn)了顯著提升。
該方法適用于大規(guī)模語言模型(LLMs)的預(yù)訓(xùn)練,在面向 Dense 模型和 MoE 模型的實(shí)驗(yàn)中,展示了顯著性能提升效果,使預(yù)訓(xùn)練收斂速度最高可加速 80%。
研究團(tuán)隊(duì)還發(fā)現(xiàn),超連接在兩個(gè)小型的視覺任務(wù)中表現(xiàn)同樣優(yōu)異,這表明,該方法在多個(gè)領(lǐng)域有廣泛的應(yīng)用前景。
- 論文標(biāo)題:Hyper-Connections
- 論文鏈接:https://arxiv.org/pdf/2409.19606
1. 超連接的核心思想
前文提及,殘差連接的兩種主要變體 Pre-Norm 和 Post-Norm 各自都有其局限性,具體體現(xiàn)如下:
- Pre-Norm:在每個(gè)殘差塊之前進(jìn)行歸一化操作,可有效減少梯度消失問題。然而,Pre-Norm 在較深網(wǎng)絡(luò)中容易導(dǎo)致表示崩潰,即深層隱藏表示過于相似,從而削弱了模型學(xué)習(xí)能力。
- Post-Norm:在殘差塊之后進(jìn)行歸一化操作,有助于減少表示崩潰問題,但也重新引入梯度消失問題。在 LLM 中,通常不會(huì)采用此方法。
超連接的核心思路在于 —— 引入可學(xué)習(xí)的深度連接(Depth-connections)和寬度連接(Width-connections)。
從理論上,這使得模型不僅能夠動(dòng)態(tài)調(diào)整不同層之間的連接強(qiáng)度,甚至能重新排列網(wǎng)絡(luò)層次結(jié)構(gòu),彌補(bǔ)了殘差連接在梯度消失和表示崩潰(Representation Collapse)之間的權(quán)衡困境。
深度連接與寬度連接
起初,該方法會(huì)將網(wǎng)絡(luò)輸入擴(kuò)展為 n 個(gè)隱向量(n 稱作 Expansion rate)。之后每一層的輸入都會(huì)是 n 個(gè)隱向量,超連接會(huì)對(duì)這些隱向量建立以下兩類連接:
- 深度連接(Depth-Connections):這些連接類似于殘差連接,只為輸入與輸出之間的連接分配權(quán)重,允許網(wǎng)絡(luò)學(xué)習(xí)不同層之間的連接強(qiáng)度。
- 寬度連接(Width-Connections):這些連接使得每一層多個(gè)隱藏向量之間可進(jìn)行信息交換,從而提高模型表示能力。
靜態(tài)與動(dòng)態(tài)超連接
超連接可以是靜態(tài)的,也可以是動(dòng)態(tài)的。
其中,靜態(tài)超連接(Static Hyper-Connections, SHC)意味著連接權(quán)重在訓(xùn)練結(jié)束后固定不變。而動(dòng)態(tài)超連接(Dynamic Hyper-Connections, DHC)則對(duì)應(yīng)連接權(quán)重可根據(jù)輸入動(dòng)態(tài)調(diào)整。實(shí)驗(yàn)表明,動(dòng)態(tài)超連接效果更好。
2. 技術(shù)細(xì)節(jié)
超連接(Hyper-connections)
首先,考慮第 k 層的輸入隱藏向量,網(wǎng)絡(luò)的初始輸入為
,并將其復(fù)制 n 次,形成初始的超隱藏矩陣(Hyper Hidden Matrix):
這里,n 稱為擴(kuò)展率(Expansion Rate)。在第 k 層,輸入是上一層的超隱藏矩陣,即:
對(duì)最后一層的超隱藏矩陣逐行求和,得到所需的隱藏向量,并通過一個(gè)投影層輸出網(wǎng)絡(luò)最終的結(jié)果(在 Transformer 中即為歸一化層和解嵌入層)。
為了簡化后續(xù)分析的符號(hào)表示,作者省略層索引,直接將超隱藏矩陣表示為:
超連接可以用一個(gè)矩陣來表示,對(duì)于擴(kuò)展率為 n 的情況,超連接矩陣 HC 如下:
考慮一層網(wǎng)絡(luò),它可能是 Transformer 中的 attention 層或者是 FFN 層。超連接的輸出
可以簡單地表示為:
也就是說,用 作為權(quán)重對(duì)輸入
進(jìn)行加權(quán)求和,得到當(dāng)前層的輸入
:
同時(shí),用于將
映射到殘差超隱藏矩陣
,表示如下:
最終的輸出表達(dá)式為:
偽代碼如下:
動(dòng)態(tài)超連接的實(shí)現(xiàn)
超連接矩陣 的元素可以動(dòng)態(tài)依賴于輸入
,動(dòng)態(tài)超連接的矩陣表示為:
同樣,給定層 和輸入
,可以得到動(dòng)態(tài)超連接的輸出:
在實(shí)際操作中,團(tuán)隊(duì)結(jié)合了靜態(tài)和動(dòng)態(tài)矩陣來實(shí)現(xiàn)動(dòng)態(tài)超連接,動(dòng)態(tài)參數(shù)通過線性變換獲得。
為了穩(wěn)定訓(xùn)練過程,團(tuán)隊(duì)在線性變換前引入歸一化,并在其后應(yīng)用 tanh 激活函數(shù),通過一個(gè)可學(xué)習(xí)的小因子進(jìn)行縮放。動(dòng)態(tài)參數(shù)的計(jì)算公式如下:
實(shí)驗(yàn)表明,動(dòng)態(tài)超連接在語言建模任務(wù)中優(yōu)于靜態(tài)超連接。
3. 為什么使用超連接(Hyper-Connections)
研究團(tuán)隊(duì)認(rèn)為,殘差連接的兩種變體,即前歸一化(Pre-Norm)和后歸一化(Post-Norm),可以被視為不可訓(xùn)練的超連接。
隨后,團(tuán)隊(duì)引入了順序 - 并行二象性概念,展示了超連接如何動(dòng)態(tài)優(yōu)化層的排列以提升網(wǎng)絡(luò)性能。
殘差連接是不可訓(xùn)練的超連接
前歸一化和后歸一化的殘差連接可以表示為以下擴(kuò)展率為 的超連接矩陣:
其中,和
分別表示神經(jīng)網(wǎng)絡(luò)層輸入和輸出的標(biāo)準(zhǔn)差,
表示它們之間的協(xié)方差。
對(duì)于 Pre-Norm,其超連接矩陣是一個(gè) 的矩陣,右下三角部分填充為 1,其余部分為占位符 0。對(duì)于 Post-Norm,權(quán)重依賴于輸入和輸出的方差及協(xié)方差,形成一個(gè)
的矩陣。因此,它們的超連接矩陣是不可訓(xùn)練的。
而本工作提出的方法的超連接矩陣是 矩陣,且權(quán)重是可訓(xùn)練的,甚至可以基于輸入進(jìn)行動(dòng)態(tài)預(yù)測。
順序 - 并行二象性
給定一系列神經(jīng)網(wǎng)絡(luò)模塊,我們可以將它們順序排列或并行排列。作者認(rèn)為,超連接可以學(xué)習(xí)如何將這些層重新排列,形成順序和并行配置的混合。
在不失一般性的情況下,可以將擴(kuò)展率設(shè)置為 n=2。如果超連接以如下矩陣形式學(xué)習(xí),神經(jīng)網(wǎng)絡(luò)將被順序排列:
在這種情況下,深度連接退化為殘差連接,如圖 (a) 所示。
當(dāng)奇數(shù)層和偶數(shù)層的超連接矩陣分別定義為以下形式時(shí),神經(jīng)網(wǎng)絡(luò)每兩層將被并行排列,類似于 Transformer 中的 parallel transformer block 的排列方式,如圖 (b) 所示。
因此,通過學(xué)習(xí)不同形式的超連接矩陣,網(wǎng)絡(luò)層的排列可以超越傳統(tǒng)的順序和并行配置,形成軟混合甚至動(dòng)態(tài)排列。對(duì)于靜態(tài)超連接,網(wǎng)絡(luò)中的層排列在訓(xùn)練后保持固定;而對(duì)于動(dòng)態(tài)超連接,排列可以根據(jù)每個(gè)輸入動(dòng)態(tài)調(diào)整。
4. 實(shí)驗(yàn)結(jié)果
實(shí)驗(yàn)主要集中在大規(guī)模語言模型的預(yù)訓(xùn)練上,涵蓋了 Dense 模型和 MoE 模型。
實(shí)驗(yàn)結(jié)果表明,使用超連接的模型顯著優(yōu)于使用殘差連接的模型。
1B Dense 模型實(shí)驗(yàn)
只要擴(kuò)展率 > 1,效果就十分顯著,且訓(xùn)練更穩(wěn)定,消掉了訓(xùn)練 loss 的 spikes。
7B Dense 模型實(shí)驗(yàn)
團(tuán)隊(duì)甚至 Scale 到了 7B 模型,效果也十分亮眼,同時(shí)可以看到有超連接的網(wǎng)絡(luò)訓(xùn)練更穩(wěn)定。
7B 候選激活 1.3B 的 MoE 模型實(shí)驗(yàn)
可以看到,下游指標(biāo)全漲,在 ARC-Challenge 上甚至漲了 6 個(gè)百分點(diǎn)。
綜上,研究團(tuán)隊(duì)介紹了超連接(Hyper-Connections),它解決了殘差連接在梯度消失和表示崩潰之間的權(quán)衡問題。實(shí)驗(yàn)結(jié)果表明,超連接在大規(guī)模語言模型的預(yù)訓(xùn)練以及視覺任務(wù)中都表現(xiàn)出顯著的性能提升。
值得注意的是,超連接的引入幾乎不增加額外的計(jì)算開銷或參數(shù)量,團(tuán)隊(duì)認(rèn)為,該成果具有廣泛的應(yīng)用潛力,可以推廣到文音視圖模態(tài)的不同任務(wù)上,包括多模態(tài)理解、生成基座模型等。
5. 寫在最后
團(tuán)隊(duì)關(guān)注底層問題,尤其在 LLMs 和多模態(tài)方面,期望實(shí)現(xiàn)更多突破。
更多團(tuán)隊(duì)技術(shù)研究進(jìn)展,可以進(jìn)入「豆包大模型團(tuán)隊(duì)」技術(shù)解讀欄目了解。