30行代碼,500萬長文本推理提速8倍!「樹注意力」讓GPU越多省的越多
跨GPU的注意力并行,最高提速8倍,支持512萬序列長度推理。
環(huán)注意力(Ring Attention)后繼者——樹注意力(Tree Attention)來了。
最關鍵之處在于,通信步數(shù)隨設備數(shù)量成對數(shù)增長,而不是線性增長。
換句話說,樹注意力的優(yōu)勢隨著設備數(shù)量增大會更加明顯。實驗中,在128卡、512萬序列長度設置時達到最高8倍加速。
與環(huán)注意力相比,峰值內(nèi)存占用也能節(jié)省不少。
相關代碼已經(jīng)開源,基于谷歌jax框架,已和Flash Attention整合,實現(xiàn)起來只需要30行代碼。
論文一公布,就被業(yè)界評價為“對高推理需求的大型公司很重要”。
這下和黃仁勛的GPU“買的越多,省的越多”論對上了,英偉達再次贏麻。
注意力機制的能量視角
首先簡單回顧一下這次被拿來對比的環(huán)注意力,由UC伯克利大牛Pieter Abeel團隊提出。
環(huán)注意力被認為是讓上一波大模型紛紛擴展到百萬上下文的關鍵,從谷歌Gemini 1.5到后來的Llama 3.1系列都用了它的某種變體。
簡單來說,環(huán)注意力的核心思想是將長序列分成多個Block,每個GPU處理一個。在拓撲意義上相當于所有GPU排成一個圓環(huán),將Key-Value信息傳下去,同時從上一個GPU接收信息。
只要保證計算時間比數(shù)據(jù)傳輸時間長,這個過程就不會造成額外開銷。
同時與之前的近似方法不同,環(huán)注意力不會損失精度,保持了完整的注意力計算。
最新的樹注意力,在分塊計算、跨設備并行、保持精度特性的基礎上,提出了一種自注意力的能量函數(shù),通過計算梯度利用樹形拓撲優(yōu)化多GPU間的通信。
傳統(tǒng)上,人們把注意力看作Query向量與Key向量的相似度匹配,再對Value向量做加權求和。
樹注意力團隊在Hopfield網(wǎng)絡等基于能量的模型相關研究基礎上,將注意力解釋為一個能量函數(shù)對某變量的梯度。
存在一個標量能量函數(shù)F,它依賴于Key、Query、Value以及一個輔助變量ζ,而注意力的結(jié)果恰好等于F對ζ的梯度在ζ=0處的值。
結(jié)合自動微分等技術,從能量和梯度的視角看待自注意力,暗示了只要能高效計算F就能高效計算自注意力。
具體到語言模型中基于KV緩存的解碼,能量函數(shù)可以表示成:
由于logsumexp和max運算操作都滿足結(jié)合律,可以按任意順序進行,而不會影響最終結(jié)果。
在此前提下,團隊設計了新的并行化算法,先在各GPU上并行計算局部能量函數(shù),再通過樹狀的Allreduce匯總各處結(jié)果,最后用自動微分取梯度,即可得到注意力的輸出。
全過程僅需與計算能量函數(shù)相同的時間開銷,而顯存占用也幾乎沒有額外負擔。
樹注意力在設計上還充分利用了GPU集群的兩級拓撲特點——即同節(jié)點內(nèi)使用高速NVLink,而節(jié)點間則依賴IB或以太網(wǎng)等。
相比之下,環(huán)形注意力天然不適應這種拓撲,難以將通信與計算很好地重疊,終會被最慢的互聯(lián)帶寬所制約。
最后值得一提的是,雖然理論上單GPU內(nèi)部也可用類似策略提速,但當前硬件的流式處理器(SM)間通信還是共享內(nèi)存,優(yōu)勢并不明顯。
不過,英偉達在H100上實驗性地支持了SM間點對點的指令,這為未來單卡注意力優(yōu)化帶來了新的想象空間。
最被低估的AI實驗室之一
樹注意力團隊主要成員來自Zyphra,一家新興的AI創(chuàng)業(yè)公司,被評價為“當前最被低估的AI實驗室之一”。
Zyphra重點關注邊緣AI、端側(cè)AI,曾發(fā)布基于Mamba架構的基礎模型Zamba。
創(chuàng)始人Krithik Puthalath以及樹注意力共同一作Vasudev Shyam、Jonathan Pilault都有數(shù)學和理論物理學術背景。