自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

熬了一晚上,我從零實(shí)現(xiàn)了Transformer模型,把代碼講給你聽(tīng)

新聞 前端
本文詳細(xì)介紹了每個(gè)模塊的實(shí)現(xiàn)思路以及筆者在Coding過(guò)程中的感悟。

 [[428029]]

自從徹底搞懂 Self_Attention 機(jī)制之后,筆者對(duì) Transformer 模型的理解直接從地下一層上升到大氣層,任督二脈呼之欲出。夜夜入睡之前,那句柔情百轉(zhuǎn)的"Attention is all you need"時(shí)常在耳畔環(huán)繞,情到深處不禁拍床叫好。于是在腎上腺素的驅(qū)使下,筆者熬了一個(gè)晚上,終于實(shí)現(xiàn)了 Transformer 模型。

1. 模型總覽

代碼講解之前,首先放出這張經(jīng)典的模型架構(gòu)圖。下面的內(nèi)容中,我會(huì)將每個(gè)模塊的實(shí)現(xiàn)思路以及筆者在Coding過(guò)程中的感悟知無(wú)不答。沒(méi)有代碼基礎(chǔ)的讀者不要慌張,筆者也是最近才入門(mén)的,所寫(xiě)Pytorch代碼沒(méi)有花里胡哨,所用變量名詞盡量保持與論文一致,對(duì)新手十分友好。

我們觀察模型的結(jié)構(gòu)圖,Transformer模型包含哪些模塊?筆者將其分為以下幾個(gè)部分:

接下來(lái)我們首先逐個(gè)講解,最后將其拼接完成模型的復(fù)現(xiàn)。

2. config

下面是這個(gè)Demo所用的庫(kù)文件以及一些超參的信息。 單獨(dú)實(shí)現(xiàn)一個(gè)Config類(lèi)保存的原因是,方便日后復(fù)用。直接將模型部分復(fù)制,所用超參保存在新項(xiàng)目的Config類(lèi)中即可 。這里不過(guò)多贅述。

  1. import torch 
  2. import torch.nn as nn 
  3. import numpy as np 
  4. import math 
  5.  
  6.  
  7. class Config(object): 
  8.     def __init__(self): 
  9.         self.vocab_size = 6 
  10.  
  11.         self.d_model = 20 
  12.         self.n_heads = 2 
  13.  
  14.         assert self.d_model % self.n_heads == 0 
  15.         dim_k  = d_model % n_heads 
  16.         dim_v = d_model % n_heads 
  17.  
  18.         self.padding_size = 30 
  19.         self.UNK = 5 
  20.         self.PAD = 4 
  21.  
  22.         self.N = 6 
  23.         self.p = 0.1 
  24.  
  25. config = Config() 

3. Embedding

Embedding部分接受原始的文本輸入(batch_size*seq_len,例:[[1,3,10,5],[3,4,5],[5,3,1,1]]),疊加一個(gè)普通的Embedding層以及一個(gè)Positional Embedding層,輸出最后結(jié)果。

在這一層中,輸入的是一個(gè)list: [batch_size * seq_len],輸出的是一個(gè)tensor:[batch_size * seq_len * d_model]

普通的 Embedding 層想說(shuō)兩點(diǎn):

  •  
    1. torch.nn.Embedding 
    2. torch.nn.Embedding 
    3. padding_idx 
  • 在padding過(guò)程中,短補(bǔ)長(zhǎng)截

  1. class Embedding(nn.Module): 
  2.     def __init__(self,vocab_size): 
  3.         super(Embedding, self).__init__() 
  4.         # 一個(gè)普通的 embedding層,我們可以通過(guò)設(shè)置padding_idx=config.PAD 來(lái)實(shí)現(xiàn)論文中的 padding_mask 
  5.         self.embedding = nn.Embedding(vocab_size,config.d_model,padding_idx=config.PAD) 
  6.  
  7.  
  8.     def forward(self,x): 
  9.         # 根據(jù)每個(gè)句子的長(zhǎng)度,進(jìn)行padding,短補(bǔ)長(zhǎng)截 
  10.         for i in range(len(x)): 
  11.             if len(x[i]) < config.padding_size: 
  12.                 x[i].extend([config.UNK] * (config.padding_size - len(x[i]))) # 注意 UNK是你詞表中用來(lái)表示oov的token索引,這里進(jìn)行了簡(jiǎn)化,直接假設(shè)為6 
  13.             else
  14.                 x[i] = x[i][:config.padding_size] 
  15.         x = self.embedding(torch.tensor(x)) # batch_size * seq_len * d_model 
  16.         return x 

