三行代碼無損加速40%,尤洋團隊AI訓(xùn)練加速器入選ICLR Oral論文
用剪枝的方式加速AI訓(xùn)練,也能實現(xiàn)無損操作了,只要三行代碼就能完成!
今年的深度學(xué)習(xí)頂會ICLR上,新加坡國立大學(xué)尤洋教授團隊的一項成果被收錄為Oral論文。
利用這項技術(shù),可以在沒有損失的前提下,節(jié)約最高40%的訓(xùn)練成本。
這項成果叫做InfoBatch,采用的依然是修剪樣本的加速方式。
但通過動態(tài)調(diào)整剪枝的內(nèi)容,InfoBatch解決了加速帶來的訓(xùn)練損失問題。
而且即插即用,不受架構(gòu)限制,CNN網(wǎng)絡(luò)和Transformer模型都能優(yōu)化。
目前,該算法已經(jīng)受到了多家云計算公司的關(guān)注。
那么,InfoBatch能實現(xiàn)怎樣的加速效果呢?
無損降低40%訓(xùn)練成本
研究團隊在多個數(shù)據(jù)集上開展的實驗。都驗證了InfoBatch的有效性。
這些實驗涵蓋的任務(wù)包括圖像的分類、分割和生成,以及語言模型的指令微調(diào)等。
在圖像分類任務(wù)上,研究團隊使用CIFAR10和CIFAR100數(shù)據(jù)集訓(xùn)練了ResNet-18。
結(jié)果在30%、50%和70%的剪枝率下,InfoBatch的準(zhǔn)確率都超越了隨機剪枝和其他baseline方法,而且在30%的剪枝率下沒有任何精度損失。
在剪枝率從30%增加到70%的過程中,InfoBatch的精度損失也顯著低于其他方式。
使用ImageNet-1K數(shù)據(jù)集訓(xùn)練的ResNet-50時,在剪枝率為40%、epoch數(shù)量為90的條件下,InfoBatch可以實現(xiàn)UCB相同的訓(xùn)練時間,但擁有更高的準(zhǔn)確率,甚至超越了全數(shù)據(jù)訓(xùn)練。
同時,ImageNet的額外(OverHead)時間成本顯著低于其他方式,僅為0.0028小時,也就是10秒鐘。
在訓(xùn)練Vit-Base(pre-train階段300epoch,fine-tune階段100epoch模型時,InfoBatch依然可以在24.8%的成本節(jié)約率下保持與全量訓(xùn)練相當(dāng)?shù)臏?zhǔn)確率。
跨架構(gòu)測試比對結(jié)果還表明,面對不同的模型架構(gòu),InfoBatch表現(xiàn)出了較強的魯棒性。
除此之外,InfoBatch還能兼容現(xiàn)有的優(yōu)化器,在與不同優(yōu)化器共同使用時都體現(xiàn)了良好的無損加速效果。
不僅是這些視覺任務(wù),InfoBatch還可以應(yīng)用于語言模型的監(jiān)督微調(diào)。
在常識(MMLU)、推理(BBH、DROP)等能力沒有明顯損失,甚至編程能力(HumanEval)還有小幅提升的情況下,InfoBatch可以在DQ的基礎(chǔ)上額外減少20%的時間消耗。
另外,根據(jù)作者最新更新,InfoBatch在檢測任務(wù)(YOLOv8)上也取得了無損加速30%的效果,代碼將會在github更新。
那么,InfoBatch是如何做到無損加速的呢?
動態(tài)調(diào)整剪枝內(nèi)容
究其核心奧義,是無偏差的動態(tài)數(shù)據(jù)修剪。
為了消除傳統(tǒng)剪枝方法梯度期望值方向偏差以及總更新量的減少的問題,InfoBatch采用了動態(tài)剪枝方式。
InfoBatch的前向傳播過程中,維護了每個樣本的分值(loss),并以均值為閾值,隨機對一定比例的低分樣本進行修剪。
同時,為了維護梯度更新期望,剩余的低分樣本的梯度被相應(yīng)放大。
通過這種方式,InfoBatch訓(xùn)練結(jié)果和原始數(shù)據(jù)訓(xùn)練結(jié)果的性能差距相比于之前方法得到了改善。
具體來看,在訓(xùn)練的前向過程中,InfoBatch會記錄樣本的損失值(loss)來作為樣本分?jǐn)?shù),這樣基本沒有額外打分的開銷。
對于首個epoch,InfoBatch初始化默認(rèn)保留所有樣本;之后的每個epoch開始前,InfoBatch會按照剪枝概率r來隨機對分?jǐn)?shù)小于平均值的樣本進行剪枝。
概率的具體表達(dá)式如下:
對于分?jǐn)?shù)小于均值但留下繼續(xù)參與訓(xùn)練的樣本,InfoBatch采用了重縮放方式,將對應(yīng)梯度增大到了1/(1-r),這使得整體更新接近于無偏。
此外,InfoBatch還采用了漸進式的修剪過程,在訓(xùn)練后期會使用完整的數(shù)據(jù)集。
這樣做的原因是,雖然理論上的期望更新基本一致,上述的期望值實際包含時刻t的多次取值。
也就是說,如果一個樣本在中間的某個輪次被剪枝,后續(xù)依舊大概率被訓(xùn)練到;但在剩余更新輪次不足時,這個概率會大幅下降,導(dǎo)致殘余的梯度期望偏差。
因此,在最后的幾個訓(xùn)練輪次中(通常是12.5%~17.5%左右),InfoBatch會采用完整的原始數(shù)據(jù)進行訓(xùn)練。
論文地址:https://arxiv.org/abs/2303.04947
GitHub主頁:
https://github.com/NUS-HPC-AI-Lab/InfoBatch