自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

比標(biāo)準(zhǔn)Attention提速5-9倍,大模型都在用的FlashAttention v2來了

人工智能 新聞
一年時(shí)間,斯坦福大學(xué)提出的新型 Attention 算法 ——FlashAttention 完成了進(jìn)化。這次在算法、并行化和工作分區(qū)等方面都有了顯著改進(jìn),對(duì)大模型的適用性也更強(qiáng)了。

近來,幾種長上下文語言模型陸續(xù)問世,包括 GPT-4(上下文長度為 32k)、MosaicML 的 MPT(上下文長度為 65k)Anthropic 的 Claude(上下文長度為 100k)。長文檔查詢和故事寫作等新興用例已經(jīng)表明擴(kuò)展語言模型上下文窗口是非常必要的。

然而,擴(kuò)大 Transformer 的上下文長度是一個(gè)挑戰(zhàn),因?yàn)槠浜诵牡淖⒁饬釉跁r(shí)間復(fù)雜度和空間復(fù)雜度與輸入序列長度的平方成正比。

一年前,來自斯坦福大學(xué)、紐約州立大學(xué)布法羅分校的研究者共同提出一種快速、內(nèi)存高效的注意力算法 ——FlashAttention。該算法無需任何近似即可加速注意力并減少內(nèi)存占用?,F(xiàn)在,已經(jīng)有許多機(jī)構(gòu)和研究實(shí)驗(yàn)室采用 FlashAttention 來加速訓(xùn)練和推理。

FlashAttention 示意圖。FlashAttention 示意圖。

盡管 FlashAttention 的速度已經(jīng)是優(yōu)化基線的 2-4 倍,但它仍然有相當(dāng)大的改進(jìn)空間。FlashAttention 仍然不如優(yōu)化過的矩陣乘法 (GEMM) 運(yùn)算快,僅達(dá)到理論最大 FLOPs/s 的 25-40%。

現(xiàn)在,研究團(tuán)隊(duì)宣布推出 FlashAttention-2。FlashAttention-2 完全從頭開始重寫,使用 Nvidia 的 CUTLASS 3.x 及其核心庫 CuTe 的原語(primitive)。

圖片圖片

FlashAttention-2 開發(fā)者 Tri Dao。他是斯坦福大學(xué)博士生,還是 Together.AI 首席科學(xué)家,并將于 2024 年 9 月開始任職普林斯頓大學(xué)計(jì)算機(jī)科學(xué)助理教授。

FlashAttention-2 的速度是 FlashAttention 的 2 倍,在 A100 GPU 上達(dá)到 230 TFLOPs/s。在端到端訓(xùn)練 GPT 類語言模型時(shí),F(xiàn)lashAttention-2 可讓訓(xùn)練速度高達(dá) 225 TFLOPs/s(模型 FLOP 利用率為 72%)。

FlashAttention-2 將加速現(xiàn)有模型的訓(xùn)練、微調(diào)和推理。這意味著我們可以用相同成本訓(xùn)練 2 倍上下文長度的語言模型。這將有助于語言模型理解長篇書籍和報(bào)告、高分辨率圖像、音頻和視頻。

圖片圖片

  • 項(xiàng)目地址:https://github.com/Dao-AILab/flash-attention
  • 技術(shù)報(bào)告:https://tridao.me/publications/flash2/flash2.pdf

FlashAttention 是什么?

FlashAttention 是一種重新排序注意力計(jì)算的算法,它利用平鋪、重計(jì)算等經(jīng)典技術(shù)來顯著提升計(jì)算速度,并將序列長度中的內(nèi)存使用實(shí)現(xiàn)從二次到線性減少。其中平鋪意味著將輸入塊從 HBM(GPU 內(nèi)存)加載到 SRAM(快速緩存),并對(duì)該塊執(zhí)行注意力操作,更新 HBM 中的輸出。

此外通過不將大型中間注意力矩陣寫入 HBM,內(nèi)存讀寫量減少,帶來了 2-4 倍的時(shí)鐘時(shí)間加速。

