注意力機制的變體之MLA 原創(chuàng)
?本文介紹注意力機制的變體-MLA。
MLA(Multi-head Latent Attention),是由杭州深度求索人工智能在DeepSeekV2提出的一種注意力機制變體。MLA主要旨在解決推理過程中由于attention機制中KV Cache占用過多內(nèi)存而導致的性能瓶頸問題。為此,MLA引入了低秩KV壓縮技術(shù),有效減少了KV Cache的大小,從而緩解了這一問題。
有興趣小伙伴可以看官方技術(shù)報告的介紹:??https://arxiv.org/pdf/2405.04434v2??
原理介紹
上圖為MHA、GQA、MQA、MLA的原理對比圖。從上圖可知傳統(tǒng)Transformer采用MHA,但KV Cache在推理過程中可能成為性能瓶頸。MQA和GQA雖然在一定程度上可以減少KV Cache的占用,但其效果通常不如MHA。MLA通過低秩的Key-Value聯(lián)合壓縮技術(shù),不僅實現(xiàn)了比MHA更優(yōu)的效果,還大幅減少了所需的KV Cache大小。
具體來說,MLA通過低秩聯(lián)合壓縮key和value來減少kv cache。從注意力機制的步驟來分析:
- 通過輸入x乘以不同的矩陣參數(shù)Wq、Wk、Wv得到不同的QKV向量
- 在轉(zhuǎn)換到QKV向量時候,將x乘以一個低秩矩陣,得到低階矩陣表示
- 再通過一個高階矩陣來恢復原來的特征空間。由于矩陣是模型的權(quán)重參數(shù)已經(jīng)保存,所以只需要保存一個低秩的潛層特征就可以恢復成KV,而不是像之前需要同時緩存KV。
代碼實現(xiàn)
bsz, q_len, _ = hidden_states.size()
# 計算壓縮后的Q,再還原成高維
# [B, q_len, hidden_size]
# 即[B, q_len, num_head * q_head_dim]
q = self.w_uq(self.q_a_layernorm(self.w_dq(hidden_states)))
# [B, num_head, q_len, q_head_dim]
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
# 包含當前位置可用上下文的長度
kv_seq_len = q.size(-2)
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# 得到當前壓縮后的kv, c_t^{kv}
# [B, q_len, d_c]
compressed_kv = self.w_dkv(hidden_states)
# 將當前位置之前的壓縮后的kv拼接到前面
if past_key_value is not None:
# 得到的應該是[B, kv_seq_len, d_c], c^{kv}
compressed_kv = past_key_value.update(compressed_kv)
# 計算得到k^C和v^C
# [B, num_head, kv_seq_len, q_head_dim]
k = self.w_uk(compressed_kv).view(bsz, -1, self.num_heads, self.q_head_dim).transpose(1, 2)
v = self.w_uv(compressed_kv).view(bsz, -1, self.num_heads, self.q_head_dim).transpose(1, 2)
# 注意力權(quán)重
# [B, num_head, q_len, kv_seq_len]
attn_weights = (
torch.matmul(q, k.transpose(2, 3)) * self.softmax_scale
)
...
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training
)
# [B, num_head, q_len, q_head_dim]
attn_output = torch.matmul(attn_weights, v)
...
以上為MLA的核心部分代碼實現(xiàn),里面有相應的代碼注釋。
本文轉(zhuǎn)載自公眾號瓦力算法學研所,作者:喜歡瓦力的卷卷
?著作權(quán)歸作者所有,如需轉(zhuǎn)載,請注明出處,否則將追究法律責任
贊
收藏
回復
分享
微博
QQ
微信
舉報

回復
相關(guān)推薦