自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

CNN與RNN對中文文本進行分類--基于TENSORFLOW實現(xiàn)

人工智能 深度學習
本文是基于TensorFlow在中文數(shù)據(jù)集上的簡化實現(xiàn),使用了字符級CNN和RNN對中文文本進行分類,達到了較好的效果。

[[211015]]

如今,TensorFlow大版本已經(jīng)升級到了1.3,對很多的網(wǎng)絡層實現(xiàn)了更高層次的封裝和實現(xiàn),甚至還整合了如Keras這樣優(yōu)秀的一些高層次框架,使得其易用性大大提升。相比早起的底層代碼,如今的實現(xiàn)更加簡潔和優(yōu)雅。

本文是基于TensorFlow在中文數(shù)據(jù)集上的簡化實現(xiàn),使用了字符級CNN和RNN對中文文本進行分類,達到了較好的效果。

數(shù)據(jù)集

本文采用了清華NLP組提供的THUCNews新聞文本分類數(shù)據(jù)集的一個子集(原始的數(shù)據(jù)集大約74萬篇文檔,訓練起來需要花較長的時間)。數(shù)據(jù)集請自行到THUCTC:一個高效的中文文本分類工具包下載,請遵循數(shù)據(jù)提供方的開源協(xié)議。

本次訓練使用了其中的10個分類,每個分類6500條數(shù)據(jù)。

類別如下:

體育, 財經(jīng), 房產(chǎn), 家居, 教育, 科技, 時尚, 時政, 游戲, 娛樂

這個子集可以在此下載:鏈接:http://pan.baidu.com/s/1bpq9Eub 密碼:ycyw

數(shù)據(jù)集劃分如下:

  • 訓練集: 5000*10
  • 驗證集: 500*10
  • 測試集: 1000*10

從原數(shù)據(jù)集生成子集的過程請參看helper下的兩個腳本。其中,copy_data.sh用于從每個分類拷貝6500個文件,cnews_group.py用于將多個文件整合到一個文件中。執(zhí)行該文件后,得到三個數(shù)據(jù)文件:

  • cnews.train.txt: 訓練集(50000條)
  • cnews.val.txt: 驗證集(5000條)
  • cnews.test.txt: 測試集(10000條)

預處理

data/cnews_loader.py為數(shù)據(jù)的預處理文件。

  • read_file(): 讀取文件數(shù)據(jù);
  • build_vocab(): 構(gòu)建詞匯表,使用字符級的表示,這一函數(shù)會將詞匯表存儲下來,避免每一次重復處理;
  • read_vocab(): 讀取上一步存儲的詞匯表,轉(zhuǎn)換為{詞:id}表示;
  • read_category(): 將分類目錄固定,轉(zhuǎn)換為{類別: id}表示;
  • to_words(): 將一條由id表示的數(shù)據(jù)重新轉(zhuǎn)換為文字;
  • preocess_file(): 將數(shù)據(jù)集從文字轉(zhuǎn)換為固定長度的id序列表示;
  • batch_iter(): 為神經(jīng)網(wǎng)絡的訓練準備經(jīng)過shuffle的批次的數(shù)據(jù)。

經(jīng)過數(shù)據(jù)預處理,數(shù)據(jù)的格式如下:

 

CNN卷積神經(jīng)網(wǎng)絡

配置項

CNN可配置的參數(shù)如下所示,在cnn_model.py中。

  1. class TCNNConfig(object): 
  2.     """CNN配置參數(shù)""" 
  3.  
  4.     embedding_dim = 64      # 詞向量維度 
  5.     seq_length = 600        # 序列長度 
  6.     num_classes = 10        # 類別數(shù) 
  7.     num_filters = 128        # 卷積核數(shù)目 
  8.     kernel_size = 5         # 卷積核尺寸 
  9.     vocab_size = 5000       # 詞匯表達小 
  10.  
  11.     hidden_dim = 128        # 全連接層神經(jīng)元 
  12.  
  13.     dropout_keep_prob = 0.5 # dropout保留比例 
  14.     learning_rate = 1e-3    # 學習率 
  15.  
  16.     batch_size = 64         # 每批訓練大小 
  17.     num_epochs = 10         # 總迭代輪次 
  18.  
  19.     print_per_batch = 100    # 每多少輪輸出一次結(jié)果 
  20.     save_per_batch = 10      # 每多少輪存入tensorboard 

