自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

低內(nèi)存占用也能實(shí)現(xiàn)滿血訓(xùn)練?!北理北大港中文MMLab推出Fira訓(xùn)練框架

人工智能 新聞
來(lái)自北理、北大和港中文MMLab的研究團(tuán)隊(duì)提出了一種滿足低秩約束的大模型全秩訓(xùn)練框架——Fira,成功打破了傳統(tǒng)低秩方法中內(nèi)存占用與訓(xùn)練表現(xiàn)的“非此即彼”僵局。

內(nèi)存占用小,訓(xùn)練表現(xiàn)也要好……大模型訓(xùn)練成功實(shí)現(xiàn)二者兼得。

來(lái)自北理、北大和港中文MMLab的研究團(tuán)隊(duì)提出了一種滿足低秩約束大模型全秩訓(xùn)練框架——Fira,成功打破了傳統(tǒng)低秩方法中內(nèi)存占用與訓(xùn)練表現(xiàn)的“非此即彼”僵局。

圖片

展開(kāi)來(lái)說(shuō)——

為了突破內(nèi)存瓶頸,許多低秩訓(xùn)練方法應(yīng)運(yùn)而生,如LoRA(分解參數(shù)矩陣)和GaLore(分解梯度矩陣)。

圖片△圖1:從宏觀層面分析三種內(nèi)存高效低秩訓(xùn)練方法

然而,如上圖所示,LoRA將訓(xùn)練局限于參數(shù)的低秩子空間,降低了模型的表征能力,難以實(shí)現(xiàn)預(yù)訓(xùn)練;GaLore將訓(xùn)練局限于梯度的低秩子空間,造成了子空間外梯度的信息損失。

相較于全秩訓(xùn)練,這兩種方法由于施加了低秩約束,會(huì)導(dǎo)致訓(xùn)練表現(xiàn)有所下降。

但是,若提高秩值,則會(huì)相應(yīng)地增加內(nèi)存占用。

因此,在實(shí)際應(yīng)用中,它們需要在確保訓(xùn)練表現(xiàn)與降低內(nèi)存消耗之間找到一個(gè)恰當(dāng)?shù)钠胶恻c(diǎn)。

這引發(fā)了一個(gè)核心問(wèn)題:

能否在維持低秩約束以確保內(nèi)存高效的同時(shí),實(shí)現(xiàn)全秩參數(shù)、全秩梯度的訓(xùn)練以提升表現(xiàn)?

Fira即為最新答案,它有三大亮點(diǎn):

  • 即插即用:Fira簡(jiǎn)單易用,其核心實(shí)現(xiàn)僅涉及兩行關(guān)鍵公式,現(xiàn)已封裝進(jìn)Python庫(kù),可直接融入現(xiàn)有的大模型訓(xùn)練流程中,替換原有優(yōu)化器。代碼示例如下:
from fira import FiraAdamW, divide_params
param_groups = divide_params(model, target_modules_list = [“Linear”], rank=8)
optimizer = FiraAdamW(param_groups, lr=learning_rate)
  • 雙贏解決方案:在維持低秩約束的前提下,F(xiàn)ira實(shí)現(xiàn)了大模型的全秩訓(xùn)練,打破了內(nèi)存占用與訓(xùn)練表現(xiàn)的取舍難題。與此同時(shí),區(qū)別于系統(tǒng)方法(如梯度檢查點(diǎn)),F(xiàn)ira不以時(shí)間換內(nèi)存;
  • 實(shí)驗(yàn)驗(yàn)證:Fira在多種規(guī)模的模型(60M至7B參數(shù))以及預(yù)訓(xùn)練微調(diào)任務(wù)中均展現(xiàn)出卓越性能,優(yōu)于現(xiàn)有的LoRA和GaLore,甚至能達(dá)到或超越全秩訓(xùn)練的效果。

打造Fira訓(xùn)練框架

Fira訓(xùn)練框架由兩部分組成:

1) 基于梯度模長(zhǎng)的縮放策略:利用了團(tuán)隊(duì)在大模型低秩和全秩訓(xùn)練中發(fā)現(xiàn)的共通點(diǎn)——自適應(yīng)優(yōu)化器對(duì)原始梯度的修正效應(yīng),實(shí)現(xiàn)了低秩約束下的全秩訓(xùn)練。

