自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

Binary Block Masking:加快稀疏 Attention 的一種新方法

發(fā)布于 2024-9-30 15:18
瀏覽
0收藏

一、背景

我們在之前的文章中簡單介紹了 Sample Packing 相關(guān)的技術(shù)方案及涉及的問題,也在看其中 Attention 計算帶來的各種挑戰(zhàn)。機(jī)緣巧合正好看到一篇文章試圖解決相應(yīng)的 Attention 計算問題,這里進(jìn)行簡單介紹。

對應(yīng)的論文為:[2409.15097] Efficiently Dispatching Flash Attention For Partially Filled Attention Masks

相關(guān)工作可以參考我們之前的文章:

二、摘要

Transformer 已廣泛應(yīng)用于各種應(yīng)用場景,會產(chǎn)生稀疏或部分填充的注意力矩陣,比如旨在降低 Attention 二次復(fù)雜度的 Attention Mask、Sample Packing 或其他技術(shù)引入的稀疏 Attention,例如用于 Medusa 投機(jī)采樣快速驗證的 Tree Attention。盡管這些矩陣具有固定的稀疏性,但 SOTA 的 FlashAttention 仍然將其當(dāng)做稠密矩陣處理。

本文中,作者提出了 Binary Block Masking (BinBlkMsk) ,通過使 FlashAttention 具備 Mask 感知能力來增強(qiáng) FlashAttention。作者也進(jìn)一步提出了兩種優(yōu)化方案,在真實場景的 Attention Mask 實驗表明,運(yùn)行速度可以提升 9x。

  • 一種針對具有連續(xù)非零 Mask:Dense Binary Block Masking
  • 一種針對極度稀疏 Mask:Binary Block Masking with RCM。?

Binary Block Masking:加快稀疏 Attention 的一種新方法-AI.x社區(qū)

PS:其實上述的 Baseline 有問題,F(xiàn)lashAttention 可以通過 Varlen 方案支持 Sample Packing 引入的 Block Diagonal Mask,避免當(dāng)做稠密矩陣計算。此外,Pytorch 官方也發(fā)布了 FlexAttention,可以支持各種 Attention Mask 變體,本文作者也并沒有對比。

三、方案

3.1 Binary Block Masking

以 Sample Packing 引入的 Block Diagonal Mask 為例,如下圖所示,Max Sequence Length 為 16,拼接了 3 個 Sample,長度分別為 4,5,7。因此會形成一個 16x16 的 Global Mask,其中包含 3 個 Causal Mask。

在 FlashAttention 實際的計算中會將其劃分為不同的 Block 計算。本文提出的 Binary Block Masking 是通過一個額外的 Binary Block Matrix 來標(biāo)記哪些 Block 全是 0,而哪些中有非 0 值。如下圖,假設(shè) Block 的大小為 2x2,則會生成一個 8x8 Binary Matrix,如下圖紅框和藍(lán)框?qū)?yīng)的位置為 1,其他位置為 0。在計算時,如果 Binary Block Matrix 中為 0 的 Mask 可以忽略不計算。

Binary Block Masking:加快稀疏 Attention 的一種新方法-AI.x社區(qū)

3.2 Dense Binary Block Masking

在自然語言場景中,通常上述的 Attention Mask 的 Block 中會存在很多連續(xù)全 1 的 Block,只在邊界的 Block 里既有 0 又有 1。對于全為 1 的 Block,在計算時不用再讀取 Attention Mask,減少訪存并提升計算效率。因此作者使用兩個額外的數(shù)組 total_ones 和 offset 來標(biāo)識這些連續(xù)全 1 Block,如下圖所示,total_ones[7] 為 2,表示最后一行有 2 個全 1 Block;offset[7] 為 5,表示最后一行連續(xù)全 1 Block 的起始索引為 5。

  • total_ones 存儲全 1 Block 的個數(shù)。
  • offset 存儲 Binary Block Matrix 中每行連續(xù)全 1 Block 的起始位置。
  • 右圖 Binary Block Matrix 中紅色 1表示 Block 中全是 1,黑色 1 表示 Block 中部分為 1。?

Binary Block Masking:加快稀疏 Attention 的一種新方法-AI.x社區(qū)

PS:實際計算中,對于所有 Layer,所有 Attention Head 都只需計算一次。

3.3 Binary Block Masking with RCM

在處理非常稀疏的注意力掩碼時,如果使用標(biāo)準(zhǔn)的 Binary Block Masking 方法,可能會存在效率問題。這是因為即使整個塊中只有一個非零值,也必須處理整個塊,導(dǎo)致計算資源的浪費。作者使用 Reverse Cuthill-McKee (RCM) 算法來有效地重新組織 Mask 矩陣的結(jié)構(gòu),從而減少必須處理的 Block 的數(shù)量。如下圖 Figure 3 所示:

Binary Block Masking:加快稀疏 Attention 的一種新方法-AI.x社區(qū)

