告別800秒魔咒!硬件級STA革新視頻DiT注意力,讓HunyuanVideo效率提升3.5倍!
論文鏈接:https://arxiv.org/pdf/2502.04507Git鏈接:https://github.com/hao-ai-lab/FastVideoHuggingface:https://huggingface.co/FastVideo/FastHunyuan
亮點(diǎn)直擊
- 識別并量化了最先進(jìn)的視頻 DiT 中的 3D 局部性和頭部 specialization,揭示了完整 3D 注意力中的大量冗余。
- 引入了SLIDING TILE ATTENTION,一種基于分塊的滑動窗口注意力機(jī)制。優(yōu)化內(nèi)核與 FlashAttention 3 相比實(shí)現(xiàn)了最小的開銷,MFU 達(dá)到 58.79%。
- STA 將注意力加速超過 10 倍,并將端到端視頻生成加速高達(dá) 3.53 倍,且沒有或僅有最小的質(zhì)量損失。
使用 DiTs 進(jìn)行視頻生成速度極慢——即使在 H100 上配備 FlashAttention-3,HunyuanVideo 生成 5 秒視頻也需 16 分鐘。而本次分享的滑動塊注意力(STA)技術(shù)將其縮短至 5 分鐘,無任何質(zhì)量損失,無需額外訓(xùn)練。STA 在注意力計算上的加速比為 FlashAttention-2 的 2.8–17 倍,F(xiàn)lashAttention-3 的 1.6–10 倍。結(jié)合 STA 及其他優(yōu)化方案,解決方案在無質(zhì)量損失且無需訓(xùn)練的情況下,相比 FA3 全注意力基線提升端到端生成速度 2.98 倍。啟用微調(diào)還能帶來更大的加速效果!
視頻 DiT 中的注意力
最先進(jìn)的 Video DiT 依賴 3D 全注意力來捕捉空間和時間關(guān)系,使每個 token 都能在空間和時間維度上關(guān)注其他所有 token。然而,現(xiàn)代視頻模型會生成大量 token——例如,HunyuanVideo 僅生成一個 5 秒 720p 視頻就需要 115K 個 token。隨著分辨率或時長的增加,問題變得更加嚴(yán)重:對于形狀為LXLXL的視頻(假設(shè)時間和空間維度相等),即使L略微增加,token 數(shù)量也會呈立方級增長。由于注意力計算的復(fù)雜度為二次方,這使得它迅速成為主要的計算瓶頸。如圖 1(a) 所示,注意力計算在推理成本中占據(jù)主導(dǎo)地位。
Figure 1: (a) Generating a 5s 720P clip in Hunyuan involves processing 115K tokens, making attention the dominant cost. (b) Attention latency comparison: existing methods fail to translate FLOP reduction into wall-clock speedup; STA is hardware-efficient and achieves proportional speedup with sparsity
假設(shè) 3D 全注意力存在大量冗余,如果能有效利用這些冗余,將大幅加速推理。為驗(yàn)證這一點(diǎn),在圖 2(左)中可視化了 HunyuanVideo 的注意力分?jǐn)?shù),并發(fā)現(xiàn)了明顯的 3D 局部性模式:查詢主要關(guān)注空間和時間上臨近的鍵。為了量化這一現(xiàn)象,我們計算了注意力召回率——即集中在局部窗口內(nèi)的注意力分?jǐn)?shù)占總注意力分?jǐn)?shù)的比例。如圖 2(中)所示,盡管 HunyuanVideo 經(jīng)過全 3D 注意力訓(xùn)練,但其仍表現(xiàn)出強(qiáng)烈的局部性:僅占總空間 15.52% 的小局部窗口就捕獲了 70% 的總注意力。
Figure 2. Left: Instead of attending to the entire image, the query (green dot)’ only attends to keys within a local window. Mid: Attention scores within the local window accouts for mojority of the entire attention. Right: Despite the different recall across heads, the standard deviation across prompts remains low.
本文的分析指出了一個看似顯而易見的解決方案:用局部化注意力替代全 3D 注意力,以加速視頻生成。一個自然的做法是滑動窗口注意力(SWA),這種方法在 NLP 的 1D 序列處理中被廣泛使用。然而,我們發(fā)現(xiàn) SWA 在 2D 和 3D 中完全失效!盡管其理論上可行,但目前尚無適用于 3D 視頻 DiT 的高效實(shí)現(xiàn)。
更糟糕的是,如圖 1(右)所示,現(xiàn)有的 SWA 方法(如 CLEAR 和 NATTEN)雖然降低了 FLOPs,但未能真正提升速度,原因在于硬件利用率極低。為什么會這樣?因?yàn)楦唠A滑動窗口注意力與現(xiàn)代 FlashAttention(FA)存在根本性不兼容問題,在 GPU 上的計算效率極低。
滑動窗口注意力的低效性
要理解為什么 SWA 與 FlashAttention(FA)不兼容,首先需要回顧 FA 的分塊計算模式。FA 并非逐個處理 token,而是將輸入序列劃分為小塊(通常為 128,64),并在 GPU 上高效計算。為簡單起見,我們在討論中假設(shè)使用方塊結(jié)構(gòu)。
FA 會將整個塊的Q 、K和 V加載到 GPU 的 SRAM,執(zhí)行所有必要的計算,并僅將輸出矩陣O寫回 HBM,從而避免存儲諸如注意力掩碼或注意力分?jǐn)?shù)等中間值。如圖 3 所示,F(xiàn)A 將注意力圖劃分為更小的塊,使每個塊成為計算的基本單元。
為什么這很重要?首先,這避免了存儲大規(guī)模的中間張量,從而節(jié)省大量內(nèi)存。其次,GPU 設(shè)計本質(zhì)上適用于矩陣乘法,不擅長處理標(biāo)量甚至向量運(yùn)算;它們在塊級計算上表現(xiàn)出色,而非逐個 token 處理。
Figure 3. The attention map of NATTEN, Tiled NATTEN, and STA. We plot with an image size 24×24 and a 12×12 local window. The tile size is set to 4×4. Note that we mainly use STA in 3D scenarios for video generation in this paper, but for better illustration, we present the 2D scenario in this plot.
在 FlashAttention 中實(shí)現(xiàn) 2D/3D SWA 主要面臨一個關(guān)鍵挑戰(zhàn):如何定義其注意力掩碼。根據(jù)掩碼的應(yīng)用方式,我們將注意力塊分為三種類型:
- 密集塊(Dense blocks):保留所有注意力分?jǐn)?shù)(計算效率極高 ?)。
- 空塊(Empty blocks):所有值均被掩碼(可完全跳過 ?)。
- 混合塊(Mixed blocks):部分分?jǐn)?shù)被保留,部分被掩碼(效率災(zāi)難 ?)。
雖然密集塊和空塊能夠很好地適配 FA,但混合塊會引入嚴(yán)重的計算開銷,主要原因如下:
- 計算浪費(fèi):由于塊是最小計算單元,F(xiàn)A 必須先計算整個塊,再應(yīng)用掩碼,這會導(dǎo)致大量無用計算。
- GPU 不友好的掩碼處理:塊內(nèi)部的掩碼不僅依賴于用戶定義的注意力模式,還取決于該塊在注意力圖中的位置。更糟糕的是,這種掩碼無法預(yù)先計算,否則會帶來二次方級的內(nèi)存開銷。即使在 FlexAttention 中,一個簡單的因果掩碼也會帶來 15% 的額外開銷,而在 3D SWA 中,掩碼處理的開銷甚至可能超過計算整個塊的成本!這就是為什么高階 SWA 在 GPU 上表現(xiàn)不佳——它會生成過多的混合塊!
為了說明這一問題,在圖 3(a) 中分析了 NATTEN——一種改進(jìn)的 SWA 變體。NATTEN 通過在圖像/視頻邊界處移動窗口中心,確保每個查詢點(diǎn)始終關(guān)注固定數(shù)量的鍵。然而,這種方法導(dǎo)致查詢點(diǎn)關(guān)注不同的鍵組,破壞了注意力圖的均勻性,并產(chǎn)生大量混合塊。
為緩解這一問題,Tiled NATTEN 通過重新排序輸入數(shù)據(jù),提高密集塊的比例(如圖 3(b) 所示)。然而,仍有相當(dāng)一部分塊屬于混合塊,使得 SWA 在 GPU 上的計算效率始終低下。
理解 SWA 在圖 3 中產(chǎn)生“之”字形注意力圖的原因可能并不直觀。為了直觀展示這一現(xiàn)象,我們提供了一個動畫,演示在(10,10)大小的圖像上,使用(5,5)窗口的 2D SWA 過程。
滑動塊注意力(Sliding Tile Attention, STA)
STA的核心思想很簡單:GPU 最適合逐塊計算,但滑動窗口注意力(SWA)逐 token 滑動窗口,效率低下。我們提出的 STA 通過逐塊滑動解決了這一問題。在 3D 場景中,將一個分塊定義為一個連續(xù)的 token 組,形成一個時空立方體,其大小由 FlashAttention 中的塊大小決定。這一小改動消除了注意力圖中的混合塊,并顯著提高了計算效率。
- SWA:逐 token 移動,創(chuàng)建不規(guī)則的注意力圖,GPU 難以高效處理。
- STA:逐塊移動,形成密集和空的注意力塊,對 GPU 友好。
具體來說:
下面的視頻展示了 STA 的工作原理。為了更好地說明,我們使用了一個 2D 場景。在此示例中,我們將 STA 應(yīng)用于一個 10×10 的圖像,分塊大小為 (2,2),窗口大小為 (6,6)。
STA可以通過 FlexAttention 高效實(shí)現(xiàn),它提供了足夠的功能來跳過所有空塊,并避免在密集塊上添加不必要的塊內(nèi)掩碼。我們可以通過將塊間掩碼邏輯與計算內(nèi)核解耦來進(jìn)一步優(yōu)化稀疏注意力掩碼。因此,我們的注意力內(nèi)核基于 ThunderKittens 和 FlashAttention-3 實(shí)現(xiàn)。
STA 的內(nèi)核級優(yōu)化
受到 FlashAttention-3 和 ThunderKittens 的啟發(fā),實(shí)現(xiàn)將線程塊(threadblock)分為計算工作組(compute warpgroups)和數(shù)據(jù)工作組(data warpgroups),其中塊間掩碼完全由數(shù)據(jù)工作組管理。每個計算工作組負(fù)責(zé)計算一個查詢塊(query block),該查詢塊始終駐留在 SRAM 中(Split-Q)。數(shù)據(jù)工作組負(fù)責(zé)異步地將鍵值塊(KV blocks)從 HBM(高帶寬內(nèi)存)加載到 SRAM 中。對于每個查詢塊,數(shù)據(jù)工作組需要確定該查詢塊在 STA 中會關(guān)注哪些鍵值塊,并僅加載這些塊。由于數(shù)據(jù)工作組是異步的,計算塊間掩碼和決定加載哪些數(shù)據(jù)的開銷可以通過重疊操作來隱藏。另一方面,計算工作線程完全不需要感知稀疏注意力模式。它使用數(shù)據(jù)工作線程加載到共享內(nèi)存中的鍵值塊執(zhí)行注意力計算,一旦循環(huán)緩存(circular cache)中的所有數(shù)據(jù)被消耗完畢,計算即完成。
Table 1. Forward speed of sparse attention kernels in a setup aligned with HunyuanVideo’s inference configuration (bf16, 720P, 5s, 115.2K seq len, dhead = 128, # heads = 24). Config controls the window size of each sparse attention.
內(nèi)核性能
在下表1中報告了內(nèi)核性能。結(jié)果顯示,現(xiàn)有的局部注意力方法在效率上存在困難。例如,盡管CLEAR將FLOPs減少到15.65,但實(shí)際上推理速度降低了14%。NATTEN也表現(xiàn)不佳——盡管實(shí)現(xiàn)了91%的稀疏性,但其基本版本比完整注意力慢了15%,即使在FlexAttention中優(yōu)化的分塊變體也只能將速度提升1.27倍。在當(dāng)前選項中,Swin是唯一一個內(nèi)存利用率因子(MFU)超過40%且內(nèi)核效率超過60%的內(nèi)核,但它犧牲了注意力機(jī)制的靈活性——Swin不是局部注意力的變體。
相比之下,在FlexAttention中測試時,STA將MFU從Tiled NATTEN的8.20%提高到41.03%。通過進(jìn)一步的內(nèi)核優(yōu)化,STA相比完整注意力實(shí)現(xiàn)了10.45倍的加速。即使在58.33%的稀疏性下,它仍然提供了2.37倍的更快處理速度。這意味著STA可以處理更大的注意力窗口,同時仍然優(yōu)于NATTEN。據(jù)我們所知,STA是第一種將高效的3D稀疏局部注意力與實(shí)際速度提升相結(jié)合的方法。
窗口大小校準(zhǔn)實(shí)現(xiàn)無需訓(xùn)練的速度提升
STA通過利用3D完整注意力中的冗余來加速注意力。另一種加速視頻生成的方法側(cè)重于緩存,利用擴(kuò)散步驟之間的冗余。我們證明STA與TeaCache兼容,TeaCache是一種基于緩存的最先進(jìn)的擴(kuò)散加速技術(shù)。結(jié)合我們的解決方案,實(shí)現(xiàn)了3倍加速,將DiT推理時間從945秒減少到317秒,且沒有質(zhì)量損失。
我們在MovieGen Bench中隨機(jī)選擇的200個提示上評估了我們的方法。以下,我們提供了原始Hunyuan模型與我們的3倍加速解決方案之間的額外非精選定性比較。
使用 STA 訓(xùn)練解鎖更大的加速
除了為每個注意力頭搜索最佳稀疏掩碼外,還可以使用固定窗口并對 STA 進(jìn)行微調(diào),以在保持高稀疏性的同時最大化性能。由于 STA 遵循 3D 局部性特性,這種適應(yīng)可以通過最小的訓(xùn)練開銷高效學(xué)習(xí)。在本文的實(shí)驗(yàn)中,微調(diào)僅需在 8 個 H100 GPU 上花費(fèi) 8 小時——與預(yù)訓(xùn)練視頻擴(kuò)散模型的成本相比可以忽略不計。盡管每個注意力層在受限的局部窗口上操作,但通過堆疊的 Transformer 層,感受野得以擴(kuò)展,使得 Diffusion Transformer 能夠生成全局一致的視頻。
思考
高效的 2D/3D 滑動窗口注意力在 STA 之前并不存在,這似乎令人驚訝——畢竟,這是一個基本概念,并且在 1D 上下文中被廣泛使用。那么,為什么直到現(xiàn)在才有人破解 2D/3D 的內(nèi)核呢?
回顧一下,讓我們看看 Swin Transformer。作者們面臨著同樣的挑戰(zhàn):高效的 2D 滑動窗口注意力內(nèi)核實(shí)現(xiàn)起來并不簡單。他們的解決方案是什么?完全避免它。Swin 沒有使用真正的滑動窗口,而是采用了非重疊的靜態(tài)窗口分區(qū),繞過了效率問題,但代價是破壞了跨窗口注意力,而這對于視頻任務(wù)至關(guān)重要。當(dāng)然,Swin 能夠成功是因?yàn)樗糜陬A(yù)訓(xùn)練設(shè)置——模型通過學(xué)習(xí)在層之間通過移動窗口拼接信息來彌補(bǔ)這一限制。當(dāng)你有預(yù)訓(xùn)練的奢侈條件時,這沒問題,但在像我們這樣的無需訓(xùn)練或微調(diào)場景中,它就不那么有效了。
所以,至少可以欣慰地知道,解決這個問題從來都不容易——但這正是讓進(jìn)展變得更加令人興奮的原因!
結(jié)論
本研究引入SLIDING TILE ATTENTION 來加速視頻擴(kuò)散模型,通過為高階滑動窗口類注意力優(yōu)化內(nèi)核,實(shí)現(xiàn)了高效的 GPU 執(zhí)行,同時保留了局部性特性。實(shí)驗(yàn)表明,SLIDING TILE ATTENTION 在加速視頻生成的同時,僅帶來最小或沒有質(zhì)量損失。從概念上講,本文的方法與其他加速技術(shù)(如緩存和一致性蒸餾)是正交的。相信 STA 的潛力遠(yuǎn)不止于加速視頻擴(kuò)散模型。它可以應(yīng)用于預(yù)訓(xùn)練,并推廣到其他高階數(shù)據(jù)。局部性幾乎是所有數(shù)據(jù)模態(tài)的普遍屬性。希望 STA 能夠激發(fā)跨領(lǐng)域的新的、更高效的模型。
本文轉(zhuǎn)自AI生成未來 ,作者:AI生成未來