關(guān)于Positional Embedding,我們需要參考論文給出的公式。說(shuō)一句題外話(huà),在作者的實(shí)驗(yàn)中對(duì)比了Positional Embedding與單獨(dú)采用一個(gè)Embedding訓(xùn)練模型對(duì)位置的感知兩種方式,模型效果相差無(wú)幾。

  1. class Positional_Encoding(nn.Module): 
  2.  
  3.     def __init__(self,d_model): 
  4.         super(Positional_Encoding,self).__init__() 
  5.         self.d_model = d_model 
  6.  
  7.  
  8.     def forward(self,seq_len,embedding_dim): 
  9.         positional_encoding = np.zeros((seq_len,embedding_dim)) 
  10.         for pos in range(positional_encoding.shape[0]): 
  11.             for i in range(positional_encoding.shape[1]): 
  12.                 positional_encoding[pos][i] = math.sin(pos/(10000**(2*i/self.d_model))) if i % 2 == 0 else math.cos(pos/(10000**(2*i/self.d_model))) 
  13.         return torch.from_numpy(positional_encoding) 

4. Encoder

Muti_head_Attention

這一部分是模型的核心內(nèi)容,理論部分就不過(guò)多講解了,讀者可以參考文章開(kāi)頭的第一個(gè)傳送門(mén),文中有基礎(chǔ)的代碼實(shí)現(xiàn)。

Encoder 中的 Muti_head_Attention 不需要Mask,因此與我們上一篇文章中的實(shí)現(xiàn)方式相同。

為了避免模型信息泄露的問(wèn)題,Decoder 中的 Muti_head_Attention 需要Mask。這一節(jié)中我們重點(diǎn)講解Muti_head_Attention中Mask機(jī)制的實(shí)現(xiàn)。