如下圖所示,作者展示了一個合成掩碼的結(jié)果,其中 RCM 預(yù)處理減少了 90% 的 Block 數(shù)量,從而獲得顯著的性能提升。Binary Block Masking with RCM 的速度明顯快于 Binary Block Masking。

Binary Block Masking:加快稀疏 Attention 的一種新方法-AI.x社區(qū)

四、實驗和結(jié)果

4.1 配置

作者主要評估三種 Mask 類型,包括 Medusa([2401.10774] Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads) 里面的 Tree Attention Mask,Sample Packing 中的 Block Diagonal Mask 以及 LongFormer([2004.05150] Longformer: The Long-Document Transformer) 中的稀疏 Mask。測試時使用 Batch Size 為 4,32 個 Attention Head。FlashAttention 中的 Block Size 固定為 (128, 32)。

此外,測試的 FlashAttention 方案并不是開源的 FlashAttention 實現(xiàn),而是作者基于 Triton 實現(xiàn)的,以方便對比。

4.2 預(yù)處理

如下圖 Table 1 所示,作者使用 Triton 實現(xiàn)了基于 Attention Mask 生成 Binary Block Mask 的 Kernel,并測試了相應(yīng)的時間。雖然 Binary Block Mask 的生成時間和 Batch Size 為 1,1 個 Head 的計算時間差不多,但是整體只用計算一次,影響比較小。

Binary Block Masking:加快稀疏 Attention 的一種新方法-AI.x社區(qū)

4.3 Medusa Tree Attention

如下圖 Figure 3 所示為對 Medusa Tree Attention 的優(yōu)化效果。左圖為 Tree Attention 的 Mask,其中 K 表示 Medusa 中 Head 個數(shù),sk 表示每個 Head 候選 Token 個數(shù)。(b) 表示固定候選 Token 為 3,不同 Head 個數(shù)下的速度;(c) 表示固定 4 個 Head,不同候選 Token 的速度。

  • Naive Attention Masking:所有 Block 都計算。
  • Bash FlashAttention:Causal Mask,只計算下三角的 Block,當(dāng) Head 或者候選 Token 很多時 FlashAttention 的 Causal Mask 時間大概是 Naive 的一半(計算量大約是一半)。
  • Binary Block Masking:只計算部分 Block,速度遠(yuǎn)快于上述兩種方案。?

Binary Block Masking:加快稀疏 Attention 的一種新方法-AI.x社區(qū)

4.4 Sample Packing Block Diagonal Mask

作者同樣評估了 Sample Packing 場景的 Block Diagonal Mask,如下圖(a)表示標(biāo)準(zhǔn) Causal Mask 的組合;(b)表示 Input 雙向 Attention,而 Output Causal Mask 的情況。從(b)和(c)可以得出類似 Tree Mask 的結(jié)論,當(dāng)序列比較長時,本文的 Binary Block Masking 可以帶來 9x 加速。不過需要說明的是:Base FlashAttention 的時間幾乎是 Naive Attention Masking 的一半,也可以推導(dǎo)出作者使用的 FlashAttention 方案確實計算了整個下三角 Block,而沒有使用 Varlen 方案。

Binary Block Masking:加快稀疏 Attention 的一種新方法-AI.x社區(qū)

4.5 LongFormer Attention Mask

作者同樣評估了 LongFormer 場景的各種稀疏 Mask,如下圖分別為 Sliding Window Attention,Dilated Attention 以及 Global Attention 等場景,可以得出類似的結(jié)論。需要說明的是,此時 Base FlashAttention 相比 Naive Attention 的提升比較有限。

Binary Block Masking:加快稀疏 Attention 的一種新方法-AI.x社區(qū)

4.6 Pytorch FlexAttention

Pytorch 在 2.5.0 版本引入了 FlexAttention(FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention),可以很容易支持各種 Attention Mask 變種,比如標(biāo)準(zhǔn) Causal Mask、Sliding Window 、Prefix Mask 以及 Block Diagonal Mask 等,相比 FlashAttention 靈活了很多,本文提到的各種 Attention 基本也都能實現(xiàn)。

如下表所示,我們同樣做了 Sample Packing 的實驗,隨著 Sequence 中 Sample 分布的不同,計算的耗時甚至可能差 10x,與本文中 9x 的提升也能對上。

Binary Block Masking:加快稀疏 Attention 的一種新方法-AI.x社區(qū)

五、參考鏈接

  1. ??https://arxiv.org/abs/2409.15097??
  2. ??https://arxiv.org/abs/2401.10774??
  3. ??https://arxiv.org/abs/2004.05150??
  4. ??https://pytorch.org/blog/flexattention/???

本文轉(zhuǎn)載自 ??AI閑談??,作者: AI閑談

標(biāo)簽
已于2024-9-30 18:06:49修改
收藏
回復(fù)
舉報
回復(fù)
相關(guān)推薦