一文詳解MHA、GQA、MQA原理 原創(chuàng)
前言
本文回顧一下MHA、GQA、MQA,詳細(xì)解讀下MHA、GQA、MQA這三種常見(jiàn)注意力機(jī)制的原理。
圖1 MHA、GQA、MQA一覽
self-attention
self-attention
在自注意力機(jī)制中,輸入通常是一個(gè)統(tǒng)一的輸入矩陣,而這個(gè)矩陣后續(xù)會(huì)通過(guò)乘以不同的權(quán)重矩陣來(lái)轉(zhuǎn)換成三個(gè)不同的向量集合:查詢(xún)向量Q、鍵向量K和值向量V。這三組向量是通過(guò)線(xiàn)性變換方式生成:
1.查詢(xún)向量 (Q): Q=XWQ
2.鍵向量 (K): K=XWK
3.值向量 (V): V=XWV
WQ ,WK和WV是可學(xué)習(xí)的權(quán)重矩陣,分別對(duì)應(yīng)于查詢(xún)、鍵和值。這些矩陣的維度取決于模型的設(shè)計(jì),通常它們的輸出維度(列數(shù)) 是預(yù)先定義的,以滿(mǎn)足特定的模型架構(gòu)要求。 在Transformer模型中,使用不同的權(quán)重矩陣WQ ,WK和WV來(lái)分別生成查詢(xún)向量Q、鍵向量K和值向量V的目的是為了允許模型在不同的表示空間中學(xué)習(xí)和抽取特征。這樣做增加了模型的靈活性和表達(dá)能力,允許模型分別優(yōu)化用于匹配(Q 和K)和用于輸出信息合成(V)的表示。
在自注意力和多頭注意力機(jī)制中,使用
作為縮放因子進(jìn)行縮放操作是為了防止在計(jì)算點(diǎn)積時(shí)由于維度較高導(dǎo)致的數(shù)值穩(wěn)定性問(wèn)題。這里的dk是鍵向量的維度。如果不進(jìn)行縮放,當(dāng)dk較大時(shí),點(diǎn)積的結(jié)果可能會(huì)變得非常大,這會(huì)導(dǎo)致在應(yīng)用softmax函數(shù)時(shí)產(chǎn)生的梯度非常小。因?yàn)閟oftmax函數(shù)是通過(guò)指數(shù)函數(shù)計(jì)算的,大的輸入值會(huì)使得部分輸出接近于1,而其他接近于0,從而導(dǎo)致梯度消失,這會(huì)在反向傳播過(guò)程中造成梯度非常小,使得學(xué)習(xí)變得非常緩慢。
通過(guò)點(diǎn)積結(jié)果除以
,可以調(diào)整這些值的范圍,使得它們不會(huì)太大。這樣,softmax的輸入在一個(gè)合適的范圍內(nèi),有助于避免極端的指數(shù)運(yùn)算結(jié)果,從而保持?jǐn)?shù)值穩(wěn)定性和更有效的梯度流。這個(gè)操作確保了即使在dk很大的情況下, 注意力機(jī)制也能穩(wěn)定并有效地學(xué)習(xí)。
代碼實(shí)現(xiàn)
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, seq_length):
super(SelfAttention, self).__init__()
self.input_size = seq_length
# 定義三個(gè)權(quán)重矩陣:Wq、Wk、Wv
self.Wq = nn.Linear(seq_length, seq_length) # 線(xiàn)性變換
self.Wk = nn.Linear(seq_length, seq_length)
self.Wv = nn.Linear(seq_length, seq_length)
def forward(self, input):
# 計(jì)算Q,K,V 三個(gè)矩陣
q = self.Wq(input)
k = self.Wk(input)
v = self.Wv(input)
# 計(jì)算QK^T,即向量之間的相關(guān)度
attention_scores = torch.matmul(q, k.transpose(-1, -2)) / torch.sqrt(torch.tensor(float(self.input_size)))
# 計(jì)算向量權(quán)重,softmax歸一化
attention_weight = F.softmax(attention_scores, dim=-1)
# 計(jì)算輸出
output = torch.matmul(attention_weight, v)
return output
x = torch.randn(2, 3, 4)
Self_Attention = SelfAttention(4) # 傳入輸入向量的維度
output = Self_Attention(x)
print(output.shape)
MHA(多頭注意力)
Transformer 編碼器塊內(nèi)的縮放點(diǎn)積注意力機(jī)制和多頭注意力機(jī)制
MHA計(jì)算過(guò)程
代碼實(shí)現(xiàn)
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.wq = nn.Linear(embed_dim, embed_dim)
self.wk = nn.Linear(embed_dim, embed_dim)
self.wv = nn.Linear(embed_dim, embed_dim)
self.wo = nn.Linear(embed_dim, embed_dim)
def mh_split(self, hidden):
batch_size = hidden.shape[0]
x = hidden.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
return x
def forward(self, hidden_states, mask=None):
batch_size = hidden_states.size(0)
# 線(xiàn)性變換
q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)
# 多頭切分
q, k, v = self.mh_split(q), self.mh_split(k), self.mh_split(v)
# 注意力計(jì)算
scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention = torch.softmax(scores, dim=-1)
output = torch.matmul(attention, v)
# 拼接多頭
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
# 線(xiàn)性變換
output = self.wo(output)
return output
x = torch.rand(2, 3, 36)
print(x)
output = MultiHeadAttention(36, 6)
y = output(x)
print(y.shape)
MHA 能夠理解輸入不同部分之間的關(guān)系。然而,這種復(fù)雜性是有代價(jià)的——對(duì)內(nèi)存帶寬的需求很大,尤其是在解碼器推理期間。主要問(wèn)題的關(guān)鍵在于內(nèi)存開(kāi)銷(xiāo)。在自回歸模型中,每個(gè)解碼步驟都需要加載解碼器權(quán)重以及所有注意鍵和值。這個(gè)過(guò)程不僅計(jì)算量大,而且內(nèi)存帶寬也大。隨著模型規(guī)模的擴(kuò)大,這種開(kāi)銷(xiāo)也會(huì)增加,使得擴(kuò)展變得越來(lái)越艱巨。
因此,多查詢(xún)注意 (MQA) 應(yīng)運(yùn)而生,成為緩解這一瓶頸的解決方案。其理念簡(jiǎn)單而有效:使用多個(gè)查詢(xún)頭,但只使用一個(gè)鍵和值頭。這種方法顯著減少了內(nèi)存負(fù)載,提高了推理速度。
MQA(多查詢(xún)注意力)
圖2 MHA和MQA的差別
MQA是MHA的一種變體,也是用于自回歸解碼的一種注意力機(jī)制。,圖1、圖2很形象的描繪了MHA和MQA的對(duì)比,與MHA 不同的是,MQA 讓所有的Head之間共享同樣的一份 K 和 V 矩陣(意味K和V的計(jì)算唯一),只讓 Q 保留了原始多頭的性質(zhì)(每個(gè)Head存在不同的轉(zhuǎn)換),從而大大減少 K 和 V 矩陣的參數(shù)量以及KV Cache的顯存占用,以此來(lái)達(dá)到提升推理速度,但是會(huì)帶來(lái)精度上的損失。MQA被大量應(yīng)用于LLM中,如ChatGLM2。
左 - 多頭注意力,中 - 多查詢(xún)注意力,右 - 將現(xiàn)有的 MHA 檢查點(diǎn)轉(zhuǎn)換為 MQA
如何將現(xiàn)有的預(yù)訓(xùn)練多頭注意力模型轉(zhuǎn)換為多查詢(xún)注意力模型 (MQA)?從現(xiàn)有的多頭模型創(chuàng)建多查詢(xún)注意力模型涉及兩個(gè)步驟:模型結(jié)構(gòu)的轉(zhuǎn)換和隨后的預(yù)訓(xùn)練。
- 模型結(jié)構(gòu)的轉(zhuǎn)換:此步驟將多頭模型的結(jié)構(gòu)轉(zhuǎn)換為多查詢(xún)模型。它是通過(guò)將原始模型的多個(gè)頭的鍵和值的投影矩陣(線(xiàn)性層)合并(均值池化)為鍵和值的單個(gè)投影矩陣來(lái)實(shí)現(xiàn)的。這種均值池化方法被發(fā)現(xiàn)比選擇現(xiàn)有鍵和值頭之一或從頭開(kāi)始初始化新的鍵和值頭更有效。生成的結(jié)構(gòu)具有合并的鍵和值投影,這是多查詢(xún)模型的特征。
- 對(duì)轉(zhuǎn)換后的模型進(jìn)行預(yù)訓(xùn)練:結(jié)構(gòu)轉(zhuǎn)換后,模型將接受額外的訓(xùn)練。此訓(xùn)練不像原始模型訓(xùn)練那樣廣泛;它只是原始模型訓(xùn)練步驟的一小部分(表示為 α)。此預(yù)訓(xùn)練階段的目的是讓模型根據(jù)其新的簡(jiǎn)化注意力機(jī)制調(diào)整和優(yōu)化其性能。訓(xùn)練遵循與原始相同的方法,確保學(xué)習(xí)動(dòng)態(tài)的一致性。
代碼實(shí)現(xiàn)
import torch
import torch.nn as nn
class MultiQuerySelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super(MultiQuerySelfAttention, self).__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.wq = nn.Linear(embed_dim, embed_dim)
# MHA
# self.wk = nn.Linear(embed_dim, embed_dim)
# self.wv = nn.Linear(embed_dim, embed_dim)
# MQA
self.wk = nn.Linear(embed_dim, self.head_dim)
self.wv = nn.Linear(embed_dim, self.head_dim)
self.wo = nn.Linear(embed_dim, embed_dim)
def q_h_split(self, hidden, head_num=None):
batch_size, seq_len = hidden.size()[:2]
# q拆分多頭
if head_num == None:
x = hidden.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
return x
else:
# 這是MQA: 需要拆分k和v,這里面的head_num =1 的
# 最終返回維度(batch_size, 1, seq_len, head_dim)
return hidden.view(batch_size, seq_len, head_num, self.head_dim).transpose(1, 2)
def forward(self, hidden_states, mask=None):
batch_size = hidden_states.size(0)
# 線(xiàn)性變換
q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)
# 多頭切分
# 這是MHA的
# q, k ,v = self.split(q), self.split(k), self.split(v)
# 這是MQA的
q, k, v = self.q_h_split(q), self.q_h_split(k, 1), self.q_h_split(v, 1)
# 注意力計(jì)算
scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
print("scores:", scores.shape)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention = torch.softmax(scores, dim=-1)
output = torch.matmul(attention, v)
# 多頭合并
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
# 線(xiàn)性變換
output = self.wo(output)
return output
x = torch.rand(3, 12, 512)
atten = MultiQuerySelfAttention(512, 8)
y = atten(x)
print(y.shape)
GQA(分組查詢(xún)注意力)
雖然MQA方式大幅減小了參數(shù)數(shù)量,但是,帶來(lái)推理加速的同時(shí)會(huì)造成模型性能損失,且在訓(xùn)練過(guò)程使得模型變得不穩(wěn)定(復(fù)雜度的降低可能會(huì)導(dǎo)致質(zhì)量下降和訓(xùn)練不穩(wěn)定),因此在此基礎(chǔ)上提出了GQA,它將Query進(jìn)行分組,每個(gè)組內(nèi)共享一組Key、Value。(GQA在LLaMA-2 和 Mistral7B得到應(yīng)用)
GQA 的數(shù)學(xué)原理:
分組:在 GQA 中,傳統(tǒng)多頭模型中的查詢(xún)頭 (Q) 被分成 G 組。每組分配一個(gè)鍵 (K) 和值 (V) 頭。此配置表示為 GQA-G,其中 G 表示組數(shù)。
GQA 的特殊情況:
- GQA-1 = MQA:只有一個(gè)組(G = 1),GQA 等同于 MQA,因?yàn)樗胁樵?xún)頭只有一個(gè)鍵和值頭。
- GQA-H = MHA:當(dāng)組數(shù)等于頭數(shù)(G = H)時(shí),GQA 退化為 MHA,每個(gè)查詢(xún)頭都有其唯一的鍵和值頭。
對(duì)每個(gè)組中原始頭部的鍵和值投影矩陣進(jìn)行均值池化,以將MHA模型轉(zhuǎn)換為 GQA 模型。此技術(shù)對(duì)組中每個(gè)頭部的投影矩陣進(jìn)行平均,從而為該組生成單個(gè)鍵和值投影。
通過(guò)利用 GQA,該模型在 MHA 質(zhì)量和 MQA 速度之間保持平衡。由于鍵值對(duì)較少,內(nèi)存帶寬和數(shù)據(jù)加載需求被最小化。G 的選擇代表了一種權(quán)衡:更多的組(更接近 MHA)可帶來(lái)更高的質(zhì)量但性能較慢,而更少的組(接近 MQA)可提高速度但有犧牲質(zhì)量的風(fēng)險(xiǎn)。此外,隨著模型規(guī)模的擴(kuò)大,GQA 允許內(nèi)存帶寬和模型容量按比例減少,與模型規(guī)模相對(duì)應(yīng)。相比之下,對(duì)于更大的模型,在 MQA 中減少到單個(gè)鍵和值頭可能會(huì)過(guò)于嚴(yán)重。
代碼實(shí)現(xiàn)
import torch
import torch.nn as nn
class GroupedQueryAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super(GroupedQueryAttention, self).__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.wq = nn.Linear(embed_dim, embed_dim)
# 這是MHA的
# self.wk = nn.Linear(embed_dim, embed_dim)
# self.wv = nn.Linear(embed_dim, embed_dim)
# 這是MQA的
# self.wk = nn.Linear(embed_dim, self.head_dim)
# self.wv = nn.Linear(embed_dim, self.head_dim)
# 這是GQA的
self.group_num = 4 # 這是4個(gè)組
self.wk = nn.Linear(embed_dim, self.group_num * self.head_dim)
self.wv = nn.Linear(embed_dim, self.group_num * self.head_dim)
self.wo = nn.Linear(embed_dim, embed_dim)
def split(self, hidden, group_num=None):
batch_size, seq_len = hidden.size()[:2]
# q需要拆分多頭
if group_num == None:
x = hidden.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
return x
else:
# 這是kv需要拆分的多頭
x = hidden.view(batch_size, seq_len, group_num, self.head_dim).transpose(1, 2)
x = x[:, :, None, :, :].expand(batch_size, group_num, self.num_heads // group_num, seq_len,
self.head_dim).reshape(batch_size, self.num_heads, seq_len, self.head_dim)
return x
def forward(self, hidden_states, mask=None):
batch_size = hidden_states.size(0)
# 線(xiàn)性變換
q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)
# 多頭切分
# 這是MHA的
# q, k ,v = self.split(q), self.split(k), self.split(v)
# 這是MQA的
# q, k ,v = self.split(q), self.split(k, 1), self.split(v, 1)
# 這是GQA的
q, k, v = self.split(q), self.split(k, self.group_num), self.split(v, self.group_num)
# 注意力計(jì)算
scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
print("scores:", scores.shape)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention = torch.softmax(scores, dim=-1)
output = torch.matmul(attention, v)
# 合并多頭
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
# 線(xiàn)性變換
output = self.wo(output)
return output
x = torch.ones(3, 12, 512)
atten = GroupedQueryAttention(512, 8)
y = atten(x)
print(y.shape)
參考文獻(xiàn)
- GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints,https://arxiv.org/pdf/2305.13245
- Attention Is All You Need,https://arxiv.org/pdf/1706.03762
- Fast Transformer Decoding: One Write-Head is All You Need,https://arxiv.org/pdf/1911.02150v1
本文轉(zhuǎn)載自公眾號(hào)大模型自然語(yǔ)言處理 作者:余俊暉
原文鏈接:??https://mp.weixin.qq.com/s/72fGm-qYV5DdCGz-bNjuXQ??