CNN模型

具體參看cnn_model.py的實現(xiàn)。

大致結(jié)構(gòu)如下:

 

訓練與驗證

運行 python run_cnn.py train,可以開始訓練。

若之前進行過訓練,請把tensorboard/textcnn刪除,避免TensorBoard多次訓練結(jié)果重疊。

  1. Configuring CNN model... 
  2. Configuring TensorBoard and Saver... 
  3. Loading training and validation data... 
  4. Time usage: 0:00:14 
  5. Training and evaluating... 
  6. Epoch: 1 
  7. Iter:      0, Train Loss:    2.3, Train Acc:  10.94%, Val Loss:    2.3, Val Acc:   8.92%, Time: 0:00:01 * 
  8. Iter:    100, Train Loss:   0.88, Train Acc:  73.44%, Val Loss:    1.2, Val Acc:  68.46%, Time: 0:00:04 * 
  9. Iter:    200, Train Loss:   0.38, Train Acc:  92.19%, Val Loss:   0.75, Val Acc:  77.32%, Time: 0:00:07 * 
  10. Iter:    300, Train Loss:   0.22, Train Acc:  92.19%, Val Loss:   0.46, Val Acc:  87.08%, Time: 0:00:09 * 
  11. Iter:    400, Train Loss:   0.24, Train Acc:  90.62%, Val Loss:    0.4, Val Acc:  88.62%, Time: 0:00:12 * 
  12. Iter:    500, Train Loss:   0.16, Train Acc:  96.88%, Val Loss:   0.36, Val Acc:  90.38%, Time: 0:00:15 * 
  13. Iter:    600, Train Loss:  0.084, Train Acc:  96.88%, Val Loss:   0.35, Val Acc:  91.36%, Time: 0:00:17 * 
  14. Iter:    700, Train Loss:   0.21, Train Acc:  93.75%, Val Loss:   0.26, Val Acc:  92.58%, Time: 0:00:20 * 
  15. Epoch: 2 
  16. Iter:    800, Train Loss:   0.07, Train Acc:  98.44%, Val Loss:   0.24, Val Acc:  94.12%, Time: 0:00:23 * 
  17. Iter:    900, Train Loss:  0.092, Train Acc:  96.88%, Val Loss:   0.27, Val Acc:  92.86%, Time: 0:00:25 
  18. Iter:   1000, Train Loss:   0.17, Train Acc:  95.31%, Val Loss:   0.28, Val Acc:  92.82%, Time: 0:00:28 
  19. Iter:   1100, Train Loss:    0.2, Train Acc:  93.75%, Val Loss:   0.23, Val Acc:  93.26%, Time: 0:00:31 
  20. Iter:   1200, Train Loss:  0.081, Train Acc:  98.44%, Val Loss:   0.25, Val Acc:  92.96%, Time: 0:00:33 
  21. Iter:   1300, Train Loss:  0.052, Train Acc: 100.00%, Val Loss:   0.24, Val Acc:  93.58%, Time: 0:00:36 
  22. Iter:   1400, Train Loss:    0.1, Train Acc:  95.31%, Val Loss:   0.22, Val Acc:  94.12%, Time: 0:00:39 
  23. Iter:   1500, Train Loss:   0.12, Train Acc:  98.44%, Val Loss:   0.23, Val Acc:  93.58%, Time: 0:00:41 
  24. Epoch: 3 
  25. Iter:   1600, Train Loss:    0.1, Train Acc:  96.88%, Val Loss:   0.26, Val Acc:  92.34%, Time: 0:00:44 
  26. Iter:   1700, Train Loss:  0.018, Train Acc: 100.00%, Val Loss:   0.22, Val Acc:  93.46%, Time: 0:00:47 
  27. Iter:   1800, Train Loss:  0.036, Train Acc: 100.00%, Val Loss:   0.28, Val Acc:  92.72%, Time: 0:00:50 
  28. No optimization for a long time, auto-stopping... 

在驗證集上的最佳效果為94.12%,且只經(jīng)過了3輪迭代就已經(jīng)停止。

準確率和誤差如圖所示:

 

 

測試

