自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

一文詳解MHA、GQA、MQA原理 原創(chuàng)

發(fā)布于 2024-11-14 15:40
瀏覽
0收藏

前言

本文回顧一下MHA、GQA、MQA,詳細(xì)解讀下MHA、GQA、MQA這三種常見(jiàn)注意力機(jī)制的原理。

一文詳解MHA、GQA、MQA原理-AI.x社區(qū)

圖1 MHA、GQA、MQA一覽

self-attention

一文詳解MHA、GQA、MQA原理-AI.x社區(qū)

self-attention

在自注意力機(jī)制中,輸入通常是一個(gè)統(tǒng)一的輸入矩陣,而這個(gè)矩陣后續(xù)會(huì)通過(guò)乘以不同的權(quán)重矩陣來(lái)轉(zhuǎn)換成三個(gè)不同的向量集合:查詢(xún)向量Q、鍵向量K和值向量V。這三組向量是通過(guò)線(xiàn)性變換方式生成:

1.查詢(xún)向量 (Q): Q=XWQ

2.鍵向量 (K): K=XWK

3.值向量 (V): V=XWV

W,WK和WV可學(xué)習(xí)的權(quán)重矩陣,分別對(duì)應(yīng)于查詢(xún)、鍵和值。這些矩陣的維度取決于模型的設(shè)計(jì),通常它們的輸出維度(列數(shù)) 是預(yù)先定義的,以滿(mǎn)足特定的模型架構(gòu)要求。 在Transformer模型中,使用不同的權(quán)重矩陣W,WK和WV來(lái)分別生成查詢(xún)向量Q、鍵向量K和值向量V的目的是為了允許模型在不同的表示空間中學(xué)習(xí)和抽取特征。這樣做增加了模型的靈活性和表達(dá)能力,允許模型分別優(yōu)化用于匹配(Q 和K)和用于輸出信息合成(V)的表示。

在自注意力和多頭注意力機(jī)制中,使用

一文詳解MHA、GQA、MQA原理-AI.x社區(qū)

作為縮放因子進(jìn)行縮放操作是為了防止在計(jì)算點(diǎn)積時(shí)由于維度較高導(dǎo)致的數(shù)值穩(wěn)定性問(wèn)題。這里的dk是鍵向量的維度。如果不進(jìn)行縮放,當(dāng)dk較大時(shí),點(diǎn)積的結(jié)果可能會(huì)變得非常大,這會(huì)導(dǎo)致在應(yīng)用softmax函數(shù)時(shí)產(chǎn)生的梯度非常小。因?yàn)閟oftmax函數(shù)是通過(guò)指數(shù)函數(shù)計(jì)算的,大的輸入值會(huì)使得部分輸出接近于1,而其他接近于0,從而導(dǎo)致梯度消失,這會(huì)在反向傳播過(guò)程中造成梯度非常小,使得學(xué)習(xí)變得非常緩慢。

通過(guò)點(diǎn)積結(jié)果除以

一文詳解MHA、GQA、MQA原理-AI.x社區(qū)

 ,可以調(diào)整這些值的范圍,使得它們不會(huì)太大。這樣,softmax的輸入在一個(gè)合適的范圍內(nèi),有助于避免極端的指數(shù)運(yùn)算結(jié)果,從而保持?jǐn)?shù)值穩(wěn)定性和更有效的梯度流。這個(gè)操作確保了即使在dk很大的情況下, 注意力機(jī)制也能穩(wěn)定并有效地學(xué)習(xí)。

代碼實(shí)現(xiàn)

import torch
import torch.nn as nn
import torch.nn.functional as F


class SelfAttention(nn.Module):
    def __init__(self, seq_length):
        super(SelfAttention, self).__init__()
        self.input_size = seq_length
        # 定義三個(gè)權(quán)重矩陣:Wq、Wk、Wv
        self.Wq = nn.Linear(seq_length, seq_length)  # 線(xiàn)性變換
        self.Wk = nn.Linear(seq_length, seq_length)
        self.Wv = nn.Linear(seq_length, seq_length)

    def forward(self, input):
        # 計(jì)算Q,K,V 三個(gè)矩陣
        q = self.Wq(input)
        k = self.Wk(input)
        v = self.Wv(input)

        # 計(jì)算QK^T,即向量之間的相關(guān)度
        attention_scores = torch.matmul(q, k.transpose(-1, -2)) / torch.sqrt(torch.tensor(float(self.input_size)))
        # 計(jì)算向量權(quán)重,softmax歸一化
        attention_weight = F.softmax(attention_scores, dim=-1)
        # 計(jì)算輸出
        output = torch.matmul(attention_weight, v)
        return output


