ICML2024高分!魔改注意力,讓小模型能打兩倍大的模型
改進(jìn)Transformer核心機(jī)制注意力,讓小模型能打兩倍大的模型!
ICML 2024高分論文,彩云科技團(tuán)隊(duì)構(gòu)建DCFormer框架,替換Transformer核心組件多頭注意力模塊(MHA),提出可動(dòng)態(tài)組合的多頭注意力(DCMHA)。
DCMHA解除了MHA注意力頭的查找選擇回路和變換回路的固定綁定,讓它們可以根據(jù)輸入動(dòng)態(tài)組合,從根本上提升了模型的表達(dá)能力。
可以近似理解為,原來每層有固定的H個(gè)注意力頭,現(xiàn)在用幾乎同樣的參數(shù)量和算力,可按需動(dòng)態(tài)組合出多至HxH個(gè)注意力頭。
DCMHA即插即用,可在任何Transformer架構(gòu)中替換MHA,得到通用、高效和可擴(kuò)展的新架構(gòu)DCFormer。
這項(xiàng)工作由來自北京郵電大學(xué)、AI創(chuàng)業(yè)公司彩云科技的研究人員共同完成。
研究人員用在DCFormer基礎(chǔ)上打造的模型DCPythia-6.9B,在預(yù)訓(xùn)練困惑度和下游任務(wù)評估上都優(yōu)于開源Pythia-12B。
DCFormer模型在性能上與那些計(jì)算量是其1.7-2倍的Transformer模型相當(dāng)。
多頭注意力模塊有何局限?
大模型的scaling law告訴我們,隨著算力的提升,模型更大、數(shù)據(jù)更多,模型效果會越來越好。雖然還沒有人能明確說明這條路的天花板有多高,能否達(dá)到AGI,但這確實(shí)是目前大家最普遍的做法。
但除此以外,另一個(gè)問題同樣值得思考:目前絕大多數(shù)大模型都基于Transformer,它們都是用一個(gè)一個(gè)Transformer塊像搭積木一樣搭起來的,那作為積木塊的Transformer本身,還有多大的改進(jìn)提升空間?
這是模型結(jié)構(gòu)研究要回答的基本問題,也正是彩云科技和北京郵電大學(xué)聯(lián)合完成的DCFormer這項(xiàng)工作的出發(fā)點(diǎn)。
在Transformer的多頭注意力模塊(MHA)中,各個(gè)注意力頭彼此完全獨(dú)立的工作。
這個(gè)設(shè)計(jì)因其簡單易實(shí)現(xiàn)的優(yōu)點(diǎn)已在實(shí)踐中大獲成功,但同時(shí)也帶來注意力分?jǐn)?shù)矩陣的低秩化削弱了表達(dá)能力、注意力頭功能的重復(fù)冗余浪費(fèi)了參數(shù)和計(jì)算資源等一些弊端?;诖耍陙碛幸恍┭芯抗ぷ髟噲D引入某種形式的注意力頭間的交互。
根據(jù)Transformer回路理論,在MHA中 ,每個(gè)注意力頭的行為由WQ、WK、WV、WO四個(gè)權(quán)重矩陣刻畫(其中WO由MHA的輸出投影矩陣切分得到)。
其中,WQWK叫做QK回路(或叫查找選擇回路),決定從當(dāng)前token關(guān)注上下文中的哪個(gè)(些)token,例如:
WOWV叫做OV回路(或叫投影變換回路),決定從關(guān)注到的token取回什么信息(或投影什么屬性)寫入當(dāng)前位置的殘差流,進(jìn)而預(yù)測下一個(gè)token。例如:
研究人員注意到,查找(從哪拿)和變換(拿什么)本來是獨(dú)立的兩件事,理應(yīng)可以分別指定并按需自由組合(就像在SQL查詢中,WHERE后的選擇條件和SELECT后的屬性投影是分開寫的一樣),MHA硬把它們放到一個(gè)注意力頭的QKOV里“捆綁銷售”,限制了靈活性和表達(dá)能力。
例如,假設(shè)有個(gè)模型存在注意力頭A、B、C其QK和OV回路能夠完成上面的例子=,那換成:
需要交叉組合現(xiàn)有注意力頭的QK和OV回路,模型就可能“轉(zhuǎn)不過彎兒”了(經(jīng)研究人員系統(tǒng)構(gòu)造的合成測試集驗(yàn)證,<=6B的中小尺寸模型在這類看似簡單的任務(wù)上確實(shí)表現(xiàn)不佳)。
動(dòng)態(tài)組合多頭注意力長啥樣?
以此為出發(fā)點(diǎn),本文研究團(tuán)隊(duì)在MHA中引入compose操作:
如下圖所示,得到DCMHA:
△圖1. DCMHA總體結(jié)構(gòu)
將QWQ和KWK算出的注意力分?jǐn)?shù)矩陣AS和注意力權(quán)重矩陣AW,與VWV相乘之前,對其在num_heads維上做線性映射得到新的矩陣A’,通過不同的線性映射矩陣(composition map),以實(shí)現(xiàn)各種注意力頭組合的效果。
例如圖2(c)中將head 3和7的QK回路與head 1的OV回路組合在一起,形成一個(gè)“新的”注意力頭。
△圖2. 8個(gè)注意力頭的簡化的典型composition map的功能,淺色表示大值
為了最大限度的增強(qiáng)表達(dá)能力,研究人員希望映射矩陣由輸入動(dòng)態(tài)生成,即動(dòng)態(tài)決定注意力頭怎樣組合。
但他們要生成的映射矩陣不是一個(gè),而是對序列中每對源位置的query Qi和目的位置的key Kj,都要生成這樣一個(gè)矩陣,計(jì)算開銷和顯存占用都將難以接受。
為此,他們進(jìn)一步將映射矩陣分解為一個(gè)輸入無關(guān)的靜態(tài)矩陣Wb、一個(gè)低秩矩陣w1w2和一個(gè)對角矩陣Diag(wg)之和,分別負(fù)責(zé)基礎(chǔ)組合、注意力頭間的有限方式(即秩R<=2)的動(dòng)態(tài)組合和頭自身的動(dòng)態(tài)門控(見圖2(d)和圖3(b))。其中后兩個(gè)矩陣由Q矩陣和K矩陣動(dòng)態(tài)生成。
在不犧牲效果的前提下,將計(jì)算和參數(shù)復(fù)雜度降低到幾乎可以忽略的程度(詳見論文中復(fù)雜度分析)。再結(jié)合JAX和PyTorch實(shí)現(xiàn)層面的優(yōu)化,讓DCFormer可以高效訓(xùn)練和推理。
△圖3. Compose的計(jì)算
效果如何?
規(guī)模擴(kuò)展
評估一個(gè)架構(gòu)的好壞,研究人員關(guān)注的最核心指標(biāo)是算力轉(zhuǎn)化為智能的效率(或叫性能算力比),即投入單位算力能帶來的模型性能提升——花更少的算力,得到更好的模型。
從圖4和圖5的scaling law曲線(在對數(shù)坐標(biāo)下,每個(gè)模型架構(gòu)的損失隨算力的變化可畫出一條近似直線,損失越低,模型越好)可以看出,DCFormer可以達(dá)到1.7~2倍算力的Transformer模型的效果,即算力智能轉(zhuǎn)化率提升了1.7~2倍。
△圖4. Transformer和DCFormer的規(guī)模擴(kuò)展效果
△圖5. Pythia和DCPythia的規(guī)模擴(kuò)展效果
怎么理解這個(gè)提升幅度呢?
自2017年Transformer誕生至今,從改進(jìn)性能算力比的角度,GLU MLP和旋轉(zhuǎn)位置編碼RoPE是經(jīng)大量實(shí)踐驗(yàn)證普適有效且被廣泛采用的為數(shù)不多的兩項(xiàng)架構(gòu)改進(jìn)。
在原始Transformer中加入這兩項(xiàng)改進(jìn)的架構(gòu)也叫Transformer++,Llama、Mistral等最強(qiáng)開源模型均采用該架構(gòu)。無論Transformer還是Transformer++架構(gòu),都可通過DCMHA獲得顯著改進(jìn)。
在1.4B模型規(guī)模下,DCMHA的改進(jìn)幅度大于Transformer++的兩項(xiàng)改進(jìn)之和,且擴(kuò)展性更好(圖4下藍(lán)綠線和黑線的對比,DCMHA的改進(jìn)幅度隨算力增加衰減的更慢,以及圖4和圖5的對比)。
可以說,DCFormer讓Transformer的能力又躍上一個(gè)新臺階。
下游任務(wù)評測
研究團(tuán)隊(duì)訓(xùn)練了DCPythia-2.8B和DCPythia-6.9B兩個(gè)模型在主流NLP下游任務(wù)上進(jìn)行測評并和同規(guī)模的開源模型Pythia進(jìn)行比較(訓(xùn)練采用和Pythia完全相同超參數(shù)設(shè)置)。
△表1. DCFormer 和 Pythia 在下游任務(wù)中的表現(xiàn)
從表1中可以看出,DCPythia-2.8B和6.9B不僅在Pile驗(yàn)證集上的ppl 更低,而且在大部分下游任務(wù)上都顯著超過了Pythia,DCPythia6.9B在 ppl 和下游任務(wù)上的平均準(zhǔn)確率甚至超過了Pythia-12B。
DCFormer++2.8B相對于DCPythia-2.8B有進(jìn)一步的提升,驗(yàn)證了DCMHA和Lllama架構(gòu)結(jié)合的有效性。
訓(xùn)練和推理速度
雖然引入DCMHA會帶來額外的訓(xùn)練和推理開銷,但是從表2中可以看出DCFormer++的訓(xùn)練速度是Transformer++的74.5%-89.2%,推理速度則是81.1%-89.7%,而且隨著模型參數(shù)的增長,額外的計(jì)算開銷會逐漸降低。
△表2. Transformer++和DCFormer++的訓(xùn)練和推理速度對比
訓(xùn)練速度是在TPU v3 pod,序列長度為2048,batch_size為1k的情況下對比得到的;推理速度是在A100 80G GPU上進(jìn)行評測的,輸入長度1024,生成長度128。
消融實(shí)驗(yàn)
結(jié)果如下:
△表3. DCMHA的消融實(shí)驗(yàn)
從表3中可以看出以下幾點(diǎn):
- 雖然加入靜態(tài)的組合權(quán)重就可以降低ppl,但引入動(dòng)態(tài)的組合權(quán)重可以進(jìn)一步降低ppl,說明了動(dòng)態(tài)組合的必要性。
- 低秩動(dòng)態(tài)組合比動(dòng)態(tài)門控的效果更好。
- 只用query-wise或者key-wise的動(dòng)態(tài)組合得到的ppl相當(dāng),與DCFormer++的差距很小。
- 在softmax后做注意力頭組合比在softmax前做更有效,可能是因?yàn)閟oftmax后的概率能更直接影響輸出。
- 動(dòng)態(tài)組合權(quán)重的秩無需設(shè)置過大,也說明了組合權(quán)重的低秩性。
此外,研究人員還通過增加局部注意力層的比例和只用query-wise動(dòng)態(tài)組合的方式去進(jìn)一步減少訓(xùn)練和推理開銷,詳見論文Table 10。
總的來說,研究團(tuán)隊(duì)有兩點(diǎn)總結(jié)。
關(guān)于動(dòng)態(tài)權(quán)重:近期Mamba,GLA,RWKV6,HGRN等SSM和線性注意力/RNN的工作,通過引入動(dòng)態(tài)(input-dependent)權(quán)重的方式,追趕上了Transformer++,但DCFormer用動(dòng)態(tài)組合注意力頭的方式說明了在使用 softmax 注意力的情況下,通過引入動(dòng)態(tài)權(quán)重也可以大幅提升Transformer++的效果。
關(guān)于模型架構(gòu)創(chuàng)新:這項(xiàng)工作表明,如果存在一個(gè)具有極限算力智能轉(zhuǎn)化效率的“理想模型架構(gòu)”,當(dāng)前的Transformer架構(gòu)雖已非常強(qiáng)大,但距離這個(gè)理想架構(gòu)很可能還存在很大的差距,仍有廣闊的提升空間。因此,除了堆算力堆數(shù)據(jù)的大力出奇跡路線,模型架構(gòu)創(chuàng)新同樣大有可為。
研究團(tuán)隊(duì)還表示,彩云科技會率先在旗下產(chǎn)品彩云天氣、彩云小譯、彩云小夢上應(yīng)用DCformer。
有關(guān)更多研究細(xì)節(jié),可參閱原始論文。
ICML2024論文鏈接:https://icml.cc/virtual/2024/poster/34047。
Arxiv 論文鏈接:https://arxiv.org/abs/2405.08553。
代碼鏈接:https://github.com/Caiyun-AI/DCFormer。