2) 梯度模長(zhǎng)限制器,通過(guò)限制梯度模長(zhǎng)的相對(duì)增長(zhǎng)比例,解決了大模型訓(xùn)練中常出現(xiàn)的損失尖峰問(wèn)題。

背景動(dòng)機(jī)

大模型訓(xùn)練常常面臨顯著的內(nèi)存瓶頸,尤其是其中的優(yōu)化器狀態(tài)

舉例來(lái)說(shuō),使用Adam優(yōu)化器從頭預(yù)訓(xùn)練一個(gè)LLaMA 7B模型(batchsize為1,精度為BF16)可能需要至少58GB內(nèi)存。

其中14GB用于加載參數(shù),14GB用于儲(chǔ)存梯度,28GB用于儲(chǔ)存優(yōu)化器狀態(tài),剩下2GB用于儲(chǔ)存激活值。

在這之中,優(yōu)化器狀態(tài)所占內(nèi)存甚至要大于參數(shù)本身。

因此,使用低秩方法來(lái)減少這一部分內(nèi)存,實(shí)現(xiàn)大模型的內(nèi)存高效訓(xùn)練十分重要。

而在現(xiàn)有的低秩方法中,LoRA通過(guò)分解參數(shù)矩陣,使用低秩適配器來(lái)減少內(nèi)存占用;Galore通過(guò)分解梯度矩陣,在自適應(yīng)優(yōu)化器中儲(chǔ)存低秩梯度來(lái)減少內(nèi)存占用。

鑒于使用LoRA低秩適配器方法來(lái)實(shí)現(xiàn)全參數(shù)訓(xùn)練的困難性,團(tuán)隊(duì)選擇拓展Galore的梯度投影方法來(lái)實(shí)現(xiàn)全秩訓(xùn)練。

在Galore中,全秩梯度G?? ? ?mxn,會(huì)被投影矩陣P?? ? ?mxr分解成兩項(xiàng)低秩梯度P??R??和(G??—P??R??),其中圖片。

為減少像Adam這樣的自適應(yīng)優(yōu)化器在內(nèi)存中對(duì)應(yīng)的狀態(tài)占用,Galore僅在優(yōu)化器核心??中保留低秩梯度R??,而非全秩梯度G??

而另一項(xiàng)梯度(G??—P??R??),則會(huì)因?yàn)槿鄙賹?duì)應(yīng)的優(yōu)化器狀態(tài),被Galore直接丟棄,從而造成嚴(yán)重的信息損失。

這也解釋了,為什么Galore的性能會(huì)在rank值減小時(shí),顯著衰減。

△圖2:Fira與Galore及其變體的訓(xùn)練損失對(duì)比

為了彌補(bǔ)上述信息損失,最直觀的方法是直接加上這一部分梯度(G??—P??R??):

圖片

其中,W是參數(shù)矩陣, ??是學(xué)習(xí)率。

然而,如圖所示,使用這種方法(Galore-add)不僅未能帶來(lái)性能提升,反而可能導(dǎo)致訓(xùn)練過(guò)程更加不穩(wěn)定,且結(jié)果更差。

分析原因可歸結(jié)于這一部分的梯度缺乏優(yōu)化器狀態(tài),直接使用會(huì)退化為單純的SGD算法,并且可能與前面使用的Adam優(yōu)化器的梯度不匹配,導(dǎo)致效果不佳。

基于梯度模長(zhǎng)的縮放策略

為了解決上述挑戰(zhàn),團(tuán)隊(duì)提出了scaling factor概念,來(lái)描述Adam這樣的自適應(yīng)優(yōu)化器對(duì)原始梯度的修正效應(yīng),并揭示了它在大模型的低秩訓(xùn)練和全秩訓(xùn)練之間的相似性。

圖片

其中,?? 就是scaling factor,代表經(jīng)過(guò)優(yōu)化器修正過(guò)的梯度與原始梯度的模長(zhǎng)比例。

如下圖,如果根據(jù)scaling factor的平均值對(duì)參數(shù)矩陣進(jìn)行排序,可以發(fā)現(xiàn)低秩和全秩之間的排序非常相似。