x = torch.randn(2, 3, 4)
Self_Attention = SelfAttention(4)  # 傳入輸入向量的維度
output = Self_Attention(x)
print(output.shape)

MHA(多頭注意力)

一文詳解MHA、GQA、MQA原理-AI.x社區(qū)

Transformer 編碼器塊內(nèi)的縮放點(diǎn)積注意力機(jī)制和多頭注意力機(jī)制

一文詳解MHA、GQA、MQA原理-AI.x社區(qū)

MHA計(jì)算過(guò)程

一文詳解MHA、GQA、MQA原理-AI.x社區(qū)

代碼實(shí)現(xiàn)

import torch
import torch.nn as nn


class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.wq = nn.Linear(embed_dim, embed_dim)
        self.wk = nn.Linear(embed_dim, embed_dim)
        self.wv = nn.Linear(embed_dim, embed_dim)
        self.wo = nn.Linear(embed_dim, embed_dim)

    def mh_split(self, hidden):
        batch_size = hidden.shape[0]
        x = hidden.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        return x

    def forward(self, hidden_states, mask=None):
        batch_size = hidden_states.size(0)

        # 線(xiàn)性變換
        q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)

        # 多頭切分
        q, k, v = self.mh_split(q), self.mh_split(k), self.mh_split(v)

        # 注意力計(jì)算
        scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attention = torch.softmax(scores, dim=-1)
        output = torch.matmul(attention, v)

        # 拼接多頭
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)

        # 線(xiàn)性變換
        output = self.wo(output)

        return output

x = torch.rand(2, 3, 36)
print(x)
output = MultiHeadAttention(36, 6)
y = output(x)
print(y.shape)

MHA 能夠理解輸入不同部分之間的關(guān)系。然而,這種復(fù)雜性是有代價(jià)的——對(duì)內(nèi)存帶寬的需求很大,尤其是在解碼器推理期間。主要問(wèn)題的關(guān)鍵在于內(nèi)存開(kāi)銷(xiāo)。在自回歸模型中,每個(gè)解碼步驟都需要加載解碼器權(quán)重以及所有注意鍵和值。這個(gè)過(guò)程不僅計(jì)算量大,而且內(nèi)存帶寬也大。隨著模型規(guī)模的擴(kuò)大,這種開(kāi)銷(xiāo)也會(huì)增加,使得擴(kuò)展變得越來(lái)越艱巨。

因此,多查詢(xún)注意 (MQA) 應(yīng)運(yùn)而生,成為緩解這一瓶頸的解決方案。其理念簡(jiǎn)單而有效:使用多個(gè)查詢(xún)頭,但只使用一個(gè)鍵和值頭。這種方法顯著減少了內(nèi)存負(fù)載,提高了推理速度。

MQA(多查詢(xún)注意力)

一文詳解MHA、GQA、MQA原理-AI.x社區(qū)

圖2 MHA和MQA的差別

MQA是MHA的一種變體,也是用于自回歸解碼的一種注意力機(jī)制。,圖1、圖2很形象的描繪了MHA和MQA的對(duì)比,與MHA 不同的是,MQA 讓所有的Head之間共享同樣的一份 K 和 V 矩陣(意味K和V的計(jì)算唯一),只讓 Q 保留了原始多頭的性質(zhì)(每個(gè)Head存在不同的轉(zhuǎn)換),從而大大減少 K 和 V 矩陣的參數(shù)量以及KV Cache的顯存占用,以此來(lái)達(dá)到提升推理速度,但是會(huì)帶來(lái)精度上的損失。MQA被大量應(yīng)用于LLM中,如ChatGLM2。

一文詳解MHA、GQA、MQA原理-AI.x社區(qū)

左 - 多頭注意力,中 - 多查詢(xún)注意力,右 - 將現(xiàn)有的 MHA 檢查點(diǎn)轉(zhuǎn)換為 MQA

