機(jī)器學(xué)習(xí)|從0開始大模型之位置編碼
1、什么是位置編碼
在語言中,一句話是由詞組成的,詞與詞之間是有順序的,如果順序亂了或者重排,其實(shí)整個(gè)句子的意思就變了,所以詞與詞之間是有順序的。在循環(huán)神經(jīng)網(wǎng)絡(luò)中,序列與序列之間也是有順序的,所以循環(huán)神經(jīng)網(wǎng)絡(luò)中,序列與序列之間也是有順序的,不需要處理這種問題。但是在Transformer中,每個(gè)詞是獨(dú)立的,所以需要將詞的位置信息添加到模型中,讓模型維護(hù)順序關(guān)系。
位置編碼
位置編碼就是將hello world! 的token和位置關(guān)系通過向量表示出來,作為訓(xùn)練的輸入數(shù)據(jù),如上圖,位置編碼最終會變成:
[
[P00, P01, P02 ... P0d],
[P10, P11, P12 ... P1d],
[P20, P21, P22 ... P2d],
]
2、計(jì)算位置編碼
計(jì)算位置編碼有多種方式:固定位置編碼,相對位置編碼,絕對位置編碼,其中Transformer的作者設(shè)計(jì)了一種三角函數(shù)位置編碼方式,通過三角函數(shù)計(jì)算輸出位置編碼向量。
為什么三角函數(shù)可以作為計(jì)算位置編碼的函數(shù)?
- 首先我們來回顧一下三角函數(shù)的基本性質(zhì):函數(shù)具有周期性,取值范圍是[-1, 1]。
sin
- 其次,如果用絕對位置編碼計(jì)算最大序列為3的位置(0-7),二進(jìn)制表示如下:
[
[0, 0, 0],
[0, 0, 1],
[0, 1, 0],
[0, 1, 1],
[1, 0, 0],
[1, 0, 1],
[1, 1, 0],
[1, 1, 1]
]
從上可以表示看出,較高比特位的交替頻率低于較低比特位,存在周期性bit位變化,符合三角函數(shù)的周期性,而且三角函數(shù)的取值范圍是[-1, 1],輸出浮點(diǎn)數(shù),并且數(shù)據(jù)連續(xù),比直接使用二進(jìn)制更節(jié)省空間。
3、Transformer中的位置編碼層
假設(shè)你有一個(gè)長度為L的輸入序列,要計(jì)算第K個(gè)元素的位置編碼,位置編碼由不同頻率的正弦和余弦函數(shù)給出:
函數(shù)
- k:詞序列中的第K個(gè)元素
- d:詞向量維度,比如512,1024,8K等
- P(k, i):位置函數(shù),輸出位置編碼向量
- n:定義的標(biāo)量,Attention Is All You Need 的作者設(shè)置為 10,000
- i:映射到列索引,范圍是0~d/2(由于輸入是2i表示,如果用i表示,范圍可以是0~d)
按照上述Hello world!的例子,計(jì)算位置編碼結(jié)果如下:
計(jì)算結(jié)果
那么用代碼實(shí)現(xiàn)一個(gè)簡化版本的位置編碼:
import numpy as np
def getPositionEncoding(seq_len, d, n=10000):
P = np.zeros((seq_len, d))
for k in range(seq_len):
for i in np.arange(int(d/2)):
denominator = np.power(n, 2*i/d)
P[k, 2*i] = np.sin(k/denominator)
P[k, 2*i+1] = np.cos(k/denominator)
return P
P = getPositionEncoding(seq_len=3, d=3, n=100)
print(P)
# 輸出結(jié)果:
[[ 0. 1. 0. ]
[ 0.84147098 0.54030231 0. ]
[ 0.90929743 -0.41614684 0. ]]
4、大模型訓(xùn)練中的位置編碼代碼
在我們從0訓(xùn)練大模型中,其位置編碼的實(shí)現(xiàn)如下:
def precompute_pos_cis(dim: int, seq_len: int, theta: float = 10000.0):
"""預(yù)計(jì)算相對位置編碼的復(fù)數(shù)形式,用于旋轉(zhuǎn)位置編碼(RoPE)。"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # 計(jì)算頻率
t = torch.arange(seq_len, device=freqs.device) # 創(chuàng)建時(shí)間步長
freqs = torch.outer(t, freqs).float() # 計(jì)算頻率的外積
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # 生成復(fù)數(shù)形式的頻率
return pos_cis # 返回預(yù)計(jì)算的復(fù)數(shù)位置編碼
def apply_rotary_emb(xq, xk, pos_cis):
"""應(yīng)用旋轉(zhuǎn)位置編碼到查詢和鍵。"""
def unite_shape(pos_cis, x):
"""調(diào)整位置編碼的形狀以匹配輸入張量的形狀。"""
ndim = x.ndim # 獲取輸入的維度
assert 0 <= 1 < ndim # 確保維度有效
assert pos_cis.shape == (x.shape[1], x.shape[-1]) # 確保位置編碼形狀匹配
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # 生成新形狀
return pos_cis.reshape(*shape) # 調(diào)整位置編碼的形狀
# 將查詢和鍵轉(zhuǎn)換為復(fù)數(shù)形式
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
pos_cis = unite_shape(pos_cis, xq_) # 調(diào)整位置編碼形狀
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3) # 應(yīng)用位置編碼并轉(zhuǎn)換回實(shí)數(shù)
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3) # 同上
return xq_out.type_as(xq), xk_out.type_as(xk) # 返回與輸入類型一致的輸出
這里使用的是RoPE旋轉(zhuǎn)位置編碼,和相對位置編碼相比,RoPE 具有更好的外推性,Meta 的 LLAMA 和 清華的 ChatGLM 都使用該編碼,目前是大模型相對位置編碼中應(yīng)用最廣的方式之一,具體原理由于篇幅原因就不講了,可以看看這篇文章:https://cloud.tencent.com/developer/article/2327751。
參考
(1)http://www.bimant.com/blog/transformer-positional-encoding-illustration/(2)https://hub.baai.ac.cn/view/29979
本文轉(zhuǎn)載自 ??周末程序猿??,作者: 周末程序猿