運行 python run_cnn.py test 在測試集上進行測試。

  1. Configuring CNN model... 
  2. Loading test data... 
  3. Testing... 
  4. Test Loss:   0.14, Test Acc:  96.04% 
  5. Precision, Recall and F1-Score... 
  6.              precision    recall  f1-score   support 
  7.  
  8.          體育       0.99      0.99      0.99      1000 
  9.          財經(jīng)       0.96      0.99      0.97      1000 
  10.          房產(chǎn)       1.00      1.00      1.00      1000 
  11.          家居       0.95      0.91      0.93      1000 
  12.          教育       0.95      0.89      0.92      1000 
  13.          科技       0.94      0.97      0.95      1000 
  14.          時尚       0.95      0.97      0.96      1000 
  15.          時政       0.94      0.94      0.94      1000 
  16.          游戲       0.97      0.96      0.97      1000 
  17.          娛樂       0.95      0.98      0.97      1000 
  18.  
  19. avg / total       0.96      0.96      0.96     10000 
  20.  
  21. Confusion Matrix... 
  22. [[991   0   0   0   2   1   0   4   1   1] 
  23.  [  0 992   0   0   2   1   0   5   0   0] 
  24.  [  0   1 996   0   1   1   0   0   0   1] 
  25.  [  0  14   0 912   7  15   9  29   3  11] 
  26.  [  2   9   0  12 892  22  18  21  10  14] 
  27.  [  0   0   0  10   1 968   4   3  12   2] 
  28.  [  1   0   0   9   4   4 971   0   2   9] 
  29.  [  1  16   0   4  18  12   1 941   1   6] 
  30.  [  2   4   1   5   4   5  10   1 962   6] 
  31.  [  1   0   1   6   4   3   5   0   1 979]] 
  32. Time usage: 0:00:05 

在測試集上的準確率達到了96.04%,且各類的precision, recall和f1-score都超過了0.9。

從混淆矩陣也可以看出分類效果非常優(yōu)秀。

RNN循環(huán)神經(jīng)網(wǎng)絡

配置項

RNN可配置的參數(shù)如下所示,在rnn_model.py中。

  1. class TRNNConfig(object): 
  2.     """RNN配置參數(shù)""" 
  3.  
  4.     # 模型參數(shù) 
  5.     embedding_dim = 64      # 詞向量維度 
  6.     seq_length = 600        # 序列長度 
  7.     num_classes = 10        # 類別數(shù) 
  8.     vocab_size = 5000       # 詞匯表達小 
  9.  
  10.     num_layers= 2           # 隱藏層層數(shù) 
  11.     hidden_dim = 128        # 隱藏層神經(jīng)元 
  12.     rnn = 'gru'             # lstm 或 gru 
  13.  
  14.     dropout_keep_prob = 0.8 # dropout保留比例 
  15.     learning_rate = 1e-3    # 學習率 
  16.  
  17.     batch_size = 128         # 每批訓練大小 
  18.     num_epochs = 10          # 總迭代輪次 
  19.  
  20.     print_per_batch = 100    # 每多少輪輸出一次結(jié)果 
  21.     save_per_batch = 10      # 每多少輪存入tensorboard 

RNN模型

具體參看rnn_model.py的實現(xiàn)。

大致結(jié)構(gòu)如下:

 

訓練與驗證

這部分的代碼與 run_cnn.py極為相似,只需要將模型和部分目錄稍微修改。

運行 python run_rnn.py train,可以開始訓練。