△圖3:scaling factor在大模型低秩和全秩訓(xùn)練間的相似性

基于這個(gè)觀察,團(tuán)隊(duì)就嘗試在矩陣層面用低秩梯度R??的scaling factor,作為全秩梯度G??的scaling factor的替代,從而近似地修正(G??—P??R??),彌補(bǔ)其缺少的優(yōu)化器狀態(tài):

圖片

這樣團(tuán)隊(duì)就在低秩約束下成功實(shí)現(xiàn)了全秩訓(xùn)練。

進(jìn)一步來(lái)說(shuō),剛才是從矩陣層面來(lái)考慮scaling factor。

順理成章地,團(tuán)隊(duì)可以從更細(xì)粒度的角度——列的層面,來(lái)考慮scaling factor,實(shí)現(xiàn)更加精細(xì)地修正。

圖片

其中R??,:,?? 是低秩梯度R??的第i列,圖片是scaling factor的第i項(xiàng)。

梯度模長(zhǎng)限制器

在訓(xùn)練過(guò)程中,梯度常常會(huì)突然增大,導(dǎo)致?lián)p失函數(shù)出現(xiàn)尖峰,從而影響訓(xùn)練的表現(xiàn)。

經(jīng)過(guò)分析,可能原因是Galore在切換投影矩陣時(shí)存在不穩(wěn)定性,以及維持(G??—P??R??)這種原始梯度的方向的方式,無(wú)法像Adam這樣的自適應(yīng)算法,有效應(yīng)對(duì)大模型訓(xùn)練中存在的陡峭損失景觀。

圖片△圖4:3種Fira變體的訓(xùn)練損失與梯度模長(zhǎng)

然而,常見(jiàn)的梯度裁剪方法(如圖中的Fira-gradient-clipping)由于采用絕對(duì)裁剪,難以適應(yīng)不同參數(shù)矩陣間梯度的較大差異,從而可能導(dǎo)致次優(yōu)的訓(xùn)練結(jié)果。

為此,團(tuán)隊(duì)提出了一種新的梯度模長(zhǎng)限制器,它通過(guò)限制梯度模長(zhǎng)的相對(duì)增長(zhǎng)比例,來(lái)更好地適應(yīng)不同梯度的變化:

圖片

其中??是比例增長(zhǎng)的上限,S??=????(R??)(G??—P??R??)是原始梯度(G??—P??R??)修正后的結(jié)果。

通過(guò)提出的控制梯度相對(duì)增長(zhǎng)比例的方法,能夠?qū)⑻荻鹊捏E然增大轉(zhuǎn)化為平緩的上升,從而有效穩(wěn)定訓(xùn)練過(guò)程。

如圖2和圖3所示,團(tuán)隊(duì)的限制器成功避免了損失函數(shù)的尖峰情況,并顯著提升了訓(xùn)練表現(xiàn)。

實(shí)驗(yàn)結(jié)果

如下表所示,在預(yù)訓(xùn)練任務(wù)中,F(xiàn)ira在保持內(nèi)存高效的前提下,驗(yàn)證集困惑度(↓)顯著超過(guò)各類(lèi)基線方法,甚至超越全秩方法。

具體來(lái)說(shuō),在預(yù)訓(xùn)練LLaMA 1B模型時(shí),F(xiàn)ira節(jié)約了61.1%優(yōu)化器狀態(tài)所占內(nèi)存,并且取得了比全秩訓(xùn)練更加好的結(jié)果。

圖片△使用C4數(shù)據(jù)集預(yù)訓(xùn)練不同大小的LLaMA模型驗(yàn)證集困惑度(↓)對(duì)比

在預(yù)訓(xùn)練LLaMA 7B模型時(shí),F(xiàn)ira在使用了比Galore小8倍的秩rank的情況下,訓(xùn)練表現(xiàn)遠(yuǎn)超Galore。

這展現(xiàn)了Fira在大規(guī)模大模型上的有效性,以及相較Galore更高的內(nèi)存減少能力。

△使用C4數(shù)據(jù)集預(yù)訓(xùn)練LLaMA 7B的驗(yàn)證集困惑度(↓)對(duì)比

