自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

LLM高效推理:KV緩存與分頁(yè)注意力機(jī)制深度解析

人工智能
隨著大型語(yǔ)言模型(LLM)規(guī)模和復(fù)雜性的持續(xù)增長(zhǎng),高效推理的重要性日益凸顯。KV(鍵值)緩存與分頁(yè)注意力是兩種優(yōu)化LLM推理的關(guān)鍵技術(shù)。本文將深入剖析這些概念,闡述其重要性,并探討它們?cè)趦H解碼器(decoder-only)模型中的工作原理。

隨著大型語(yǔ)言模型(LLM)規(guī)模和復(fù)雜性的持續(xù)增長(zhǎng),高效推理的重要性日益凸顯。KV(鍵值)緩存與分頁(yè)注意力是兩種優(yōu)化LLM推理的關(guān)鍵技術(shù)。本文將深入剖析這些概念,闡述其重要性,并探討它們?cè)趦H解碼器(decoder-only)模型中的工作原理。

常規(guī)推理機(jī)制

首先,我們通過(guò)一個(gè)簡(jiǎn)單的例子來(lái)理解Transformer模型中典型的推理過(guò)程。假設(shè)我們需要生成短語(yǔ):

“The quick brown fox jumped”

以下是常規(guī)推理的簡(jiǎn)化實(shí)現(xiàn):

import numpy as np
# 簡(jiǎn)化的嵌入表示,僅用于演示
embeddings = {
    'The': np.array([1, 0, 0, 0]),
    'quick': np.array([0, 1, 0, 0]),
    'brown': np.array([0, 0, 1, 0]),
    'fox': np.array([0, 0, 0, 1]),
    'jumped': np.array([1, 1, 0, 0])
}

# 權(quán)重矩陣(簡(jiǎn)化)
W_Q = W_K = W_V = np.array([[1, 0],
    [0, 1],
    [0, 0],
    [0, 0]])

def compute_attention(self, input_words):
    # 將單詞轉(zhuǎn)換為嵌入向量
    E = np.array([embeddings[word] for word in input_words])

    # 計(jì)算所有token的K和V矩陣
    K = E @ W_K  # 形狀: (seq_len, 2)
    V = E @ W_V  # 形狀: (seq_len, 2)

    # 計(jì)算最后一個(gè)token的Q矩陣
    Q = E[-1] @ W_Q  # 形狀: (1, 2)

    # 計(jì)算縮放的點(diǎn)積注意力得分
    scale = np.sqrt(2)  # 縮放因子,為key/query維度(此處為2)的平方根
    scores = (Q @ K.T) / scale  # 形狀: (1, seq_len)

    # 應(yīng)用Softmax函數(shù),獲得注意力權(quán)重
    attention_weights = self.softmax(scores)  # 形狀: (1, seq_len)

    # 將注意力權(quán)重應(yīng)用于V矩陣
    output = attention_weights @ V  # 形狀: (1, 2)

     return output

以下是逐步生成的過(guò)程:

# 步驟1: 生成 "brown"
input_words_step1 = ['The', 'quick']
output_step1 = compute_attention(input_words_step1)
# 步驟2: 生成 "fox"
input_words_step2 = ['The', 'quick', 'brown']
output_step2 = compute_attention(input_words_step2)
# 步驟3: 生成 "jumped"
input_words_step3 = ['The', 'quick', 'brown', 'fox']
 output_step3 = compute_attention(input_words_step3)

冗余計(jì)算:觀察上述代碼可以發(fā)現(xiàn)對(duì)于每個(gè)新生成的token:

  1. 需要為所有先前的token重新計(jì)算K和V矩陣。
  2. 矩陣的大小隨著token數(shù)量的增加而增大。
  3. 存在大量不必要的重復(fù)計(jì)算。

KV緩存機(jī)制

當(dāng)使用Transformer模型生成文本時(shí),通過(guò)緩存鍵(K)和值(V)矩陣,可以顯著優(yōu)化推理過(guò)程。下圖展示了KV緩存的工作原理:

