斯坦福大學(xué)CS博士新作:新型Attention提速2-4倍,BERT單節(jié)點(diǎn)訓(xùn)練最快
一種快速、內(nèi)存高效的注意力算法來了,被命名為 FlashAttention。通過減少 GPU 內(nèi)存讀取 / 寫入,F(xiàn)lashAttention 的運(yùn)行速度比 PyTorch 標(biāo)準(zhǔn)注意力快 2-4 倍,所需內(nèi)存減少 5-20 倍。
這項(xiàng)研究由斯坦福大學(xué)、紐約州立大學(xué)布法羅分校的研究者共同完成。共同一作是兩位斯坦福計(jì)算機(jī)博士生 Tri Dao 和 Dan Fu。
下面我們介紹一下論文具體內(nèi)容。
FlashAttention
Transformer 已然成為自然語言處理和圖像分類等應(yīng)用中最廣泛使用的架構(gòu)。隨著研究的不斷前進(jìn),Transformer 尺寸變得越來越大、層數(shù)也越來越深,但是給 Transformer 配備更長的上下文仍然很困難,因?yàn)?Transformer 核心自注意力模塊的時(shí)間復(fù)雜度以及內(nèi)存復(fù)雜度在序列長度上是二次方的。
有研究者提出一些近似注意力的方法,旨在減少注意力計(jì)算和內(nèi)存需求。這些方法包括稀疏近似、低秩近似以及它們的組合。從序列長度來看,盡管這些方法可以將計(jì)算降低到線性或接近線性,但它們并沒有顯示出針對(duì)標(biāo)準(zhǔn)注意力的 wall-clock 加速,因而沒有被廣泛使用。這其中一個(gè)主要原因是這些研究專注于減少 FLOP(這可能與 wall-clock 速度無關(guān))并且傾向于忽略來自內(nèi)存訪問 (IO) 的開銷。
在本文中,該研究認(rèn)為應(yīng)該讓注意力算法具有 IO 感知——即考慮顯存級(jí)間的讀寫?,F(xiàn)代 GPU 計(jì)算速度超過了內(nèi)存速度,transformer 中的大多數(shù)操作都被內(nèi)存訪問所阻塞。IO 感知算法對(duì)于類似的內(nèi)存綁定操作至關(guān)重要,這種重要性體現(xiàn)在當(dāng)讀寫數(shù)據(jù)占據(jù)很大運(yùn)行時(shí)——例如數(shù)據(jù)庫連接、圖像處理、數(shù)值線性代數(shù)等。然而,用于深度學(xué)習(xí)的常見 Python 接口,如 PyTorch 和 Tensorflow,不允許對(duì)內(nèi)存訪問進(jìn)行細(xì)粒度控制。
論文地址:https://arxiv.org/pdf/2205.14135.pdfGitHub 地址:https://github.com/HazyResearch/flash-attention
該研究提出了一種新的注意力算法 FlashAttention,它可以使用更少的內(nèi)存訪問來計(jì)算精確的注意力。FlashAttention 旨在避免從 HBM(High Bandwidth Memory)中讀取和寫入注意力矩陣。這需要做到:(i) 在不訪問整個(gè)輸入的情況下計(jì)算 softmax reduction;(ii) 在后向傳播中不能存儲(chǔ)中間注意力矩陣。
該研究采用兩種成熟的技術(shù)來應(yīng)對(duì)這些挑戰(zhàn):
(i) 該研究重組注意力計(jì)算,將輸入分成塊,并在輸入塊上進(jìn)行多次傳遞,從而逐步執(zhí)行 softmax reduction(也稱為 tiling);(ii) 該研究存儲(chǔ)前向傳遞的 softmax 歸一化因子,在后向傳播中快速重新計(jì)算片上注意力,這比從 HBM 中讀取中間注意力矩陣的標(biāo)準(zhǔn)方法更快。
該研究在 CUDA 中實(shí)現(xiàn) FlashAttention ,以達(dá)到對(duì)內(nèi)存訪問的細(xì)粒度控制,并將所有注意力操作融合到一個(gè) GPU 內(nèi)核中。即使由于重新計(jì)算導(dǎo)致 FLOPs 增加,但其運(yùn)行速度更快(在 GPT-2 上高達(dá) 7.6 倍,圖 1 右圖)并且使用更少的內(nèi)存(序列長度線性),主要是因?yàn)榇蟠鬁p少了 HBM 訪問量。
?
該研究分析了 FlashAttention 的 IO 復(fù)雜度,證明它需要??(??^2??^2^???1)HBM 訪問,其中??是 head 維度,??是 SRAM 的大小,而標(biāo)準(zhǔn)的注意力需要Ω(???? + ??^2 )HBM 訪問。對(duì)于?? 和 ?? 的典型值,與標(biāo)準(zhǔn)注意力相比,F(xiàn)lashAttention 需要的 HBM 訪問次數(shù)要少很多(最多減少 9 倍,如圖 2 所示)。此外,該研究還提供了一個(gè)下限,表明沒有精確的注意力算法可以漸近地提高所有 SRAM 大小的 HBM 訪問次數(shù)。
?
該研究還表明,F(xiàn)lashAttention 可以作為一種原語(primitive),通過克服內(nèi)存訪問開銷問題來實(shí)現(xiàn)近似注意力算法。作為概念證明,該研究實(shí)現(xiàn)了塊稀疏 FlashAttention,這是一種稀疏注意力算法,比 FlashAttention 快 2-4 倍,可擴(kuò)展到 64k 的序列長度。該研究證明了塊稀疏 FlashAttention 比 FlashAttention 具有更好的 IO 復(fù)雜度。
值得一提的是,該研究還開源了 FlashAttention。
實(shí)驗(yàn)結(jié)果
BERT:FlashAttention 得到了最快的單節(jié)點(diǎn) BERT 訓(xùn)練速度。該研究在 Wikipedia 上用 FlashAttention 訓(xùn)練了一個(gè) BERT-large 模型。表 1 將 FlashAttention 訓(xùn)練時(shí)間與 Nvidia MLPerf 1.1 進(jìn)行了比較,結(jié)果表明 FlashAttention 的訓(xùn)練速度提高了 15%。
?
GPT-2:表 2 顯示,與 HuggingFace 相比,F(xiàn)lashAttention 端到端加速可達(dá) 3 倍,與 Megatron-LM 相比,加速可達(dá) 1.7 倍
Long-range Arena:該研究在 long-range arena (LRA) 基準(zhǔn)上進(jìn)行了實(shí)驗(yàn),他們測量了準(zhǔn)確率、吞吐量、訓(xùn)練時(shí)間。每個(gè)任務(wù)有不同的序列長度,從 1024 到 4096 不等。此外,實(shí)驗(yàn)遵循 Tay 和 Xiong 等人的實(shí)驗(yàn)設(shè)置。表 3 顯示,與標(biāo)準(zhǔn)注意力相比,F(xiàn)lashAttention 的速度提高了 2.4 倍。塊稀疏 FlashAttention 比所有近似注意力方法都要快。
具有長上下文的語言模型:FlashAttention 的運(yùn)行時(shí)間和內(nèi)存效率允許我們將 GPT-2 的上下文長度增加 4 倍,同時(shí)仍然比 Megatron-LM 的運(yùn)行更快。從表 4 可以看出,上下文長度為 4K 的 FlashAttention GPT-2 仍然比上下文長度為 1K 的 Megatron 的 GPT-2 快 30%,同時(shí) perplexity 提高了 0.7。
表 5 表明,在 MIMIC 上序列長度為 16K 的性能比長度為 512 的高出 4.3 個(gè)點(diǎn),而在 ECtHR 上,序列長度為 8K 的比長度 512 高出 8.5 個(gè)點(diǎn)。
表 6 展示了 Transformer 模型可以解決 Path-X、Path-256 問題。該研究在 Path-64 上預(yù)訓(xùn)練 transformer,然后通過空間插值位置嵌入遷移到 Path-X。FlashAttention 在 Path-X 上達(dá)到 61.4 的準(zhǔn)確率。此外,塊稀疏 FlashAttention 使得 Transformers 將序列擴(kuò)展到 64K,在 Path-256 實(shí)現(xiàn) 63.1 的準(zhǔn)確率。
圖 3(左) 報(bào)告了以毫秒為單位的 FlashAttention 和塊稀疏 FlashAttention 前向 + 后向傳播的運(yùn)行時(shí)間與基準(zhǔn)比較,圖 3(右) 顯示了與各種精確、近似和稀疏注意基線相比,F(xiàn)lashAttention 和塊稀疏 FlashAttention 的內(nèi)存占用情況。