若之前進行過訓練,請把tensorboard/textrnn刪除,避免TensorBoard多次訓練結(jié)果重疊。

  1. Configuring RNN model... 
  2. Configuring TensorBoard and Saver... 
  3. Loading training and validation data... 
  4. Time usage: 0:00:14 
  5. Training and evaluating... 
  6. Epoch: 1 
  7. Iter:      0, Train Loss:    2.3, Train Acc:   8.59%, Val Loss:    2.3, Val Acc:  11.96%, Time: 0:00:08 * 
  8. Iter:    100, Train Loss:   0.95, Train Acc:  64.06%, Val Loss:    1.3, Val Acc:  53.06%, Time: 0:01:15 * 
  9. Iter:    200, Train Loss:   0.61, Train Acc:  79.69%, Val Loss:   0.94, Val Acc:  69.88%, Time: 0:02:22 * 
  10. Iter:    300, Train Loss:   0.49, Train Acc:  85.16%, Val Loss:   0.63, Val Acc:  81.44%, Time: 0:03:29 * 
  11. Epoch: 2 
  12. Iter:    400, Train Loss:   0.23, Train Acc:  92.97%, Val Loss:    0.6, Val Acc:  82.86%, Time: 0:04:36 * 
  13. Iter:    500, Train Loss:   0.27, Train Acc:  92.97%, Val Loss:   0.47, Val Acc:  86.72%, Time: 0:05:43 * 
  14. Iter:    600, Train Loss:   0.13, Train Acc:  98.44%, Val Loss:   0.43, Val Acc:  87.46%, Time: 0:06:50 * 
  15. Iter:    700, Train Loss:   0.24, Train Acc:  91.41%, Val Loss:   0.46, Val Acc:  87.12%, Time: 0:07:57 
  16. Epoch: 3 
  17. Iter:    800, Train Loss:   0.11, Train Acc:  96.09%, Val Loss:   0.49, Val Acc:  87.02%, Time: 0:09:03 
  18. Iter:    900, Train Loss:   0.15, Train Acc:  96.09%, Val Loss:   0.55, Val Acc:  85.86%, Time: 0:10:10 
  19. Iter:   1000, Train Loss:   0.17, Train Acc:  96.09%, Val Loss:   0.43, Val Acc:  89.44%, Time: 0:11:18 * 
  20. Iter:   1100, Train Loss:   0.25, Train Acc:  93.75%, Val Loss:   0.42, Val Acc:  88.98%, Time: 0:12:25 
  21. Epoch: 4 
  22. Iter:   1200, Train Loss:   0.14, Train Acc:  96.09%, Val Loss:   0.39, Val Acc:  89.82%, Time: 0:13:32 * 
  23. Iter:   1300, Train Loss:    0.2, Train Acc:  96.09%, Val Loss:   0.43, Val Acc:  88.68%, Time: 0:14:38 
  24. Iter:   1400, Train Loss:  0.012, Train Acc: 100.00%, Val Loss:   0.37, Val Acc:  90.58%, Time: 0:15:45 * 
  25. Iter:   1500, Train Loss:   0.15, Train Acc:  96.88%, Val Loss:   0.39, Val Acc:  90.58%, Time: 0:16:52 
  26. Epoch: 5 
  27. Iter:   1600, Train Loss:  0.075, Train Acc:  97.66%, Val Loss:   0.41, Val Acc:  89.90%, Time: 0:17:59 
  28. Iter:   1700, Train Loss:  0.042, Train Acc:  98.44%, Val Loss:   0.41, Val Acc:  90.08%, Time: 0:19:06 
  29. Iter:   1800, Train Loss:   0.08, Train Acc:  97.66%, Val Loss:   0.38, Val Acc:  91.36%, Time: 0:20:13 * 
  30. Iter:   1900, Train Loss:  0.089, Train Acc:  98.44%, Val Loss:   0.39, Val Acc:  90.18%, Time: 0:21:20 
  31. Epoch: 6 
  32. Iter:   2000, Train Loss:  0.092, Train Acc:  96.88%, Val Loss:   0.36, Val Acc:  91.42%, Time: 0:22:27 * 
  33. Iter:   2100, Train Loss:  0.062, Train Acc:  98.44%, Val Loss:   0.39, Val Acc:  90.56%, Time: 0:23:34 
  34. Iter:   2200, Train Loss:  0.053, Train Acc:  98.44%, Val Loss:   0.39, Val Acc:  90.02%, Time: 0:24:41 
  35. Iter:   2300, Train Loss:   0.12, Train Acc:  96.09%, Val Loss:   0.37, Val Acc:  90.84%, Time: 0:25:48 
  36. Epoch: 7 
  37. Iter:   2400, Train Loss:  0.014, Train Acc: 100.00%, Val Loss:   0.41, Val Acc:  90.38%, Time: 0:26:55 
  38. Iter:   2500, Train Loss:   0.14, Train Acc:  96.88%, Val Loss:   0.37, Val Acc:  91.22%, Time: 0:28:01 
  39. Iter:   2600, Train Loss:   0.11, Train Acc:  96.88%, Val Loss:   0.43, Val Acc:  89.76%, Time: 0:29:08 
  40. Iter:   2700, Train Loss:  0.089, Train Acc:  97.66%, Val Loss:   0.37, Val Acc:  91.18%, Time: 0:30:15 
  41. Epoch: 8 
  42. Iter:   2800, Train Loss: 0.0081, Train Acc: 100.00%, Val Loss:   0.44, Val Acc:  90.66%, Time: 0:31:22 
  43. Iter:   2900, Train Loss:  0.017, Train Acc: 100.00%, Val Loss:   0.44, Val Acc:  89.62%, Time: 0:32:29 
  44. Iter:   3000, Train Loss:  0.061, Train Acc:  96.88%, Val Loss:   0.43, Val Acc:  90.04%, Time: 0:33:36 
  45. No optimization for a long time, auto-stopping... 

