自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

從理論到實現(xiàn),手把手實現(xiàn)Attention網(wǎng)絡(luò)

網(wǎng)絡(luò) 網(wǎng)絡(luò)優(yōu)化
Transformer當(dāng)中也有attention結(jié)構(gòu),它就是正兒八經(jīng)地利用向量之間的相似度來計算的。常理上來說,按照向量相似度來計算權(quán)重,這種做法應(yīng)該更容易理解一些。

作者 | 梁唐

出品 | 公眾號:Coder梁(ID:Coder_LT)

大家好,我是老梁。

我們之前介紹了Transformer的核心——attention網(wǎng)絡(luò),我們之前只是介紹了它的原理,并且沒有詳細(xì)解釋它的實現(xiàn)方法。光聊理論難免顯得有些空洞,所以我們來談?wù)勊膶崿F(xiàn)。

為了幫助大家更好地理解, 這里我選了電商場景中的DIN模型來做切入點。

一方面可以幫助大家理解現(xiàn)在電商系統(tǒng)中的推薦和廣告系統(tǒng)中的商品排序都是怎么做的,另外我個人感覺DIN要比直接去硬啃transformer容易理解一些。

我們可以先從attention網(wǎng)絡(luò)的數(shù)據(jù)入手,它的輸入數(shù)據(jù)有兩個:一個是用戶的歷史行為序列,一個是待打分的item(以下稱為target item)。用戶的歷史行為序列本質(zhì)上其實就是一個用戶歷史上有過交互的item的數(shù)組。這里為了簡化,我們假設(shè)已經(jīng)完成了從item到embedding的轉(zhuǎn)換。

首先是target item,它的shape應(yīng)該是[B, E]。這里的B指的是batch_size,即訓(xùn)練時候一個批量的大小。這里的E指的是embedding的長度。也有一些文章里使用別的字母表示,這也沒有一個硬性的標(biāo)準(zhǔn),能看懂就行。

我們再來看用戶行為序列,除了batch_size和embedding長度之外,還需要一個額外的參數(shù)來表示行為序列的長度,通常我們用字母T。對于所有的樣本,我們都需要保證它的行為序列長度是T,如果不足T的,則使用默認(rèn)值補足。如果超過T的,則進(jìn)行截斷。如此,它的shape應(yīng)該是[B, T, E]。

根據(jù)attention網(wǎng)絡(luò)的原理,我們需要根據(jù)行為序列中的每個item與target item的相似度,再根據(jù)相似度計算權(quán)重。最后對這T個item的embedding進(jìn)行加權(quán)求和。求和之后,這T個item根據(jù)計算得到的權(quán)重合并得到一個embedding。論文中說這個集成T個行為序列的embedding就是用戶興趣的表達(dá),我們只需要將它和目標(biāo)item拼接在一起發(fā)送到神經(jīng)網(wǎng)絡(luò)即可,就可以幫助模型更好地決策了。這里用戶興趣的shape應(yīng)該和item是一樣的,也是[B, E]。

簡單總結(jié)一下,我們現(xiàn)在需要一個模塊,它接收兩個輸入,一個是item的embedding,一個是用戶行為序列的embedding。它的輸出應(yīng)該是[B, T],對應(yīng)行為序列中T個item的權(quán)重。剩下的問題就是怎么生成這個結(jié)果。

原理講完了,接下來講講實現(xiàn),我們可以結(jié)合一下下面這兩張論文中的結(jié)構(gòu)圖幫助理解。

圖片

圖片圖片

首先,我們來統(tǒng)一一下輸入的維度,手動將item的embedding這個二維的向量變成三維,即shape變成[B, 1, E]。

這里一種做法是,手動循環(huán)T次,每次從行為序列中拿出一個item embedding,和目標(biāo)item的embedding拼接在一起丟進(jìn)一個神經(jīng)網(wǎng)絡(luò)中得到一個分?jǐn)?shù)。

這種做法非常不推薦,一般在神經(jīng)網(wǎng)絡(luò)當(dāng)中,我們不到萬不得已,不手動循環(huán),因為循環(huán)是線性計算,沒辦法利用GPU的并行計算來加速。