在上圖中:

  1. q_new表示最新token的查詢向量。
  2. K_prev和V_prev是從先前計(jì)算中緩存得到的鍵和值矩陣。
  3. k_new和v_new僅為當(dāng)前新token計(jì)算。
  4. 藍(lán)色箭頭表示如何利用緩存值和新值計(jì)算注意力。

以下是KV緩存的實(shí)現(xiàn)示例:

def compute_attention_with_cache(self, input_words):
    """使用KV緩存計(jì)算注意力"""
    # 獲取新token(序列中的最后一個(gè)單詞)
    new_word = input_words[-1]
    e_new = embeddings[new_word]

    # 計(jì)算新token的K和V矩陣
    K_new = e_new @ W_K  # 形狀: (2,)
    V_new = e_new @ W_V  # 形狀: (2,)

    # 更新緩存的K和V矩陣
    if self.cached_K is None:
        self.cached_K = K_new.reshape(1, -1)  # 形狀: (1, 2)
        self.cached_V = V_new.reshape(1, -1)  # 形狀: (1, 2)
    else:
        self.cached_K = np.vstack([self.cached_K, K_new])  # 形狀: (seq_len, 2)
        self.cached_V = np.vstack([self.cached_V, V_new])  # 形狀: (seq_len, 2)

    # 計(jì)算最后一個(gè)token的Q矩陣
    Q = e_new @ W_Q  # 形狀: (2,)

    # 使用緩存的K矩陣計(jì)算縮放的點(diǎn)積注意力得分
    scale = np.sqrt(2)  # 縮放因子,為key/query維度(此處為2)的平方根
    scores = (Q @ self.cached_K.T) / scale  # 形狀: (1, seq_len)

    # 應(yīng)用Softmax函數(shù),獲得注意力權(quán)重
    attention_weights = self.softmax(scores)  # 形狀: (1, seq_len)

    # 使用緩存的V矩陣計(jì)算注意力輸出
    output = attention_weights @ self.cached_V  # 形狀: (1, 2)

     return output

以下是逐步生成的過(guò)程:

# 步驟1: 生成 "brown"
input_words_step1 = ['The', 'quick']
output_step1 = compute_attention_with_cache(input_words_step1)
# 步驟2: 生成 "fox"
input_words_step2 = ['The', 'quick', 'brown']
output_step2 = compute_attention_with_cache(input_words_step2)
# 步驟 3: 生成 "jumped"
input_words_step3 = ['The', 'quick', 'brown', 'fox']
 output_step3 = compute_attention_with_cache(input_words_step3)

比較有無(wú)KV緩存的推理計(jì)算

內(nèi)存需求與挑戰(zhàn)

我們來(lái)看一個(gè)使用典型模型參數(shù)的實(shí)際例子:

  • 序列長(zhǎng)度: 4096
  • 層數(shù): 32
  • 注意力頭數(shù): 32
  • 頭維度: 128
  • 精度: FP16 (2 bytes)

每個(gè)token所需的內(nèi)存:

KV_cache_per_token = 2×num_layers×(num_heads×head_dim)×precision
 = 2 × 32 × (32 × 128) × 2 bytes
 = 2 × 32 × 4096 × 2 bytes
 = 524,288 bytes
 ≈ 0.5 MB

KV緩存的低效性

盡管KV緩存顯著提高了計(jì)算效率,但它也帶來(lái)了內(nèi)存管理方面的挑戰(zhàn)。以下是三種主要的內(nèi)存低效類(lèi)型:

內(nèi)部碎片

  • 由因未知輸出長(zhǎng)度而導(dǎo)致的過(guò)度分配引起。
  • 示例:在圖像中,2040個(gè)槽位從未被使用。
  • 影響:可能浪費(fèi)高達(dá)60-80%的已分配內(nèi)存。
  • 解決方案:更精確的輸出長(zhǎng)度估計(jì)或動(dòng)態(tài)分配策略。

