自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制

發(fā)布于 2024-6-28 10:56
瀏覽
0收藏

本文作者李宏康,美國倫斯勒理工大學(xué)電氣、計算機(jī)與系統(tǒng)工程系在讀博士生,本科畢業(yè)于中國科學(xué)技術(shù)大學(xué)。研究方向包括深度學(xué)習(xí)理論,大語言模型理論,統(tǒng)計機(jī)器學(xué)習(xí)等等。目前已在 ICLR/ICML/Neurips 等 AI 頂會發(fā)表多篇論文。


上下文學(xué)習(xí) (in-context learning, 簡寫為 ICL) 已經(jīng)在很多 LLM 有關(guān)的應(yīng)用中展現(xiàn)了強(qiáng)大的能力,但是對其理論的分析仍然比較有限。人們依然試圖理解為什么基于 Transformer 架構(gòu)的 LLM 可以展現(xiàn)出 ICL 的能力。


近期,一個來自美國倫斯勒理工大學(xué)和 IBM 研究院的團(tuán)隊從優(yōu)化和泛化理論的角度分析了帶有非線性注意力模塊 (attention) 和多層感知機(jī) (MLP) 的 Transformer 的 ICL 能力。他們特別從理論端證明了單層 Transformer 首先在 attention 層根據(jù) query 選擇一些上下文示例,然后在 MLP 層根據(jù)標(biāo)簽嵌入進(jìn)行預(yù)測的 ICL 機(jī)制。該文章已收錄在 ICML 2024。


ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)


  • 論文題目:How Do Nonlinear Transformers Learn and Generalize in In-Context Learning?
  • 論文地址:https://arxiv.org/pdf/2402.15607


背景介紹


上下文學(xué)習(xí) in context learning (ICL)


上下文學(xué)習(xí) (ICL) 是一種新的學(xué)習(xí)范式,在大語言模型 (LLM) 中非常流行。它具體是指在測試查詢 (testing query)

ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)

前添加 N 個測試樣本 testing examples (上下文),即測試輸入

ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)

和測試輸出

ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)

的組合,從而構(gòu)成一個 testing prompt:

ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)

,作為模型的輸入以引導(dǎo)模型作出正確的推斷。這種方式不同于經(jīng)典的對預(yù)訓(xùn)練模型進(jìn)行微調(diào)的方式,它不需要改變模型的權(quán)重,從而更加的高效。


ICL 理論工作的進(jìn)展


近期的很多理論工作都是基于 [1] 所提出的研究框架,即人們可以直接使用 prompt 的格式來對 Transformer 進(jìn)行訓(xùn)練 (這一步也可以理解為在模擬一種簡化的 LLM 預(yù)訓(xùn)練模式),從而使得模型具有 ICL 能力。已有的理論工作聚焦于模型的表達(dá)能力 (expressive power) 的角度 [2]。他們發(fā)現(xiàn),人們能夠找到一個有著 “完美” 的參數(shù)的 Transformer 可以通過前向運算執(zhí)行 ICL,甚至隱含地執(zhí)行梯度下降等經(jīng)典機(jī)器學(xué)習(xí)算法。但是這些工作無法回答為什么 Transformer 可以被訓(xùn)練成這樣 “完美” 的,具有 ICL 能力的參數(shù)。因此,還有一些工作試圖從 Transformer 的訓(xùn)練或泛化的角度理解 ICL 機(jī)制 [3,4]。不過,受制于分析 Transformer 結(jié)構(gòu)的復(fù)雜性,這些工作目前止步于研究線性回歸任務(wù),而所考慮的模型通常會略去 Transformer 中的非線形部分。