如何將現(xiàn)有的預(yù)訓(xùn)練多頭注意力模型轉(zhuǎn)換為多查詢(xún)注意力模型 (MQA)?從現(xiàn)有的多頭模型創(chuàng)建多查詢(xún)注意力模型涉及兩個(gè)步驟:模型結(jié)構(gòu)的轉(zhuǎn)換和隨后的預(yù)訓(xùn)練。

  • 模型結(jié)構(gòu)的轉(zhuǎn)換:此步驟將多頭模型的結(jié)構(gòu)轉(zhuǎn)換為多查詢(xún)模型。它是通過(guò)將原始模型的多個(gè)頭的鍵和值的投影矩陣(線(xiàn)性層)合并(均值池化)為鍵和值的單個(gè)投影矩陣來(lái)實(shí)現(xiàn)的。這種均值池化方法被發(fā)現(xiàn)比選擇現(xiàn)有鍵和值頭之一或從頭開(kāi)始初始化新的鍵和值頭更有效。生成的結(jié)構(gòu)具有合并的鍵和值投影,這是多查詢(xún)模型的特征。
  • 對(duì)轉(zhuǎn)換后的模型進(jìn)行預(yù)訓(xùn)練:結(jié)構(gòu)轉(zhuǎn)換后,模型將接受額外的訓(xùn)練。此訓(xùn)練不像原始模型訓(xùn)練那樣廣泛;它只是原始模型訓(xùn)練步驟的一小部分(表示為 α)。此預(yù)訓(xùn)練階段的目的是讓模型根據(jù)其新的簡(jiǎn)化注意力機(jī)制調(diào)整和優(yōu)化其性能。訓(xùn)練遵循與原始相同的方法,確保學(xué)習(xí)動(dòng)態(tài)的一致性。

代碼實(shí)現(xiàn)

import torch
import torch.nn as nn


class MultiQuerySelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiQuerySelfAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.wq = nn.Linear(embed_dim, embed_dim)

        # MHA
        # self.wk = nn.Linear(embed_dim, embed_dim)
        # self.wv = nn.Linear(embed_dim, embed_dim)

        # MQA
        self.wk = nn.Linear(embed_dim, self.head_dim)
        self.wv = nn.Linear(embed_dim, self.head_dim)
        self.wo = nn.Linear(embed_dim, embed_dim)

    def q_h_split(self, hidden, head_num=None):
        batch_size, seq_len = hidden.size()[:2]
        # q拆分多頭
        if head_num == None:
            x = hidden.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
            return x
        else:
            # 這是MQA: 需要拆分k和v,這里面的head_num =1 的
            # 最終返回維度(batch_size, 1, seq_len, head_dim)
            return hidden.view(batch_size, seq_len, head_num, self.head_dim).transpose(1, 2)

    def forward(self, hidden_states, mask=None):
        batch_size = hidden_states.size(0)

        # 線(xiàn)性變換
        q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)

        # 多頭切分
        # 這是MHA的
        # q, k ,v  = self.split(q), self.split(k), self.split(v)
        # 這是MQA的
        q, k, v = self.q_h_split(q), self.q_h_split(k, 1), self.q_h_split(v, 1)

        # 注意力計(jì)算
        scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        print("scores:", scores.shape)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attention = torch.softmax(scores, dim=-1)
        output = torch.matmul(attention, v)

        # 多頭合并
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
        # 線(xiàn)性變換
        output = self.wo(output)
        return output


x = torch.rand(3, 12, 512)
atten = MultiQuerySelfAttention(512, 8)
y = atten(x)
print(y.shape)

GQA(分組查詢(xún)注意力)

一文詳解MHA、GQA、MQA原理-AI.x社區(qū)

雖然MQA方式大幅減小了參數(shù)數(shù)量,但是,帶來(lái)推理加速的同時(shí)會(huì)造成模型性能損失,且在訓(xùn)練過(guò)程使得模型變得不穩(wěn)定(復(fù)雜度的降低可能會(huì)導(dǎo)致質(zhì)量下降和訓(xùn)練不穩(wěn)定),因此在此基礎(chǔ)上提出了GQA,它將Query進(jìn)行分組,每個(gè)組內(nèi)共享一組Key、Value。(GQA在LLaMA-2 和 Mistral7B得到應(yīng)用)

GQA 的數(shù)學(xué)原理

分組:在 GQA 中,傳統(tǒng)多頭模型中的查詢(xún)頭 (Q) 被分成 G 組。每組分配一個(gè)鍵 (K) 和值 (V) 頭。此配置表示為 GQA-G,其中 G 表示組數(shù)。