預(yù)留浪費(fèi)

  • 為將來(lái)的token生成而預(yù)留的內(nèi)存。
  • 在圖像中顯示為“3 slots future used (reserved)”。
  • 維持生成連續(xù)性的必要措施。
  • 可以通過(guò)更好地預(yù)測(cè)所需的未來(lái)槽位來(lái)優(yōu)化。

外部碎片

  • 由處理具有不同序列長(zhǎng)度的多個(gè)請(qǐng)求導(dǎo)致。
  • 在不同請(qǐng)求之間創(chuàng)建內(nèi)存間隙。
  • 解決方案包括內(nèi)存碎片整理和智能請(qǐng)求批處理。

如上圖所示,通常僅有20-40%的KV緩存被用于存儲(chǔ)實(shí)際的token狀態(tài)。

分頁(yè)注意力:解決內(nèi)存低效的方案

為了應(yīng)對(duì)這些內(nèi)存挑戰(zhàn),可以采用分頁(yè)注意力機(jī)制。

分頁(yè)注意力是一種用于有效處理Transformer模型中長(zhǎng)序列的技術(shù),它通過(guò)將注意力計(jì)算分解為更小、更易于管理的“頁(yè)”或“塊”來(lái)實(shí)現(xiàn)。這種方法降低了內(nèi)存消耗和計(jì)算復(fù)雜度,從而能夠處理原本因過(guò)大而無(wú)法放入內(nèi)存的序列。

def compute_attention_with_paging(self, input_words):
    """使用分頁(yè)KV緩存計(jì)算注意力"""
    # 獲取新token(序列中的最后一個(gè)單詞)
    new_word = input_words[-1]
    e_new = embeddings[new_word]

    # 計(jì)算新token的K和V矩陣
    K_new = e_new @ W_K  # 形狀: (2,)
    V_new = e_new @ W_V  # 形狀: (2,)

    # 確定當(dāng)前頁(yè)的索引
    total_tokens = sum(len(K_page) for K_page in self.cached_K_pages) + 1
    current_page_idx = (total_tokens - 1) // PAGE_SIZE

    # 如果需要,初始化新頁(yè)
    if len(self.cached_K_pages) <= current_page_idx:
        self.cached_K_pages.append([])
        self.cached_V_pages.append([])

    # 將K和V添加到當(dāng)前頁(yè)的緩存中
    self.cached_K_pages[current_page_idx].append(K_new)
    self.cached_V_pages[current_page_idx].append(V_new)

    # 計(jì)算當(dāng)前token的Q矩陣
    Q = e_new @ W_Q  # Shape: (2,)

    # 僅在當(dāng)前頁(yè)內(nèi)計(jì)算注意力
    K_current_page = np.array(self.cached_K_pages[current_page_idx])
    V_current_page = np.array(self.cached_V_pages[current_page_idx])

    # 添加縮放因子,用于點(diǎn)積注意力
    scale = np.sqrt(2)  # 縮放因子,為key/query維度(此處為2)的平方根
    scores = (Q @ K_current_page.T) / scale

    # 應(yīng)用Softmax函數(shù),獲得注意力權(quán)重
    attention_weights = self.softmax(scores)  # 形狀: (1, current_page_size)

    # 將注意力權(quán)重應(yīng)用于當(dāng)前頁(yè)中的V矩陣
    output = attention_weights @ V_current_page

     return output

以下是逐步生成的過(guò)程:

# 步驟1: 生成 "brown"
input_words_step1 = ['The', 'quick']
output_step1 = compute_attention_with_paging(input_words_step1)
# 步驟2: 生成 "fox"
input_words_step2 = ['The', 'quick', 'brown']
output_step2 = compute_attention_with_paging(input_words_step2)
# 步驟3: 生成 "jumped"
input_words_step3 = ['The', 'quick', 'brown', 'fox']
 output_step3 = compute_attention_with_paging(input_words_step3)

