自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

Mamba作者新作:將Llama3蒸餾成混合線性 RNN

人工智能 新聞
最近,一篇題為《The Mamba in the Llama: Distilling and Accelerating Hybrid Models》的論文證明:通過重用注意力層的權(quán)重,大型 transformer 可以被蒸餾成大型混合線性 RNN,只需最少的額外計(jì)算,同時(shí)可保留其大部分生成質(zhì)量。

Transformer 在深度學(xué)習(xí)領(lǐng)域取得巨大成功的關(guān)鍵是注意力機(jī)制。注意力機(jī)制讓基于 Transformer 的模型關(guān)注與輸入序列相關(guān)的部分,實(shí)現(xiàn)了更好的上下文理解。然而,注意力機(jī)制的缺點(diǎn)是計(jì)算開銷大,會(huì)隨輸入規(guī)模而二次增長(zhǎng),Transformer 也因此難以處理非常長(zhǎng)的文本。

前段時(shí)間,Mamba 的出現(xiàn)打破了這一局面,它可以隨上下文長(zhǎng)度的增加實(shí)現(xiàn)線性擴(kuò)展。隨著 Mamba 的發(fā)布,這些狀態(tài)空間模型 (SSM) 在中小型規(guī)模上已經(jīng)可以與 Transformer 匹敵,甚至超越 Transformer,同時(shí)還能維持隨序列長(zhǎng)度的線性可擴(kuò)展性,這讓 Mamba 具有有利的部署特性。

簡(jiǎn)單來說,Mamba 首先引入了一個(gè)簡(jiǎn)單卻有效的選擇機(jī)制,其可根據(jù)輸入對(duì) SSM 進(jìn)行重新參數(shù)化,從而可讓模型在濾除不相關(guān)信息的同時(shí)無限期地保留必要和相關(guān)的數(shù)據(jù)。

最近,一篇題為《The Mamba in the Llama: Distilling and Accelerating Hybrid Models》的論文證明:通過重用注意力層的權(quán)重,大型 transformer 可以被蒸餾成大型混合線性 RNN,只需最少的額外計(jì)算,同時(shí)可保留其大部分生成質(zhì)量。

由此產(chǎn)生的混合模型包含四分之一的注意力層,在聊天基準(zhǔn)測(cè)試中實(shí)現(xiàn)了與原始 Transformer 相當(dāng)?shù)男阅埽⑶以诹奶旎鶞?zhǔn)測(cè)試和一般基準(zhǔn)測(cè)試中優(yōu)于使用數(shù)萬億 token 從頭開始訓(xùn)練的開源混合 Mamba 模型。此外,該研究還提出了一種硬件感知推測(cè)解碼算法,可以加快 Mamba 和混合模型的推理速度。

圖片

論文地址:https://arxiv.org/pdf/2408.15237

該研究的性能最佳模型是從 Llama3-8B-Instruct 中蒸餾出來的,在 AlpacaEval 2 上相對(duì)于 GPT-4 實(shí)現(xiàn)了 29.61 的長(zhǎng)度控制(length-controlled)勝率,在 MT-Bench 上實(shí)現(xiàn)了 7.35 的勝率,超越了最好的指令調(diào)整線性 RNN 模型。

方法

知識(shí)蒸餾(KD)作為一種模型壓縮技術(shù),用于將大型模型(教師模型)的知識(shí)遷移到較小的模型(學(xué)生模型)中,旨在訓(xùn)練學(xué)生網(wǎng)絡(luò)模仿教師網(wǎng)絡(luò)的行為。該研究旨在對(duì) Transformer 進(jìn)行蒸餾,使其性能與原始語言模型相當(dāng)。

該研究提出了一種多級(jí)蒸餾方法,結(jié)合了漸進(jìn)式蒸餾、監(jiān)督微調(diào)和定向偏好優(yōu)化。與普通蒸餾相比,這種方法可以獲得更好的困惑度和下游評(píng)估結(jié)果。

該研究假設(shè)來自 Transformer 的大部分知識(shí)都保留在從原始模型遷移而來的 MLP 層中,并專注于蒸餾 LLM 的微調(diào)和對(duì)齊步驟。在此階段,MLP 層保持凍結(jié)狀態(tài),Mamba 層進(jìn)行訓(xùn)練。

圖片

該研究認(rèn)為線性 RNN 和注意力機(jī)制之間天然存在一些聯(lián)系。通過刪除 softmax 可以線性化注意力公式:

圖片