本文從優(yōu)化和泛化理論的角度分析了帶有非線性 attention 和 MLP 的 Transformer 的 ICL 能力和機(jī)制:


  • 基于一個簡化的分類模型,本文具體量化了數(shù)據(jù)的特征如何影響了一層單頭 Transformer 的域內(nèi) (in-domain) 和域外 (out-of-domain, OOD) 的 ICL 泛化能力。
  • 本文進(jìn)一步闡釋了 ICL 是如何通過被訓(xùn)練的 Transformer 來實現(xiàn)了。
  • 基于被訓(xùn)練的 Transformer 的特點,本文還分析了在 ICL 推斷的時候使用基于幅值的模型剪枝 (magnitude-based pruning) 的可行性。


理論部分


問題描述


本文考慮一個二分類問題,即將

ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)

通過一個任務(wù)

ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)

映射到

ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)

。為了解決這樣的一個問題,本文構(gòu)建了 prompt 來進(jìn)行學(xué)習(xí)。這里的 prompt 被表示為:


ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)


訓(xùn)練網(wǎng)絡(luò)為一個單層單頭 Transformer:


ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)


預(yù)訓(xùn)練過程是求解一個對所有訓(xùn)練任務(wù)的經(jīng)驗風(fēng)險最小化 (empirical risk minimization)。損失函數(shù)使用的是適合二分類問題的 Hinge loss,訓(xùn)練算法是隨機(jī)梯度下降。


本文定義了兩種 ICL 泛化的情況。一個是 in-domain 的,即泛化的時候測試數(shù)據(jù)的分布和訓(xùn)練數(shù)據(jù)一樣,注意這個情況里面測試任務(wù)不必和訓(xùn)練任務(wù)一樣,即這里已經(jīng)考慮了對未見任務(wù) (unseen task) 的泛化。另一個是 out-of-domain 的,即測試、訓(xùn)練數(shù)據(jù)分布不一樣。


本文還涉及了在 ICL 推斷的時候進(jìn)行 magnitude-based pruning 的分析,這里的剪枝方式是指對于訓(xùn)練得到的中的各個神經(jīng)元,根據(jù)其幅值大小,進(jìn)行從小到大的刪除。


對數(shù)據(jù)和任務(wù)的構(gòu)建


這一部分請參考原文的 Section 3.2,這里只做一個概述。本文的理論分析是基于最近比較火熱的 feature learning 路線,即通常將數(shù)據(jù)假設(shè)為可分(通常是正交)的 pattern,從而推導(dǎo)出基于不同 pattern 的梯度變化。本文首先定義了一組 in-domain-relevant (IDR) pattern 用于決定 in-domain 任務(wù)的分類,和一組與任務(wù)無關(guān)的 in-domain-irrelevant (IDI) pattern,這些 pattern 之間互相正交。IDR pattern 有

ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)

個,IDI pattern 有

ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)

個。一個

ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)

被表示為一個 IDR pattern 和一個 IDI pattern 的和。一個 in-domain 任務(wù)就被定義為基于某兩個 IDR pattern 的分類問題。


類似地,本文通過定義 out-of-domain-relevant (ODR) pattern 和 out-of-domain-irrelevant (ODI) pattern,可以刻畫 OOD 泛化時候的數(shù)據(jù)和任務(wù)。


本文對 prompt 的表示可以用下圖的例子來闡述,其中

ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)

是 IDR pattern,

ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)

是 IDI pattern。這里在做的任務(wù)是基于 x 中的

ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)

做分類,如果是

ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)

那么其標(biāo)簽為 + 1,對應(yīng)于 +q,如果是

ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)

那么其標(biāo)簽為 - 1,對應(yīng)于 -q。α,α' 分別被定義為訓(xùn)練和測試 prompt 中跟 query 的 IDR/ODR pattern 一樣的上下文示例。下圖中的例子里面,

ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)

。


ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)


理論結(jié)果


首先,對于 in-domain 的情況,本文先給了一個 condition 3.2 來規(guī)定訓(xùn)練任務(wù)需要滿足的條件,即訓(xùn)練任務(wù)需要覆蓋所有的 IDR pattern 和標(biāo)簽。然后 in-domain 的結(jié)果如下:


ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)


