自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

注意力機制的變體之MLA 原創(chuàng)

發(fā)布于 2024-10-15 13:54
瀏覽
0收藏

?本文介紹注意力機制的變體-MLA。

MLA(Multi-head Latent Attention),是由杭州深度求索人工智能在DeepSeekV2提出的一種注意力機制變體。MLA主要旨在解決推理過程中由于attention機制中KV Cache占用過多內(nèi)存而導致的性能瓶頸問題。為此,MLA引入了低秩KV壓縮技術(shù),有效減少了KV Cache的大小,從而緩解了這一問題。

有興趣小伙伴可以看官方技術(shù)報告的介紹:??https://arxiv.org/pdf/2405.04434v2??

原理介紹

注意力機制的變體之MLA-AI.x社區(qū)

上圖為MHA、GQA、MQA、MLA的原理對比圖。從上圖可知傳統(tǒng)Transformer采用MHA,但KV Cache在推理過程中可能成為性能瓶頸。MQA和GQA雖然在一定程度上可以減少KV Cache的占用,但其效果通常不如MHA。MLA通過低秩的Key-Value聯(lián)合壓縮技術(shù),不僅實現(xiàn)了比MHA更優(yōu)的效果,還大幅減少了所需的KV Cache大小。

具體來說,MLA通過低秩聯(lián)合壓縮key和value來減少kv cache。從注意力機制的步驟來分析:

  • 通過輸入x乘以不同的矩陣參數(shù)Wq、Wk、Wv得到不同的QKV向量
  • 在轉(zhuǎn)換到QKV向量時候,將x乘以一個低秩矩陣,得到低階矩陣表示
  • 再通過一個高階矩陣來恢復原來的特征空間。由于矩陣是模型的權(quán)重參數(shù)已經(jīng)保存,所以只需要保存一個低秩的潛層特征就可以恢復成KV,而不是像之前需要同時緩存KV。

代碼實現(xiàn)


bsz, q_len, _ = hidden_states.size()
        
# 計算壓縮后的Q,再還原成高維
# [B, q_len, hidden_size]
# 即[B, q_len, num_head * q_head_dim]
q = self.w_uq(self.q_a_layernorm(self.w_dq(hidden_states)))
# [B, num_head, q_len, q_head_dim]
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
# 包含當前位置可用上下文的長度
kv_seq_len = q.size(-2)
if past_key_value is not None:
    if self.layer_idx is None:
        raise ValueError(
            f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
            "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
            "with a layer index."
        )
    kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# 得到當前壓縮后的kv, c_t^{kv}
# [B, q_len, d_c]
compressed_kv = self.w_dkv(hidden_states)

# 將當前位置之前的壓縮后的kv拼接到前面
if past_key_value is not None:
    # 得到的應該是[B, kv_seq_len, d_c], c^{kv}
    compressed_kv = past_key_value.update(compressed_kv)
# 計算得到k^C和v^C
# [B, num_head, kv_seq_len, q_head_dim]
k = self.w_uk(compressed_kv).view(bsz, -1, self.num_heads, self.q_head_dim).transpose(1, 2)
v = self.w_uv(compressed_kv).view(bsz, -1, self.num_heads, self.q_head_dim).transpose(1, 2)

# 注意力權(quán)重
# [B, num_head, q_len, kv_seq_len]
attn_weights = (
    torch.matmul(q, k.transpose(2, 3)) * self.softmax_scale
)
...
attn_weights = nn.functional.softmax(
    attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
    attn_weights, p=self.attention_dropout, training=self.training
)
# [B, num_head, q_len, q_head_dim]
attn_output = torch.matmul(attn_weights, v)
...

以上為MLA的核心部分代碼實現(xiàn),里面有相應的代碼注釋。


本文轉(zhuǎn)載自公眾號瓦力算法學研所,作者:喜歡瓦力的卷卷

原文鏈接:??https://mp.weixin.qq.com/s/dWZk8TBY89re207ZL3GjfA???

?著作權(quán)歸作者所有,如需轉(zhuǎn)載,請注明出處,否則將追究法律責任
收藏
回復
舉報
回復
相關(guān)推薦