為何需要分頁(yè)注意力?

  • 內(nèi)存約束:由于注意力矩陣的規(guī)模與序列長(zhǎng)度呈平方關(guān)系,Transformer模型在處理長(zhǎng)序列時(shí)面臨嚴(yán)重的內(nèi)存限制。
  • 長(zhǎng)序列處理:在諸如語(yǔ)言建?;蛭臋n摘要等任務(wù)中,序列可能非常長(zhǎng)。
  • 效率:通過(guò)以分頁(yè)的方式處理注意力計(jì)算,可以將內(nèi)存使用量保持在一個(gè)常量水平,從而不受序列長(zhǎng)度的影響。

分頁(yè)注意力如何工作?

  • 分割序列:將輸入序列分割成更小的塊或頁(yè)。
  • 局部注意力:在每個(gè)頁(yè)內(nèi)計(jì)算注意力。
  • 跨頁(yè)注意力:可選地,允許有限的跨頁(yè)注意力,以捕獲頁(yè)之間的依賴關(guān)系。
  • 滑動(dòng)窗口:使用重疊的頁(yè)來(lái)確保連續(xù)性。

上述實(shí)現(xiàn)僅限于局部注意力,跨頁(yè)注意力和滑動(dòng)窗口的實(shí)現(xiàn)超出了本文的范圍,將在后續(xù)文章中詳細(xì)介紹。

分頁(yè)注意力的討論

優(yōu)勢(shì)

  • 內(nèi)存效率:注意力計(jì)算被限制在頁(yè)大小內(nèi),內(nèi)存使用量保持恒定,不受總序列長(zhǎng)度的影響。
  • 計(jì)算效率:降低了注意力計(jì)算的復(fù)雜度。
  • 可擴(kuò)展性:能夠處理原本無(wú)法放入內(nèi)存的超長(zhǎng)序列。
權(quán)衡與考慮
  • 上下文信息受限:模型會(huì)丟失跨頁(yè)的一些依賴關(guān)系,這對(duì)于需要全局上下文的任務(wù)可能很重要。

可能的解決方案:

  • 重疊頁(yè):允許頁(yè)之間重疊一定數(shù)量的token,重疊區(qū)域的token可以關(guān)注前一頁(yè)的token。
  • 分層注意力:使用更高層次的注意力機(jī)制來(lái)連接跨頁(yè)的信息。

重疊頁(yè)、分層注意力、跨頁(yè)注意力和滑動(dòng)窗口的完整實(shí)現(xiàn)超出了本文的范圍。

以下實(shí)現(xiàn)僅捕獲局部注意力,作為示例不應(yīng)在實(shí)際應(yīng)用中使用:

# 本實(shí)現(xiàn)僅為演示和理解目的而設(shè)計(jì)的簡(jiǎn)化版本。
# 實(shí)際應(yīng)用中需要更高效和可擴(kuò)展的實(shí)現(xiàn)。

import numpy as np

embeddings = {
    'The': np.array([1, 0, 0, 0]),
    'quick': np.array([0, 1, 0, 0]),
    'brown': np.array([0, 0, 1, 0]),
    'fox': np.array([0, 0, 0, 1]),
    'jumped': np.array([1, 1, 0, 0])
}

W_Q = W_K = W_V = np.array([[1, 0],
                            [0, 1],
                            [0, 0],
                            [0, 0]])

PAGE_SIZE = 2  # 演示用的小頁(yè)尺寸