下圖為 FlashAttention 的前向傳遞圖:通過平鋪和 softmax 重新縮放,研究者按塊進(jìn)行操作,避免從 HBM 中讀取 / 寫入,同時(shí)獲得正確的輸出,無需近似操作。

圖片圖片

然而,F(xiàn)lashAttention 仍然存在一些低效率問題,原因在于不同線程塊之間的工作分區(qū)不理想以及 GPU 上的 warp。這些導(dǎo)致低占用率或不必要的共享內(nèi)存讀寫。

FlashAttention-2

更好的算法、并行化和工作分區(qū)


更少的非矩陣乘法 Flops

研究者調(diào)整了 FlashAttention 的算法,從而減少了非矩陣乘法(non-matmul)的 Flops 數(shù)量。這點(diǎn)很重要,因?yàn)楝F(xiàn)代 GPU 具有專門的計(jì)算單元(例如 Nvidia GPU 上的張量核心),使得矩陣乘法速度更快。

舉例而言,A100 GPU 的 FP16/BF16 矩陣乘法的最大理論吞吐量為 312 TFLOPs/s,但非矩陣乘法 FP32 的理論吞吐量僅為 19.5 TFLOPs/s。

換一種思考方式,每個(gè)非矩陣乘法 FLOP 比矩陣乘法 FLOP 的代價(jià)高 16 倍。為了保持高吞吐量,研究者希望在矩陣乘法 FLOP 上花費(fèi)盡可能多的時(shí)間。因此他們重寫了 FlashAttention 中使用的在線 softmax 技巧,以減少重新縮放操作、邊界檢查和因果掩碼操作的數(shù)量,而無需更改輸出。

更好的并行化

FlashAttention v1 在批大小和頭(head)數(shù)量上進(jìn)行并行化。研究者使用 1 個(gè)線程塊來處理一個(gè)注意力頭,總共有(批大小 * 頭數(shù)量)個(gè)線程塊。每個(gè)線程塊都計(jì)劃在流式多處理器(SM)上運(yùn)行,例如 A100 GPU 上有 108 個(gè)這樣的 SM。當(dāng)這個(gè)數(shù)字非常大(如 >= 80)時(shí),這種調(diào)度是有效的,這時(shí)可以高效地使用 GPU 上幾乎所有計(jì)算資源。

在長序列的情況下(通常意味著小批量或少量頭),為了更好地利用 GPU 上的多處理器,現(xiàn)在研究者在序列長度維數(shù)上額外地進(jìn)行并行化,使該機(jī)制顯著加速。

更好的工作分區(qū)

即使在每個(gè)線程塊內(nèi),研究者也必須決定如何在不同的 warp 之間劃分工作(一組 32 個(gè)線程一起工作)。通常情況下,每個(gè)線程塊使用 4 或 8 個(gè) warp,分區(qū)方案如下圖所述。 

研究者改進(jìn)了 FlashAttention-2 中的這種分區(qū),減少不同 warp 之間的同步和通信量,進(jìn)而減少共享內(nèi)存讀寫

圖片圖片

對(duì)于每個(gè)塊,F(xiàn)lashAttention 將 K 和 V 分割到 4 個(gè) warp 上,同時(shí)保持 Q 可被所有 warp 訪問。這被稱為「sliced-K」方案。不過,這種方案是低效的,原因在于所有 warp 都需要將它們的中間結(jié)果寫入共享內(nèi)存,并同步,然后將中間結(jié)果相加。這些共享內(nèi)存讀寫會(huì)減慢 FlashAttention 中的前向傳遞速度。

在 FlashAttention-2 中,研究者將 Q 分割在 4 個(gè) warp 上,同時(shí)保持 K 和 V 可被所有的 warp 訪問。每個(gè) warp 執(zhí)行矩陣乘法以獲得 Q K^T 的切片,然后只需與 V 的共享切片相乘就能獲得相應(yīng)的輸出切片。warp 之間不需要通信。共享內(nèi)存讀寫的減少也可以提升速度。

新特性:頭維數(shù)高達(dá) 256、多查詢注意力

