徹底了解 BiLSTM 和 CRF 算法
CRF 是一種常用的序列標注算法,可用于詞性標注,分詞,命名實體識別等任務。BiLSTM+CRF 是目前比較流行的序列標注算法,其將 BiLSTM 和 CRF 結合在一起,使模型即可以像 CRF 一樣考慮序列前后之間的關聯性,又可以擁有 LSTM 的特征抽取及擬合能力。
1.前言
在之前的文章《CRF 條件隨機場》中,介紹了條件隨機場 CRF,描述了 CRF 和 LSTM 的區(qū)別。我們以分詞為例,每個字對應的標簽可以是 s, b, m, e 四種。
給定一個句子 "什么是地攤經濟",其正確的分詞方式是 "什么 / 是 / 地攤 / 經濟",每個字對應的分詞標簽是 "be / s / be / be"。從下面的圖片可以看出 LSTM 在做序列標注時的問題。

BiLSTM 分詞
BiLSTM 可以預測出每一個字屬于不同標簽的概率,然后使用 Softmax 得到概率最大的標簽,作為該位置的預測值。這樣在預測的時候會忽略了標簽之間的關聯性,如上圖中 BiLSTM 把第一個詞預測成 s,把第二個詞預測成 e。但是實際上在分詞時 s 后面是不會出現 e 的,因此 BiLSTM 沒有考慮標簽間聯系。
因此 BiLSTM+CRF 在 BiLSTM 的輸出層加上一個 CRF,使得模型可以考慮類標之間的相關性,標簽之間的相關性就是 CRF 中的轉移矩陣,表示從一個狀態(tài)轉移到另一個狀態(tài)的概率。假設 CRF 的轉移矩陣如下圖所示。

CRF 狀態(tài)轉移矩陣
則對于前兩個字 "什么",其標簽為 "se" 的概率 =0.8×0×0.7=0,而標簽為 "be" 的概率=0.6×0.5×0.7=0.21。
因此,BiLSTM+CRF 考慮的是整個類標路徑的概率而不僅僅是單個類標的概率,在 BiLSTM 輸出層加上 CRF 后,如下所示。

BiLSTM+CRF 分詞
最終算得所有路徑中,besbebe 的概率最大,因此預測結果為 besbebe。
2.BiLSTM+CRF 模型
CRF 包括兩種特征函數,不熟悉的童鞋可以看下之前的文章。第一種特征函數是狀態(tài)特征函數,也稱為發(fā)射概率,表示字 x 對應標簽 y 的概率。

CRF 狀態(tài)特征函數
在 BiLSTM+CRF 中,這一個特征函數 (發(fā)射概率) 直接使用 LSTM 的輸出計算得到,如第一小節(jié)中的圖所示,LSTM 可以計算出每一時刻位置對應不同標簽的概率。
CRF 的第二個特征函數是狀態(tài)轉移特征函數,表示從一個狀態(tài) y1 轉移到另一個狀態(tài) y2 的概率。

