CNN與RNN對中文文本進行分類--基于TENSORFLOW實現(xiàn)
如今,TensorFlow大版本已經(jīng)升級到了1.3,對很多的網(wǎng)絡層實現(xiàn)了更高層次的封裝和實現(xiàn),甚至還整合了如Keras這樣優(yōu)秀的一些高層次框架,使得其易用性大大提升。相比早起的底層代碼,如今的實現(xiàn)更加簡潔和優(yōu)雅。
本文是基于TensorFlow在中文數(shù)據(jù)集上的簡化實現(xiàn),使用了字符級CNN和RNN對中文文本進行分類,達到了較好的效果。
數(shù)據(jù)集
本文采用了清華NLP組提供的THUCNews新聞文本分類數(shù)據(jù)集的一個子集(原始的數(shù)據(jù)集大約74萬篇文檔,訓練起來需要花較長的時間)。數(shù)據(jù)集請自行到THUCTC:一個高效的中文文本分類工具包下載,請遵循數(shù)據(jù)提供方的開源協(xié)議。
本次訓練使用了其中的10個分類,每個分類6500條數(shù)據(jù)。
類別如下:
體育, 財經(jīng), 房產(chǎn), 家居, 教育, 科技, 時尚, 時政, 游戲, 娛樂
這個子集可以在此下載:鏈接:http://pan.baidu.com/s/1bpq9Eub 密碼:ycyw
數(shù)據(jù)集劃分如下:
- 訓練集: 5000*10
- 驗證集: 500*10
- 測試集: 1000*10
從原數(shù)據(jù)集生成子集的過程請參看helper下的兩個腳本。其中,copy_data.sh用于從每個分類拷貝6500個文件,cnews_group.py用于將多個文件整合到一個文件中。執(zhí)行該文件后,得到三個數(shù)據(jù)文件:
- cnews.train.txt: 訓練集(50000條)
- cnews.val.txt: 驗證集(5000條)
- cnews.test.txt: 測試集(10000條)
預處理
data/cnews_loader.py為數(shù)據(jù)的預處理文件。
- read_file(): 讀取文件數(shù)據(jù);
- build_vocab(): 構(gòu)建詞匯表,使用字符級的表示,這一函數(shù)會將詞匯表存儲下來,避免每一次重復處理;
- read_vocab(): 讀取上一步存儲的詞匯表,轉(zhuǎn)換為{詞:id}表示;
- read_category(): 將分類目錄固定,轉(zhuǎn)換為{類別: id}表示;
- to_words(): 將一條由id表示的數(shù)據(jù)重新轉(zhuǎn)換為文字;
- preocess_file(): 將數(shù)據(jù)集從文字轉(zhuǎn)換為固定長度的id序列表示;
- batch_iter(): 為神經(jīng)網(wǎng)絡的訓練準備經(jīng)過shuffle的批次的數(shù)據(jù)。
經(jīng)過數(shù)據(jù)預處理,數(shù)據(jù)的格式如下:
CNN卷積神經(jīng)網(wǎng)絡
配置項
CNN可配置的參數(shù)如下所示,在cnn_model.py中。
- class TCNNConfig(object):
- """CNN配置參數(shù)"""
- embedding_dim = 64 # 詞向量維度
- seq_length = 600 # 序列長度
- num_classes = 10 # 類別數(shù)
- num_filters = 128 # 卷積核數(shù)目
- kernel_size = 5 # 卷積核尺寸
- vocab_size = 5000 # 詞匯表達小
- hidden_dim = 128 # 全連接層神經(jīng)元
- dropout_keep_prob = 0.5 # dropout保留比例
- learning_rate = 1e-3 # 學習率
- batch_size = 64 # 每批訓練大小
- num_epochs = 10 # 總迭代輪次
- print_per_batch = 100 # 每多少輪輸出一次結(jié)果
- save_per_batch = 10 # 每多少輪存入tensorboard
CNN模型
具體參看cnn_model.py的實現(xiàn)。
大致結(jié)構(gòu)如下:
訓練與驗證
運行 python run_cnn.py train,可以開始訓練。
若之前進行過訓練,請把tensorboard/textcnn刪除,避免TensorBoard多次訓練結(jié)果重疊。
- Configuring CNN model...
- Configuring TensorBoard and Saver...
- Loading training and validation data...
- Time usage: 0:00:14
- Training and evaluating...
- Epoch: 1
- Iter: 0, Train Loss: 2.3, Train Acc: 10.94%, Val Loss: 2.3, Val Acc: 8.92%, Time: 0:00:01 *
- Iter: 100, Train Loss: 0.88, Train Acc: 73.44%, Val Loss: 1.2, Val Acc: 68.46%, Time: 0:00:04 *
- Iter: 200, Train Loss: 0.38, Train Acc: 92.19%, Val Loss: 0.75, Val Acc: 77.32%, Time: 0:00:07 *
- Iter: 300, Train Loss: 0.22, Train Acc: 92.19%, Val Loss: 0.46, Val Acc: 87.08%, Time: 0:00:09 *
- Iter: 400, Train Loss: 0.24, Train Acc: 90.62%, Val Loss: 0.4, Val Acc: 88.62%, Time: 0:00:12 *
- Iter: 500, Train Loss: 0.16, Train Acc: 96.88%, Val Loss: 0.36, Val Acc: 90.38%, Time: 0:00:15 *
- Iter: 600, Train Loss: 0.084, Train Acc: 96.88%, Val Loss: 0.35, Val Acc: 91.36%, Time: 0:00:17 *
- Iter: 700, Train Loss: 0.21, Train Acc: 93.75%, Val Loss: 0.26, Val Acc: 92.58%, Time: 0:00:20 *
- Epoch: 2
- Iter: 800, Train Loss: 0.07, Train Acc: 98.44%, Val Loss: 0.24, Val Acc: 94.12%, Time: 0:00:23 *
- Iter: 900, Train Loss: 0.092, Train Acc: 96.88%, Val Loss: 0.27, Val Acc: 92.86%, Time: 0:00:25
- Iter: 1000, Train Loss: 0.17, Train Acc: 95.31%, Val Loss: 0.28, Val Acc: 92.82%, Time: 0:00:28
- Iter: 1100, Train Loss: 0.2, Train Acc: 93.75%, Val Loss: 0.23, Val Acc: 93.26%, Time: 0:00:31
- Iter: 1200, Train Loss: 0.081, Train Acc: 98.44%, Val Loss: 0.25, Val Acc: 92.96%, Time: 0:00:33
- Iter: 1300, Train Loss: 0.052, Train Acc: 100.00%, Val Loss: 0.24, Val Acc: 93.58%, Time: 0:00:36
- Iter: 1400, Train Loss: 0.1, Train Acc: 95.31%, Val Loss: 0.22, Val Acc: 94.12%, Time: 0:00:39
- Iter: 1500, Train Loss: 0.12, Train Acc: 98.44%, Val Loss: 0.23, Val Acc: 93.58%, Time: 0:00:41
- Epoch: 3
- Iter: 1600, Train Loss: 0.1, Train Acc: 96.88%, Val Loss: 0.26, Val Acc: 92.34%, Time: 0:00:44
- Iter: 1700, Train Loss: 0.018, Train Acc: 100.00%, Val Loss: 0.22, Val Acc: 93.46%, Time: 0:00:47
- Iter: 1800, Train Loss: 0.036, Train Acc: 100.00%, Val Loss: 0.28, Val Acc: 92.72%, Time: 0:00:50
- No optimization for a long time, auto-stopping...
在驗證集上的最佳效果為94.12%,且只經(jīng)過了3輪迭代就已經(jīng)停止。
準確率和誤差如圖所示:
測試
運行 python run_cnn.py test 在測試集上進行測試。
- Configuring CNN model...
- Loading test data...
- Testing...
- Test Loss: 0.14, Test Acc: 96.04%
- Precision, Recall and F1-Score...
- precision recall f1-score support
- 體育 0.99 0.99 0.99 1000
- 財經(jīng) 0.96 0.99 0.97 1000
- 房產(chǎn) 1.00 1.00 1.00 1000
- 家居 0.95 0.91 0.93 1000
- 教育 0.95 0.89 0.92 1000
- 科技 0.94 0.97 0.95 1000
- 時尚 0.95 0.97 0.96 1000
- 時政 0.94 0.94 0.94 1000
- 游戲 0.97 0.96 0.97 1000
- 娛樂 0.95 0.98 0.97 1000
- avg / total 0.96 0.96 0.96 10000
- Confusion Matrix...
- [[991 0 0 0 2 1 0 4 1 1]
- [ 0 992 0 0 2 1 0 5 0 0]
- [ 0 1 996 0 1 1 0 0 0 1]
- [ 0 14 0 912 7 15 9 29 3 11]
- [ 2 9 0 12 892 22 18 21 10 14]
- [ 0 0 0 10 1 968 4 3 12 2]
- [ 1 0 0 9 4 4 971 0 2 9]
- [ 1 16 0 4 18 12 1 941 1 6]
- [ 2 4 1 5 4 5 10 1 962 6]
- [ 1 0 1 6 4 3 5 0 1 979]]
- Time usage: 0:00:05
在測試集上的準確率達到了96.04%,且各類的precision, recall和f1-score都超過了0.9。
從混淆矩陣也可以看出分類效果非常優(yōu)秀。
RNN循環(huán)神經(jīng)網(wǎng)絡
配置項
RNN可配置的參數(shù)如下所示,在rnn_model.py中。
- class TRNNConfig(object):
- """RNN配置參數(shù)"""
- # 模型參數(shù)
- embedding_dim = 64 # 詞向量維度
- seq_length = 600 # 序列長度
- num_classes = 10 # 類別數(shù)
- vocab_size = 5000 # 詞匯表達小
- num_layers= 2 # 隱藏層層數(shù)
- hidden_dim = 128 # 隱藏層神經(jīng)元
- rnn = 'gru' # lstm 或 gru
- dropout_keep_prob = 0.8 # dropout保留比例
- learning_rate = 1e-3 # 學習率
- batch_size = 128 # 每批訓練大小
- num_epochs = 10 # 總迭代輪次
- print_per_batch = 100 # 每多少輪輸出一次結(jié)果
- save_per_batch = 10 # 每多少輪存入tensorboard
RNN模型
具體參看rnn_model.py的實現(xiàn)。
大致結(jié)構(gòu)如下:
訓練與驗證
這部分的代碼與 run_cnn.py極為相似,只需要將模型和部分目錄稍微修改。
運行 python run_rnn.py train,可以開始訓練。
若之前進行過訓練,請把tensorboard/textrnn刪除,避免TensorBoard多次訓練結(jié)果重疊。
- Configuring RNN model...
- Configuring TensorBoard and Saver...
- Loading training and validation data...
- Time usage: 0:00:14
- Training and evaluating...
- Epoch: 1
- Iter: 0, Train Loss: 2.3, Train Acc: 8.59%, Val Loss: 2.3, Val Acc: 11.96%, Time: 0:00:08 *
- Iter: 100, Train Loss: 0.95, Train Acc: 64.06%, Val Loss: 1.3, Val Acc: 53.06%, Time: 0:01:15 *
- Iter: 200, Train Loss: 0.61, Train Acc: 79.69%, Val Loss: 0.94, Val Acc: 69.88%, Time: 0:02:22 *
- Iter: 300, Train Loss: 0.49, Train Acc: 85.16%, Val Loss: 0.63, Val Acc: 81.44%, Time: 0:03:29 *
- Epoch: 2
- Iter: 400, Train Loss: 0.23, Train Acc: 92.97%, Val Loss: 0.6, Val Acc: 82.86%, Time: 0:04:36 *
- Iter: 500, Train Loss: 0.27, Train Acc: 92.97%, Val Loss: 0.47, Val Acc: 86.72%, Time: 0:05:43 *
- Iter: 600, Train Loss: 0.13, Train Acc: 98.44%, Val Loss: 0.43, Val Acc: 87.46%, Time: 0:06:50 *
- Iter: 700, Train Loss: 0.24, Train Acc: 91.41%, Val Loss: 0.46, Val Acc: 87.12%, Time: 0:07:57
- Epoch: 3
- Iter: 800, Train Loss: 0.11, Train Acc: 96.09%, Val Loss: 0.49, Val Acc: 87.02%, Time: 0:09:03
- Iter: 900, Train Loss: 0.15, Train Acc: 96.09%, Val Loss: 0.55, Val Acc: 85.86%, Time: 0:10:10
- Iter: 1000, Train Loss: 0.17, Train Acc: 96.09%, Val Loss: 0.43, Val Acc: 89.44%, Time: 0:11:18 *
- Iter: 1100, Train Loss: 0.25, Train Acc: 93.75%, Val Loss: 0.42, Val Acc: 88.98%, Time: 0:12:25
- Epoch: 4
- Iter: 1200, Train Loss: 0.14, Train Acc: 96.09%, Val Loss: 0.39, Val Acc: 89.82%, Time: 0:13:32 *
- Iter: 1300, Train Loss: 0.2, Train Acc: 96.09%, Val Loss: 0.43, Val Acc: 88.68%, Time: 0:14:38
- Iter: 1400, Train Loss: 0.012, Train Acc: 100.00%, Val Loss: 0.37, Val Acc: 90.58%, Time: 0:15:45 *
- Iter: 1500, Train Loss: 0.15, Train Acc: 96.88%, Val Loss: 0.39, Val Acc: 90.58%, Time: 0:16:52
- Epoch: 5
- Iter: 1600, Train Loss: 0.075, Train Acc: 97.66%, Val Loss: 0.41, Val Acc: 89.90%, Time: 0:17:59
- Iter: 1700, Train Loss: 0.042, Train Acc: 98.44%, Val Loss: 0.41, Val Acc: 90.08%, Time: 0:19:06
- Iter: 1800, Train Loss: 0.08, Train Acc: 97.66%, Val Loss: 0.38, Val Acc: 91.36%, Time: 0:20:13 *
- Iter: 1900, Train Loss: 0.089, Train Acc: 98.44%, Val Loss: 0.39, Val Acc: 90.18%, Time: 0:21:20
- Epoch: 6
- Iter: 2000, Train Loss: 0.092, Train Acc: 96.88%, Val Loss: 0.36, Val Acc: 91.42%, Time: 0:22:27 *
- Iter: 2100, Train Loss: 0.062, Train Acc: 98.44%, Val Loss: 0.39, Val Acc: 90.56%, Time: 0:23:34
- Iter: 2200, Train Loss: 0.053, Train Acc: 98.44%, Val Loss: 0.39, Val Acc: 90.02%, Time: 0:24:41
- Iter: 2300, Train Loss: 0.12, Train Acc: 96.09%, Val Loss: 0.37, Val Acc: 90.84%, Time: 0:25:48
- Epoch: 7
- Iter: 2400, Train Loss: 0.014, Train Acc: 100.00%, Val Loss: 0.41, Val Acc: 90.38%, Time: 0:26:55
- Iter: 2500, Train Loss: 0.14, Train Acc: 96.88%, Val Loss: 0.37, Val Acc: 91.22%, Time: 0:28:01
- Iter: 2600, Train Loss: 0.11, Train Acc: 96.88%, Val Loss: 0.43, Val Acc: 89.76%, Time: 0:29:08
- Iter: 2700, Train Loss: 0.089, Train Acc: 97.66%, Val Loss: 0.37, Val Acc: 91.18%, Time: 0:30:15
- Epoch: 8
- Iter: 2800, Train Loss: 0.0081, Train Acc: 100.00%, Val Loss: 0.44, Val Acc: 90.66%, Time: 0:31:22
- Iter: 2900, Train Loss: 0.017, Train Acc: 100.00%, Val Loss: 0.44, Val Acc: 89.62%, Time: 0:32:29
- Iter: 3000, Train Loss: 0.061, Train Acc: 96.88%, Val Loss: 0.43, Val Acc: 90.04%, Time: 0:33:36
- No optimization for a long time, auto-stopping...
在驗證集上的最佳效果為91.42%,經(jīng)過了8輪迭代停止,速度相比CNN慢很多。
準確率和誤差如圖所示:
測試
運行 python run_rnn.py test 在測試集上進行測試。
- Testing...
- Test Loss: 0.21, Test Acc: 94.22%
- Precision, Recall and F1-Score...
- precision recall f1-score support
- 體育 0.99 0.99 0.99 1000
- 財經(jīng) 0.91 0.99 0.95 1000
- 房產(chǎn) 1.00 1.00 1.00 1000
- 家居 0.97 0.73 0.83 1000
- 教育 0.91 0.92 0.91 1000
- 科技 0.93 0.96 0.94 1000
- 時尚 0.89 0.97 0.93 1000
- 時政 0.93 0.93 0.93 1000
- 游戲 0.95 0.97 0.96 1000
- 娛樂 0.97 0.96 0.97 1000
- avg / total 0.94 0.94 0.94 10000
- Confusion Matrix...
- [[988 0 0 0 4 0 2 0 5 1]
- [ 0 990 1 1 1 1 0 6 0 0]
- [ 0 2 996 1 1 0 0 0 0 0]
- [ 2 71 1 731 51 20 88 28 3 5]
- [ 1 3 0 7 918 23 4 31 9 4]
- [ 1 3 0 3 0 964 3 5 21 0]
- [ 1 0 1 7 1 3 972 0 6 9]
- [ 0 16 0 0 22 26 0 931 2 3]
- [ 2 3 0 0 2 2 12 0 972 7]
- [ 0 3 1 1 7 3 11 5 9 960]]
- Time usage: 0:00:33
在測試集上的準確率達到了94.22%,且各類的precision, recall和f1-score,除了家居這一類別,都超過了0.9。
從混淆矩陣可以看出分類效果非常優(yōu)秀。
對比兩個模型,可見RNN除了在家居分類的表現(xiàn)不是很理想,其他幾個類別較CNN差別不大。
還可以通過進一步的調(diào)節(jié)參數(shù),來達到更好的效果。