LLM高效推理:KV緩存與分頁(yè)注意力機(jī)制深度解析
隨著大型語(yǔ)言模型(LLM)規(guī)模和復(fù)雜性的持續(xù)增長(zhǎng),高效推理的重要性日益凸顯。KV(鍵值)緩存與分頁(yè)注意力是兩種優(yōu)化LLM推理的關(guān)鍵技術(shù)。本文將深入剖析這些概念,闡述其重要性,并探討它們?cè)趦H解碼器(decoder-only)模型中的工作原理。
常規(guī)推理機(jī)制
首先,我們通過(guò)一個(gè)簡(jiǎn)單的例子來(lái)理解Transformer模型中典型的推理過(guò)程。假設(shè)我們需要生成短語(yǔ):
“The quick brown fox jumped”
以下是常規(guī)推理的簡(jiǎn)化實(shí)現(xiàn):
import numpy as np
# 簡(jiǎn)化的嵌入表示,僅用于演示
embeddings = {
'The': np.array([1, 0, 0, 0]),
'quick': np.array([0, 1, 0, 0]),
'brown': np.array([0, 0, 1, 0]),
'fox': np.array([0, 0, 0, 1]),
'jumped': np.array([1, 1, 0, 0])
}
# 權(quán)重矩陣(簡(jiǎn)化)
W_Q = W_K = W_V = np.array([[1, 0],
[0, 1],
[0, 0],
[0, 0]])
def compute_attention(self, input_words):
# 將單詞轉(zhuǎn)換為嵌入向量
E = np.array([embeddings[word] for word in input_words])
# 計(jì)算所有token的K和V矩陣
K = E @ W_K # 形狀: (seq_len, 2)
V = E @ W_V # 形狀: (seq_len, 2)
# 計(jì)算最后一個(gè)token的Q矩陣
Q = E[-1] @ W_Q # 形狀: (1, 2)
# 計(jì)算縮放的點(diǎn)積注意力得分
scale = np.sqrt(2) # 縮放因子,為key/query維度(此處為2)的平方根
scores = (Q @ K.T) / scale # 形狀: (1, seq_len)
# 應(yīng)用Softmax函數(shù),獲得注意力權(quán)重
attention_weights = self.softmax(scores) # 形狀: (1, seq_len)
# 將注意力權(quán)重應(yīng)用于V矩陣
output = attention_weights @ V # 形狀: (1, 2)
return output
以下是逐步生成的過(guò)程:
# 步驟1: 生成 "brown"
input_words_step1 = ['The', 'quick']
output_step1 = compute_attention(input_words_step1)
# 步驟2: 生成 "fox"
input_words_step2 = ['The', 'quick', 'brown']
output_step2 = compute_attention(input_words_step2)
# 步驟3: 生成 "jumped"
input_words_step3 = ['The', 'quick', 'brown', 'fox']
output_step3 = compute_attention(input_words_step3)
冗余計(jì)算:觀察上述代碼可以發(fā)現(xiàn)對(duì)于每個(gè)新生成的token:
- 需要為所有先前的token重新計(jì)算K和V矩陣。
- 矩陣的大小隨著token數(shù)量的增加而增大。
- 存在大量不必要的重復(fù)計(jì)算。
KV緩存機(jī)制
當(dāng)使用Transformer模型生成文本時(shí),通過(guò)緩存鍵(K)和值(V)矩陣,可以顯著優(yōu)化推理過(guò)程。下圖展示了KV緩存的工作原理:
在上圖中:
- q_new表示最新token的查詢向量。
- K_prev和V_prev是從先前計(jì)算中緩存得到的鍵和值矩陣。
- k_new和v_new僅為當(dāng)前新token計(jì)算。
- 藍(lán)色箭頭表示如何利用緩存值和新值計(jì)算注意力。
以下是KV緩存的實(shí)現(xiàn)示例:
def compute_attention_with_cache(self, input_words):
"""使用KV緩存計(jì)算注意力"""
# 獲取新token(序列中的最后一個(gè)單詞)
new_word = input_words[-1]
e_new = embeddings[new_word]
# 計(jì)算新token的K和V矩陣
K_new = e_new @ W_K # 形狀: (2,)
V_new = e_new @ W_V # 形狀: (2,)
# 更新緩存的K和V矩陣
if self.cached_K is None:
self.cached_K = K_new.reshape(1, -1) # 形狀: (1, 2)
self.cached_V = V_new.reshape(1, -1) # 形狀: (1, 2)
else:
self.cached_K = np.vstack([self.cached_K, K_new]) # 形狀: (seq_len, 2)
self.cached_V = np.vstack([self.cached_V, V_new]) # 形狀: (seq_len, 2)
# 計(jì)算最后一個(gè)token的Q矩陣
Q = e_new @ W_Q # 形狀: (2,)
# 使用緩存的K矩陣計(jì)算縮放的點(diǎn)積注意力得分
scale = np.sqrt(2) # 縮放因子,為key/query維度(此處為2)的平方根
scores = (Q @ self.cached_K.T) / scale # 形狀: (1, seq_len)
# 應(yīng)用Softmax函數(shù),獲得注意力權(quán)重
attention_weights = self.softmax(scores) # 形狀: (1, seq_len)
# 使用緩存的V矩陣計(jì)算注意力輸出
output = attention_weights @ self.cached_V # 形狀: (1, 2)
return output
以下是逐步生成的過(guò)程:
# 步驟1: 生成 "brown"
input_words_step1 = ['The', 'quick']
output_step1 = compute_attention_with_cache(input_words_step1)
# 步驟2: 生成 "fox"
input_words_step2 = ['The', 'quick', 'brown']
output_step2 = compute_attention_with_cache(input_words_step2)
# 步驟 3: 生成 "jumped"
input_words_step3 = ['The', 'quick', 'brown', 'fox']
output_step3 = compute_attention_with_cache(input_words_step3)
比較有無(wú)KV緩存的推理計(jì)算
內(nèi)存需求與挑戰(zhàn)
我們來(lái)看一個(gè)使用典型模型參數(shù)的實(shí)際例子:
- 序列長(zhǎng)度: 4096
- 層數(shù): 32
- 注意力頭數(shù): 32
- 頭維度: 128
- 精度: FP16 (2 bytes)
每個(gè)token所需的內(nèi)存:
KV_cache_per_token = 2×num_layers×(num_heads×head_dim)×precision
= 2 × 32 × (32 × 128) × 2 bytes
= 2 × 32 × 4096 × 2 bytes
= 524,288 bytes
≈ 0.5 MB
KV緩存的低效性
盡管KV緩存顯著提高了計(jì)算效率,但它也帶來(lái)了內(nèi)存管理方面的挑戰(zhàn)。以下是三種主要的內(nèi)存低效類(lèi)型:
內(nèi)部碎片
- 由因未知輸出長(zhǎng)度而導(dǎo)致的過(guò)度分配引起。
- 示例:在圖像中,2040個(gè)槽位從未被使用。
- 影響:可能浪費(fèi)高達(dá)60-80%的已分配內(nèi)存。
- 解決方案:更精確的輸出長(zhǎng)度估計(jì)或動(dòng)態(tài)分配策略。
預(yù)留浪費(fèi)
- 為將來(lái)的token生成而預(yù)留的內(nèi)存。
- 在圖像中顯示為“3 slots future used (reserved)”。
- 維持生成連續(xù)性的必要措施。
- 可以通過(guò)更好地預(yù)測(cè)所需的未來(lái)槽位來(lái)優(yōu)化。
外部碎片
- 由處理具有不同序列長(zhǎng)度的多個(gè)請(qǐng)求導(dǎo)致。
- 在不同請(qǐng)求之間創(chuàng)建內(nèi)存間隙。
- 解決方案包括內(nèi)存碎片整理和智能請(qǐng)求批處理。
如上圖所示,通常僅有20-40%的KV緩存被用于存儲(chǔ)實(shí)際的token狀態(tài)。
分頁(yè)注意力:解決內(nèi)存低效的方案
為了應(yīng)對(duì)這些內(nèi)存挑戰(zhàn),可以采用分頁(yè)注意力機(jī)制。
分頁(yè)注意力是一種用于有效處理Transformer模型中長(zhǎng)序列的技術(shù),它通過(guò)將注意力計(jì)算分解為更小、更易于管理的“頁(yè)”或“塊”來(lái)實(shí)現(xiàn)。這種方法降低了內(nèi)存消耗和計(jì)算復(fù)雜度,從而能夠處理原本因過(guò)大而無(wú)法放入內(nèi)存的序列。
def compute_attention_with_paging(self, input_words):
"""使用分頁(yè)KV緩存計(jì)算注意力"""
# 獲取新token(序列中的最后一個(gè)單詞)
new_word = input_words[-1]
e_new = embeddings[new_word]
# 計(jì)算新token的K和V矩陣
K_new = e_new @ W_K # 形狀: (2,)
V_new = e_new @ W_V # 形狀: (2,)
# 確定當(dāng)前頁(yè)的索引
total_tokens = sum(len(K_page) for K_page in self.cached_K_pages) + 1
current_page_idx = (total_tokens - 1) // PAGE_SIZE
# 如果需要,初始化新頁(yè)
if len(self.cached_K_pages) <= current_page_idx:
self.cached_K_pages.append([])
self.cached_V_pages.append([])
# 將K和V添加到當(dāng)前頁(yè)的緩存中
self.cached_K_pages[current_page_idx].append(K_new)
self.cached_V_pages[current_page_idx].append(V_new)
# 計(jì)算當(dāng)前token的Q矩陣
Q = e_new @ W_Q # Shape: (2,)
# 僅在當(dāng)前頁(yè)內(nèi)計(jì)算注意力
K_current_page = np.array(self.cached_K_pages[current_page_idx])
V_current_page = np.array(self.cached_V_pages[current_page_idx])
# 添加縮放因子,用于點(diǎn)積注意力
scale = np.sqrt(2) # 縮放因子,為key/query維度(此處為2)的平方根
scores = (Q @ K_current_page.T) / scale
# 應(yīng)用Softmax函數(shù),獲得注意力權(quán)重
attention_weights = self.softmax(scores) # 形狀: (1, current_page_size)
# 將注意力權(quán)重應(yīng)用于當(dāng)前頁(yè)中的V矩陣
output = attention_weights @ V_current_page
return output
以下是逐步生成的過(guò)程:
# 步驟1: 生成 "brown"
input_words_step1 = ['The', 'quick']
output_step1 = compute_attention_with_paging(input_words_step1)
# 步驟2: 生成 "fox"
input_words_step2 = ['The', 'quick', 'brown']
output_step2 = compute_attention_with_paging(input_words_step2)
# 步驟3: 生成 "jumped"
input_words_step3 = ['The', 'quick', 'brown', 'fox']
output_step3 = compute_attention_with_paging(input_words_step3)
為何需要分頁(yè)注意力?
- 內(nèi)存約束:由于注意力矩陣的規(guī)模與序列長(zhǎng)度呈平方關(guān)系,Transformer模型在處理長(zhǎng)序列時(shí)面臨嚴(yán)重的內(nèi)存限制。
- 長(zhǎng)序列處理:在諸如語(yǔ)言建?;蛭臋n摘要等任務(wù)中,序列可能非常長(zhǎng)。
- 效率:通過(guò)以分頁(yè)的方式處理注意力計(jì)算,可以將內(nèi)存使用量保持在一個(gè)常量水平,從而不受序列長(zhǎng)度的影響。
分頁(yè)注意力如何工作?
- 分割序列:將輸入序列分割成更小的塊或頁(yè)。
- 局部注意力:在每個(gè)頁(yè)內(nèi)計(jì)算注意力。
- 跨頁(yè)注意力:可選地,允許有限的跨頁(yè)注意力,以捕獲頁(yè)之間的依賴關(guān)系。
- 滑動(dòng)窗口:使用重疊的頁(yè)來(lái)確保連續(xù)性。
上述實(shí)現(xiàn)僅限于局部注意力,跨頁(yè)注意力和滑動(dòng)窗口的實(shí)現(xiàn)超出了本文的范圍,將在后續(xù)文章中詳細(xì)介紹。
分頁(yè)注意力的討論
優(yōu)勢(shì)
- 內(nèi)存效率:注意力計(jì)算被限制在頁(yè)大小內(nèi),內(nèi)存使用量保持恒定,不受總序列長(zhǎng)度的影響。
- 計(jì)算效率:降低了注意力計(jì)算的復(fù)雜度。
- 可擴(kuò)展性:能夠處理原本無(wú)法放入內(nèi)存的超長(zhǎng)序列。
權(quán)衡與考慮
- 上下文信息受限:模型會(huì)丟失跨頁(yè)的一些依賴關(guān)系,這對(duì)于需要全局上下文的任務(wù)可能很重要。
可能的解決方案:
- 重疊頁(yè):允許頁(yè)之間重疊一定數(shù)量的token,重疊區(qū)域的token可以關(guān)注前一頁(yè)的token。
- 分層注意力:使用更高層次的注意力機(jī)制來(lái)連接跨頁(yè)的信息。
重疊頁(yè)、分層注意力、跨頁(yè)注意力和滑動(dòng)窗口的完整實(shí)現(xiàn)超出了本文的范圍。
以下實(shí)現(xiàn)僅捕獲局部注意力,作為示例不應(yīng)在實(shí)際應(yīng)用中使用:
# 本實(shí)現(xiàn)僅為演示和理解目的而設(shè)計(jì)的簡(jiǎn)化版本。
# 實(shí)際應(yīng)用中需要更高效和可擴(kuò)展的實(shí)現(xiàn)。
import numpy as np
embeddings = {
'The': np.array([1, 0, 0, 0]),
'quick': np.array([0, 1, 0, 0]),
'brown': np.array([0, 0, 1, 0]),
'fox': np.array([0, 0, 0, 1]),
'jumped': np.array([1, 1, 0, 0])
}
W_Q = W_K = W_V = np.array([[1, 0],
[0, 1],
[0, 0],
[0, 0]])
PAGE_SIZE = 2 # 演示用的小頁(yè)尺寸
class AttentionWithCache:
def __init__(self):
self.cached_K = None # 形狀: (seq_len, 2)
self.cached_V = None # 形狀: (seq_len, 2)
self.cached_K_pages = [] # 包含K向量的頁(yè)列表
self.cached_V_pages = [] # 包含V向量的頁(yè)列表
def softmax(self, x, axis=-1):
"""
為x中的每組分?jǐn)?shù)計(jì)算Softmax值。
包含數(shù)值穩(wěn)定性改進(jìn)。
"""
# 應(yīng)用最大值減法以提高數(shù)值穩(wěn)定性
x_max = np.max(x, axis=axis, keepdims=True)
exp_x = np.exp(x - x_max)
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
def compute_attention(self, input_words):
# 將單詞轉(zhuǎn)換為嵌入向量
E = np.array([embeddings[word] for word in input_words])
# 計(jì)算所有token的K和V矩陣
K = E @ W_K # 形狀: (seq_len, 2)
V = E @ W_V # 形狀: (seq_len, 2)
# 計(jì)算最后一個(gè)token的Q矩陣
Q = E[-1] @ W_Q # 形狀: (1, 2)
# 計(jì)算縮放的點(diǎn)積注意力得分
scale = np.sqrt(2) # 縮放因子,為key/query維度(此處為2)的平方根
scores = (Q @ K.T) / scale # 形狀: (1, seq_len)
# 應(yīng)用Softmax函數(shù),獲得注意力權(quán)重
attention_weights = self.softmax(scores) # 形狀: (1, seq_len)
# 將注意力權(quán)重應(yīng)用于V矩陣
output = attention_weights @ V # 形狀: (1, 2)
return output
def compute_attention_with_cache(self, input_words):
"""使用KV緩存計(jì)算注意力"""
# 獲取新token(序列中的最后一個(gè)單詞)
new_word = input_words[-1]
e_new = embeddings[new_word]
# 計(jì)算新token的K和V矩陣
K_new = e_new @ W_K # 形狀: (2,)
V_new = e_new @ W_V # 形狀: (2,)
# 更新緩存的K和V矩陣
if self.cached_K is None:
self.cached_K = K_new.reshape(1, -1) # 形狀: (1, 2)
self.cached_V = V_new.reshape(1, -1) # 形狀: (1, 2)
else:
self.cached_K = np.vstack([self.cached_K, K_new]) # 形狀: (seq_len, 2)
self.cached_V = np.vstack([self.cached_V, V_new]) # 形狀: (seq_len, 2)
# 計(jì)算最后一個(gè)token的Q矩陣
Q = e_new @ W_Q # 形狀: (2,)
# 使用緩存的K矩陣計(jì)算縮放的點(diǎn)積注意力得分
scale = np.sqrt(2) # 縮放因子,為key/query維度(此處為2)的平方根
scores = (Q @ self.cached_K.T) / scale # 形狀: (1, seq_len)
# 應(yīng)用Softmax函數(shù),獲得注意力權(quán)重
attention_weights = self.softmax(scores) # 形狀: (1, seq_len)
# 使用緩存的V矩陣計(jì)算注意力輸出
output = attention_weights @ self.cached_V # 形狀: (1, 2)
return output
def compute_attention_with_paging(self, input_words):
"""使用分頁(yè)KV緩存計(jì)算注意力"""
# 獲取新token(序列中的最后一個(gè)單詞)
new_word = input_words[-1]
e_new = embeddings[new_word]
# 計(jì)算新token的K和V矩陣
K_new = e_new @ W_K # 形狀: (2,)
V_new = e_new @ W_V # 形狀: (2,)
# 確定當(dāng)前頁(yè)的索引
total_tokens = sum(len(K_page) for K_page in self.cached_K_pages) + 1
current_page_idx = (total_tokens - 1) // PAGE_SIZE
# 如果需要,初始化新頁(yè)
if len(self.cached_K_pages) <= current_page_idx:
self.cached_K_pages.append([])
self.cached_V_pages.append([])
# 將K和V添加到當(dāng)前頁(yè)的緩存中
self.cached_K_pages[current_page_idx].append(K_new)
self.cached_V_pages[current_page_idx].append(V_new)
# 計(jì)算當(dāng)前token的Q矩陣
Q = e_new @ W_Q # Shape: (2,)
# 僅在當(dāng)前頁(yè)內(nèi)計(jì)算注意力
K_current_page = np.array(self.cached_K_pages[current_page_idx])
V_current_page = np.array(self.cached_V_pages[current_page_idx])
# 添加縮放因子,用于點(diǎn)積注意力
scale = np.sqrt(2) # 縮放因子,為key/query維度(此處為2)的平方根
scores = (Q @ K_current_page.T) / scale
# 應(yīng)用Softmax函數(shù),獲得注意力權(quán)重
attention_weights = self.softmax(scores) # 形狀: (1, current_page_size)
# 將注意力權(quán)重應(yīng)用于當(dāng)前頁(yè)中的V矩陣
output = attention_weights @ V_current_page
return output
def compare_implementations():
print("原始實(shí)現(xiàn):")
attention1 = AttentionWithCache()
# 使用原始方法處理序列
for i in range(len(['The', 'quick', 'brown', 'fox'])):
words = ['The', 'quick', 'brown', 'fox'][:i + 1]
output = attention1.compute_attention(words)
print(f"處理 {words} 后的輸出:")
print(f"Output: {output}")
print("\nKV緩存實(shí)現(xiàn):")
attention2 = AttentionWithCache()
# 使用KV緩存處理序列
for i in range(len(['The', 'quick', 'brown', 'fox'])):
words = ['The', 'quick', 'brown', 'fox'][:i + 1]
output = attention2.compute_attention_with_cache(words)
print(f"處理 {words} 后的輸出:")
print(f"Output: {output}")
print("\n分頁(yè)注意力實(shí)現(xiàn):")
attention3 = AttentionWithCache()
# 使用分頁(yè)注意力處理序列
for i in range(len(['The', 'quick', 'brown', 'fox'])):
words = ['The', 'quick', 'brown', 'fox'][:i + 1]
output = attention3.compute_attention_with_paging(words)
print(f"處理 {words} 后的輸出:")
print(f"Output: {output}")
print(f"頁(yè)數(shù): {len(attention3.cached_K_pages)}")
print(f"當(dāng)前頁(yè)大小: {len(attention3.cached_K_pages[-1])}\n")
if __name__ == "__main__":
compare_implementations()
總結(jié)
KV緩存和分頁(yè)注意力是提升LLM推理效率和可擴(kuò)展性的重要技術(shù)。KV緩存通過(guò)消除冗余計(jì)算來(lái)優(yōu)化計(jì)算過(guò)程,而分頁(yè)注意力則解決了處理長(zhǎng)序列時(shí)面臨的內(nèi)存限制。
隨著模型規(guī)模和復(fù)雜性的不斷增長(zhǎng),這些優(yōu)化技術(shù)對(duì)于實(shí)際應(yīng)用變得至關(guān)重要。深入理解和有效實(shí)施這些技術(shù),可以顯著提升LLM部署的性能和效率。