CRF 狀態(tài)轉移特征函數
CRF 的狀態(tài)轉移特征函數可以用一個狀態(tài)轉移矩陣表示,在訓練時需要調整狀態(tài)轉移矩陣的元素值。因此 BiLSTM+CRF 需要在 BiLSTM 的模型內增加一個狀態(tài)轉移矩陣。在代碼中如下。
- class BiLSTM_CRF(nn.Module):
- def __init__(self, vocab_size, tag2idx, embedding_dim, hidden_dim):
- self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
- self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2,
- num_layers=1, bidirectional=True)
- # 對應 CRF 的發(fā)射概率,即每一個位置對應不同類標的概率
- self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)
- # 轉移矩陣,維度等于標簽數量,表示從一個標簽轉移到另一標簽的概率
- self.transitions = nn.Parameter(
- torch.randn(len(tag2idx), len(tag2idx))
給定句子 x,其標簽序列為 y 的概率用下面的公式計算。

p(y|x)
公式中的 score 用下面的式子計算,其中 Emit 對應發(fā)射概率 (即 LSTM 輸出的概率),而 Trans 對應了轉移概率 (即 CRF 轉移矩陣對應的數值)

score 的計算公式
BiLSTM+CRF 采用最大似然法訓練,對應的損失函數如下:

損失函數
其中 score(x,y) 比較容易計算,而 Z(x) 是所有標簽序列 (y) 打分的指數之和,如果序列的長度是 l,標簽個數是 k,則序列的數量為 (k^l)。無法直接計算,因此要用前向算法進行計算。
用目前主流的深度學習框架,對 loss 進行求導和梯度下降,即可優(yōu)化 BiLSTM+CRF。訓練好模型之后可以采用 viterbi 算法 (動態(tài)規(guī)劃) 找出最優(yōu)的路徑。
3.損失函數計算
計算 BiLSTM+CRF 損失函數的難點在于計算 log Z(x),用 F 表示 log Z(x),如下公式所示。

我們將 score 拆分,變成發(fā)射概率 p 和轉移概率 T 的和。為了簡化問題,我們假設序列的長度為3,則可以分別計算寫出長度為 1、2、3 時候的 log Z 值,如下所示。

上式中 p 表示發(fā)射概率,T 表示轉移概率,Start 表示開始,End 表示句子結束。F(3) 即是最終得到的 log Z(x) 值。通過對上式進行變換,可以將 F(3) 轉成遞歸的形式,如下。

可以看到上式中每一步的操作都是一樣的,操作包括 log_sum_exp,例如 F(1):
- 首先需要計算 exp,對于所有 y1,計算 exp(p(y1)+T(Start,y1))
- 求和,對上一步得到的 exp 值進行求和
- 求 log,對求和的結果計算 log
因此可以寫出前向算法計算 log Z 的代碼,如下所示:
- def forward_algorithm(self, probs):
- def forward_algorithm(probs):
- """
- probs: LSTM 輸出的概率值,尺寸為 [seq_len, num_tags],num_tags 是標簽的個數
- """
- # forward_var (可以理解為文章中的 F) 保存前一時刻的值,是一個向量,維度等于 num_tags
- # 初始時只有 Start 為 0,其他的都取一個很小的值 (-10000.)
- forward_var = torch.full((1, num_tags), -10000.0) # [1, num_tags]
- forward_var[0][Start] = 0.0
- for p in probs: # probs [seq_len, num_tags],遍歷序列
- alphas_t = [] # alphas_t 保存下一時刻取不同標簽的累積概率值
- for next_tag in range(num_tags): # 遍歷標簽
- # 下一時刻發(fā)射 next_tag 的概率
- emit_score = p[next_tag].view(1, -1).expand(1, num_tags)
- # 從所有標簽轉移到 next_tag 的概率, transitions 是一個矩陣,長寬都是 num_tags
- trans_score = transitions[next_tag].view(1, -1)
- # next_tag_ver = F(i-1) + p + T
- next_tag_var = forward_var + trans_score + emit_score
- alphas_t.append(log_sum_exp(next_tag_var).view(1))
- forward_var = torch.cat(alphas_t).view(1, -1)
- terminal_var = forward_var + self.transitions[Stop] # 最后轉移到 Stop 表示句子結束
- alpha = log_sum_exp(terminal_var)
- return alpha
4.viterbi 算法解碼
訓練好模型后,預測過程需要用 viterbi 算法對序列進行解碼,感興趣的童鞋可以參看《統(tǒng)計學習方法》。下面介紹一下 viterbi 的公式,首先是一些符號的意義,如下:

然后可以得到 viterbi 算法的遞推公式

最終可以根據 viterbi 計算得到的值,往前查找最合適的序列

最后推薦大家閱讀 pytorch 官網的 BiLSTM+CRF 代碼,通過代碼更容易理解。
5.參考文獻
ADVANCED: MAKING DYNAMIC DECISIONS AND THE BI-LSTM CRF