但線性化注意力會(huì)導(dǎo)致模型能力退化。為了設(shè)計(jì)一個(gè)有效的蒸餾線性 RNN,該研究盡可能接近原始 Transformer 參數(shù)化,同時(shí)以有效的方式擴(kuò)展線性 RNN 的容量。該研究沒有嘗試讓新模型捕獲精確的原始注意力函數(shù),而是使用線性化形式作為蒸餾的起點(diǎn)。

如算法 1 所示,該研究將來自注意力機(jī)制的標(biāo)準(zhǔn) Q、K、V 頭直接饋入到 Mamba 離散化中,然后應(yīng)用得到的線性 RNN。這可以看作是使用線性注意力進(jìn)行粗略初始化,并允許模型通過擴(kuò)展的隱藏狀態(tài)學(xué)習(xí)更豐富的交互。

圖片

該研究用微調(diào)線性 RNN 層直接替換 Transformer 注意力頭,保持 Transformer MLP 層不變,不訓(xùn)練它們。這種方法還需要處理其他組件,例如跨頭共享鍵和值的分組查詢注意力。研究團(tuán)隊(duì)注意到,這種架構(gòu)與許多 Mamba 系統(tǒng)中使用的架構(gòu)不同,這種初始化允許用線性 RNN 塊替換任何注意力塊。

圖片

該研究還提出了一種使用硬件感知多步生成的線性 RNN 推測(cè)解碼新算法。

算法 2 和圖 2 顯示了完整的算法。該方法僅在緩存中保留一個(gè) RNN 隱藏狀態(tài)以進(jìn)行驗(yàn)證,并根據(jù)多步內(nèi)核的成功來延遲推進(jìn)它。由于蒸餾模型包含 transformer 層,該研究還將推測(cè)解碼擴(kuò)展到 Attention/RNN 混合架構(gòu)。在此設(shè)置中,RNN 層根據(jù)算法 2 執(zhí)行驗(yàn)證,而 Transformer 層僅執(zhí)行并行驗(yàn)證。

圖片

為了驗(yàn)證這種方法的有效性,該研究使用 Mamba 7B 和 Mamba 2.8B 作為目標(biāo)模型進(jìn)行推測(cè)。結(jié)果如表 1 所示。

圖片

圖 3 顯示了多步內(nèi)核本身的性能特征。

圖片

H100 GPU 上的加速。該研究提出的算法在 Ampere GPU 上表現(xiàn)出強(qiáng)大的性能,如上表 1 所示。但在 H100 GPU 上面臨巨大挑戰(zhàn)。這主要是因?yàn)?GEMM 操作速度太快,這使得緩存和重新計(jì)算操作產(chǎn)生的開銷更加明顯。實(shí)際上,該研究的算法的簡(jiǎn)單實(shí)現(xiàn)(使用多個(gè)不同的內(nèi)核調(diào)用)在 3090 GPU 上實(shí)現(xiàn)了相當(dāng)大的加速,但在 H100 上根本沒有加速。

實(shí)驗(yàn)及結(jié)果

該研究使用兩個(gè) LLM 聊天模型進(jìn)行實(shí)驗(yàn):Zephyr-7B 是在 Mistral 7B 模型的基礎(chǔ)上微調(diào)而來, 以及 Llama-3 Instruct 8B。對(duì)于線性 RNN 模型,該研究使用 Mamba 和 Mamba2 的混合版本,其中注意力層分別為 50%、25%、12.5% 和 0%,并將 0% 稱為純 Mamba 模型。Mamba2 是 Mamba 的一種變體架構(gòu),主要針對(duì)最近的 GPU 架構(gòu)而設(shè)計(jì)。

在聊天基準(zhǔn)上的評(píng)估

表 2 顯示了模型在聊天基準(zhǔn)上的性能,主要對(duì)比的模型是大型 Transformer 模型。結(jié)果顯示:

蒸餾后的混合 Mamba 模型 (50%) 在 MT 基準(zhǔn)測(cè)試中取得的分?jǐn)?shù)與教師模型相似,在 LC 勝率和總體勝率方面都略優(yōu)于 AlpacaEval 基準(zhǔn)測(cè)試中的教師模型。

蒸餾后的混合 Mamba (25% 和 12.5%) 的性能在 MT 基準(zhǔn)測(cè)試中略遜于教師模型,但即使在 AlpcaaEval 中具有更多參數(shù),它仍然超越了一些大型 Transformer。

