自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

三天把Llama訓成Mamba,性能不降,推理更快!

人工智能
近日,Mamba方面又搞出了有意思的研究:來自康奈爾、普林斯頓等機構的研究人員成功將Llama提煉成了Mamba模型,并且設計了新的推測解碼算法,加速了模型的推理。

先來看一張其樂融融的圖片(一眼AI):

圖片圖片

右邊的小羊駝代表Llama,而左邊的蛇(Mamba)也是我們的老熟人了。

至于到底能不能其樂融融,咱就不管了,之所以有此場景,是因為Mamba方面又搞出了有意思的研究:

——如何把Llama變成Mamba?

圖片圖片

論文地址:https://arxiv.org/pdf/2408.15237

代碼地址:https://github.com/jxiw/MambaInLlama

近日,來自康奈爾、普林斯頓等機構的研究人員推出了上面這篇工作,將Llama這樣的大型Transformer提煉成了Mamba模型,

并且成功在Mamba架構上應用了帶有硬件感知的推測解碼算法,提高了整個模型的推理速度。

為什么要把Llama變成Mamba?

因為從頭開始訓練一個大模型太貴了。

Mamba也火了這么長時間了,相關的研究每天都有,但自己訓練大尺寸Mamba模型的卻很少。

目前比較有名的是AI21的Jamba(進化到了1.5版本,最大398B,MoE),以及NVIDIA的Hybrid Mamba2模型(8B)。

圖片圖片

不過世界上有那么多成功的Transformer大模型,而知識就包含在這些模型參數(shù)里。

如果能夠鎖住知識,同時把Transformer微調(diào)成Mamba,不就解決問題了?

圖片圖片

在本文中,研究人員結(jié)合漸進式蒸餾、監(jiān)督微調(diào)(SFT)和定向偏好優(yōu)化(DPO)等方法達成了這一目標。

光是變大還不夠

在性能匹配Transformer的前提下,速度也要夠快才行。

Mamba憑借固定的推理開銷,在長序列中的優(yōu)勢明顯,但Transformer這邊也是有推理加速方案的,比如推測解碼。

而由于Mamba本身的結(jié)構特性,不能直接應用這種方案,所以作者設計了全新的算法,并結(jié)合硬件的性質(zhì)來實現(xiàn)基于Mamba的推測解碼。

圖片圖片

最終,研究人員將Zephyr-7B、Llama-3 8B提煉為了線性RNN模型(混合Mamba和Mamba2),且性能與蒸餾之前的標準模型相當。

圖片

整個訓練過程只使用了20B的token,效果卻能夠與使用1.2T個token從頭開始訓練的Mamba 7B模型,以及使用3.5T個token訓練的NVIDIA Hybrid Mamba2模型相媲美。

圖片

從 Transformer 到 Mamba

在介紹Mamba 2的時候我們講過,線性RNN(或SSM)跟線性注意力是一回事。

所以可以根據(jù)x,B,C與V,K,Q的對應關系直接復用注意力中的投影矩陣。

圖片圖片

額外的參數(shù)包括SSM需要的A矩陣和Δt(由x投影得到),這就完成了基本的參數(shù)初始化。

之后就是SSM的運算過程,再通過投影和累加得到輸出。

模型架構和訓練

下圖給出了模型的架構,因為Transformer的知識存在于MLP層,所以凍結(jié)這部分參數(shù)。

除了用線性RNN層(Mamba)替換掉注意力頭,還有一些組件需要處理,比如跨頭共享鍵和值的分組查詢注意力(GQA)。

圖片圖片

知識蒸餾(Knowledge distillation,KD)是一種常用的壓縮技術,用來訓練模仿較大模型(teacher)行為的較小網(wǎng)絡(student)。

圖片圖片

根據(jù)經(jīng)驗,這里采用逐步替換Attention層的策略,先是每2層進行蒸餾,然后每4層繼續(xù)蒸餾......

監(jiān)督微調(diào)

有兩種常見的蒸餾方法。一種方法是使用word-level的KL散度,此時訓練student模型去匹配teacher模型輸出的完整概率分布。

第二種方法是序列級知識蒸餾(SeqKD),直接使用teacher模型的輸出作為ground truth來訓練student模型(也稱為偽標簽)。

圖片

這里θ是student模型的可訓練參數(shù),α和β分別控制序列和詞的loss項的權重。

偏好優(yōu)化

LLM指令調(diào)優(yōu)的第二階段是使其符合用戶偏好。這個階段,使用一組期望的偏好對來改進模型的輸出。

優(yōu)化的目標是使獎勵模型最大化,同時保持產(chǎn)生的輸出接近參考模型。

通常,參考模型使用上一步監(jiān)督微調(diào)后的模型。這里因為是蒸餾,直接可以用teacher模型:

圖片圖片

偏好模型的獎勵函數(shù)定義取決于所使用的方法,本文采用直接偏好優(yōu)化(DPO),通過直接梯度更新有效地到達優(yōu)化目標。

圖片圖片

DPO表明,對于給定的提示x ,如果我們能夠獲得preferred和dispreferred兩種輸出,就可以將這個優(yōu)化問題重新表述為:

圖片圖片

這種優(yōu)化可以在序列級別上執(zhí)行,讓teacher模型和student模型一起對preferred和dispreferred輸出進行評分,然后反向傳播給student模型。