對于當(dāng)前問題來說,我們完全可以使用矩陣運算來代替。通過使用expand/tile函數(shù),將[B, 1, E]的item embedding復(fù)制T份,形狀也變成[B, T, E]。這樣一來,兩個輸入的shape都變成了[B, T, E],我們就可以把它們拼接到一起變成[B, T, 2E]。

然后經(jīng)過一個輸入是2E,輸出是1的神經(jīng)網(wǎng)絡(luò),最終得到[B, T, 1]的結(jié)果,我們把它調(diào)換一下維度,變成[B, 1, T],這個就是我們想要的權(quán)重了。

這里我找來一份Pytorch的代碼,大家代入一下上面的邏輯去看一下,應(yīng)該不難看懂。

class LocalActivationUnit(nn.Module):

def __init__(self, hidden_units=(64, 32), embedding_dim=4, activation='sigmoid', dropout_rate=0, dice_dim=3,
             l2_reg=0, use_bn=False):
    super(LocalActivationUnit, self).__init__()

    self.dnn = DNN(inputs_dim=4 * embedding_dim,
                   hidden_units=hidden_units,
                   activation=activation,
                   l2_reg=l2_reg,
                   dropout_rate=dropout_rate,
                   dice_dim=dice_dim,
                   use_bn=use_bn)

    self.dense = nn.Linear(hidden_units[-1], 1)

def forward(self, query, user_behavior): # query ad : size -> batch_size * 1 * embedding_size # user behavior : size -> batch_size * time_seq_len * embedding_size user_behavior_len = user_behavior.size(1)

queries = query.expand(-1, user_behavior_len, -1)

    attention_input = torch.cat([queries, user_behavior, queries - user_behavior, queries * user_behavior],
                                dim=-1)  # as the source code, subtraction simulates verctors' difference
    attention_output = self.dnn(attention_input)

    attention_score = self.dense(attention_output)  # [B, T, 1]

    return attention_score

其實這一段代碼就是attention網(wǎng)絡(luò)的核心,它生成的是attention中最重要的權(quán)重。權(quán)重有了之后,我們只需要將它和用戶行為序列的embedding相乘。利用矩陣乘法的特性,一個[B, 1. T]的矩陣乘上一個[B, T, E]的矩陣,會得到[B, 1, E]的結(jié)果。這個相乘之后的結(jié)果其實就是我們需要的加權(quán)求和,只不過是通過矩陣乘法來實現(xiàn)了。

我們再看下源碼加深一下理解:

class AttentionSequencePoolingLayer(nn.Module): """The Attentional sequence pooling operation used in DIN & DIEN.

Arguments
      - **att_hidden_units**:list of positive integer, the attention net layer number and units in each layer.

      - **att_activation**: Activation function to use in attention net.

      - **weight_normalization**: bool.Whether normalize the attention score of local activation unit.

      - **supports_masking**:If True,the input need to support masking.

    References
      - [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf)
  """

def __init__(self, att_hidden_units=(80, 40), att_activation='sigmoid', weight_normalization=False,
             return_score=False, supports_masking=False, embedding_dim=4, **kwargs):
    super(AttentionSequencePoolingLayer, self).__init__()
    self.return_score = return_score
    self.weight_normalization = weight_normalization
    self.supports_masking = supports_masking
    self.local_att = LocalActivationUnit(hidden_units=att_hidden_units, embedding_dim=embedding_dim,
                                         activation=att_activation,
                                         dropout_rate=0, use_bn=False)

