入門 Transformer:概念、代碼與流程詳解
引言
論文《Attention is All You Need》(Vaswani等,2017)提出了Transformer架構(gòu),這一模型通過完全摒棄標(biāo)準(zhǔn)的循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)組件,徹底改變了自然語言處理(NLP)領(lǐng)域。相反,它利用了一種稱為“注意力”的機(jī)制,讓模型在生成輸出時(shí)決定如何關(guān)注輸入的特定部分(如句子中的單詞)。在Transformer之前,基于RNN的模型(如LSTM)主導(dǎo)了NLP領(lǐng)域。這些模型一次處理一個(gè)詞元,難以有效捕捉長(zhǎng)距離依賴關(guān)系。而Transformer則通過并行化數(shù)據(jù)流,并依賴注意力機(jī)制來識(shí)別詞元之間的重要關(guān)系。這一步驟帶來了巨大的影響,推動(dòng)了機(jī)器翻譯、文本生成(如GPT)甚至計(jì)算機(jī)視覺任務(wù)等領(lǐng)域的巨大飛躍。
在深入Transformer模型的實(shí)現(xiàn)之前,強(qiáng)烈建議對(duì)深度學(xué)習(xí)概念有基本的了解。熟悉神經(jīng)網(wǎng)絡(luò)、embedding、激活函數(shù)和優(yōu)化技術(shù)等主題,將更容易理解代碼并掌握Transformer的工作原理。如果你對(duì)這些概念不熟悉,可以考慮探索深度學(xué)習(xí)框架的入門資源,以及反向傳播等基礎(chǔ)主題。
導(dǎo)入庫(kù)
我們將使用PyTorch作為深度學(xué)習(xí)框架。PyTorch提供了構(gòu)建和訓(xùn)練神經(jīng)網(wǎng)絡(luò)所需的所有基本工具:
import torch
import torch.nn as nn
import math
這些導(dǎo)入包括:
- torch:PyTorch主庫(kù)。
- torch.nn:包含與神經(jīng)網(wǎng)絡(luò)相關(guān)的類和函數(shù),如nn.Linear、nn.Dropout等。
- math:用于常見的數(shù)學(xué)操作。
Transformer架構(gòu)
輸入embedding
什么是embedding?
embedding是單詞或詞元的密集向量表示。與將單詞表示為獨(dú)熱編碼向量不同,embedding將每個(gè)單詞映射到低維連續(xù)向量空間。這些embedding捕捉了單詞之間的語義關(guān)系。例如,“男人”和“女人”的詞embedding在向量空間中的距離可能比“男人”和“狗”更近。
以下是embedding層的代碼:
class InputEmbedding(nn.Module):
def __init__(self, d_model: int, vocab_size: int):
super().__init__()
self.d_model = d_model
self.vocab_size = vocab_size
self.embedding = nn.Embedding(vocab_size, d_model)
def forward(self, x):
return self.embedding(x) * math.sqrt(self.d_model)
解釋:
- nn.Embedding:將單詞索引轉(zhuǎn)換為大小為d_model的密集向量。
- 縮放因子sqrt(d_model):論文中使用此方法在訓(xùn)練期間穩(wěn)定梯度。
示例:
如果詞匯表大小為6(例如,詞元如[“Bye”, “Hello”等]),且d_model為512,embedding層將每個(gè)詞元映射到一個(gè)512維向量。
位置編碼
什么是位置編碼?
Transformer并行處理輸入序列,因此它們?nèi)狈逃械捻樞蚋拍睿ㄅcRNN不同,RNN是順序處理詞元的)。位置編碼被添加到embedding中,以向模型提供詞元在序列中的相對(duì)或絕對(duì)位置信息。
以下是位置編碼的代碼:
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
super().__init__()
self.d_model = d_model
self.seq_len = seq_len
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(seq_len, d_model)
position = torch.arange(0, seq_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
return self.dropout(x)
解釋:
- 正弦函數(shù):編碼在偶數(shù)和奇數(shù)維度之間交替使用正弦和余弦函數(shù)。
- 為什么使用正弦函數(shù)?這些函數(shù)允許模型泛化到比訓(xùn)練時(shí)更長(zhǎng)的序列。
- register_buffer:確保位置編碼與模型一起保存,但在訓(xùn)練期間不更新。
我們只需要計(jì)算一次位置編碼,然后為每個(gè)句子重復(fù)使用。
層歸一化
什么是層歸一化?
層歸一化是一種通過跨特征維度歸一化輸入來穩(wěn)定和加速訓(xùn)練的技術(shù)。它確保每個(gè)輸入向量的均值為0,方差為1。
我們還引入了兩個(gè)參數(shù),gamma(乘法)和beta(加法),它們會(huì)在數(shù)據(jù)中引入波動(dòng)。網(wǎng)絡(luò)將學(xué)習(xí)調(diào)整這兩個(gè)可學(xué)習(xí)參數(shù),以在需要時(shí)引入波動(dòng)。
以下是代碼:
class LayerNormalization(nn.Module):
def __init__(self, eps: float = 1e-6) -> None:
super().__init__()
self.eps = eps
self.alpha = nn.Parameter(torch.ones(1))
self.bias = nn.Parameter(torch.zeros(1))
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
std = x.std(dim=-1, keepdim=True)
return self.alpha * (x - mean) / (std + self.eps) + self.bias
解釋:
- alpha和bias:可學(xué)習(xí)參數(shù),用于縮放和偏移歸一化后的輸出。
- eps:添加到分母中的小值,以防止除以零。
前饋塊
什么是前饋塊?
它是一個(gè)簡(jiǎn)單的兩層神經(jīng)網(wǎng)絡(luò),獨(dú)立應(yīng)用于序列中的每個(gè)位置。它幫助模型學(xué)習(xí)復(fù)雜的變換。以下是代碼:
class FeedForwardBlock(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
super().__init__()
self.linear_1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.linear_2 = nn.Linear(d_ff, d_model)
def forward(self, x):
return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))
解釋:
- 第一層線性層:將輸入維度從d_model擴(kuò)展到d_ff,即512 → 2048。
- ReLU激活:增加非線性。
- 第二層線性層:投影回d_model。
多頭注意力
什么是注意力?
注意力機(jī)制允許模型在做出預(yù)測(cè)時(shí)關(guān)注輸入的相關(guān)部分。它計(jì)算值(V)的加權(quán)和,其中權(quán)重由查詢(Q)和鍵(K)之間的相似性決定。
什么是多頭注意力?
與計(jì)算單一注意力分?jǐn)?shù)不同,多頭注意力將輸入分成多個(gè)“頭h”,以學(xué)習(xí)不同類型的關(guān)系。
注意力
按頭計(jì)算的注意力
所有頭的注意力
- Q(查詢):代表當(dāng)前單詞或詞元。
- K(鍵):代表序列中的所有單詞或詞元。
- V(值):代表與每個(gè)單詞或詞元相關(guān)的信息。
- Softmax:將相似性分?jǐn)?shù)轉(zhuǎn)換為概率,使其總和為1。
- 縮放因子sqrt(d_k):防止點(diǎn)積變得過大,從而破壞softmax函數(shù)的穩(wěn)定性。
每個(gè)頭計(jì)算自己的注意力,結(jié)果被拼接并投影回原始維度。以下是多頭注意力塊的代碼:
class MultiHeadAttentionBlock(nn.Module):
def __init__(self, d_model: int, h: int, dropout: float) -> None: # h is number of heads
super().__init__()
self.d_model = d_model
self.h = h
# Check if d_model is divisible by num_heads
assert d_model % h == 0, "d_model must be divisible by num_heads"
self.d_k = d_model // h
# Define matrices W_q, W_k, W_v , W_o
self.w_q = nn.Linear(d_model, d_model) # W_q
self.w_k = nn.Linear(d_model, d_model) # W_k
self.w_v = nn.Linear(d_model, d_model) # W_v
self.w_o = nn.Linear(d_model, d_model) # W_o
self.dropout = nn.Dropout(dropout)
@staticmethod
def attention(query, key, value, d_k, mask=None, dropout=nn.Dropout):
# Compute attention scores
attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
attention_scores.masked_fill(mask == 0, -1e9) # Mask padding tokens
attention_scores = torch.softmax(attention_scores, dim=-1)
if dropout is not None:
attention_scores = dropout(attention_scores)
return (attention_scores @ value), attention_scores
def forward(self, q, k, v, mask):
# Compute Q, K, V
query = self.w_q(q)
key = self.w_k(k)
value = self.w_v(v)
# Split into multiple heads
query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
# Compute attention
x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, self.d_k, mask, self.dropout)
# Concatenate heads
x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
# Final linear projection
return self.w_o(x)
解釋:
- 線性層(W_q, W_k, W_v):這些層將輸入轉(zhuǎn)換為查詢、鍵和值。
- 分割成頭:輸入被分割成h個(gè)頭,每個(gè)頭的維度較?。╠_k = d_model / h)。
- 拼接:所有頭的輸出被拼接,并使用w_o投影回原始維度。
示例:
如果d_model=512且h=8,每個(gè)頭的維度為d_k=64。輸入被分割成8個(gè)頭,每個(gè)頭獨(dú)立計(jì)算注意力。結(jié)果被拼接回一個(gè)512維向量。
殘差連接和層歸一化
什么是殘差連接?
殘差連接將層的輸入添加到其輸出中。這有助于防止“梯度消失”問題。以下是殘差連接的代碼:
class ResidualConnection(nn.Module):
def __init__(self, dropout: float) -> None:
super().__init__()
self.dropout = nn.Dropout(dropout)
self.norm = LayerNormalization()
def forward(self, x, sublayer):
return x + self.dropout(sublayer(self.norm(x)))
解釋:
- x:層的輸入。
- sublayer:表示層的函數(shù)(如注意力或前饋塊)。
- 輸出是輸入和子層輸出的和,隨后進(jìn)行dropout。
1. 編碼器塊
編碼器塊結(jié)合了我們迄今為止討論的所有組件:多頭注意力、前饋塊、殘差連接和層歸一化。以下是編碼器塊的代碼:
class EncoderBlock(nn.Module):
def __init__(self, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
super().__init__()
self.self_attention_block = self_attention_block
self.feed_forward_block = feed_forward_block
self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)])
def forward(self, x, src_mask):
x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
x = self.residual_connections[1](x, self.feed_forward_block)
return x
解釋:
- 自注意力:輸入對(duì)自身進(jìn)行注意力計(jì)算,以捕捉詞元之間的關(guān)系。
- 前饋塊:對(duì)每個(gè)詞元應(yīng)用全連接網(wǎng)絡(luò)。
- 殘差連接:將輸入添加回每個(gè)子層的輸出。
2. 解碼器塊
解碼器塊與編碼器塊類似,但包含一個(gè)額外的交叉注意力層。這使得解碼器能夠關(guān)注編碼器的輸出。以下是解碼器塊的代碼:
class DecoderBlock(nn.Module):
def __init__(self, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
super().__init__()
self.self_attention_block = self_attention_block
self.cross_attention_block = cross_attention_block
self.feed_forward_block = feed_forward_block
self.residual_connections = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)])
def forward(self, x, encoder_output, src_mask, tgt_mask):
x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
x = self.residual_connections[2](x, self.feed_forward_block)
return x
解釋:
- 自注意力:解碼器對(duì)其自身輸出進(jìn)行注意力計(jì)算。
- 交叉注意力:解碼器對(duì)編碼器的輸出進(jìn)行注意力計(jì)算。
- 前饋塊:對(duì)每個(gè)詞元應(yīng)用全連接網(wǎng)絡(luò)。
3. 最終線性層
該層將解碼器的輸出投影到詞匯表大小,將embedding為每個(gè)單詞的概率。該層將包含一個(gè)線性層和一個(gè)softmax層。這里使用log_softmax以避免下溢并使此步驟在數(shù)值上更穩(wěn)定。以下是最終線性層的代碼:
class ProjectionLayer(nn.Module):
def __init__(self, d_model: int, vocab_size: int) -> None:
super().__init__()
self.proj = nn.Linear(d_model, vocab_size)
def forward(self, x):
return torch.log_softmax(self.proj(x), dim=-1)
解釋:
- 輸入:解碼器的輸出(形狀:[batch, seq_len, d_model])。
- 輸出:詞匯表中每個(gè)單詞的對(duì)數(shù)概率(形狀:[batch, seq_len, vocab_size])。
Transformer模型
Transformer將所有內(nèi)容結(jié)合在一起:編碼器、解碼器、embedding、位置編碼和投影層。以下是代碼:
class Transformer(nn.Module):
def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbedding, tgt_embed: InputEmbedding, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer) -> None:
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.src_embed = src_embed
self.tgt_embed = tgt_embed
self.src_pos = src_pos
self.tgt_pos = tgt_pos
self.projection_layer = projection_layer
def encode(self, src, src_mask):
src = self.src_embed(src)
src = self.src_pos(src)
return self.encoder(src, src_mask)
def decode(self, encoder_output, tgt, src_mask, tgt_mask):
tgt = self.tgt_embed(tgt)
tgt = self.tgt_pos(tgt)
return self.decoder(tgt, encoder_output, src_mask, tgt_mask)
def project(self, decoder_output):
return self.projection_layer(decoder_output)
最終構(gòu)建函數(shù)塊
最后一個(gè)塊是一個(gè)輔助函數(shù),通過組合我們迄今為止看到的所有組件來構(gòu)建整個(gè)Transformer模型。它允許我們指定論文中使用的超參數(shù)。以下是代碼:
def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int,
d_model: int = 512, N: int = 6, h: int = 8, dropout: float = 0.1, d_ff: int = 2048) -> Transformer:
# Create the embedding layers for source and target
src_embed = InputEmbedding(d_model, src_vocab_size)
tgt_embed = InputEmbedding(d_model, tgt_vocab_size)
# Create positional encoding layers for source and target
src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)
# Create the encoder blocks
encoder_blocks = []
for _ in range(N):
encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
encoder_block = EncoderBlock(encoder_self_attention_block, feed_forward_block, dropout)
encoder_blocks.append(encoder_block)
# Create the decoder blocks
decoder_blocks = []
for _ in range(N):
decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
decoder_block = DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
decoder_blocks.append(decoder_block)
# Create the encoder and decoder
encoder = Encoder(nn.ModuleList(encoder_blocks))
decoder = Decoder(nn.ModuleList(decoder_blocks))
# Create the projection layer
projection_layer = ProjectionLayer(d_model, tgt_vocab_size)
# Create the Transformer model
transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)
# Initialize the parameters using Xavier initialization
for p in transformer.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
return transformer
該函數(shù)接受以下超參數(shù)作為輸入:
- src_vocab_size:源詞匯表的大?。ㄔ凑Z言中唯一詞元的數(shù)量)。
- tgt_vocab_size:目標(biāo)詞匯表的大?。繕?biāo)語言中唯一詞元的數(shù)量)。
- src_seq_len:源輸入的最大序列長(zhǎng)度。
- tgt_seq_len:目標(biāo)輸入的最大序列長(zhǎng)度。
- d_model:模型的維度(默認(rèn):512)。
- N:編碼器和解碼器塊的數(shù)量(默認(rèn):6)。
- h:多頭注意力機(jī)制中的注意力頭數(shù)量(默認(rèn):8)。
- dropout:防止過擬合的dropout率(默認(rèn):0.1)。
- d_ff:前饋網(wǎng)絡(luò)的維度(默認(rèn):2048)。
我們可以使用此函數(shù)創(chuàng)建具有所需超參數(shù)的Transformer模型。例如:
src_vocab_size = 10000
tgt_vocab_size = 10000
src_seq_len = 50
tgt_seq_len = 50
transformer = build_transformer(src_vocab_size, tgt_vocab_size, src_seq_len, tgt_seq_len)
總結(jié) — “Attention is All You Need” 核心貢獻(xiàn)
- 并行化:RNN按順序處理單詞,而Transformer并行處理整個(gè)句子。這大大減少了訓(xùn)練時(shí)間。
- 多功能性:注意力機(jī)制可以適應(yīng)各種任務(wù),如翻譯、文本分類、問答、計(jì)算機(jī)視覺、語音識(shí)別等。
- 為基礎(chǔ)模型鋪路:該架構(gòu)為BERT、GPT和T5等大規(guī)模語言模型鋪平了道路。