斯坦福博士獨作!大模型訓(xùn)練速度再翻倍,還官宣加入明星創(chuàng)業(yè)公司當首席科學(xué)家
本文經(jīng)AI新媒體量子位(公眾號ID:QbitAI)授權(quán)轉(zhuǎn)載,轉(zhuǎn)載請聯(lián)系出處。
現(xiàn)有大語言模型的訓(xùn)練和推理速度,還能再快一點——
快多少?2-4倍。
各種大模型都在用的FlashAttention今天正式發(fā)布第2代并開源,所有Transformer架構(gòu)的模型都可使用它來加速。
圖片
一代方法去年6月發(fā)布,無需任何近似即可加速注意力并減少內(nèi)存占用。
現(xiàn)在,F(xiàn)lashAttention-2將它再度升級,使其核心注意力操作的速度再提高2倍,端到端訓(xùn)練Transformer時的速度再提高1.3倍,并可在英偉達A100上訓(xùn)練時實現(xiàn)72%的模型FLOP利用率(一般模型都在50%上下)。
圖片
鑒于現(xiàn)在煉一個大語言模型的成本高達數(shù)千萬美元,F(xiàn)lashAttention-2這一系列操作直接就能幫我們省掉數(shù)百萬(美元)!
網(wǎng)友驚得臟話都出來了(狗頭):
圖片
目前,這個項目已在GitHub上收獲4.4k標星。
與此同時,我們注意到,它的一作已經(jīng)完成斯坦福博士學(xué)位并加盟大模型創(chuàng)業(yè)公司Together AI。
具體實現(xiàn)
據(jù)介紹,一代FlashAttention是一種對注意力計算重新排序的算法,它利用經(jīng)典方法如tiling(切片)來顯著加快計算速度,并將序列長度的內(nèi)存使用量從二次方減為線性。
其中tiling方法指的是將輸入塊從HBM(GPU內(nèi)存)加載到SRAM(快速緩存),然后對該塊進行attention操作,再更新HBM中的輸出。
對HBM的反復(fù)讀寫就成了最大的性能瓶頸。
圖片
正是這種通過避免將大型中間注意力矩陣寫入HBM的方法,F(xiàn)lashAttention減少了內(nèi)存讀/寫量,從而帶來2-4倍的時鐘時間加速。
然而,這個算法仍然存在一些低效率的問題,導(dǎo)致它仍然不如優(yōu)化矩陣乘法 (GEMM) 運算來得快,最終僅達到理論最大FLOPs/s的25-40%(例如在A100上最多124 TFLOPs/s)。
究其原因,還是因為不同線程塊之間的工作和GPU上的wrap劃分不理想。
在此,F(xiàn)lashAttention-2進行了三方面的改進。
首先,在基礎(chǔ)算法上,減少非matmul(矩陣乘法) FLOP的數(shù)量。
一層原因是由于現(xiàn)代GPU具有專門的計算單元,matmul速度更快。例如A100上FP16/BF16 matmul的最大理論吞吐量為312TFLOPs/s,但非matmul FP32的理論吞吐量僅為19.5 TFLOPs/s。
另一層原因是價格考量,畢竟每個非matmul FLOP比matmul FLOP貴16倍。同時在matmul FLOP上花費盡可能多的時間也能保持高吞吐量。
為此,作者重寫了FlashAttention中的softmax trick,無需更改輸出即可減少重新縮放操作的數(shù)量,以及邊界檢查和因果屏蔽操作(causal masking operation)。
其次,當batch size較小時并行化以獲得更高的占用率。
FlashAttention一代在batch size和注意力頭數(shù)量上進行并行化。
由于它使用1個線程塊來處理1個注意力頭,總共就有(batch_size*注意力頭數(shù))個線程塊,每個線程塊被安排在流式多處理器 (SM) 上運行。
當在像A100這樣有108個SM處理器上操作時,如果線程塊很多比如>=80,這樣的調(diào)度安排就很有效。
而在長序列的情況下,也就是batch size和頭數(shù)量很少(小)時,就需要在序列長度維度上另外進行并行化來更好地利用GPU上的多處理器了。
這個改進也是FlashAttention-2速度顯著提升的一大原因。
最后,改進工作分區(qū)。
在線程塊內(nèi),我們必須確定如何在不同的warp之間劃分工作。通常是每個塊使用4或8個warp,現(xiàn)在,作者改進了這一方式,來減少不同warp之間的同步和通信量,從而減少共享內(nèi)存讀寫操作。
如下圖左所示,F(xiàn)lashAttention一代的做法是將K和V分割到4個warp上,同時保持Q可被所有warp訪問。這樣的后果是所有warp都需要將其中間結(jié)果寫入共享內(nèi)存,然后進行同步再將中間結(jié)果相加,非常低效,減慢了FlashAttention中的前向傳播速度。
圖片
而在FlashAttention-2中,作者將Q分為四個warp,同時保證所有warp都可訪問K和V。
每個warp執(zhí)行矩陣乘法獲得Q K^T的切片后,只需與V的共享切片相乘即可獲得相應(yīng)的輸出。也就是說warp之間不需要通信,那么共享內(nèi)存讀寫操作就少了很多,速度也就提上來了。
除了這三個大改進,F(xiàn)lashAttention-2還有兩個小改動:
一是注意力頭數(shù)從128增至256,這意味著GPT-J、CodeGen和CodeGen2以及StableDiffusion 1.x等模型都可以使用 FlashAttention-2來進行加速和內(nèi)存節(jié)省了;
二是支持多查詢注意力(MQA)和分組查詢注意力(GQA)。
實驗評估
作者在A100 80GB SXM4 GPU上對不同配置(有無causal mask,頭數(shù)量64或128)下的運行時間進行了測量。
結(jié)果發(fā)現(xiàn):
FlashAttention-2比FlashAttention(包括xformers庫和Triton中的其他實現(xiàn))快大約2倍,這也意味我們可以用與之前訓(xùn)練8k上下文模型相同的價格來訓(xùn)練具有16k上下文的模型了(也就是模型上下文長度加倍)。
而與PyTorch中的標準注意力實現(xiàn)相比,F(xiàn)lashAttention-2的速度最高可達9倍。
圖片
此外,有了FlashAttention-2,我們只需在H100 GPU上運行相同的實現(xiàn)(不使用特殊指令利用TMA和第四代Tensor Core等新硬件功能),訓(xùn)練速度就可以跑到高達335TFLOPs/s的成績。
圖片
以及當用于端到端訓(xùn)練GPT式模型時,F(xiàn)lashAttention-2還能在A100上實現(xiàn)高達225TFLOPs/s的速度(模型FLOPs利用率達72%)。這與已經(jīng)優(yōu)化程序足夠高的FlashAttention相比,速度再提高了1.3倍。
圖片
一作加入大模型創(chuàng)業(yè)公司
FlashAttention-2論文僅顯示一位作者:Tri Dao。他也是FlashAttention一代的兩位共同作者之一。
圖片
據(jù)了解,Tri Dao的研究方向為機器學(xué)習(xí)和系統(tǒng)的交叉領(lǐng)域,去年拿下ICML 2022杰出論文亞軍獎。
最近他剛剛獲得斯坦福大學(xué)計算機科學(xué)博士學(xué)位,即將上升普林斯頓大學(xué)助理教授,并已宣布加盟生成式AI創(chuàng)業(yè)公司Together AI(該司主要目標構(gòu)建一個用于運行、訓(xùn)練和微調(diào)開源模型的云平臺)擔(dān)任首席科學(xué)家。
One More Thing
最后,有網(wǎng)友發(fā)現(xiàn),除了FlashAttention-2,最近還有一系列類似成果,包括DeepSpeed的ZeRO++、馬薩諸塞大學(xué)de ReLoRA。
它們都是用于加速大型模型預(yù)訓(xùn)練和微調(diào),這些研究成果讓他覺得:
未來在低vram低帶寬的消費顯卡上訓(xùn)練大模型,似乎已不是在做夢了。
圖片
大家認為呢?
論文地址:https://tridao.me/publications/flash2/flash2.pdf
博文地址:https://princeton-nlp.github.io/flash-atttention-2/
GitHub主頁:https://github.com/Dao-AILab/flash-attention