GQA 的特殊情況

  • GQA-1 = MQA:只有一個(gè)組(G = 1),GQA 等同于 MQA,因?yàn)樗胁樵?xún)頭只有一個(gè)鍵和值頭。
  • GQA-H = MHA:當(dāng)組數(shù)等于頭數(shù)(G = H)時(shí),GQA 退化為 MHA,每個(gè)查詢(xún)頭都有其唯一的鍵和值頭。

對(duì)每個(gè)組中原始頭部的鍵和值投影矩陣進(jìn)行均值池化,以將MHA模型轉(zhuǎn)換為 GQA 模型。此技術(shù)對(duì)組中每個(gè)頭部的投影矩陣進(jìn)行平均,從而為該組生成單個(gè)鍵和值投影。

通過(guò)利用 GQA,該模型在 MHA 質(zhì)量和 MQA 速度之間保持平衡。由于鍵值對(duì)較少,內(nèi)存帶寬和數(shù)據(jù)加載需求被最小化。G 的選擇代表了一種權(quán)衡:更多的組(更接近 MHA)可帶來(lái)更高的質(zhì)量但性能較慢,而更少的組(接近 MQA)可提高速度但有犧牲質(zhì)量的風(fēng)險(xiǎn)。此外,隨著模型規(guī)模的擴(kuò)大,GQA 允許內(nèi)存帶寬和模型容量按比例減少,與模型規(guī)模相對(duì)應(yīng)。相比之下,對(duì)于更大的模型,在 MQA 中減少到單個(gè)鍵和值頭可能會(huì)過(guò)于嚴(yán)重。

代碼實(shí)現(xiàn)

import torch
import torch.nn as nn


class GroupedQueryAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(GroupedQueryAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.wq = nn.Linear(embed_dim, embed_dim)

        # 這是MHA的
        # self.wk = nn.Linear(embed_dim, embed_dim)
        # self.wv = nn.Linear(embed_dim, embed_dim)

        # 這是MQA的
        # self.wk = nn.Linear(embed_dim, self.head_dim)
        # self.wv = nn.Linear(embed_dim, self.head_dim)

        # 這是GQA的
        self.group_num = 4  # 這是4個(gè)組
        self.wk = nn.Linear(embed_dim, self.group_num * self.head_dim)
        self.wv = nn.Linear(embed_dim, self.group_num * self.head_dim)

        self.wo = nn.Linear(embed_dim, embed_dim)

    def split(self, hidden, group_num=None):
        batch_size, seq_len = hidden.size()[:2]
        # q需要拆分多頭
        if group_num == None:
            x = hidden.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
            return x
        else:
            # 這是kv需要拆分的多頭
            x = hidden.view(batch_size, seq_len, group_num, self.head_dim).transpose(1, 2)
            x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len,
                                           self.head_dim).reshape(batch_size, self.num_heads, seq_len, self.head_dim)
            return x

    def forward(self, hidden_states, mask=None):
        batch_size = hidden_states.size(0)

        # 線(xiàn)性變換
        q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)

        # 多頭切分
        # 這是MHA的
        # q, k ,v  = self.split(q), self.split(k), self.split(v)
        # 這是MQA的
        # q, k ,v  = self.split(q), self.split(k, 1), self.split(v, 1)
        # 這是GQA的
        q, k, v = self.split(q), self.split(k, self.group_num), self.split(v, self.group_num)

        # 注意力計(jì)算
        scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        print("scores:", scores.shape)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attention = torch.softmax(scores, dim=-1)
        output = torch.matmul(attention, v)

        # 合并多頭
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)

        # 線(xiàn)性變換
        output = self.wo(output)

        return output


x = torch.ones(3, 12, 512)
atten = GroupedQueryAttention(512, 8)
y = atten(x)
print(y.shape)

參考文獻(xiàn)

  • GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints,https://arxiv.org/pdf/2305.13245
  • Attention Is All You Need,https://arxiv.org/pdf/1706.03762
  • Fast Transformer Decoding: One Write-Head is All You Need,https://arxiv.org/pdf/1911.02150v1


本文轉(zhuǎn)載自公眾號(hào)大模型自然語(yǔ)言處理  作者:余俊暉

原文鏈接:??https://mp.weixin.qq.com/s/72fGm-qYV5DdCGz-bNjuXQ??


?著作權(quán)歸作者所有,如需轉(zhuǎn)載,請(qǐng)注明出處,否則將追究法律責(zé)任
已于2024-11-28 18:52:23修改
收藏
回復(fù)
舉報(bào)
回復(fù)
相關(guān)推薦