DeepMind升級(jí)Transformer,前向通過(guò)FLOPs最多可降一半 精華
Transformer 的重要性無(wú)需多言,目前也有很多研究團(tuán)隊(duì)致力于改進(jìn)這種變革性技術(shù),其中一個(gè)重要的改進(jìn)方向是提升 Transformer 的效率,比如讓其具備自適應(yīng)計(jì)算能力,從而可以節(jié)省下不必要的計(jì)算。
正如不久前 Transformer 架構(gòu)的提出之一、NEAR Protocol 聯(lián)合創(chuàng)始人 Illiya Polosukhin 在與黃仁勛的對(duì)話中說(shuō)到的那樣:「自適應(yīng)計(jì)算是接下來(lái)必須出現(xiàn)的。我們要關(guān)注,在特定問(wèn)題上具體要花費(fèi)多少計(jì)算資源。」
其實(shí)人類就天生具備自適應(yīng)計(jì)算的能力 —— 人在解決各種不同的問(wèn)題時(shí),會(huì)自然地分配不同的時(shí)間和精力。
語(yǔ)言建模也應(yīng)如此,為了得到準(zhǔn)確的預(yù)測(cè)結(jié)果,并不需要為所有 token 和序列都投入同樣的時(shí)間或資源。但是,Transformer 模型在一次前向傳播中卻會(huì)為每個(gè) token 花費(fèi)同等的計(jì)算量。這不禁讓人哀嘆:大部分計(jì)算都被浪費(fèi)了!
理想情況下,如果可以不執(zhí)行非必要的計(jì)算,就可以降低 Transformer 的計(jì)算預(yù)算。
條件式計(jì)算這種技術(shù)可在需要執(zhí)行計(jì)算時(shí)才執(zhí)行計(jì)算,由此可以減少總計(jì)算量。之前許多研究者已經(jīng)提出了多種可以評(píng)估何時(shí)執(zhí)行計(jì)算以及使用多少計(jì)算量的算法。
但是,對(duì)于這個(gè)頗具挑戰(zhàn)性的問(wèn)題,普遍使用的解決形式可能無(wú)法很好地應(yīng)對(duì)現(xiàn)有的硬件限制,因?yàn)樗鼈兺鶗?huì)引入動(dòng)態(tài)計(jì)算圖。最有潛力的條件式計(jì)算方法反而可能是那些能協(xié)調(diào)使用當(dāng)前硬件棧的方法,其會(huì)優(yōu)先使用靜態(tài)計(jì)算圖和已知的張量大?。ɑ趯?duì)硬件的最大利用而選取這個(gè)張量大?。?。
近日,Google DeepMind 研究了這個(gè)問(wèn)題,他們希望使用更低的計(jì)算預(yù)算來(lái)縮減 Transformer 使用的計(jì)算量。
- 論文標(biāo)題:Mixture-of-Depths: Dynamically allocating compute in transformer-based language models
- 論文地址:https://arxiv.org/pdf/2404.02258.pdf?
他們?cè)O(shè)想:在每一層中,網(wǎng)絡(luò)必須學(xué)會(huì)為每個(gè) token 做決策,從而動(dòng)態(tài)地分配可用計(jì)算預(yù)算。在他們的具體實(shí)現(xiàn)中,總計(jì)算量由用戶在訓(xùn)練之前設(shè)定并且不再更改,而非網(wǎng)絡(luò)工作時(shí)執(zhí)行決策的函數(shù)。這樣一來(lái),便可以提前預(yù)知并利用硬件效率收益(比如內(nèi)存足跡減少量或每次前向傳播的 FLOPs 減少量)。該團(tuán)隊(duì)的實(shí)驗(yàn)表明:可以在不損害網(wǎng)絡(luò)整體性能的前提下獲得這些收益。
DeepMind 的這個(gè)團(tuán)隊(duì)采用了類似于混合專家(MoE) Transformer 的方法,其中會(huì)在整個(gè)網(wǎng)絡(luò)深度上執(zhí)行動(dòng)態(tài) token 層面的路由決策。
而與 MoE 不同的是,這里他們的選擇是:要么是將計(jì)算應(yīng)用于 token(和標(biāo)準(zhǔn) Transformer 一樣),要么就是通過(guò)一個(gè)殘差連接繞過(guò)它(保持不變,節(jié)省計(jì)算)。另一個(gè)與 MoE 的不同之處是:這里是將這種路由機(jī)制同時(shí)用在 MLP 和多頭注意力上。因此,這也會(huì)影響網(wǎng)絡(luò)處理的鍵值和查詢,因此該路由不僅要決定更新哪些 token,還要決定哪些 token 可供關(guān)注。
DeepMind 將這一策略命名為 Mixture-of-Depths(MoD),以突顯這一事實(shí):各個(gè) token 在 Transformer 深度上通過(guò)不同數(shù)量的層或模塊。我們這里將其翻譯成「混合深度」,見(jiàn)圖 1。
MoD 支持使用者權(quán)衡考量性能與速度。一方面,使用者可以使用與常規(guī) Transformer 同等的訓(xùn)練 FLOPs 來(lái)訓(xùn)練 MoD Transformer,這可為最終的對(duì)數(shù)概率訓(xùn)練目標(biāo)帶來(lái)多達(dá) 1.5% 的提升。另一方面,MoD Transformer 使用更少的計(jì)算量就能達(dá)到與常規(guī) Transformer 同樣的訓(xùn)練損失 —— 每一次前向傳播的 FLOPs 可少最多 50%。
這些結(jié)果表明,MoD Transformer 可以學(xué)習(xí)智能地路由(即跳過(guò)不必要的計(jì)算)。
實(shí)現(xiàn)混合深度(MoD)Transformer
概況來(lái)說(shuō),其策略如下:
- 設(shè)定一個(gè)靜態(tài)的計(jì)算預(yù)算,該預(yù)算低于等價(jià)的常規(guī) Transformer 所需的計(jì)算量;做法是限制序列中可參與模塊計(jì)算(即自注意力模塊和后續(xù)的 MLP)的 token 數(shù)量。舉個(gè)例子,常規(guī) Transformer 可能允許序列中的所有 token 都參與自注意力計(jì)算,但 MoD Transformer 可限定僅使用序列中 50% 的 token。
- 針對(duì)每個(gè) token,每個(gè)模塊中都有一個(gè)路由算法給出一個(gè)標(biāo)量權(quán)重;該權(quán)重表示路由對(duì)各個(gè) token 的偏好 —— 是參與模塊的計(jì)算還是繞過(guò)去。
- 在每個(gè)模塊中,找到最大的前 k 個(gè)標(biāo)量權(quán)重,它們對(duì)應(yīng)的 token 會(huì)參與到該模塊的計(jì)算中。由于必定只有 k 個(gè) token 參與到該模塊的計(jì)算中,因此其計(jì)算圖和張量大小在訓(xùn)練過(guò)程中是靜態(tài)的;這些 token 都是路由算法認(rèn)定的動(dòng)態(tài)且與上下文有關(guān)的 token。
路由方案
該團(tuán)隊(duì)考慮了兩種學(xué)習(xí)到的路由方案(見(jiàn)圖 2):token 選擇型和專家選擇型。
在 token 選擇型路由方案中,路由算法會(huì)跨計(jì)算路徑(比如跨 MoE Transformer 中的專家身份)生成針對(duì)每個(gè) token 的概率分布。然后 token 會(huì)被傳送到它們偏好的路徑(即概率最高的路徑),而輔助損失可以確保所有 token 不會(huì)收斂到同一路徑。token 選擇型路由可能會(huì)有負(fù)載平衡問(wèn)題,因?yàn)椴荒艽_保 token 在可能的路徑之間劃分適當(dāng)。
專家選擇型路由則是將上述方案反過(guò)來(lái):不是讓 token 選擇它們偏好的路徑,而是讓每條路徑基于 token 偏好選擇前 k 個(gè) token(top-k)。這能確保負(fù)載完美平衡,因?yàn)槊織l路徑總是保證 k 個(gè) token。但是,這也可能導(dǎo)致某些 token 被過(guò)處理或欠處理,因?yàn)槟承?token 可能是多條路徑的前 k 名,另一些 token 則可能沒(méi)有相應(yīng)路徑。
DeepMind 的選擇是使用專家選擇型路由,原因有三。
第一,它無(wú)需輔助性的平衡損失。
第二,由于選取前 k 名這一操作取決于路由權(quán)重的幅度,因此該路由方案允許使用相對(duì)路由權(quán)重,這有助于確定當(dāng)前模塊計(jì)算最需要哪些 token;路由算法可以通過(guò)適當(dāng)?shù)卦O(shè)定權(quán)重來(lái)盡力確保最關(guān)鍵的 token 是在前 k 名之中 —— 這是 token 選擇型路由方案無(wú)法做到的。在具體的用例中,有一條計(jì)算路徑本質(zhì)上是 null 操作,因此應(yīng)該避免將重要 token 路由到 null。
第三,由于路由只會(huì)經(jīng)由兩條路徑,因此單次 top-k 操作就能高效地將 token 分成兩個(gè)互斥的集合(每條計(jì)算路徑一個(gè)集合),這能應(yīng)對(duì)上面提及的過(guò)處理或欠處理問(wèn)題。
該路由方案的具體實(shí)現(xiàn)請(qǐng)參看原論文。
采樣
盡管專家選擇型路由有很多優(yōu)點(diǎn),但它也有一個(gè)很明顯的問(wèn)題:top-k 操作是非因果式的。也就是說(shuō),一個(gè)給定 token 的路由權(quán)重是否在前 k 名取決于其之后的路由權(quán)重的值,但在執(zhí)行自回歸采樣時(shí),我們無(wú)法獲得這些權(quán)重。
為了解決這個(gè)問(wèn)題,該團(tuán)隊(duì)測(cè)試了兩種方法。
第一種是引入一個(gè)簡(jiǎn)單的輔助損失;實(shí)踐證明,其對(duì)語(yǔ)言建模主目標(biāo)的影響程度為 0.2%? 0.3%,但卻能夠讓模型自回歸地采樣。他們使用了一個(gè)二元交叉熵?fù)p失,其中路由算法的輸出提供 logit,通過(guò)選取這些 logit 中的 top-k,就能提供目標(biāo)(即,如果一個(gè) token 在 top-k 中,就為 1,否則為 0)。
第二種方法是引入一個(gè)小的輔助 MLP 預(yù)測(cè)器(就像是又一個(gè)路由算法),其輸入與路由算法的一樣(具有 stop gradient),但其輸出是一個(gè)預(yù)測(cè)結(jié)果:token 是否在序列的 top-k 中。該方法不會(huì)影響語(yǔ)言建模目標(biāo),實(shí)驗(yàn)表明也不會(huì)顯著影響該步驟的速度。
有了這些新方法,就可以通過(guò)選擇路由到的 token 來(lái)執(zhí)行自回歸采樣,也可以根據(jù)路由算法的輸出繞過(guò)一個(gè)模塊,這無(wú)需依賴任何未來(lái) token 的信息。實(shí)驗(yàn)結(jié)果表明,這是一種相對(duì)簡(jiǎn)單輔助任務(wù),可以很快實(shí)現(xiàn) 99% 的準(zhǔn)確度。
結(jié)果
訓(xùn)練,isoFLOP 比較
首先,該團(tuán)隊(duì)訓(xùn)練了一些 FLOP 預(yù)算相對(duì)較?。?e18)的模型,以確定最優(yōu)的超參數(shù)(見(jiàn)下圖 3)。
總體而言,可以看到 MoD Transformer 會(huì)將基準(zhǔn) isoFLOP 曲線向右下方拖動(dòng)。也就是說(shuō),最優(yōu)的 MoD Transformer 的損失比最優(yōu)的基準(zhǔn)模型更低,同時(shí)參數(shù)也更多。這種效應(yīng)帶來(lái)了一個(gè)幸運(yùn)的結(jié)果:存在一些和最優(yōu)基準(zhǔn)模型表現(xiàn)一樣好甚至更好的 MoD 模型(同時(shí)步驟速度更快),盡管它們本身在其超參數(shù)設(shè)置下并不是 isoFLOP 最優(yōu)的。舉個(gè)例子,一個(gè) 220M 參數(shù)量的 MoD 變體(圖 3 中的 3 號(hào)模型)稍優(yōu)于 isoFLOP 最優(yōu)基準(zhǔn)模型(參數(shù)量也是 220M,圖 3 中的 1 號(hào)模型),但這個(gè) MoD 變體在訓(xùn)練期間的步驟速度快了 60% 以上。
下圖 4 給出了總 FLOPs 為 6e18、2e19 和 1e20 時(shí)的 isoFLOP 分析??梢钥吹?,當(dāng) FLOP 預(yù)算更大時(shí),趨勢(shì)依然繼續(xù)。
下圖 5 給出了一個(gè)使用交織的路由模塊訓(xùn)練的 MoD Transformer 的路由決策。盡管其中存在大量繞過(guò)模塊的情況,但這個(gè) MoD Transformer 依然能實(shí)現(xiàn)優(yōu)于常規(guī) Transformer 的性能
自回歸評(píng)估
他們也評(píng)估了 MoD 變體的自回歸采樣表現(xiàn),結(jié)果見(jiàn)下圖 6。這些結(jié)果表明 MoD Transformer 所帶來(lái)的計(jì)算節(jié)省不僅僅局限于訓(xùn)練設(shè)置。
混合深度與專家(MoDE)
MoD 技術(shù)可以自然地與 MoE 模型整合起來(lái),組成所謂的 MoDE 模型。下圖 7 展示了 MoDE 及其帶來(lái)的提升。
MoDE 有兩種變體:分階段 MoDE 和集成式 MoDE。
其中分階段 MoDE 是在自注意力步驟之前進(jìn)行路由繞過(guò)或到達(dá) token 的操作;而集成式 MoDE 則是通過(guò)在常規(guī) MLP 專家之間集成「無(wú)操作」專家來(lái)實(shí)現(xiàn) MoD 路由。前者的優(yōu)勢(shì)是允許 token 跳過(guò)自注意力步驟,而后者的好處在于其路由機(jī)制很簡(jiǎn)單。
該團(tuán)隊(duì)注意到,以集成方式實(shí)現(xiàn) MoDE 明顯優(yōu)于直接降低專家的能力、依靠丟棄 token 來(lái)實(shí)現(xiàn)殘差路由的設(shè)計(jì)。
本文轉(zhuǎn)自 機(jī)器之心 ,作者:機(jī)器之心
