如何用PyTorch實(shí)現(xiàn)遞歸神經(jīng)網(wǎng)絡(luò)?
從 Siri 到谷歌翻譯,深度神經(jīng)網(wǎng)絡(luò)已經(jīng)在機(jī)器理解自然語言方面取得了巨大突破。這些模型大多數(shù)將語言視為單調(diào)的單詞或字符序列,并使用一種稱為循環(huán)神經(jīng)網(wǎng)絡(luò)(recurrent neural network/RNN)的模型來處理該序列。但是許多語言學(xué)家認(rèn)為語言最好被理解為具有樹形結(jié)構(gòu)的層次化詞組,一種被稱為遞歸神經(jīng)網(wǎng)絡(luò)(recursive neural network)的深度學(xué)習(xí)模型考慮到了這種結(jié)構(gòu),這方面已經(jīng)有大量的研究。雖然這些模型非常難以實(shí)現(xiàn)且效率很低,但是一個(gè)全新的深度學(xué)習(xí)框架 PyTorch 能使它們和其它復(fù)雜的自然語言處理模型變得更加容易。
雖然遞歸神經(jīng)網(wǎng)絡(luò)很好地顯示了 PyTorch 的靈活性,但它也廣泛支持其它的各種深度學(xué)習(xí)框架,特別的是,它能夠?qū)τ?jì)算機(jī)視覺(computer vision)計(jì)算提供強(qiáng)大的支撐。PyTorch 是 Facebook AI Research 和其它幾個(gè)實(shí)驗(yàn)室的開發(fā)人員的成果,該框架結(jié)合了 Torch7 高效靈活的 GPU 加速后端庫與直觀的 Python 前端,它的特點(diǎn)是快速成形、代碼可讀和支持最廣泛的深度學(xué)習(xí)模型。
開始 SPINN
鏈接中的文章(https://github.com/jekbradbury/examples/tree/spinn/snli)詳細(xì)介紹了一個(gè)遞歸神經(jīng)網(wǎng)絡(luò)的 PyTorch 實(shí)現(xiàn),它具有一個(gè)循環(huán)跟蹤器(recurrent tracker)和 TreeLSTM 節(jié)點(diǎn),也稱為 SPINN——SPINN 是深度學(xué)習(xí)模型用于自然語言處理的一個(gè)例子,它很難通過許多流行的框架構(gòu)建。這里的模型實(shí)現(xiàn)部分運(yùn)用了批處理(batch),所以它可以利用 GPU 加速,使得運(yùn)行速度明顯快于不使用批處理的版本。
SPINN 的意思是堆棧增強(qiáng)的解析器-解釋器神經(jīng)網(wǎng)絡(luò)(Stack-augmented Parser-Interpreter Neural Network),由 Bowman 等人于 2016 年作為解決自然語言推理任務(wù)的一種方法引入,該論文中使用了斯坦福大學(xué)的 SNLI 數(shù)據(jù)集。
該任務(wù)是將語句對(duì)分為三類:假設(shè)語句 1 是一幅看不見的圖像的準(zhǔn)確標(biāo)題,那么語句 2(a)肯定(b)可能還是(c)絕對(duì)不是一個(gè)準(zhǔn)確的標(biāo)題?(這些類分別被稱為蘊(yùn)含(entailment)、中立(neutral)和矛盾(contradiction))。例如,假設(shè)一句話是「兩只狗正跑過一片場(chǎng)地」,蘊(yùn)含可能會(huì)使這個(gè)語句對(duì)變成「戶外的動(dòng)物」,中立可能會(huì)使這個(gè)語句對(duì)變成「一些小狗正在跑并試圖抓住一根棍子」,矛盾能會(huì)使這個(gè)語句對(duì)變成「寵物正坐在沙發(fā)上」。
特別地,研究 SPINN 的初始目標(biāo)是在確定語句的關(guān)系之前將每個(gè)句子編碼(encoding)成固定長(zhǎng)度的向量表示(也有其它方式,例如注意模型(attention model)中將每個(gè)句子的每個(gè)部分用一種柔焦(soft focus)的方法相互比較)。
數(shù)據(jù)集是用句法解析樹(syntactic parse tree)方法由機(jī)器生成的,句法解析樹將每個(gè)句子中的單詞分組成具有獨(dú)立意義的短語和子句,每個(gè)短語由兩個(gè)詞或子短語組成。許多語言學(xué)家認(rèn)為,人類通過如上面所說的樹的分層方式來組合詞意并理解語言,所以用相同的方式嘗試構(gòu)建一個(gè)神經(jīng)網(wǎng)絡(luò)是值得的。下面的例子是數(shù)據(jù)集中的一個(gè)句子,其解析樹由嵌套括號(hào)表示:
- ( ( The church ) ( ( has ( cracks ( in ( the ceiling ) ) ) ) . ) )
這個(gè)句子進(jìn)行編碼的一種方式是使用含有解析樹的神經(jīng)網(wǎng)絡(luò)構(gòu)建一個(gè)神經(jīng)網(wǎng)絡(luò)層 Reduce,這個(gè)神經(jīng)網(wǎng)絡(luò)層能夠組合詞語對(duì)(用詞嵌入(word embedding)表示,如 GloVe)、 和/或短語,然后遞歸地應(yīng)用此層(函數(shù)),將最后一個(gè) Reduce 產(chǎn)生的結(jié)果作為句子的編碼:
- X = Reduce(“the”, “ceiling”)
- Y = Reduce(“in”, X)
- ... etc.
但是,如果我希望網(wǎng)絡(luò)以更類似人類的方式工作,從左到右閱讀并保留句子的語境,同時(shí)仍然使用解析樹組合短語?或者,如果我想訓(xùn)練一個(gè)網(wǎng)絡(luò)來構(gòu)建自己的解析樹,讓解析樹根據(jù)它看到的單詞讀取句子?這是一個(gè)同樣的但方式略有不同的解析樹的寫法:
- The church ) has cracks in the ceiling ) ) ) ) . ) )
或者用第 3 種方式表示,如下:
- WORDS: The church has cracks in the ceiling .
- PARSES: S S R S S S S S R R R R S R R
我所做的只是刪除開括號(hào),然后用「S」標(biāo)記「shift」,并用「R」替換閉括號(hào)用于「reduce」。但是現(xiàn)在可以從左到右讀取信息作為一組指令來操作一個(gè)堆棧(stack)和一個(gè)類似堆棧的緩沖區(qū)(buffer),能得到與上述遞歸方法完全相同的結(jié)果:
- 將單詞放入緩沖區(qū)。
- 從緩沖區(qū)的前部彈出「The」,將其推送(push)到堆棧上層,緊接著是「church」。
- 彈出前 2 個(gè)堆棧值,應(yīng)用于 Reduce,然后將結(jié)果推送回堆棧。
- 從緩沖區(qū)彈出「has」,然后推送到堆棧,然后是「cracks」,然后是「in」,然后是「the」,然后是「ceiling」。
- 重復(fù)四次:彈出 2 個(gè)堆棧值,應(yīng)用于 Reduce,然后推送結(jié)果。
- 從緩沖區(qū)彈出「.」,然后推送到堆棧上層。
- 重復(fù)兩次:彈出 2 個(gè)堆棧值,應(yīng)用于 Reduce,然后推送結(jié)果。
- 彈出剩余的堆棧值,并將其作為句子編碼返回。
我還想保留句子的語境,以便在對(duì)句子的后半部分應(yīng)用 Reduce 層時(shí)考慮系統(tǒng)已經(jīng)讀取的句子部分的信息。所以我將用一個(gè)三參數(shù)函數(shù)替換雙參數(shù)的 Reduce 函數(shù),該函數(shù)的輸入值為一個(gè)左子句、一個(gè)右子句和當(dāng)前句的上下文狀態(tài)。該狀態(tài)由神經(jīng)網(wǎng)絡(luò)的第二層(稱為循環(huán)跟蹤器(Tracker)的單元)創(chuàng)建。Tracker 在給定當(dāng)前句子上下文狀態(tài)、緩沖區(qū)中的頂部條目 b 和堆棧中前兩個(gè)條目 s1\s2 時(shí),在堆棧操作的每個(gè)步驟(即,讀取每個(gè)單詞或閉括號(hào))后生成一個(gè)新狀態(tài):
- context[t+1] = Tracker(context[t], b, s1, s2)
容易設(shè)想用你最喜歡的編程語言來編寫代碼做這些事情。對(duì)于要處理的每個(gè)句子,它將從緩沖區(qū)加載下一個(gè)單詞,運(yùn)行跟蹤器,檢查是否將單詞推送入堆?;驁?zhí)行 Reduce 函數(shù),執(zhí)行該操作;然后重復(fù),直到對(duì)整個(gè)句子完成處理。通過對(duì)單個(gè)句子的應(yīng)用,該過程構(gòu)成了一個(gè)大而復(fù)雜的深度神經(jīng)網(wǎng)絡(luò),通過堆棧操作的方式一遍又一遍地應(yīng)用它的兩個(gè)可訓(xùn)練層。但是,如果你熟悉 TensorFlow 或 Theano 等傳統(tǒng)的深度學(xué)習(xí)框架,就知道它們很難實(shí)現(xiàn)這樣的動(dòng)態(tài)過程。你值得花點(diǎn)時(shí)間回顧一下,探索為什么 PyTorch 能有所不同。
圖 1:一個(gè)函數(shù)的圖結(jié)構(gòu)表示
深度神經(jīng)網(wǎng)絡(luò)本質(zhì)上是有大量參數(shù)的復(fù)雜函數(shù)。深度學(xué)習(xí)的目的是通過計(jì)算以損失函數(shù)(loss)度量的偏導(dǎo)數(shù)(梯度)來優(yōu)化這些參數(shù)。如果函數(shù)表示為計(jì)算圖結(jié)構(gòu)(圖 1),則向后遍歷該圖可實(shí)現(xiàn)這些梯度的計(jì)算,而無需冗余工作。每個(gè)現(xiàn)代深度學(xué)習(xí)框架都是基于此反向傳播(backpropagation)的概念,因此每個(gè)框架都需要一個(gè)表示計(jì)算圖的方式。
在許多流行的框架中,包括 TensorFlow、Theano 和 Keras 以及 Torch7 的 nngraph 庫,計(jì)算圖是一個(gè)提前構(gòu)建的靜態(tài)對(duì)象。該圖是用像數(shù)學(xué)表達(dá)式的代碼定義的,但其變量實(shí)際上是尚未保存任何數(shù)值的占位符(placeholder)。圖中的占位符變量被編譯進(jìn)函數(shù),然后可以在訓(xùn)練集的批處理上重復(fù)運(yùn)行該函數(shù)來產(chǎn)生輸出和梯度值。
這種靜態(tài)計(jì)算圖(static computation graph)方法對(duì)于固定結(jié)構(gòu)的卷積神經(jīng)網(wǎng)絡(luò)效果很好。但是在許多其它應(yīng)用中,有用的做法是令神經(jīng)網(wǎng)絡(luò)的圖結(jié)構(gòu)根據(jù)數(shù)據(jù)而有所不同。在自然語言處理中,研究人員通常希望通過每個(gè)時(shí)間步驟中輸入的單詞來展開(確定)循環(huán)神經(jīng)網(wǎng)絡(luò)。上述 SPINN 模型中的堆棧操作很大程度上依賴于控制流程(如 for 和 if 語句)來定義特定句子的計(jì)算圖結(jié)構(gòu)。在更復(fù)雜的情況下,你可能需要構(gòu)建結(jié)構(gòu)依賴于模型自身的子網(wǎng)絡(luò)輸出的模型。
這些想法中的一些(雖然不是全部)可以被生搬硬套到靜態(tài)圖系統(tǒng)中,但幾乎總是以降低透明度和增加代碼的困惑度為代價(jià)。該框架必須在其計(jì)算圖中添加特殊的節(jié)點(diǎn),這些節(jié)點(diǎn)代表如循環(huán)和條件的編程原語(programming primitive),而用戶必須學(xué)習(xí)和使用這些節(jié)點(diǎn),而不僅僅是編程代碼語言中的 for 和 if 語句。這是因?yàn)槌绦騿T使用的任何控制流程語句將僅運(yùn)行一次,當(dāng)構(gòu)建圖時(shí)程序員需要硬編碼(hard coding)單個(gè)計(jì)算路徑。
例如,通過詞向量(從初始狀態(tài) h0 開始)運(yùn)行循環(huán)神經(jīng)網(wǎng)絡(luò)單元(rnn_unit)需要 TensorFlow 中的特殊控制流節(jié)點(diǎn) tf.while_loop。需要一個(gè)額外的特殊節(jié)點(diǎn)來獲取運(yùn)行時(shí)的詞長(zhǎng)度,因?yàn)樵谶\(yùn)行代碼時(shí)它只是一個(gè)占位符。
- # TensorFlow
- # (this code runs once, during model initialization)
- # “words” is not a real list (it’s a placeholder variable) so
- # I can’t use “len”
- cond = lambda i, h: i < tf.shape(words)[0]
- cell = lambda i, h: rnn_unit(words[i], h)
- i = 0
- _, h = tf.while_loop(cond, cell, (i, h0))
基于動(dòng)態(tài)計(jì)算圖(dynamic computation graph)的方法與之前的方法有根本性不同,它有幾十年的學(xué)術(shù)研究歷史,其中包括了哈佛的 Kayak、自動(dòng)微分庫(autograd)以及以研究為中心的框架 Chainer和 DyNet。在這樣的框架(也稱為運(yùn)行時(shí)定義(define-by-run))中,計(jì)算圖在運(yùn)行時(shí)被建立和重建,使用相同的代碼為前向通過(forward pass)執(zhí)行計(jì)算,同時(shí)也為反向傳播(backpropagation)建立所需的數(shù)據(jù)結(jié)構(gòu)。這種方法能產(chǎn)生更直接的代碼,因?yàn)榭刂屏鞒痰木帉懣梢允褂脴?biāo)準(zhǔn)的 for 和 if。它還使調(diào)試更容易,因?yàn)檫\(yùn)行時(shí)斷點(diǎn)(run-time breakpoint)或堆棧跟蹤(stack trace)將追蹤到實(shí)際編寫的代碼,而不是執(zhí)行引擎中的編譯函數(shù)。可以在動(dòng)態(tài)框架中使用簡(jiǎn)單的 Python 的 for 循環(huán)來實(shí)現(xiàn)有相同變量長(zhǎng)度的循環(huán)神經(jīng)網(wǎng)絡(luò)。
- # PyTorch (also works in Chainer)
- # (this code runs on every forward pass of the model)
- # “words” is a Python list with actual values in it
- h = h0
- for word in words: h = rnn_unit(word, h)
PyTorch 是第一個(gè) define-by-run 的深度學(xué)習(xí)框架,它與靜態(tài)圖框架(如 TensorFlow)的功能和性能相匹配,使其能很好地適合從標(biāo)準(zhǔn)卷積神經(jīng)網(wǎng)絡(luò)(convolutional network)到最瘋狂的強(qiáng)化學(xué)習(xí)(reinforcement learning)等思想。所以讓我們來看看 SPINN 的實(shí)現(xiàn)。
代碼
在開始構(gòu)建網(wǎng)絡(luò)之前,我需要設(shè)置一個(gè)數(shù)據(jù)加載器(data loader)。通過深度學(xué)習(xí),模型可以通過數(shù)據(jù)樣本的批處理進(jìn)行操作,通過并行化(parallelism)加快訓(xùn)練,并在每一步都有一個(gè)更平滑的梯度變化。我想在這里可以做到這一點(diǎn)(稍后我將解釋上述堆棧操作過程如何進(jìn)行批處理)。以下 Python 代碼使用內(nèi)置于 PyTorch 的文本庫的系統(tǒng)來加載數(shù)據(jù),它可以通過連接相似長(zhǎng)度的數(shù)據(jù)樣本自動(dòng)生成批處理。運(yùn)行此代碼之后,train_iter、dev_iter 和 test_itercontain 循環(huán)遍歷訓(xùn)練集、驗(yàn)證集和測(cè)試集分塊 SNLI 的批處理。
- from torchtext import data, datasets TEXT = datasets.snli.ParsedTextField(lower=True)
- TRANSITIONS = datasets.snli.ShiftReduceField()
- LABELS = data.Field(sequential=False)train, dev, test = datasets.SNLI.splits( TEXT, TRANSITIONS, LABELS, wv_type='glove.42B')TEXT.build_vocab(train, dev, test)
- train_iter, dev_iter, test_iter = data.BucketIterator.splits( (train, dev, test), batch_size=64)
你可以在 train.py中找到設(shè)置訓(xùn)練循環(huán)和準(zhǔn)確性(accuracy)測(cè)量的其余代碼。讓我們繼續(xù)。如上所述,SPINN 編碼器包含參數(shù)化的 Reduce 層和可選的循環(huán)跟蹤器來跟蹤句子上下文,以便在每次網(wǎng)絡(luò)讀取單詞或應(yīng)用 Reduce 時(shí)更新隱藏狀態(tài);以下代碼代表的是,創(chuàng)建一個(gè) SPINN 只是意味著創(chuàng)建這兩個(gè)子模塊(我們將很快看到它們的代碼),并將它們放在一個(gè)容器中以供稍后使用。
- import torchfrom torch import nn
- # subclass the Module class from PyTorch’s neural network package
- class SPINN(nn.Module):
- def __init__(self, config):
- super(SPINN, self).__init__()
- self.config = config
- self.reduce = Reduce(config.d_hidden, config.d_tracker)
- if config.d_tracker is not None:
- self.tracker = Tracker(config.d_hidden, config.d_tracker)
當(dāng)創(chuàng)建模型時(shí),SPINN.__init__ 被調(diào)用了一次;它分配和初始化參數(shù),但不執(zhí)行任何神經(jīng)網(wǎng)絡(luò)操作或構(gòu)建任何類型的計(jì)算圖。在每個(gè)新的批處理數(shù)據(jù)上運(yùn)行的代碼由 SPINN.forward 方法定義,它是用戶實(shí)現(xiàn)的方法中用于定義模型向前過程的標(biāo)準(zhǔn) PyTorch 名稱。上面描述的是堆棧操作算法的一個(gè)有效實(shí)現(xiàn),即在一般 Python 中,在一批緩沖區(qū)和堆棧上運(yùn)行,每一個(gè)例子都對(duì)應(yīng)一個(gè)緩沖區(qū)和堆棧。我使用轉(zhuǎn)移矩陣(transition)包含的「shift」和「reduce」操作集合進(jìn)行迭代,運(yùn)行 Tracker(如果存在),并遍歷批處理中的每個(gè)樣本來應(yīng)用「shift」操作(如果請(qǐng)求),或?qū)⑵涮砑拥叫枰竢educe」操作的樣本列表中。然后在該列表中的所有樣本上運(yùn)行 Reduce 層,并將結(jié)果推送回到它們各自的堆棧。
- def forward(self, buffers, transitions):
- # The input comes in as a single tensor of word embeddings;
- # I need it to be a list of stacks, one for each example in
- # the batch, that we can pop from independently. The words in
- # each example have already been reversed, so that they can
- # be read from left to right by popping from the end of each
- # list; they have also been prefixed with a null value.
- buffers = [list(torch.split(b.squeeze(1), 1, 0))
- for b in torch.split(buffers, 1, 1)]
- # we also need two null values at the bottom of each stack,
- # so we can copy from the nulls in the input; these nulls
- # are all needed so that the tracker can run even if the
- # buffer or stack is empty
- stacks = [[buf[0], buf[0]] for buf in buffers]
- if hasattr(self, 'tracker'):
- self.tracker.reset_state()
- for trans_batch in transitions:
- if hasattr(self, 'tracker'):
- # I described the Tracker earlier as taking 4
- # arguments (context_t, b, s1, s2), but here I
- # provide the stack contents as a single argument
- # while storing the context inside the Tracker
- # object itself.
- tracker_states, _ = self.tracker(buffers, stacks)
- else:
- tracker_states = itertools.repeat(None)
- lefts, rights, trackings = [], [], []
- batch = zip(trans_batch, buffers, stacks, tracker_states)
- for transition, buf, stack, tracking in batch:
- if transition == SHIFT:
- stack.append(buf.pop())
- elif transition == REDUCE:
- rights.append(stack.pop())
- lefts.append(stack.pop())
- trackings.append(tracking)
- if rights:
- reduced = iter(self.reduce(lefts, rights, trackings))
- for transition, stack in zip(trans_batch, stacks):
- if transition == REDUCE:
- stack.append(next(reduced))
- return [stack.pop() for stack in stacks]
在調(diào)用 self.tracker 或 self.reduce 時(shí)分別運(yùn)行 Tracker 或 Reduce 子模塊的向前方法,該方法需要在樣本列表上應(yīng)用前向操作。在主函數(shù)的向前方法中,在不同的樣本上進(jìn)行獨(dú)立的操作是有意義的,即為批處理中每個(gè)樣本提供分離的緩沖區(qū)和堆棧,因?yàn)樗惺芤嬗谂幚韴?zhí)行的重度使用數(shù)學(xué)和需要 GPU 加速的操作都在 Tracker 和 Reduce 中進(jìn)行。為了更干凈地編寫這些函數(shù),我將使用一些 helper(稍后將定義)將這些樣本列表轉(zhuǎn)化成批處理張量(tensor),反之亦然。
我希望 Reduce 模塊自動(dòng)批處理其參數(shù)以加速計(jì)算,然后解批處理(unbatch)它們,以便可以單獨(dú)推送和彈出。用于將每對(duì)左、右子短語表達(dá)組合成父短語(parent phrase)的實(shí)際組合函數(shù)是 TreeLSTM,它是普通循環(huán)神經(jīng)網(wǎng)絡(luò)單元 LSTM 的變型。該組合函數(shù)要求每個(gè)子短語的狀態(tài)實(shí)際上由兩個(gè)張量組成,一個(gè)隱藏狀態(tài) h 和一個(gè)存儲(chǔ)單元(memory cell)狀態(tài) c,而函數(shù)是使用在子短語的隱藏狀態(tài)操作的兩個(gè)線性層(nn.Linear)和將線性層的結(jié)果與子短語的存儲(chǔ)單元狀態(tài)相結(jié)合的非線性組合函數(shù) tree_lstm。在 SPINN 中,這種方式通過添加在 Tracker 的隱藏狀態(tài)下運(yùn)行的第 3 個(gè)線性層進(jìn)行擴(kuò)展。
圖 2:TreeLSTM 組合函數(shù)增加了第 3 個(gè)輸入(x,在這種情況下為 Tracker 狀態(tài))。在下面所示的 PyTorch 實(shí)現(xiàn)中,5 組的三種線性變換(由藍(lán)色、黑色和紅色箭頭的三元組表示)組合為三個(gè) nn.Linear 模塊,而 tree_lstm 函數(shù)執(zhí)行位于框內(nèi)的所有計(jì)算。圖來自 Chen et al. (2016)。
- def tree_lstm(c1, c2, lstm_in):
- # Takes the memory cell states (c1, c2) of the two children, as
- # well as the sum of linear transformations of the children’s
- # hidden states (lstm_in)
- # That sum of transformed hidden states is broken up into a
- # candidate output a and four gates (i, f1, f2, and o).
- a, i, f1, f2, o = lstm_in.chunk(5, 1)
- c = a.tanh() * i.sigmoid() + f1.sigmoid() * c1 + f2.sigmoid() * c2
- h = o.sigmoid() * c.tanh()
- return h, cclass Reduce(nn.Module):
- def __init__(self, size, tracker_size=None):
- super(Reduce, self).__init__()
- self.left = nn.Linear(size, 5 * size)
- self.right = nn.Linear(size, 5 * size, bias=False)
- if tracker_size is not None:
- self.track = nn.Linear(tracker_size, 5 * size, bias=False)
- def forward(self, left_in, right_in, tracking=None):
- left, right = batch(left_in), batch(right_in)
- tracking = batch(tracking)
- lstm_in = self.left(left[0])
- lstm_in += self.right(right[0])
- if hasattr(self, 'track'):
- lstm_in += self.track(tracking[0])
- return unbatch(tree_lstm(left[1], right[1], lstm_in))
由于 Reduce 層和類似實(shí)現(xiàn)方法的 Tracker 都用 LSTM 進(jìn)行工作,所以批處理和解批處理幫助函數(shù)在隱藏狀態(tài)和存儲(chǔ)狀態(tài)對(duì)(h,c)上運(yùn)行。
- def batch(states):
- if states is None:
- return None
- states = tuple(states)
- if states[0] is None:
- return None
- # states is a list of B tensors of dimension (1, 2H)
- # this returns two tensors of dimension (B, H)
- return torch.cat(states, 0).chunk(2, 1)def unbatch(state):
- if state is None:
- return itertools.repeat(None)
- # state is a pair of tensors of dimension (B, H)
- # this returns a list of B tensors of dimension (1, 2H)
- return torch.split(torch.cat(state, 1), 1, 0)
這就是所有。其余的必要代碼(包括 Tracker),在 spinn.py中,同時(shí)分類器層可以從兩個(gè)句子編碼中計(jì)算 SNLI 類別,并在給出最終損失(loss)變量的情況下將此結(jié)果與目標(biāo)進(jìn)行比較,代碼在 model.py中。SPINN 及其子模塊的向前代碼產(chǎn)生了非常復(fù)雜的計(jì)算圖(圖 3),最終計(jì)算出損失函數(shù),其細(xì)節(jié)在數(shù)據(jù)集中的每個(gè)批處理中都完全不同,但是每次只需很少的開銷(overhead)即可自動(dòng)反向傳播,通過調(diào)用 loss.backward(),一個(gè)內(nèi)置于 PyTorch 中的函數(shù),它可以從圖中的任何一點(diǎn)執(zhí)行反向傳播。
完整代碼中的模型和超參數(shù)可以與原始 SPINN 論文中報(bào)告的性能相匹配,但是充分利用了批處理加工和 PyTorch 的效率后,在 GPU 上訓(xùn)練的速度要快幾倍。雖然原始實(shí)現(xiàn)需要21 分鐘來編譯計(jì)算圖(意味著實(shí)施過程中的調(diào)試周期至少要那么長(zhǎng)),然后訓(xùn)練大約 5 天時(shí)間,這里描述的版本沒有編譯步驟,它的訓(xùn)練在 Tesla K40 的 GPU 上需要約 13 個(gè)小時(shí),或者在 Quadro GP100 上約 9 小時(shí)。
圖 3:具有批大小為 2 的 SPINN 計(jì)算圖的一小部分,它運(yùn)行的是本文中提供的 Chainer 代碼版本。
調(diào)用所有的強(qiáng)化學(xué)習(xí)
上述沒有跟蹤器(Tracker)的模型版本實(shí)際上非常適合 TensorFlow 的新 tf.fold 域特定語言,它針對(duì)動(dòng)態(tài)圖形的特殊情況,但是有跟蹤器的版本將難以實(shí)現(xiàn)。這是因?yàn)樘砑痈櫰饕馕吨鴱倪f歸(recursive)方法切換到基于堆棧的方法。這(如上面的代碼)是最直接地使用依賴于輸入值的條件分支(conditional branch)來實(shí)現(xiàn)的。但是 Fold 缺少內(nèi)置的條件分支操作,所以使用它構(gòu)建的模型中的圖形結(jié)構(gòu)只能取決于輸入的結(jié)構(gòu)而不是其數(shù)值。此外,構(gòu)建一個(gè)其跟蹤器在讀取輸入句子時(shí)就決定如何解析輸入句子的 SPINN 的版本是完全沒有可能的,因?yàn)橐坏┘虞d了一個(gè)輸入樣本 Fold 中的圖結(jié)構(gòu)必須完全固定(圖結(jié)構(gòu)依賴于輸入樣本的結(jié)構(gòu))。
DeepMind 和谷歌大腦的研究人員研究了一個(gè)這樣的模型,他們應(yīng)用強(qiáng)化學(xué)習(xí)來訓(xùn)練一個(gè) SPINN 的跟蹤器解析輸入句子,而不使用任何外部解析數(shù)據(jù)。本質(zhì)上,這樣一個(gè)模型從隨機(jī)猜測(cè)開始,當(dāng)它的解析在整體分類任務(wù)上恰好產(chǎn)生良好的準(zhǔn)確性時(shí),它產(chǎn)生一個(gè)自我獎(jiǎng)勵(lì)(reward)并通過獎(jiǎng)勵(lì)來學(xué)習(xí)。研究人員寫道,他們「使用的批處理大小為 1,因?yàn)樵诿看蔚杏?jì)算圖需要根據(jù)每個(gè)來自策略網(wǎng)絡(luò)(policy network)的樣本重新構(gòu)建 [Tracker]」——但 PyTorch 使得在像這樣一個(gè)復(fù)雜的、結(jié)構(gòu)隨機(jī)變化的網(wǎng)絡(luò)上進(jìn)行批處理訓(xùn)練成為可能。
PyTorch 也是第一個(gè)以隨機(jī)計(jì)算圖(stochastic computation graph)形式建立強(qiáng)化學(xué)習(xí)(RL)庫的框架,使得策略梯度(policy gradient)強(qiáng)化學(xué)習(xí)如反向傳播一樣易于使用。要將其添加到上述模型中,你只需重新編寫主 SPINN 的 for 循環(huán)的前幾行,如下所示,使得 Tracker 能夠定義進(jìn)行每種解析轉(zhuǎn)移矩陣的概率。
- !# nn.functional contains neural network operations without parametersfrom torch.nn import functional as F
- transitions = []for i in range(len(buffers[0]) * 2 - 3):
- # we know how many steps
- # obtain raw scores for each kind of parser transition
- tracker_states, transition_scores = self.tracker(buffers, stacks)
- # use a softmax function to normalize scores into probabilities,
- # then sample from the distribution these probabilities define
- transition_batch = F.softmax(transition_scores).multinomial()
- transitions.append(transition_batch
然后,隨著批處理一直運(yùn)行,模型會(huì)得出它預(yù)測(cè)的類別的準(zhǔn)確程度,我可以通過這些隨機(jī)計(jì)算圖的節(jié)點(diǎn)發(fā)出獎(jiǎng)勵(lì)信號(hào),另外在圖的其余部分以傳統(tǒng)方式進(jìn)行反向傳播:
- # losses should contain a loss per example, while mean and std
- # represent averages across many batches
- rewards = (-losses - mean) / std
- for transition in transitions:
- transition.reinforce(rewards)
- # connect the stochastic nodes to the final loss variable
- # so that backpropagation can find them, multiplying by zero
- # because this trick shouldn’t change the loss value
- loss = losses.mean() + 0 * sum(transitions).sum()
- # perform backpropagation through deterministic nodes and
- # policy gradient RL for stochastic nodesloss.backward()
谷歌研究人員報(bào)告了包含強(qiáng)化學(xué)習(xí)的 SPINN 的結(jié)果,比原始的在 SNLI 數(shù)據(jù)集的 SPINN 上獲得的結(jié)果要好一些——盡管強(qiáng)化學(xué)習(xí)版本不使用預(yù)先計(jì)算的解析樹信息。自然語言處理的深度強(qiáng)化學(xué)習(xí)領(lǐng)域是全新的,該領(lǐng)域的研究課題非常廣泛;通過將強(qiáng)化學(xué)習(xí)構(gòu)建到框架中,PyTorch 大大降低了進(jìn)入門檻。
從今天起開始使用 PyTorch
遵循 pytorch.org 網(wǎng)站上的說明安裝,選擇安裝平臺(tái)(即將推出 Windows 版本的支持)。PyTorch 支持 Python 2 和 3 以及使用 CUDA 7.5 或 8.0 和 CUDNN 5.1 或 6.0 的 CPU 或 NVIDIA GPU 上的計(jì)算。它有針對(duì) conda 和 pip 的 Linux 二進(jìn)制文件甚至含有 CUDA 本身,因此你不需要自己設(shè)置它。
官方教程包含 60 分鐘的介紹和深度 Q-學(xué)習(xí)(Deep Q-Learning,一種現(xiàn)代強(qiáng)化學(xué)習(xí)模型)的演練。斯坦福的 Justin Johnson 教授有一個(gè)非常全面的教程,官方示例還包括——深度卷積生成對(duì)抗網(wǎng)絡(luò)(DCGAN)、ImageNet 模型和神經(jīng)機(jī)器翻譯模型(neural machine translation)。新加坡國(guó)立大學(xué)的 Richie Ng 制作了最新的其它 PyTorch 實(shí)現(xiàn)、示例和教程的列表。PyTorch 開發(fā)人員和用戶社區(qū)在討論論壇上的第一時(shí)間回答問題,但你應(yīng)該首先檢查 API 文檔。
盡管 PyTorch 僅使用了較短時(shí)間,但三篇研究論文已經(jīng)使用了它,幾個(gè)學(xué)術(shù)實(shí)驗(yàn)室和業(yè)界實(shí)驗(yàn)室也采用了 PyTorch?;厮葸^去,在當(dāng)時(shí)動(dòng)態(tài)計(jì)算圖比現(xiàn)在模糊時(shí),我與同事在 Salesforce Research(https://metamind.io/research.html)也曾經(jīng)考慮過 Chainer 作為我們的秘密配方;現(xiàn)在,在大公司的支持下,我們很高興 PyTorch 將這一水平的能力和靈活性帶入了主流。