這里表明:1,訓(xùn)練任務(wù)的數(shù)量只需要在全部任務(wù)中占比達(dá)到滿足 condition 3.2 的小比例,我們就可以對 unseen task 實現(xiàn)很好的泛化;2,跟當(dāng)前任務(wù)相關(guān)的 IDR pattern 在 prompt 中的比例越高,就可以以更少的訓(xùn)練數(shù)據(jù),訓(xùn)練迭代次數(shù),以及更短的 training/testing prompt 實現(xiàn)理想的泛化。


接下來是 out-of-domain 泛化的結(jié)果。


ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)


這里說明,如果 ODR pattern 是 IDR pattern 的線性組合且系數(shù)和大于 1,那么此時 OOD ICL 泛化可以達(dá)到理想的效果。這個結(jié)果給出了在 ICL 的框架下,好的 OOD 泛化所需要的訓(xùn)練和測試數(shù)據(jù)之間的內(nèi)在聯(lián)系。該定理也通過 GPT-2 的實驗得到了驗證。如下圖所示,當(dāng) (12) 中的系數(shù)和

ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)

大于 1 的時候,OOD 分類可以達(dá)到理想的結(jié)果。與此同時,當(dāng)

ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)

,即 prompt 中和分類任務(wù)相關(guān)的 ODR/IDR pattern 比例越高的時候,所需要的 context 長度越小。


ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)


然后,本文給出了帶有 magnitude-based pruning 的 ICL 泛化結(jié)果。


ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)


這個結(jié)果表明,首先,訓(xùn)練得到的

ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)

中有一部分(常數(shù)比例)神經(jīng)元的幅值很小,而剩下的相對比較大(公式 14)。當(dāng)我們只枝剪小神經(jīng)元的時候,對泛化結(jié)果基本沒有影響,而當(dāng)枝剪比例增加到要剪大神經(jīng)元的時候,泛化誤差會隨之顯著變大(公式 15,16)。以下實驗驗證了定理 3.7。下圖 A 中淺藍(lán)色的豎線表示訓(xùn)練得到的

ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)

呈現(xiàn)出了公式 14 的結(jié)果。而對小神經(jīng)元進(jìn)行枝剪不會使泛化變差,這個結(jié)果符合理論。圖 B 反映出當(dāng) prompt 中和任務(wù)相關(guān)的上下文越多的時候,我們可以允許更大的枝剪比例以達(dá)到相同的泛化性能。


ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)


ICL 機(jī)制


通過對預(yù)訓(xùn)練過程的刻畫,本文得到了單層單頭非線性 Transformer 做 ICL 的內(nèi)在機(jī)制,這一部分在原文的 Section 4。該過程可以用下圖表示。


ICML 2024 | 揭示非線形Transformer在上下文學(xué)習(xí)中學(xué)習(xí)和泛化的機(jī)制-AI.x社區(qū)


簡而言之,attention 層會選擇和 query 的 ODR/IDR pattern 一樣的上下文,賦予它們幾乎全部 attention 權(quán)重,然后 MLP 層會重點根據(jù) attention 層輸出中的標(biāo)簽嵌入來作出最后的分類。


總結(jié)


本文講解了在 ICL 當(dāng)中,非線性 Transformer 的訓(xùn)練機(jī)制,以及對于新任務(wù)和分布偏移數(shù)據(jù)的泛化能力。理論結(jié)果對于設(shè)計 prompt 選擇算法和 LLM 剪枝算法有一定實際意義。


本文轉(zhuǎn)自 機(jī)器之心 ,作者:機(jī)器之心


原文鏈接:??https://mp.weixin.qq.com/s/SJQiIp1W5kwWSVJaOXA9yA??

標(biāo)簽
已于2024-6-28 10:57:20修改
收藏
回復(fù)
舉報
回復(fù)
相關(guān)推薦