我們知道,F(xiàn)lashAttention 僅支持最高 128 的頭維數(shù),這適用于大多數(shù)模型,但有一些模型被遺漏了。

因此,FlashAttention-2 支持了高達(dá) 256 的頭維數(shù),這意味著 GPT-J、CodeGen 和 CodeGen2、StableDiffusion 1.x 等模型可以使用 FlashAttention-2 來獲得加速和節(jié)省內(nèi)存。

此外,F(xiàn)lashAttention-2 還支持了多查詢注意力(multi-query attention, MQA)以及分組查詢注意力(grouped-query attention, GQA)。它們是注意力的變體,其中多個(gè)查詢頭關(guān)注相同的鍵和值頭,以減少推理過程中 KV 緩存的大小,并可以顯著提高推理吞吐量。

注意力基準(zhǔn)結(jié)果

研究者在 A100 80GB SXM4 GPU 上,測(cè)量不同設(shè)置(無 / 有因果掩碼、頭維數(shù) 64 或 128)下不同注意力方法的運(yùn)行時(shí)。 

結(jié)果發(fā)現(xiàn), FlashAttention-2 的速度是 FlashAttention(以及 xformers 庫和 Triton 中的其他實(shí)現(xiàn))的 2 倍。與 PyTorch 中的標(biāo)準(zhǔn)注意力實(shí)現(xiàn)相比,F(xiàn)lashAttention-2 的速度最高是它們的 9 倍。

A100 GPU 上的注意力前向 + 后向速度

A100 GPU 上的注意力前向 + 后向速度。

此外只需要在 H100 GPU 上 運(yùn)行相同的實(shí)現(xiàn)(不使用特殊指令來利用 TMA 和第四代 Tensor Core 等新硬件功能),研究者最高獲得了 335 TFLOPs/s。

圖片

H100 GPU 上的注意力前向 + 后向速度。

當(dāng)用于端到端 GPT 類模型訓(xùn)練時(shí),F(xiàn)lashAttention-2 有助于在 A100 GPU 上實(shí)現(xiàn)最高 225 TFLOPs/s(模型 FLOPs 利用率為 72%)。與優(yōu)化良好的 FlashAttention 模型相比,端到端實(shí)現(xiàn) 1.3 倍加速。

圖片

這里的基線是不使用 FlashAttention 的 Megatron-LM,它現(xiàn)在也可以選擇使用 FlashAttention 了。不久的將來,FlashAttention-2 也將集成到 Megatron-LM 中。

研究團(tuán)隊(duì)表示:下一步將針對(duì) H100 GPU 優(yōu)化 FlashAttention-2,以使用新的硬件功能。

責(zé)任編輯:張燕妮 來源: 機(jī)器之心
相關(guān)推薦

2023-01-09 12:41:55

模型

2020-06-12 14:25:36

框架PyTorch開發(fā)

2023-07-18 14:18:00

Attention模型圖像

2024-06-28 16:03:38

2023-09-29 11:55:55

2024-12-27 09:30:00

AI數(shù)據(jù)訓(xùn)練

2024-01-03 13:06:50

2024-12-27 09:50:00

模型數(shù)據(jù)測(cè)試

2024-05-10 08:47:22

標(biāo)準(zhǔn)庫v2Go

2013-07-17 10:07:29

Windows Pho功能

2024-03-13 13:49:22

Sora核心組件DiT

2010-08-05 17:00:04

RIP V2協(xié)議

2010-08-06 14:07:21

RIP V2

2023-12-11 15:40:32

PyTorch代碼大模型

2023-06-27 17:35:39

FastSAM模型SAM

2023-06-20 08:01:09

RoseDB存儲(chǔ)數(shù)據(jù)

2024-01-02 15:15:00

AI模型開源

2023-08-05 13:49:31

鴻蒙操作系統(tǒng)

2021-12-15 06:58:28

RedisEhCache緩存

2024-12-20 07:00:00

大模型人工智能AI
點(diǎn)贊
收藏

51CTO技術(shù)棧公眾號(hào)