自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

機(jī)器學(xué)習(xí)|從0開始大模型之位置編碼

發(fā)布于 2025-1-20 12:07
瀏覽
0收藏

1、什么是位置編碼

在語言中,一句話是由詞組成的,詞與詞之間是有順序的,如果順序亂了或者重排,其實(shí)整個(gè)句子的意思就變了,所以詞與詞之間是有順序的。在循環(huán)神經(jīng)網(wǎng)絡(luò)中,序列與序列之間也是有順序的,所以循環(huán)神經(jīng)網(wǎng)絡(luò)中,序列與序列之間也是有順序的,不需要處理這種問題。但是在Transformer中,每個(gè)詞是獨(dú)立的,所以需要將詞的位置信息添加到模型中,讓模型維護(hù)順序關(guān)系。

機(jī)器學(xué)習(xí)|從0開始大模型之位置編碼-AI.x社區(qū)

位置編碼

位置編碼就是將hello world! 的token和位置關(guān)系通過向量表示出來,作為訓(xùn)練的輸入數(shù)據(jù),如上圖,位置編碼最終會變成:

[
    [P00, P01, P02 ... P0d],
    [P10, P11, P12 ... P1d],
    [P20, P21, P22 ... P2d],
]

2、計(jì)算位置編碼

計(jì)算位置編碼有多種方式:固定位置編碼,相對位置編碼,絕對位置編碼,其中Transformer的作者設(shè)計(jì)了一種三角函數(shù)位置編碼方式,通過三角函數(shù)計(jì)算輸出位置編碼向量。

為什么三角函數(shù)可以作為計(jì)算位置編碼的函數(shù)?

  • 首先我們來回顧一下三角函數(shù)的基本性質(zhì):函數(shù)具有周期性,取值范圍是[-1, 1]。

機(jī)器學(xué)習(xí)|從0開始大模型之位置編碼-AI.x社區(qū)

sin

  • 其次,如果用絕對位置編碼計(jì)算最大序列為3的位置(0-7),二進(jìn)制表示如下:

[
    [0, 0, 0], 
    [0, 0, 1], 
    [0, 1, 0], 
    [0, 1, 1], 
    [1, 0, 0], 
    [1, 0, 1], 
    [1, 1, 0], 
    [1, 1, 1]
]

從上可以表示看出,較高比特位的交替頻率低于較低比特位,存在周期性bit位變化,符合三角函數(shù)的周期性,而且三角函數(shù)的取值范圍是[-1, 1],輸出浮點(diǎn)數(shù),并且數(shù)據(jù)連續(xù),比直接使用二進(jìn)制更節(jié)省空間。

3、Transformer中的位置編碼層

假設(shè)你有一個(gè)長度為L的輸入序列,要計(jì)算第K個(gè)元素的位置編碼,位置編碼由不同頻率的正弦和余弦函數(shù)給出:

機(jī)器學(xué)習(xí)|從0開始大模型之位置編碼-AI.x社區(qū)

函數(shù)

  • k:詞序列中的第K個(gè)元素
  • d:詞向量維度,比如512,1024,8K等
  • P(k, i):位置函數(shù),輸出位置編碼向量
  • n:定義的標(biāo)量,Attention Is All You Need 的作者設(shè)置為 10,000
  • i:映射到列索引,范圍是0~d/2(由于輸入是2i表示,如果用i表示,范圍可以是0~d)

按照上述Hello world!的例子,計(jì)算位置編碼結(jié)果如下:

機(jī)器學(xué)習(xí)|從0開始大模型之位置編碼-AI.x社區(qū)

計(jì)算結(jié)果

那么用代碼實(shí)現(xiàn)一個(gè)簡化版本的位置編碼:

import numpy as np

def getPositionEncoding(seq_len, d, n=10000):
    P = np.zeros((seq_len, d))
    for k in range(seq_len):
        for i in np.arange(int(d/2)):
            denominator = np.power(n, 2*i/d)
            P[k, 2*i] = np.sin(k/denominator)
            P[k, 2*i+1] = np.cos(k/denominator)
    return P

P = getPositionEncoding(seq_len=3, d=3, n=100)
print(P)

# 輸出結(jié)果:
[[ 0.          1.          0.        ]
 [ 0.84147098  0.54030231  0.        ]
 [ 0.90929743 -0.41614684  0.        ]]

4、大模型訓(xùn)練中的位置編碼代碼

在我們從0訓(xùn)練大模型中,其位置編碼的實(shí)現(xiàn)如下:

def precompute_pos_cis(dim: int, seq_len: int, theta: float = 10000.0):
    """預(yù)計(jì)算相對位置編碼的復(fù)數(shù)形式,用于旋轉(zhuǎn)位置編碼(RoPE)。"""
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # 計(jì)算頻率
    t = torch.arange(seq_len, device=freqs.device)  # 創(chuàng)建時(shí)間步長
    freqs = torch.outer(t, freqs).float()  # 計(jì)算頻率的外積
    pos_cis = torch.polar(torch.ones_like(freqs), freqs)  # 生成復(fù)數(shù)形式的頻率
    return pos_cis # 返回預(yù)計(jì)算的復(fù)數(shù)位置編碼

def apply_rotary_emb(xq, xk, pos_cis):
    """應(yīng)用旋轉(zhuǎn)位置編碼到查詢和鍵。"""
    def unite_shape(pos_cis, x):
        """調(diào)整位置編碼的形狀以匹配輸入張量的形狀。"""
        ndim = x.ndim # 獲取輸入的維度
        assert 0 <= 1 < ndim # 確保維度有效
        assert pos_cis.shape == (x.shape[1], x.shape[-1])  # 確保位置編碼形狀匹配
        shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # 生成新形狀
        return pos_cis.reshape(*shape) # 調(diào)整位置編碼的形狀

    # 將查詢和鍵轉(zhuǎn)換為復(fù)數(shù)形式
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    pos_cis = unite_shape(pos_cis, xq_) # 調(diào)整位置編碼形狀
    xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3) # 應(yīng)用位置編碼并轉(zhuǎn)換回實(shí)數(shù)
    xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3) # 同上
    return xq_out.type_as(xq), xk_out.type_as(xk)         # 返回與輸入類型一致的輸出

這里使用的是RoPE旋轉(zhuǎn)位置編碼,和相對位置編碼相比,RoPE 具有更好的外推性,Meta 的 LLAMA 和 清華的 ChatGLM 都使用該編碼,目前是大模型相對位置編碼中應(yīng)用最廣的方式之一,具體原理由于篇幅原因就不講了,可以看看這篇文章:https://cloud.tencent.com/developer/article/2327751。

參考

(1)http://www.bimant.com/blog/transformer-positional-encoding-illustration/(2)https://hub.baai.ac.cn/view/29979

本文轉(zhuǎn)載自 ??周末程序猿??,作者: 周末程序猿

收藏
回復(fù)
舉報(bào)
回復(fù)
相關(guān)推薦