在驗證集上的最佳效果為91.42%,經(jīng)過了8輪迭代停止,速度相比CNN慢很多。

準確率和誤差如圖所示:

 

測試

運行 python run_rnn.py test 在測試集上進行測試。

  1. Testing... 
  2. Test Loss:   0.21, Test Acc:  94.22% 
  3. Precision, Recall and F1-Score... 
  4.              precision    recall  f1-score   support 
  5.  
  6.          體育       0.99      0.99      0.99      1000 
  7.          財經(jīng)       0.91      0.99      0.95      1000 
  8.          房產(chǎn)       1.00      1.00      1.00      1000 
  9.          家居       0.97      0.73      0.83      1000 
  10.          教育       0.91      0.92      0.91      1000 
  11.          科技       0.93      0.96      0.94      1000 
  12.          時尚       0.89      0.97      0.93      1000 
  13.          時政       0.93      0.93      0.93      1000 
  14.          游戲       0.95      0.97      0.96      1000 
  15.          娛樂       0.97      0.96      0.97      1000 
  16.  
  17. avg / total       0.94      0.94      0.94     10000 
  18.  
  19. Confusion Matrix... 
  20. [[988   0   0   0   4   0   2   0   5   1] 
  21.  [  0 990   1   1   1   1   0   6   0   0] 
  22.  [  0   2 996   1   1   0   0   0   0   0] 
  23.  [  2  71   1 731  51  20  88  28   3   5] 
  24.  [  1   3   0   7 918  23   4  31   9   4] 
  25.  [  1   3   0   3   0 964   3   5  21   0] 
  26.  [  1   0   1   7   1   3 972   0   6   9] 
  27.  [  0  16   0   0  22  26   0 931   2   3] 
  28.  [  2   3   0   0   2   2  12   0 972   7] 
  29.  [  0   3   1   1   7   3  11   5   9 960]] 
  30. Time usage: 0:00:33 

在測試集上的準確率達到了94.22%,且各類的precision, recall和f1-score,除了家居這一類別,都超過了0.9。

從混淆矩陣可以看出分類效果非常優(yōu)秀。

對比兩個模型,可見RNN除了在家居分類的表現(xiàn)不是很理想,其他幾個類別較CNN差別不大。

 

還可以通過進一步的調(diào)節(jié)參數(shù),來達到更好的效果。 

責任編輯:龐桂玉 來源: 36大數(shù)據(jù)
相關(guān)推薦

2017-09-12 16:31:21

TensorFlowLSTMCNN

2018-11-01 09:14:42

CNNRNN神經(jīng)網(wǎng)絡

2024-12-16 08:06:42

2022-08-08 12:54:32

中文文本糾錯糾錯錯別字

2017-04-13 09:18:02

深度學習文本分類

2020-08-20 07:00:00

人工智能深度學習技術(shù)

2024-09-20 10:02:13

2017-08-02 10:43:39

深度學習TensorFlowRNN

2022-06-29 09:00:00

前端圖像分類模型SQL

2022-08-15 15:16:20

機器學習圖片深度學習

2024-11-21 16:06:02

2018-04-09 10:20:32

深度學習

2016-12-09 13:45:21

RNN大數(shù)據(jù)深度學習

2011-12-22 12:37:17

JavaJFreeChart

2017-08-04 14:23:04

機器學習神經(jīng)網(wǎng)絡TensorFlow

2017-08-25 14:23:44

TensorFlow神經(jīng)網(wǎng)絡文本分類

2023-03-03 08:17:28

神經(jīng)網(wǎng)絡RNN網(wǎng)絡

2020-12-31 05:37:05

HiveUDFSQL

2010-03-16 16:11:41

交換機堆疊技術(shù)

2011-09-14 13:18:59

Android API
點贊
收藏

51CTO技術(shù)棧公眾號