如果讀者閱讀了我們的上一篇文章,可以發(fā)現(xiàn)下面的代碼有一點(diǎn)小小的不同,主要體現(xiàn)在 forward 函數(shù)的參數(shù)。

  •  
  • requires_mask:是否采用Mask機(jī)制,在Decoder中設(shè)置為T(mén)rue

  1. class Mutihead_Attention(nn.Module): 
  2.     def __init__(self,d_model,dim_k,dim_v,n_heads): 
  3.         super(Mutihead_Attention, self).__init__() 
  4.         self.dim_v = dim_v 
  5.         self.dim_k = dim_k 
  6.         self.n_heads = n_heads 
  7.  
  8.         self.q = nn.Linear(d_model,dim_k) 
  9.         self.k = nn.Linear(d_model,dim_k) 
  10.         self.v = nn.Linear(d_model,dim_v) 
  11.  
  12.         self.o = nn.Linear(dim_v,d_model) 
  13.         self.norm_fact = 1 / math.sqrt(d_model) 
  14.  
  15.     def generate_mask(self,dim): 
  16.         # 此處是 sequence mask ,防止 decoder窺視后面時(shí)間步的信息。 
  17.         # padding mask 在數(shù)據(jù)輸入模型之前完成。 
  18.         matirx = np.ones((dim,dim)) 
  19.         mask = torch.Tensor(np.tril(matirx)) 
  20.  
  21.         return mask==1 
  22.  
  23.     def forward(self,x,y,requires_mask=False): 
  24.         assert self.dim_k % self.n_heads == 0 and self.dim_v % self.n_heads == 0 
  25.         # size of x : [batch_size * seq_len * batch_size] 
  26.         # 對(duì) x 進(jìn)行自注意力 
  27.         Q = self.q(x).reshape(-1,x.shape[0],x.shape[1],self.dim_k // self.n_heads) # n_heads * batch_size * seq_len * dim_k 
  28.         K = self.k(x).reshape(-1,x.shape[0],x.shape[1],self.dim_k // self.n_heads) # n_heads * batch_size * seq_len * dim_k 
  29.         V = self.v(y).reshape(-1,y.shape[0],y.shape[1],self.dim_v // self.n_heads) # n_heads * batch_size * seq_len * dim_v 
  30.         # print("Attention V shape : {}".format(V.shape)) 
  31.         attention_score = torch.matmul(Q,K.permute(0,1,3,2)) * self.norm_fact 
  32.         if requires_mask: 
  33.             mask = self.generate_mask(x.shape[1]) 
  34.             attention_score.masked_fill(mask,value=float("-inf")) # 注意這里的小Trick,不需要將Q,K,V 分別MASK,只MASKSoftmax之前的結(jié)果就好了 
  35.         output = torch.matmul(attention_score,V).reshape(y.shape[0],y.shape[1],-1
  36.         # print("Attention output shape : {}".format(output.shape)) 
  37.  
  38.         output = self.o(output) 
  39.         return output 

Feed Forward

這一部分實(shí)現(xiàn)很簡(jiǎn)單,兩個(gè)Linear中連接Relu即可,目的是為模型增添非線(xiàn)性信息,提高模型的擬合能力。

  1. class Feed_Forward(nn.Module): 
  2.     def __init__(self,input_dim,hidden_dim=2048): 
  3.         super(Feed_Forward, self).__init__() 
  4.         self.L1 = nn.Linear(input_dim,hidden_dim) 
  5.         self.L2 = nn.Linear(hidden_dim,input_dim) 
  6.  
  7.     def forward(self,x): 
  8.         output = nn.ReLU()(self.L1(x)) 
  9.         output = self.L2(output) 
  10.         return output 

Add & LayerNorm

這一節(jié)我們實(shí)現(xiàn)論文中提出的殘差連接以及LayerNorm。

論文中關(guān)于這部分給出公式:

代碼中的dropout,在論文中也有所解釋?zhuān)瑢?duì)輸入layer_norm的tensor進(jìn)行dropout,對(duì)模型的性能影響還是蠻大的。

代碼中的參數(shù) sub_layer ,可以是Feed Forward,也可以是Muti_head_Attention。

  1. class Add_Norm(nn.Module): 
  2.     def __init__(self): 
  3.         self.dropout = nn.Dropout(config.p) 
  4.         super(Add_Norm, self).__init__() 
  5.  
  6.     def forward(self,x,sub_layer,**kwargs): 
  7.         sub_output = sub_layer(x,**kwargs) 
  8.         # print("{} output : {}".format(sub_layer,sub_output.size())) 
  9.         x = self.dropout(x + sub_output) 
  10.  
  11.         layer_norm = nn.LayerNorm(x.size()[1:]) 
  12.         out = layer_norm(x) 
  13.         return out 

OK,Encoder中所有模塊我們已經(jīng)講解完畢,接下來(lái)我們將其拼接作為Encoder

  1. class Encoder(nn.Module): 
  2.     def __init__(self): 
  3.         super(Encoder, self).__init__() 
  4.         self.positional_encoding = Positional_Encoding(config.d_model) 
  5.         self.muti_atten = Mutihead_Attention(config.d_model,config.dim_k,config.dim_v,config.n_heads) 
  6.         self.feed_forward = Feed_Forward(config.d_model) 
  7.  
  8.         self.add_norm = Add_Norm() 
  9.  
  10.  
  11.     def forward(self,x): # batch_size * seq_len 并且 x 的類(lèi)型不是tensor,是普通list 
  12.  
  13.         x += self.positional_encoding(x.shape[1],config.d_model) 
  14.         # print("After positional_encoding: {}".format(x.size())) 
  15.         output = self.add_norm(x,self.muti_atten,y=x) 
  16.         output = self.add_norm(output,self.feed_forward) 
  17.  
  18.         return output 

5.Decoder

在 Encoder 部分的講解中,我們已經(jīng)實(shí)現(xiàn)了大部分Decoder的模塊。Decoder的Muti_head_Attention引入了Mask機(jī)制,Decoder與Encoder 中模塊的拼接方式不同。以上兩點(diǎn)讀者在Coding的時(shí)候需要注意。

  1. class Decoder(nn.Module): 
  2.     def __init__(self): 
  3.         super(Decoder, self).__init__() 
  4.         self.positional_encoding = Positional_Encoding(config.d_model) 
  5.         self.muti_atten = Mutihead_Attention(config.d_model,config.dim_k,config.dim_v,config.n_heads) 
  6.         self.feed_forward = Feed_Forward(config.d_model) 
  7.         self.add_norm = Add_Norm() 
  8.  
  9.     def forward(self,x,encoder_output): # batch_size * seq_len 并且 x 的類(lèi)型不是tensor,是普通list 
  10.         # print(x.size()) 
  11.         x += self.positional_encoding(x.shape[1],config.d_model) 
  12.         # print(x.size()) 
  13.         # 第一個(gè) sub_layer 
  14.         output = self.add_norm(x,self.muti_atten,y=x,requires_mask=True) 
  15.         # 第二個(gè) sub_layer 
  16.         output = self.add_norm(output,self.muti_atten,y=encoder_output,requires_mask=True) 
  17.         # 第三個(gè) sub_layer 
  18.         output = self.add_norm(output,self.feed_forward) 
  19.  
  20.  
  21.         return output 

6.Transformer

至此,所有內(nèi)容已經(jīng)鋪墊完畢,我們開(kāi)始組裝Transformer模型。論文中提到,Transformer中堆疊了6個(gè)我們上文中實(shí)現(xiàn)的Encoder 和 Decoder。這里筆者采用 nn.Sequential 實(shí)現(xiàn)了堆疊操作。

Output模塊的 Linear 和 Softmax 的實(shí)現(xiàn)也包含在下面的代碼中

  1. class Transformer_layer(nn.Module): 
  2.     def __init__(self): 
  3.         super(Transformer_layer, self).__init__() 
  4.         self.encoder = Encoder() 
  5.         self.decoder = Decoder() 
  6.  
  7.     def forward(self,x): 
  8.         x_input,x_output = x 
  9.         encoder_output = self.encoder(x_input) 
  10.         decoder_output = self.decoder(x_output,encoder_output) 
  11.         return (encoder_output,decoder_output) 
  12.  
  13.  
  14. class Transformer(nn.Module): 
  15.     def __init__(self,N,vocab_size,output_dim): 
  16.         super(Transformer, self).__init__() 
  17.         self.embedding_input = Embedding(vocab_size=vocab_size) 
  18.         self.embedding_output = Embedding(vocab_size=vocab_size) 
  19.  
  20.         self.output_dim = output_dim 
  21.         self.linear = nn.Linear(config.d_model,output_dim) 
  22.         self.softmax = nn.Softmax(dim=-1
  23.         self.model = nn.Sequential(*[Transformer_layer() for _ in range(N)]) 
  24.  
  25.  
  26.     def forward(self,x): 
  27.         x_input , x_output = x 
  28.         x_input = self.embedding_input(x_input) 
  29.         x_output = self.embedding_output(x_output) 
  30.  
  31.         _ , output = self.model((x_input,x_output)) 
  32.  
  33.         output = self.linear(output) 
  34.         output = self.softmax(output) 
  35.  
  36.         return output 

完整代碼

  1. @Author:Yifx 
  2. @Contact: Xxuyifan1999@163.com 
  3. @Time:2021/9/16 20:02 
  4. @Software: PyCharm 
  5.  
  6. ""
  7. 文件說(shuō)明: 
  8. ""
  9.  
  10. import torch 
  11. import torch.nn as nn 
  12. import numpy as np 
  13. import math 
  14.  
  15.  
  16. class Config(object): 
  17.     def __init__(self): 
  18.         self.vocab_size = 6 
  19.  
  20.         self.d_model = 20 
  21.         self.n_heads = 2 
  22.  
  23.         assert self.d_model % self.n_heads == 0 
  24.         dim_k  = self.d_model // self.n_heads 
  25.         dim_v = self.d_model // self.n_heads 
  26.  
  27.  
  28.  
  29.         self.padding_size = 30 
  30.         self.UNK = 5 
  31.         self.PAD = 4 
  32.  
  33.         self.N = 6 
  34.         self.p = 0.1 
  35.  
  36. config = Config() 
  37.  
  38.  
  39. class Embedding(nn.Module): 
  40.     def __init__(self,vocab_size): 
  41.         super(Embedding, self).__init__() 
  42.         # 一個(gè)普通的 embedding層,我們可以通過(guò)設(shè)置padding_idx=config.PAD 來(lái)實(shí)現(xiàn)論文中的 padding_mask 
  43.         self.embedding = nn.Embedding(vocab_size,config.d_model,padding_idx=config.PAD) 
  44.  
  45.  
  46.     def forward(self,x): 
  47.         # 根據(jù)每個(gè)句子的長(zhǎng)度,進(jìn)行padding,短補(bǔ)長(zhǎng)截 
  48.         for i in range(len(x)): 
  49.             if len(x[i]) < config.padding_size: 
  50.                 x[i].extend([config.UNK] * (config.padding_size - len(x[i]))) # 注意 UNK是你詞表中用來(lái)表示oov的token索引,這里進(jìn)行了簡(jiǎn)化,直接假設(shè)為6 
  51.             else
  52.                 x[i] = x[i][:config.padding_size] 
  53.         x = self.embedding(torch.tensor(x)) # batch_size * seq_len * d_model 
  54.         return x 
  55.  
  56.  
  57.  
  58. class Positional_Encoding(nn.Module): 
  59.  
  60.     def __init__(self,d_model): 
  61.         super(Positional_Encoding,self).__init__() 
  62.         self.d_model = d_model 
  63.  
  64.  
  65.     def forward(self,seq_len,embedding_dim): 
  66.         positional_encoding = np.zeros((seq_len,embedding_dim)) 
  67.         for pos in range(positional_encoding.shape[0]): 
  68.             for i in range(positional_encoding.shape[1]): 
  69.                 positional_encoding[pos][i] = math.sin(pos/(10000**(2*i/self.d_model))) if i % 2 == 0 else math.cos(pos/(10000**(2*i/self.d_model))) 
  70.         return torch.from_numpy(positional_encoding) 
  71.  
  72.  
  73. class Mutihead_Attention(nn.Module): 
  74.     def __init__(self,d_model,dim_k,dim_v,n_heads): 
  75.         super(Mutihead_Attention, self).__init__() 
  76.         self.dim_v = dim_v 
  77.         self.dim_k = dim_k 
  78.         self.n_heads = n_heads 
  79.  
  80.         self.q = nn.Linear(d_model,dim_k) 
  81.         self.k = nn.Linear(d_model,dim_k) 
  82.         self.v = nn.Linear(d_model,dim_v) 
  83.  
  84.         self.o = nn.Linear(dim_v,d_model) 
  85.         self.norm_fact = 1 / math.sqrt(d_model) 
  86.  
  87.     def generate_mask(self,dim): 
  88.         # 此處是 sequence mask ,防止 decoder窺視后面時(shí)間步的信息。 
  89.         # padding mask 在數(shù)據(jù)輸入模型之前完成。 
  90.         matirx = np.ones((dim,dim)) 
  91.         mask = torch.Tensor(np.tril(matirx)) 
  92.  
  93.         return mask==1 
  94.  
  95.     def forward(self,x,y,requires_mask=False): 
  96.         assert self.dim_k % self.n_heads == 0 and self.dim_v % self.n_heads == 0 
  97.         # size of x : [batch_size * seq_len * batch_size] 
  98.         # 對(duì) x 進(jìn)行自注意力 
  99.         Q = self.q(x).reshape(-1,x.shape[0],x.shape[1],self.dim_k // self.n_heads) # n_heads * batch_size * seq_len * dim_k 
  100.         K = self.k(x).reshape(-1,x.shape[0],x.shape[1],self.dim_k // self.n_heads) # n_heads * batch_size * seq_len * dim_k 
  101.         V = self.v(y).reshape(-1,y.shape[0],y.shape[1],self.dim_v // self.n_heads) # n_heads * batch_size * seq_len * dim_v 
  102.         # print("Attention V shape : {}".format(V.shape)) 
  103.         attention_score = torch.matmul(Q,K.permute(0,1,3,2)) * self.norm_fact 
  104.         if requires_mask: 
  105.             mask = self.generate_mask(x.shape[1]) 
  106.             # masked_fill 函數(shù)中,對(duì)Mask位置為T(mén)rue的部分進(jìn)行Mask 
  107.             attention_score.masked_fill(mask,value=float("-inf")) # 注意這里的小Trick,不需要將Q,K,V 分別MASK,只MASKSoftmax之前的結(jié)果就好了 
  108.         output = torch.matmul(attention_score,V).reshape(y.shape[0],y.shape[1],-1
  109.         # print("Attention output shape : {}".format(output.shape)) 
  110.  
  111.         output = self.o(output) 
  112.         return output 
  113.  
  114.  
  115. class Feed_Forward(nn.Module): 
  116.     def __init__(self,input_dim,hidden_dim=2048): 
  117.         super(Feed_Forward, self).__init__() 
  118.         self.L1 = nn.Linear(input_dim,hidden_dim) 
  119.         self.L2 = nn.Linear(hidden_dim,input_dim) 
  120.  
  121.     def forward(self,x): 
  122.         output = nn.ReLU()(self.L1(x)) 
  123.         output = self.L2(output) 
  124.         return output 
  125.  
  126. class Add_Norm(nn.Module): 
  127.     def __init__(self): 
  128.         self.dropout = nn.Dropout(config.p) 
  129.         super(Add_Norm, self).__init__() 
  130.  
  131.     def forward(self,x,sub_layer,**kwargs): 
  132.         sub_output = sub_layer(x,**kwargs) 
  133.         # print("{} output : {}".format(sub_layer,sub_output.size())) 
  134.         x = self.dropout(x + sub_output) 
  135.  
  136.         layer_norm = nn.LayerNorm(x.size()[1:]) 
  137.         out = layer_norm(x) 
  138.         return out 
  139.  
  140.  
  141. class Encoder(nn.Module): 
  142.     def __init__(self): 
  143.         super(Encoder, self).__init__() 
  144.         self.positional_encoding = Positional_Encoding(config.d_model) 
  145.         self.muti_atten = Mutihead_Attention(config.d_model,config.dim_k,config.dim_v,config.n_heads) 
  146.         self.feed_forward = Feed_Forward(config.d_model) 
  147.  
  148.         self.add_norm = Add_Norm() 
  149.  
  150.  
  151.     def forward(self,x): # batch_size * seq_len 并且 x 的類(lèi)型不是tensor,是普通list 
  152.  
  153.         x += self.positional_encoding(x.shape[1],config.d_model) 
  154.         # print("After positional_encoding: {}".format(x.size())) 
  155.         output = self.add_norm(x,self.muti_atten,y=x) 
  156.         output = self.add_norm(output,self.feed_forward) 
  157.  
  158.         return output 
  159.  
  160. # 在 Decoder 中,Encoder的輸出作為Query和KEy輸出的那個(gè)東西。即 Decoder的Input作為V。此時(shí)是可行的 
  161. # 因?yàn)樵谳斎脒^(guò)程中,我們有一個(gè)padding操作,將Inputs和Outputs的seq_len這個(gè)維度都拉成一樣的了 
  162. # 我們知道,QK那個(gè)過(guò)程得到的結(jié)果是 batch_size * seq_len * seq_len .既然 seq_len 一樣,那么我們可以這樣操作 
  163. # 這樣操作的意義是,Outputs 中的 token 分別對(duì)于 Inputs 中的每個(gè)token作注意力 
  164.  
  165. class Decoder(nn.Module): 
  166.     def __init__(self): 
  167.         super(Decoder, self).__init__() 
  168.         self.positional_encoding = Positional_Encoding(config.d_model) 
  169.         self.muti_atten = Mutihead_Attention(config.d_model,config.dim_k,config.dim_v,config.n_heads) 
  170.         self.feed_forward = Feed_Forward(config.d_model) 
  171.         self.add_norm = Add_Norm() 
  172.  
  173.     def forward(self,x,encoder_output): # batch_size * seq_len 并且 x 的類(lèi)型不是tensor,是普通list 
  174.         # print(x.size()) 
  175.         x += self.positional_encoding(x.shape[1],config.d_model) 
  176.         # print(x.size()) 
  177.         # 第一個(gè) sub_layer 
  178.         output = self.add_norm(x,self.muti_atten,y=x,requires_mask=True) 
  179.         # 第二個(gè) sub_layer 
  180.         output = self.add_norm(x,self.muti_atten,y=encoder_output,requires_mask=True) 
  181.         # 第三個(gè) sub_layer 
  182.         output = self.add_norm(output,self.feed_forward) 
  183.         return output 
  184.  
  185. class Transformer_layer(nn.Module): 
  186.     def __init__(self): 
  187.         super(Transformer_layer, self).__init__() 
  188.         self.encoder = Encoder() 
  189.         self.decoder = Decoder() 
  190.  
  191.     def forward(self,x): 
  192.         x_input,x_output = x 
  193.         encoder_output = self.encoder(x_input) 
  194.         decoder_output = self.decoder(x_output,encoder_output) 
  195.         return (encoder_output,decoder_output) 
  196.  
  197. class Transformer(nn.Module): 
  198.     def __init__(self,N,vocab_size,output_dim): 
  199.         super(Transformer, self).__init__() 
  200.         self.embedding_input = Embedding(vocab_size=vocab_size) 
  201.         self.embedding_output = Embedding(vocab_size=vocab_size) 
  202.  
  203.         self.output_dim = output_dim 
  204.         self.linear = nn.Linear(config.d_model,output_dim) 
  205.         self.softmax = nn.Softmax(dim=-1
  206.         self.model = nn.Sequential(*[Transformer_layer() for _ in range(N)]) 
  207.  
  208.  
  209.     def forward(self,x): 
  210.         x_input , x_output = x 
  211.         x_input = self.embedding_input(x_input) 
  212.         x_output = self.embedding_output(x_output) 
  213.  
  214.         _ , output = self.model((x_input,x_output)) 
  215.  
  216.         output = self.linear(output) 
  217.         output = self.softmax(output) 
  218.  
  219.         return output 

 

責(zé)任編輯:張燕妮 來(lái)源: 知乎
相關(guān)推薦

2019-07-03 15:01:30

戴爾

2021-09-30 18:22:46

VSCode插件API

2021-12-03 11:57:27

代碼##語(yǔ)言

2021-04-23 13:46:06

Python標(biāo)準(zhǔn)庫(kù)協(xié)議

2020-04-14 10:06:20

微服務(wù)Netflix語(yǔ)言

2021-09-14 05:47:44

微信K歌騰訊

2024-01-04 14:16:05

騰訊紅黑樹(shù)Socket

2024-11-19 08:36:16

2024-05-14 08:20:59

線(xiàn)程CPU場(chǎng)景

2024-11-11 14:57:56

JWTSession微服務(wù)

2022-12-30 17:52:44

分布式容錯(cuò)架構(gòu)

2020-11-09 10:25:32

數(shù)據(jù)分析雙十一手機(jī)

2024-04-08 10:12:20

GPT4AgentAI

2024-09-12 16:25:34

2022-09-13 08:03:41

python編號(hào)數(shù)據(jù)

2020-08-06 16:55:37

虛擬化底層計(jì)算機(jī)

2019-06-17 08:21:06

RPC框架服務(wù)

2019-09-27 09:13:55

Redis內(nèi)存機(jī)制

2023-04-06 08:01:30

RustMutex

2021-09-28 13:42:55

Chrome Devwebsocket網(wǎng)絡(luò)協(xié)議
點(diǎn)贊
收藏

51CTO技術(shù)棧公眾號(hào)