在八個(gè)常識(shí)推理數(shù)據(jù)集微調(diào)LLaMA 7B的任務(wù)中,相較其他基線方法,F(xiàn)ira在一半的數(shù)據(jù)集下表現(xiàn)最好,平均準(zhǔn)確率最高的同時(shí)實(shí)現(xiàn)了內(nèi)存高效。

△在八個(gè)常識(shí)推理數(shù)據(jù)集微調(diào)LLaMA 7B準(zhǔn)確率對(duì)比

另外,消融實(shí)驗(yàn)也顯示了

  • Fira-w.o.-scaling說(shuō)明了Fira使用基于梯度模長(zhǎng)的縮放策略的有效性;
  • Fira-matrix說(shuō)明了從更細(xì)粒度的列級(jí)別,而不是矩陣級(jí)別,考慮scaling factor的有效性;
  • Fira-w.o.-limiter說(shuō)明了Fira中梯度模長(zhǎng)限制器的有效性;
  • Fira-gradient-clipping說(shuō)明了梯度裁剪可能無(wú)法完全解決損失尖峰問(wèn)題,導(dǎo)致結(jié)果次優(yōu)。

△消融實(shí)驗(yàn)

與GaLore相比,F(xiàn)ira的表現(xiàn)幾乎不受秩rank值減少的影響。

在低秩的情況下(rank=16, rank=4),F(xiàn)ira仍然能與全秩訓(xùn)練相當(dāng),相較Galore更加內(nèi)存高效。

△不同rank下的預(yù)訓(xùn)練驗(yàn)證集困惑度(↓)

最后,團(tuán)隊(duì)在不同模型大小,以及低秩和全秩條件下,訓(xùn)練10,000步,并對(duì)得到的矩陣和列級(jí)別上Scaling factor做平均。

接著,使用了斯皮爾曼(Spearman)和肯德?tīng)枺↘endall)相關(guān)系數(shù)分析了Scaling factor在矩陣和列級(jí)別上大小順序的相關(guān)性。

其中,Coefficient中1代表完全正相關(guān),-1代表完全負(fù)相關(guān),而P-value越小越好(通常小于0.05為顯著)。

在所有規(guī)模的LLaMA模型中,Scaling factor在矩陣和列的級(jí)別上都表現(xiàn)出很強(qiáng)的正相關(guān)關(guān)系,并且所有的P-value小于0.05,非常顯著,為Fira中基于梯度模長(zhǎng)的縮放策略提供了堅(jiān)實(shí)的實(shí)驗(yàn)基礎(chǔ)。

△矩陣和列級(jí)別上的Scaling factor低秩與全秩相似性分析

更多細(xì)節(jié)歡迎查閱原論文。

論文鏈接:https://arxiv.org/abs/2410.01623
代碼倉(cāng)庫(kù):https://github.com/xichen-fy/Fira

責(zé)任編輯:張燕妮 來(lái)源: 量子位
相關(guān)推薦

2023-11-10 12:51:29

微軟AI

2025-03-13 10:28:07

2024-07-15 07:30:00

自動(dòng)駕駛AI

2023-11-07 11:50:14

AI訓(xùn)練

2023-10-04 00:16:00

Chinchilla小模型

2025-03-07 10:02:10

2023-04-15 19:37:50

OpenAIGPT-5

2025-03-06 10:00:00

2023-09-12 13:43:00

智能技術(shù)

2025-01-17 09:20:00

2017-06-11 21:55:47

深度學(xué)習(xí)神經(jīng)網(wǎng)絡(luò)模型

2025-02-10 14:05:00

訓(xùn)練模型AI

2017-08-14 16:36:23

ASActivity內(nèi)存

2025-02-24 08:30:00

視覺(jué)模型訓(xùn)練

2021-07-27 11:20:10

模型人工智能深度學(xué)習(xí)

2022-03-18 12:08:10

微分計(jì)算模式

2023-11-06 11:29:02

機(jī)器人視覺(jué)

2022-03-04 19:07:03

模型視覺(jué)人工智能

2024-05-24 08:42:29

智能體訓(xùn)練
點(diǎn)贊
收藏

51CTO技術(shù)棧公眾號(hào)