class AttentionWithCache:
    def __init__(self):
        self.cached_K = None  # 形狀: (seq_len, 2)
        self.cached_V = None  # 形狀: (seq_len, 2)
        self.cached_K_pages = []  # 包含K向量的頁(yè)列表
        self.cached_V_pages = []  # 包含V向量的頁(yè)列表

    def softmax(self, x, axis=-1):
        """
        為x中的每組分?jǐn)?shù)計(jì)算Softmax值。
        包含數(shù)值穩(wěn)定性改進(jìn)。
        """
        # 應(yīng)用最大值減法以提高數(shù)值穩(wěn)定性
        x_max = np.max(x, axis=axis, keepdims=True)
        exp_x = np.exp(x - x_max)
        return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

    def compute_attention(self, input_words):
        # 將單詞轉(zhuǎn)換為嵌入向量
        E = np.array([embeddings[word] for word in input_words])

        # 計(jì)算所有token的K和V矩陣
        K = E @ W_K  # 形狀: (seq_len, 2)
        V = E @ W_V  # 形狀: (seq_len, 2)

        # 計(jì)算最后一個(gè)token的Q矩陣
        Q = E[-1] @ W_Q  # 形狀: (1, 2)

        # 計(jì)算縮放的點(diǎn)積注意力得分
        scale = np.sqrt(2)  # 縮放因子,為key/query維度(此處為2)的平方根
        scores = (Q @ K.T) / scale  # 形狀: (1, seq_len)

        # 應(yīng)用Softmax函數(shù),獲得注意力權(quán)重
        attention_weights = self.softmax(scores)  # 形狀: (1, seq_len)

        # 將注意力權(quán)重應(yīng)用于V矩陣
        output = attention_weights @ V  # 形狀: (1, 2)

        return output

    def compute_attention_with_cache(self, input_words):
        """使用KV緩存計(jì)算注意力"""
        # 獲取新token(序列中的最后一個(gè)單詞)
        new_word = input_words[-1]
        e_new = embeddings[new_word]

        # 計(jì)算新token的K和V矩陣
        K_new = e_new @ W_K  # 形狀: (2,)
        V_new = e_new @ W_V  # 形狀: (2,)

        # 更新緩存的K和V矩陣
        if self.cached_K is None:
            self.cached_K = K_new.reshape(1, -1)  # 形狀: (1, 2)
            self.cached_V = V_new.reshape(1, -1)  # 形狀: (1, 2)
        else:
            self.cached_K = np.vstack([self.cached_K, K_new])  # 形狀: (seq_len, 2)
            self.cached_V = np.vstack([self.cached_V, V_new])  # 形狀: (seq_len, 2)

        # 計(jì)算最后一個(gè)token的Q矩陣
        Q = e_new @ W_Q  # 形狀: (2,)

        # 使用緩存的K矩陣計(jì)算縮放的點(diǎn)積注意力得分
        scale = np.sqrt(2)  # 縮放因子,為key/query維度(此處為2)的平方根
        scores = (Q @ self.cached_K.T) / scale  # 形狀: (1, seq_len)

        # 應(yīng)用Softmax函數(shù),獲得注意力權(quán)重
        attention_weights = self.softmax(scores)  # 形狀: (1, seq_len)

        # 使用緩存的V矩陣計(jì)算注意力輸出
        output = attention_weights @ self.cached_V  # 形狀: (1, 2)

        return output

    def compute_attention_with_paging(self, input_words):
        """使用分頁(yè)KV緩存計(jì)算注意力"""
        # 獲取新token(序列中的最后一個(gè)單詞)
        new_word = input_words[-1]
        e_new = embeddings[new_word]

        # 計(jì)算新token的K和V矩陣
        K_new = e_new @ W_K  # 形狀: (2,)
        V_new = e_new @ W_V  # 形狀: (2,)

        # 確定當(dāng)前頁(yè)的索引
        total_tokens = sum(len(K_page) for K_page in self.cached_K_pages) + 1
        current_page_idx = (total_tokens - 1) // PAGE_SIZE

        # 如果需要,初始化新頁(yè)
        if len(self.cached_K_pages) <= current_page_idx:
            self.cached_K_pages.append([])
            self.cached_V_pages.append([])

        # 將K和V添加到當(dāng)前頁(yè)的緩存中
        self.cached_K_pages[current_page_idx].append(K_new)
        self.cached_V_pages[current_page_idx].append(V_new)

        # 計(jì)算當(dāng)前token的Q矩陣
        Q = e_new @ W_Q  # Shape: (2,)

        # 僅在當(dāng)前頁(yè)內(nèi)計(jì)算注意力
        K_current_page = np.array(self.cached_K_pages[current_page_idx])
        V_current_page = np.array(self.cached_V_pages[current_page_idx])

        # 添加縮放因子,用于點(diǎn)積注意力
        scale = np.sqrt(2)  # 縮放因子,為key/query維度(此處為2)的平方根
        scores = (Q @ K_current_page.T) / scale

        # 應(yīng)用Softmax函數(shù),獲得注意力權(quán)重
        attention_weights = self.softmax(scores)  # 形狀: (1, current_page_size)

        # 將注意力權(quán)重應(yīng)用于當(dāng)前頁(yè)中的V矩陣
        output = attention_weights @ V_current_page

        return output