推測解碼

經(jīng)過上面的一套小連招,模型轉(zhuǎn)換就搞定了,下面開始想辦法應用Transformer那邊的推測解碼。

推測解碼(Speculative Decoding)可以簡單理解為下面這張圖。

圖片圖片

Transformer做推理的時候,除了要處理不斷變長的KV cache之外,計算效率也是個問題。

因為顯卡的設計是計算高于訪存的,具體到計算單元就是做矩陣乘法。

而推理的時候每次只能進入一個詞向量,顯卡的很多計算就被浪費了。

圖片圖片

推測解碼給出的解決方案是,使用一個小模型做生成,然后拿顯卡多余的計算做驗證。

小模型跑得快,可以一口氣生成很多輸出向量,但是可能效果差一點。這時候用大模型作為驗證,一次計算之前生成的很多個向量。

所以小模型串行跑得快,大模型可以并行計算跑得也快,遇到驗證不通過的就直接回滾,整體上提高了推理的速度。

圖片圖片

Transformer可以方便地回滾,因為KV cache跟時間是一一對應的,但Mamba這邊只有一個當前的中間狀態(tài)ht,你總不能把所有中間狀態(tài)都存起來吧。

為了解決這個問題,研究人員設計了下面的算法:

圖片圖片

簡單來說就是每次使用小模型(draft model)生成一組輸出,然后大模型(verification model)驗證這一組輸出,根據(jù)驗證匹配的位置來更新需要保存的中間狀態(tài)。

我們可以從下面的偽代碼了解詳細的過程:

圖片圖片

每次生成K個草稿輸出,驗證模型通過MultiStep函數(shù)返回K個真正的輸出,以及上一次校驗成功位置的cache(中間狀態(tài)hj)和本次最后位置的cache(hk)。

圖片圖片

Multi-Step內(nèi)核的性能特征

通過FirstConflict函數(shù)找到最后匹配(校驗成功)的位置,如果所有都匹配,則cache可以更新到最后的hk,否則就只更新到上一次的hj。

兵馬后動,糧草先行,不耽誤輸出和校驗,同時只需要多存儲一個中間狀態(tài)。

當然,如果草稿模型也用Mamba的話,算法的推測部分會變得復雜一些,因為草稿模型需要重新計算上一次迭代中驗證成功位置的狀態(tài)。

硬件特定優(yōu)化

下面使用Mamba 7B和 Mamba 2.8B作為目標模型進行推測實驗。

最初,作者搞了一版簡單的算法實現(xiàn),結(jié)果在Ampere架構的GPU(3090)上面效果顯著,Mamba 2.8B獲得了1.5倍的推理加速, 同時有60%的接受率。

但是這種實現(xiàn)方式在H100 GPU上不太好使,主要是因為GEMM操作的速度更快了,使得緩存和重新計算產(chǎn)生的開銷更加明顯。

圖片

所以,作者通過融合內(nèi)核以及調(diào)整實現(xiàn)方式來優(yōu)化算法。

對于驗證模型,首先從緩存中重新計算之前的步驟,然后對新的草稿token序列進行多步解碼,最后在單個內(nèi)核中進行緩存。

對于草稿模型,重新計算、解碼和緩存也融合在單個內(nèi)核中。最終實現(xiàn)了上表中的加速效果。

實驗

研究人員使用兩個LLM聊天模型進行實驗:Zephyr-7B和Llama-3 Instruct 8B。

采用三階段蒸餾。在第一階段,使用UltraChat和UltraFeedback作為種子提示,并使用teacher模型生成偽標簽。

使用AdamW優(yōu)化器訓練模型,β=(0.9,0.98) ,批量大小64。先使用線性學習率預熱,然后進行余弦退火。

第二階段,在一個epoch中使用SFT在GenQA、InfinityInstruct和OpenHermes 2.5數(shù)據(jù)集上對模型進行監(jiān)督微調(diào),采用與Zephyr相同的超參數(shù)。

最后一個階段,對于從Zephyr中提取的模型,在UltraFeedback數(shù)據(jù)集上使用DPO與標準模型進行蒸餾對齊。

過程中只在第一階段凍結(jié)MLP層,后兩個階段所有參數(shù)都進行訓練。

作者表示,通常只需要在8卡80G A100上運行3到4天,即可重現(xiàn)本文的結(jié)果。

圖片

參考資料:https://arxiv.org/abs/2408.15237

責任編輯:武曉燕 來源: 新智元
相關推薦

2024-09-02 08:45:00

模型生成

2024-03-04 13:23:34

數(shù)據(jù)模型

2025-04-21 09:07:00

2025-04-09 10:40:32

2024-09-24 13:00:00

大語言模型AI

2024-09-13 09:14:32

2023-11-23 13:23:41

AI訓練

2024-04-01 12:09:16

模型數(shù)據(jù)

2025-01-20 07:58:51

2025-03-05 00:15:00

2024-05-30 13:10:10

2024-06-11 08:25:00

2024-08-16 12:46:08

2025-04-18 07:43:41

2024-08-13 12:49:29

2023-03-08 18:43:50

GPU模型隔離

2024-09-10 13:30:00

2024-09-23 08:20:00

模型訓練

2024-08-15 15:45:00

AI訓練

2012-05-15 15:32:45

Web App瀏覽器
點贊
收藏

51CTO技術棧公眾號