蒸餾后的純 (0%) Mamba 模型的準(zhǔn)確性確實(shí)顯著下降。

值得注意的是,蒸餾后的混合模型的表現(xiàn)優(yōu)于 Falcon Mamba,后者是從頭開始訓(xùn)練的,使用了超過 5T 的 token。

圖片

一般基準(zhǔn)評(píng)估

零樣本評(píng)估。表 3 顯示了從不同教師模型中蒸餾出的 Mamba 和 Mamba2 在 LM Eval 基準(zhǔn)中的零樣本性能。從 Llama-3 Instruct 8B 中蒸餾出的混合 Mamba-Llama3 和 Mamba2-Llama3 模型與從頭開始訓(xùn)練的開源 TRI Mamba 和 Nvidia Mamba 模型相比表現(xiàn)更好。

圖片

基準(zhǔn)評(píng)估。表 4 顯示經(jīng)過蒸餾的混合模型的性能與 Open LLM Leaderboard 上最好的開源線性 RNN 模型相匹配,同時(shí)在 GSM8K 和 CRUX 中優(yōu)于相應(yīng)的開源指令模型。

圖片

混合推測(cè)性解碼

對(duì)于 50% 和 25% 的蒸餾模型,與非推測(cè)基線相比,該研究在 Zephyr-Hybrid 上實(shí)現(xiàn)了超過 1.8 倍的加速。

實(shí)驗(yàn)還表明,該研究訓(xùn)練的 4 層 draft 模型實(shí)現(xiàn)了更高的接收率,不過由于 draft 模型規(guī)模的增加,額外開銷也變大了。在后續(xù)工作中,該研究將專注于縮小這些 draft 模型。

圖片

與其它蒸餾方法的比較:表 6(左)比較了不同模型變體的困惑度。該研究在一個(gè) epoch 內(nèi)使用 Ultrachat 作為種子提示進(jìn)行蒸餾,并比較困惑度。結(jié)果發(fā)現(xiàn)刪除更多層會(huì)使情況變得更糟。該研究還將蒸餾方法與之前的基線進(jìn)行了比較,發(fā)現(xiàn)新方法顯示出較小的退化,而 Distill Hyena 模型是在 WikiText 數(shù)據(jù)集中使用小得多的模型進(jìn)行訓(xùn)練的,并且顯示出較大的困惑度退化。

表 6(右)展示了單獨(dú)使用 SFT 或 DPO 不會(huì)產(chǎn)生太大的改進(jìn),而使用 SFT + DPO 會(huì)產(chǎn)生最佳分?jǐn)?shù)。

圖片

表 7 比較了幾種不同模型的消融研究。表 7(左)展示了使用各種初始化的蒸餾結(jié)果,表 7(右)顯示漸進(jìn)式蒸餾和將注意層與 Mamba 交錯(cuò)帶來的收益較小。

圖片

表 8 比較了使用兩種不同初始化方法的混合模型的性能:結(jié)果證實(shí)注意力權(quán)重的初始化至關(guān)重要。

圖片

表 9 比較了有 Mamba 塊和沒有 Mamba 塊的模型的性能。有 Mamba 塊的模型性能明顯優(yōu)于沒有 Mamba 塊的模型。這證實(shí)了添加 Mamba 層至關(guān)重要,并且性能的提高不僅僅歸功于剩余的注意力機(jī)制。

圖片

感興趣的讀者可以閱讀論文原文,了解更多研究?jī)?nèi)容。

責(zé)任編輯:張燕妮 來源: 機(jī)器之心
相關(guān)推薦

2024-09-10 13:30:00

2024-05-27 09:00:00

2024-09-05 12:27:17

2024-07-16 09:41:01

2024-07-15 08:20:00

2024-05-16 09:20:29

OllamaLlama3框架

2024-05-16 10:44:10

2024-03-04 08:40:44

Llama3AI谷歌

2024-03-04 13:23:34

數(shù)據(jù)模型

2024-03-15 09:00:00

2024-04-25 09:41:24

項(xiàng)目模型

2024-04-30 08:28:44

開源大模型Llama

2025-04-24 08:20:00

C#Llama3人工智能

2024-04-26 07:48:45

DockerLLama3模型

2024-04-02 09:03:43

TransformeMambaRNN

2023-12-29 08:02:17

大模型人工智能AI

2024-05-21 13:06:02

2024-11-27 15:00:00

點(diǎn)贊
收藏

51CTO技術(shù)棧公眾號(hào)