def compare_implementations():
    print("原始實(shí)現(xiàn):")
    attention1 = AttentionWithCache()

    # 使用原始方法處理序列
    for i in range(len(['The', 'quick', 'brown', 'fox'])):
        words = ['The', 'quick', 'brown', 'fox'][:i + 1]
        output = attention1.compute_attention(words)
        print(f"處理 {words} 后的輸出:")
        print(f"Output: {output}")

    print("\nKV緩存實(shí)現(xiàn):")
    attention2 = AttentionWithCache()

    # 使用KV緩存處理序列
    for i in range(len(['The', 'quick', 'brown', 'fox'])):
        words = ['The', 'quick', 'brown', 'fox'][:i + 1]
        output = attention2.compute_attention_with_cache(words)
        print(f"處理 {words} 后的輸出:")
        print(f"Output: {output}")

    print("\n分頁(yè)注意力實(shí)現(xiàn):")
    attention3 = AttentionWithCache()

    # 使用分頁(yè)注意力處理序列
    for i in range(len(['The', 'quick', 'brown', 'fox'])):
        words = ['The', 'quick', 'brown', 'fox'][:i + 1]
        output = attention3.compute_attention_with_paging(words)
        print(f"處理 {words} 后的輸出:")
        print(f"Output: {output}")
        print(f"頁(yè)數(shù): {len(attention3.cached_K_pages)}")
        print(f"當(dāng)前頁(yè)大小: {len(attention3.cached_K_pages[-1])}\n")


if __name__ == "__main__":
     compare_implementations()

總結(jié)

KV緩存和分頁(yè)注意力是提升LLM推理效率和可擴(kuò)展性的重要技術(shù)。KV緩存通過(guò)消除冗余計(jì)算來(lái)優(yōu)化計(jì)算過(guò)程,而分頁(yè)注意力則解決了處理長(zhǎng)序列時(shí)面臨的內(nèi)存限制。

隨著模型規(guī)模和復(fù)雜性的不斷增長(zhǎng),這些優(yōu)化技術(shù)對(duì)于實(shí)際應(yīng)用變得至關(guān)重要。深入理解和有效實(shí)施這些技術(shù),可以顯著提升LLM部署的性能和效率。

責(zé)任編輯:華軒 來(lái)源: DeepHub IMBA
相關(guān)推薦

2018-08-26 22:25:36

自注意力機(jī)制神經(jīng)網(wǎng)絡(luò)算法

2023-11-13 18:19:54

模型訓(xùn)練

2025-01-17 13:20:00

2024-09-19 10:07:41

2023-11-24 12:36:00

模型訓(xùn)練

2025-01-13 08:23:07

LLMMHAMLP

2025-02-25 09:40:00

模型數(shù)據(jù)AI

2010-11-25 09:37:14

MySQL查詢緩存機(jī)制

2022-03-25 11:29:04

視覺(jué)算法美團(tuán)

2024-06-28 08:04:43

語(yǔ)言模型應(yīng)用

2024-10-31 10:00:39

注意力機(jī)制核心組件

2024-12-25 16:42:18

2020-09-17 12:40:54

神經(jīng)網(wǎng)絡(luò)CNN機(jī)器學(xué)習(xí)

2024-12-09 00:00:10

2025-02-26 14:32:51

2024-12-20 16:46:22

Spring三級(jí)緩存

2024-06-06 09:18:48

2023-05-05 13:11:16

2024-12-04 09:25:00

2024-11-04 10:40:00

AI模型
點(diǎn)贊
收藏

51CTO技術(shù)棧公眾號(hào)