[docs] def forward(self, query, keys, keys_length, mask=None): """ Input shape - A list of three tensor: [query,keys,keys_length]

- query is a 3D tensor with shape:  ``(batch_size, 1, embedding_size)``

      - keys is a 3D tensor with shape:   ``(batch_size, T, embedding_size)``

      - keys_length is a 2D tensor with shape: ``(batch_size, 1)``

    Output shape
      - 3D tensor with shape: ``(batch_size, 1, embedding_size)``.
    """
    batch_size, max_length, _ = keys.size()

    # Mask
    if self.supports_masking:
        if mask is None:
            raise ValueError("When supports_masking=True,input must support masking")
        keys_masks = mask.unsqueeze(1)
    else:
        keys_masks = torch.arange(max_length, device=keys_length.device, dtype=keys_length.dtype).repeat(batch_size,1)  # [B, T]
        keys_masks = keys_masks < keys_length.view(-1, 1)  # 0, 1 mask
        keys_masks = keys_masks.unsqueeze(1)  # [B, 1, T]

    attention_score = self.local_att(query, keys)  # [B, T, 1]

    outputs = torch.transpose(attention_score, 1, 2)  # [B, 1, T]

    if self.weight_normalization:
        paddings = torch.ones_like(outputs) * (-2 ** 32 + 1)
    else:
        paddings = torch.zeros_like(outputs)

    outputs = torch.where(keys_masks, outputs, paddings)  # [B, 1, T]

    # Scale
    # outputs = outputs / (keys.shape[-1] ** 0.05)

    if self.weight_normalization:
        outputs = F.softmax(outputs, dim=-1)  # [B, 1, T]

    if not self.return_score:
        # Weighted sum
        outputs = torch.matmul(outputs, keys)  # [B, 1, E]

    return outputs

這段代碼當(dāng)中加入了mask以及normalization等邏輯,全部忽略掉的話,真正核心的主干代碼就只有三行:

attention_score = self.local_att(query, keys) # [B, T, 1] outputs = torch.transpose(attention_score, 1, 2) # [B, 1, T] outputs = torch.matmul(outputs, keys) # [B, 1, E] 到這里我們關(guān)于attention網(wǎng)絡(luò)的實現(xiàn)方法就算是講完了,對于DIN這篇論文也就理解差不多了,不過還有一個細(xì)節(jié)值得聊聊。就是關(guān)于attention權(quán)重的問題。

在DIN這篇論文當(dāng)中,我們是使用了一個單獨的LocalActivationUnit來學(xué)習(xí)的兩個embedding拼接之后的權(quán)重,也就是上圖代碼當(dāng)中這個部分:

圖片圖片

我們通過一個單獨的神經(jīng)網(wǎng)絡(luò)來對兩個向量打分給出權(quán)重,這個權(quán)重的運算邏輯并不一定是根據(jù)向量的相似度來計算的。畢竟神經(jīng)網(wǎng)絡(luò)是一個黑盒,我們無從猜測內(nèi)部邏輯。只不過從邏輯上或者經(jīng)驗上來說,我們更傾向于它是根據(jù)向量的相似度來計算的。

而Transformer當(dāng)中也有attention結(jié)構(gòu),它就是正兒八經(jīng)地利用向量之間的相似度來計算的。常理上來說,按照向量相似度來計算權(quán)重,這種做法應(yīng)該更容易理解一些。但實際上學(xué)習(xí)的過程當(dāng)中的感受卻并不一定如此,這也是為什么我先來分享DIN而不是直接上transformer self-attention的原因。

責(zé)任編輯:武曉燕 來源: Coder梁
相關(guān)推薦

2023-04-26 12:46:43

DockerSpringKubernetes

2009-11-09 14:57:37

WCF上傳文件

2011-01-06 10:39:25

.NET程序打包

2020-05-15 08:07:33

JWT登錄單點

2021-03-12 10:01:24

JavaScript 前端表單驗證

2011-04-21 10:32:44

MySQL雙機同步

2011-10-06 14:32:43

2011-02-22 14:36:40

ASP.NETmsdnC#

2015-07-15 13:18:27

附近的人開發(fā)

2016-05-12 11:54:39

2022-07-28 16:06:08

app分身

2021-01-15 13:28:53

RNNPyTorch神經(jīng)網(wǎng)絡(luò)

2021-11-09 09:01:36

Python網(wǎng)絡(luò)爬蟲Python基礎(chǔ)

2021-04-01 09:02:38

Python小說下載網(wǎng)絡(luò)爬蟲

2021-12-11 20:20:19

Python算法線性

2025-01-13 09:07:12

2021-04-02 10:01:00

JavaScript前端Web項目

2023-11-24 16:57:53

2019-06-17 16:47:54

網(wǎng)絡(luò)協(xié)議DNS

2023-03-03 14:07:06

點贊
收藏

51CTO技術(shù)棧公眾號