斯坦福讓“GPU高速運轉(zhuǎn)”的新工具火了,比FlashAttention2更快
AI算力資源越發(fā)緊張的當(dāng)下,斯坦福新研究將GPU運行效率再提升一波——
內(nèi)核只有100行代碼,讓H100比使用FlashAttention-2,性能還要提升30%。
怎么做到的?
研究人員從“硬件實際需要什么?如何滿足這些需求?”這兩個問題出發(fā),設(shè)計了 一個嵌入式CUDA DSL工具,名為ThunderKittens(暫且譯為雷貓)。
雷貓可簡化AI內(nèi)核的編寫,同時充分利用底層硬件能力。
具體來說,雷貓的主要抽象是寄存器和共享內(nèi)存中的小型張量塊(tile),和目前GPU中對小矩陣乘法的優(yōu)化相匹配。
通過操作這些tile,開發(fā)者可相對簡單地編寫代碼,充分利用張量核心、異步數(shù)據(jù)傳輸和共享內(nèi)存等硬件特性。
使用雷貓實現(xiàn)的注意力機(jī)制內(nèi)核,代碼量少且能實現(xiàn)很高的硬件利用率,性能超過直接使用底層庫(如Cutlass)。
詳細(xì)討論過程以及雷貓是怎么設(shè)計出的,研究人員以“GPUs Go Brrr”為題,發(fā)在了斯坦福Hazy Research的Blog網(wǎng)站上。
網(wǎng)友們對此討論也十分熱烈。
有網(wǎng)友表示讀這篇Blog時,讓他想起了初次了解超標(biāo)量CPU架構(gòu)時的驚訝感受:
GPU真的達(dá)到了新高度。
還有網(wǎng)友表示:
這篇文章重新點燃了我在CS 149并行編程課中所感受到的快樂。
H100里有什么?
斯坦福研究人員以H100為例,探討了優(yōu)化GPU的方法。
首先,回顧一下H100的硬件細(xì)節(jié),這對于接下來的討論非常重要。
一個H100 SXM GPU包含:
(1)80GB的HBM3內(nèi)存,帶寬為3TB/s(實際帶寬略低)。
(2)50MB的L2緩存,帶寬為12TB/s,在GPU上分為兩個25MB的部分,通過交叉開關(guān)連接(這個交叉開關(guān)表現(xiàn)不佳)。
(3)132個流式多處理器(SM),每個包含:
- 高達(dá)227KB的共享內(nèi)存位于256KB的L1緩存中(這些加起來的帶寬大約33TB/s)。
- 一個張量內(nèi)存加速器(TMA)——這是英偉達(dá)Hopper架構(gòu)中的一種新硬件組件,可進(jìn)行異步地址生成和內(nèi)存獲取,還能促進(jìn)片上內(nèi)存網(wǎng)絡(luò)。
- 4個子單元,每個含:一個warp scheduler;512個向量寄存器(每個包含32個4字節(jié)的詞);一個用于執(zhí)行矩陣乘法的張量核心;一組內(nèi)置指令,如求和、乘法等,這些指令能夠并行操作這些向量寄存器。
除了這些,一個GPU還包括內(nèi)存控制器、指令緩存……但對于這項研究而言不重要。
重要的是,所有的計算都發(fā)生在流式多處理器中,大部分計算是在寄存器中。
H100 GPU擁有989 TFLOPs的半精度矩陣乘法計算能力,以及約60 TFLOPs的“其他”計算能力。因此,每個周期內(nèi)張量核心被使用時,至少能達(dá)到94%的硬件利用率。而張量核心不被使用時,硬件的利用率不會超過6%。
換句話說:
H100的利用率=張量核心活躍周期的百分比+/- 6%。
所以要充分發(fā)揮H100的能力,關(guān)鍵是保持張量核心持續(xù)運算。
榨干H100,要注意什么?
然鵝,要保持張量核心持續(xù)運行并不容易。
研究人員發(fā)現(xiàn)GPU硬件具有一些特性,對于保持矩陣乘法的運行非常重要:
- WGMMA指令雖然是必要的,但使用起來頗為麻煩。
- 共享內(nèi)存的速度并不如預(yù)期的快,使用時還需格外注意。
- 生成地址的成本較高。
- 保持高占用率對于提升性能是有益的,寄存器至關(guān)重要。
這些特性在非H100 GPU上也有所適用,在H100上更加典型,就拿RTX 4090來說,相比H100處理起來簡單得多。
所以接下來還是以H100為例,展開探討這幾點特性。
WGMMA指令
H100引入了一套新的指令集,名為“warp group matrix multiply accumulate”(在PTX中為wgmma.mma_async,在SASS中為HGMMA/IGMMA/QGMMA/BGMMA)。
要理解這些指令的特點,需回顧以往張量核心的使用方式。
早期GPU中的張量核心指令如wmma.mma.sync和mma.sync,要求SM一個子單元內(nèi)的32個線程的一個warp同步傳輸數(shù)據(jù)塊至張量核心并等待結(jié)果。
wgmma.mma_async指令則不同。它允許128個連續(xù)線程跨SM所有子單元協(xié)作同步,并從共享內(nèi)存及寄存器(可選)異步啟動矩陣乘法。這使得這些warp在等待矩陣乘法結(jié)果時可以處理其他任務(wù)。
研究人員通過微觀基準(zhǔn)測試,發(fā)現(xiàn)這些指令是充分發(fā)揮H100計算能力所必需的。沒有這些指令,GPU的峰值利用率大約只有63%。
他們推測,這是由于張量核心需要從本地資源維持一個深度硬件pipeline。
然而,這些指令的內(nèi)存布局極其復(fù)雜。未重排的共享內(nèi)存布局合并性差,需要額外的L2帶寬。重排的內(nèi)存布局記錄不準(zhǔn)確,研究人員花費了大量時間才弄明白。
最終發(fā)現(xiàn),這些布局只適用于特定矩陣形狀,并與wgmma.mma_async指令的其他部分不兼容,例如硬件僅在未重排的布局下轉(zhuǎn)置子矩陣。
此外,未重排的wgmma布局內(nèi)存合并性差且有bank conflicts。盡管TMA和L2緩存在如flash attention這類內(nèi)核上能較好地掩蓋這些問題,但要充分利用硬件,必須精心控制內(nèi)存請求的合并和避免bank conflicts。
盡管有這些問題,但這些指令對于充分利用H100是必不可少的。沒有它們,GPU的潛在性能就損失了37%。
共享內(nèi)存
共享內(nèi)存的單次訪問延遲約為30個周期(這也與研究人員觀察的相符),這看似不多,但在這段時間內(nèi),SM的張量核心幾乎能完成兩次完整的32x32方陣乘法。
以前的研究,如Flash Attention,研究人員更多關(guān)注的是HBM-SRAM的瓶頸。但隨著HBM速度的提升和張量核心的快速發(fā)展,即使是共享內(nèi)存的相對較小延遲也變得尤為關(guān)鍵。
由于共享內(nèi)存被分為32個獨立的存儲單元,處理不當(dāng)可能會引發(fā)bank conflicts,即同一個內(nèi)存bank同時被多個請求訪問,這種情況會導(dǎo)致請求被序列化。研究人員實驗后認(rèn)為,這會顯著拖慢內(nèi)核速度,且wgmma與mma指令需要的寄存器布局容易受到bank conflicts的影響。
解決方法是通過各種“重排”模式調(diào)整共享內(nèi)存的配置,避免bank conflicts,但細(xì)節(jié)要處理得當(dāng)。
此外研究人員發(fā)現(xiàn),盡可能避免在寄存器和共享內(nèi)存之間的移動數(shù)據(jù)非常重要??赡艿脑挘墒褂脙?nèi)置硬件(如wgmma和TMA指令)進(jìn)行異步數(shù)據(jù)傳輸。實在沒法子了,再使用warp進(jìn)行同步數(shù)據(jù)傳輸。
地址生成
H100還有一個有趣的特性,其張量核心和內(nèi)存都足夠快,以至于僅生成用于獲取數(shù)據(jù)的內(nèi)存地址就占用了芯片的大量資源,特別是加入復(fù)雜的交錯或重排模式時,這種情況更為明顯。
研究人員表示,英偉達(dá)提供了張量內(nèi)存加速器(TMA),似乎就是已經(jīng)意識到了這個問題。
TMA允許用戶在全局和共享內(nèi)存中指定多維張量布局,命令其異步提取張量的一部分,并在完成后觸發(fā)一個屏障。這大大節(jié)省了地址生成的開銷,并簡化了pipelines的構(gòu)建。
研究人員認(rèn)為,TMA對于充分發(fā)揮H100的潛力至關(guān)重要,可能比wgmma.mma_async更為關(guān)鍵。
它不僅節(jié)省了寄存器資源和指令派發(fā),還提供了如異步在全局內(nèi)存上執(zhí)行歸約等實用功能——這在處理復(fù)雜的反向內(nèi)核時尤其有用。
雖然TMA的重排模式解讀有一定難度,需要進(jìn)行一些逆向工程,但研究人員表示,相比之下,他們在這上面遇到的問題要少得多。
占用率
占用率指的是在GPU的相同執(zhí)行硬件上同時調(diào)度的線程數(shù)。每個周期,SM的某一子單元的warp scheduler會嘗試向準(zhǔn)備就緒的warp線程發(fā)出指令。
研究人員認(rèn)為,英偉達(dá)采用這種模型可以更容易地保持硬件的滿負(fù)荷運行。例如,當(dāng)一個線程warp等待執(zhí)行矩陣乘法時,另一個可以被指派執(zhí)行使用快速指數(shù)運算的指令。
在某些方面,H100對占用率的依賴程度低于前幾代硬件。
它的異步特性使得即使單一指令流也能使多個硬件部分同時持續(xù)運行,包括讀取內(nèi)存、執(zhí)行矩陣乘法、進(jìn)行共享內(nèi)存的歸約,同時還能在寄存器上進(jìn)行計算。
但高占用率容易隱藏缺陷或同步問題,一個設(shè)計良好的pipeline即使在占用率不高的情況下也能運行得相當(dāng)快。
據(jù)研究人員觀察,英偉達(dá)在設(shè)計GPU時確實考慮到了占用率。且由于存在足夠多的同步操作和足夠多的錯誤可能性,根據(jù)他們的經(jīng)驗,提高占用率通常能顯著增加硬件的實際利用率。
此外,相比H100,A100和RTX 4090更依賴同步指令調(diào)度,占用率更重要。
用雷貓優(yōu)化GPU
鑒于以上情況,如何才能更輕松地編寫所需的內(nèi)核類型,同時充分發(fā)揮硬件的全部潛力?
雷貓(ThunderKittens)登場了。
這是一個嵌入在CUDA中的DSL,本是斯坦福研究人員設(shè)計出來給自己內(nèi)部使用的,后來發(fā)現(xiàn)還真挺好使。
Ps:起這么個名,一是他們覺得小貓很可愛,二來他們覺得大伙兒在代碼中輸入kittens::會很有趣。
具體來說,雷貓包含四種模板類型:
- 寄存器tiles:在寄存器文件上表示二維張量。
- 寄存器向量:在寄存器文件上表示一維張量。
- 共享tiles:在共享內(nèi)存中表示二維張量。
- 共享向量:在共享內(nèi)存中表示一維張量。
tiles通過高度、寬度和布局進(jìn)行參數(shù)化;寄存器向量通過長度和布局進(jìn)行參數(shù)化;而共享向量僅通過長度進(jìn)行參數(shù)化,通常不會遇到bank conflicts問題。
此外,研究人員提供了一系列操作來處理這些張量,既可在warp級別使用,也可用于多個warp協(xié)作,包含初始化器,如將共享向量清零;一元操作,如exp;二元操作,如mul;行/列操作,例如行求和。
雷貓作為一個嵌入到CUDA中的庫,其提供的抽象層在遇到不支持的功能時能夠很好地處理。如果雷貓缺少某些功能,可以直接擴(kuò)展它來實現(xiàn)你想要的效果。
以Tri的flash attention算法為例,在實際應(yīng)用中,即使是使用英偉達(dá)的Cutlass庫,實現(xiàn)起來也是相當(dāng)復(fù)雜。
以下是一個在RTX 4090上使用雷貓編寫的簡單flash attention內(nèi)核的示例。
總共約60行CUDA代碼,硬件利用率達(dá)到了75%。代碼復(fù)雜性主要在于算法本身,而非交織模式或寄存器布局。
#define NUM_WORKERS 16 // This kernel uses 16 workers in parallel per block, to help issue instructions more quickly.
using namespace kittens; // this kernel only handles headdim=64 for simplicity. Also n should be a multiple of 256 here.
__global__ void attend_ker64(int n, const bf16* __restrict__ __q__, const bf16* __restrict__ __k__, const bf16* __restrict__ __v__, bf16* __o__) {
auto warpid = kittens::warpid();
auto block_start = blockIdx.x*(n*64);
const bf16 *_q = __q__ + block_start, *_k = __k__ + block_start, *_v = __v__ + block_start;
bf16 *_o = __o__ + block_start;
extern __shared__ alignment_dummy __shm[]; // this is the CUDA shared memory
shared_allocator al((int*)&__shm[0]);
// K and V live in shared memory -- this is about all that will fit.
st_bf_1x4<ducks::st_layout::swizzle> (&k_smem)[NUM_WORKERS] = al.allocate<st_bf_1x4<ducks::st_layout::swizzle>, NUM_WORKERS>();
st_bf_1x4<ducks::st_layout::swizzle> (&v_smem)[NUM_WORKERS] = al.allocate<st_bf_1x4<ducks::st_layout::swizzle>, NUM_WORKERS>();
// Initialize all of the register tiles.
rt_bf_1x4<> q_reg, k_reg, v_reg; // v_reg need to be swapped into col_l
rt_fl_1x1<> att_block;
rt_bf_1x1<> att_block_mma;
rt_fl_1x4<> o_reg;
rt_fl_1x1<>::col_vec max_vec_last, max_vec; // these are column vectors for the attention block
rt_fl_1x1<>::col_vec norm_vec_last, norm_vec; // these are column vectors for the attention block
int qo_blocks = n / (q_reg.rows*NUM_WORKERS), kv_blocks = n / (q_reg.rows*NUM_WORKERS);
for(auto q_blk = 0; q_blk < qo_blocks; q_blk++) {
// each warp loads its own Q tile of 16x64, and then multiplies by 1/sqrt(d)
load(q_reg, _q + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);
mul(q_reg, q_reg, __float2bfloat16(0.125f)); // temperature adjustment
// zero flash attention L, M, and O registers.
neg_infty(max_vec); // zero registers for the Q chunk
zero(norm_vec);
zero(o_reg);
// iterate over k, v for these q's that have been loaded
for(auto kv_idx = 0; kv_idx < kv_blocks; kv_idx++) {
// each warp loads its own chunk of k, v into shared memory
load(v_smem[warpid], _v + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);
load(k_smem[warpid], _k + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);
__syncthreads(); // we need to make sure all memory is loaded before we can begin the compute phase
// now each warp goes through all of the subtiles, loads them, and then does the flash attention internal alg.
for(int subtile = 0; subtile < NUM_WORKERS; subtile++) {
load(k_reg, k_smem[subtile]); // load k from shared into registers
zero(att_block); // zero 16x16 attention tile
mma_ABt(att_block, q_reg, k_reg, att_block); // Q@K.T
copy(norm_vec_last, norm_vec);
copy(max_vec_last, max_vec);
row_max(max_vec, att_block, max_vec); // accumulate onto the max_vec
sub_row(att_block, att_block, max_vec); // subtract max from attention -- now all <=0
exp(att_block, att_block); // exponentiate the block in-place.
sub(max_vec_last, max_vec_last, max_vec); // subtract new max from old max to find the new normalization.
exp(max_vec_last, max_vec_last); // exponentiate this vector -- this is what we need to normalize by.
mul(norm_vec, norm_vec, max_vec_last); // and the norm vec is now normalized.
row_sum(norm_vec, att_block, norm_vec); // accumulate the new attention block onto the now-rescaled norm_vec
div_row(att_block, att_block, norm_vec); // now the attention block is correctly normalized
mul(norm_vec_last, norm_vec_last, max_vec_last); // normalize the previous norm vec according to the new max
div(norm_vec_last, norm_vec_last, norm_vec); // normalize the previous norm vec according to the new norm
copy(att_block_mma, att_block); // convert to bf16 for mma_AB
load(v_reg, v_smem[subtile]); // load v from shared into registers.
rt_bf_1x4<ducks::rt_layout::col> &v_reg_col = swap_layout_inplace(v_reg); // this is a reference and the call has invalidated v_reg
mul_row(o_reg, o_reg, norm_vec_last); // normalize o_reg in advance of mma_AB'ing onto it
mma_AB(o_reg, att_block_mma, v_reg_col, o_reg); // mfma onto o_reg with the local attention@V matmul.
}
__syncthreads(); // we need to make sure all warps are done before we can start loading the next kv chunk
}
store(_o + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, o_reg, q_reg.cols); // write out o. compiler has an issue with register usage if d is made constexpr q_reg.rows :/
}
}
關(guān)于TMA、WGMMA、交織模式和描述符的復(fù)雜性,這里展示了一個使用雷貓編寫的,針對H100的FlashAttention-2算法的前向傳遞示例。
template<int D>
__global__ __launch_bounds__((NUM_WORKERS)*kittens::WARP_THREADS, 2)
void fwd_attend_ker_dim(int N, const CUtensorMap* tma_q, const CUtensorMap* tma_k, const CUtensorMap* tma_v, CUtensorMap* tma_o) {
extern __shared__ int __shm[]; // this is the CUDA shared memory
tma_swizzle_allocator al((int*)&__shm[0]);
constexpr int tile_width = fwd_attend_ker_tile_dims<D>::tile_width; // constants
constexpr int qo_height = fwd_attend_ker_tile_dims<D>::qo_height;
constexpr int kv_height = fwd_attend_ker_tile_dims<D>::kv_height;
st_bf<qo_height, tile_width, layout_q> (&q_smem) [NUM_WARPGROUPS] = al.allocate<st_bf<qo_height, tile_width, layout_q>, NUM_WARPGROUPS>();
st_bf<kv_height, tile_width, layout_k> (&k_smem)[2][NUM_WORKERS_KV] = al.allocate<st_bf<kv_height, tile_width, layout_k>, 2, NUM_WORKERS_KV>();
st_bf<kv_height, tile_width, layout_v> (&v_smem)[2][NUM_WORKERS_KV] = al.allocate<st_bf<kv_height, tile_width, layout_v>, 2, NUM_WORKERS_KV>();
int tic = 0, toc = 1;
rt_fl<1, kv_height> att_block;
rt_bf<1, kv_height> att_block_mma;
rt_fl<1, qo_height> o_prev;
col_vec<rt_fl<1, kv_height>> max_vec_last, max_vec;
col_vec<rt_fl<1, kv_height>> norm_vec_last, norm_vec;
int warpid = kittens::warpid();
int warpgroupid = warpid/kittens::WARPGROUP_WARPS;
int kv_blocks = N / (NUM_WORKERS_KV*k_smem[0][0].rows);
__shared__ uint64_t qsmem_barrier, kvsmem_barrier;//, vsmem_barrier;
int q_phasebit = 0;
int kv_phasebit = 0;
if (threadIdx.x == 0) {
tma::init_barrier<st_bf<qo_height, tile_width, layout_q>, NUM_WARPGROUPS>(qsmem_barrier, 1);
tma::init_barrier<st_bf<kv_height, tile_width, layout_k>, NUM_WORKERS_KV*2>(kvsmem_barrier, 1);
}
if (warpid == 0) {
for (int wg = 0; wg < NUM_WORKERS/kittens::WARPGROUP_WARPS; wg++) { // load q
int tile_idx = (blockIdx.y * NUM_WARPGROUPS * gridDim.x) + (blockIdx.x * NUM_WARPGROUPS) + wg;
tma::load_async((q_smem[wg]), tma_q, qsmem_barrier, tile_idx);
}
for (int w = 0; w < NUM_WORKERS_KV; w++) { // load k, v
int tile_idx = (blockIdx.y * NUM_WORKERS_KV * kv_blocks) + (0 * NUM_WORKERS_KV) + w;
tma::load_async((k_smem[tic][w]), tma_k, kvsmem_barrier, tile_idx);
tma::load_async((v_smem[tic][w]), tma_v, kvsmem_barrier, tile_idx);
}
}
neg_infty(max_vec); // zero registers for the Q chunk
zero(norm_vec);
zero(o_prev);
__syncthreads();
tma::arrive_and_wait(qsmem_barrier, q_phasebit);
q_phasebit ^= 1;
if constexpr (D == 64) { warpgroup::mul(q_smem[warpgroupid], q_smem[warpgroupid], __float2bfloat16(0.125f)); }
else { warpgroup::mul(q_smem[warpgroupid], q_smem[warpgroupid], __float2bfloat16(0.08838834764f)); }
for (auto kv_idx = 0; kv_idx < kv_blocks; kv_idx++, tic ^= 1, toc ^= 1) {
tma::arrive_and_wait(kvsmem_barrier, kv_phasebit);
kv_phasebit ^= 1;
__syncthreads();
if (warpid == 0) {
tma::set_bytes(kvsmem_barrier, 2 * NUM_WORKERS_KV * k_smem[0][0].num_elements * sizeof(bf16));
if (kv_idx + 1 < kv_blocks) {
for (int w = 0; w < NUM_WORKERS_KV; w++) {
int tile_idx = (blockIdx.y * NUM_WORKERS_KV * kv_blocks) + ((kv_idx + 1) * NUM_WORKERS_KV) + w;
tma::load_async((k_smem[toc][w]), tma_k, kvsmem_barrier, tile_idx);
tma::load_async((v_smem[toc][w]), tma_v, kvsmem_barrier, tile_idx);
}
}
}
warpgroup::mma_fence(att_block);
warpgroup::mm_ABt(att_block, q_smem[warpgroupid], k_smem[tic][0]);
warpgroup::mma_commit_group();
copy(norm_vec_last, norm_vec);
copy(max_vec_last, max_vec);
warpgroup::mma_async_wait();
row_max(max_vec, att_block, max_vec); // accumulate onto the max_vec
sub_row(att_block, att_block, max_vec);
exp(att_block, att_block);
sub(max_vec_last, max_vec_last, max_vec);
exp(max_vec_last, max_vec_last);
mul(norm_vec, norm_vec, max_vec_last);
row_sum(norm_vec, att_block, norm_vec); // accumulate onto the norm_vec
div_row(att_block, att_block, norm_vec);
mul(norm_vec_last, norm_vec_last, max_vec_last);
div(norm_vec_last, norm_vec_last, norm_vec);
copy(att_block_mma, att_block); // convert to bf16 for mma
mul_row(o_prev, o_prev, norm_vec_last); // normalize o_prev in advance of mma'ing onto it
warpgroup::mma_fence(o_prev);
warpgroup::mma_AB(o_prev, att_block_mma, v_smem[tic][0]);
warpgroup::mma_commit_group();
}
auto (*o_smem) = reinterpret_cast<st_bf<qo_height, tile_width, layout_o>(*)>(q_smem); // reuse q memory
warpgroup::store(o_smem[warpgroupid], o_prev);
__syncthreads();
if (warpid % 4 == 0) { // store o
int tile_idx = (blockIdx.y * NUM_WARPGROUPS * gridDim.x) + (blockIdx.x * NUM_WARPGROUPS) + warpgroupid;
tma::store_async(tma_o, (o_smem[warpgroupid]), tile_idx);
tma::store_commit_group();
}
tma::store_async_wait();
}
那么,它的表現(xiàn)如何?
這個內(nèi)核只有100行代碼,實際上它在H100上的性能比FlashAttention-2高出約30%。雷貓負(fù)責(zé)包裝布局和指令,提供了一個可以在GPU上使用的迷你pytorch環(huán)境。
△FA2(通過Pytorch實現(xiàn))與TK在H100 SXM上的多種配置比較
此外,研究人員還發(fā)布了基于線性注意力和其他新架構(gòu)的內(nèi)核。其中基于線性注意力的內(nèi)核的運行速度可達(dá)215 TFLOPs,如果考慮到算法中固有的重計算,速度可超過300 TFLOPs。
盡管線性注意力在理論上效率更高,但此前在實際硬件上表現(xiàn)并不佳。因此,研究人員認(rèn)為這可能促進(jìn)一系列高吞吐量應(yīng)用的發(fā)展。
small tile符合AI和硬件發(fā)展趨勢
最后,雷貓研究團(tuán)隊總結(jié)了開發(fā)雷貓的一些思考。在他們看來,雷貓之所以有效,是因為它的目標(biāo)并不是試圖做所有事:
CUDA的確比雷貓表達(dá)能力更廣,雷貓小而簡單,功能有限。但雷貓的small tiles抽象設(shè)計符合AI和硬件的發(fā)展趨勢。
雖然雷貓不支持小于16的維度,但研究人員認(rèn)為這并不重要,因為硬件也不傾向于支持過小的維度。
如果你的矩陣乘法小于16x16,你確定你正在做的是AI嗎?
從理論出發(fā),研究人員認(rèn)為需要進(jìn)行一種框架轉(zhuǎn)變。
“寄存器當(dāng)然不應(yīng)該像舊CPU那樣32位字。CUDA使用的1024位寬向量寄存器確實是朝著正確方向邁出的一步。但對我們來說,寄存器是16x16的數(shù)據(jù)tile。我們認(rèn)為AI需要這樣的設(shè)計,畢竟,它仍然只是矩陣乘法、歸約和重塑。我們認(rèn)為硬件也需要這樣的設(shè)計,小型矩陣乘法迫切需要超出系統(tǒng)級MMA的硬件支持。”
研究人員認(rèn)為,應(yīng)該根據(jù)硬件特性來重新定義AI的設(shè)計理念。例如,循環(huán)狀態(tài)應(yīng)該有多大?應(yīng)該足夠大以適應(yīng)一個SM。計算的密度應(yīng)該有多高?不應(yīng)低于硬件的需求。
我們未來工作的一個重要方向是利用我們對硬件的了解來幫助我們設(shè)計與之匹配的AI。