文本文字識別、公式識別、表格文字識別核心算法及思路及實踐-DBNet、CRNN、TrOCR 原創(chuàng)
前言
OCR技術作為文檔智能解析鏈路中的核心組件之一,貫穿整個技術鏈路,包括:文字識別、表格文字識別、公式識別,參看下面這張架構圖:
前期介紹了很多關于文檔智能解析相關核心技術及思路,本著連載的目的,本次迎來介紹整個鏈路中的最后一塊拼圖-OCR。本文簡要介紹OCR常見落地的算法模型-DBNet、CRNN,并基于這兩個模型,簡單介紹文字識別在表格識別中參與的角色;并且額外介紹TrOCR這個端到端的模型,基于這個模型引入公式識別解析的思路及微調方法。
DBNet
DBNet是一種基于分割的文本檢測算法,算法將可微分二值化模塊(Differentiable Binarization)引入了分割模型,使得模型能夠通過自適應的閾值圖進行二值化,并且自適應閾值圖可以計算損失,能夠在模型訓練過程中起到輔助效果優(yōu)化的效果。DBNet在效果和性能上都有比較大的優(yōu)勢,是目前最常用的文本檢測算法之一。
模型結構
DB文本檢測模型由三個部分組成:Backbone網絡、FPN網絡、Head網絡
1.Backbone網絡
負責提取圖像的特征。Backbone部分采用的是圖像分類網絡,論文中分別使用了ResNet50和ResNet18網絡。輸入圖像[1,3,640, 640] ,進入Backbone網絡,先經過一次卷積計算尺寸變?yōu)樵瓉淼?/2, 而后經過四次下采樣,輸出四個尺度特征圖:
2.FPN網絡
特征金字塔網絡
特征金子塔,F(xiàn)eaturized image pyramid。負責結構增強特征。特征金字塔結構FPN是一種卷積網絡來高效提取圖片中各維度特征的常用方法。FPN網絡的輸入為Backbone部分的輸出,經FPN計算后輸出的特征圖的高度和寬度為原圖的1/4, 即[1, 256, 160, 160] 。
DBNet中FPN特征提取融合過程:
- 1/32特征圖: [1, N, 20, 20] ===> 卷積 + 8倍上采樣 ===> [1, 64, 160, 160]
- 1/16特征圖:[1, N, 40, 40] ===> 加1/32特征圖的兩倍上采樣 ===> 新1/16特征圖 ==> 卷積 + 4倍上采樣 ===> [1, 64, 160, 160]
- 1/8特征圖:[1, N, 80, 80] ===> 加新1/16特征圖的兩倍上采樣 ===>新1/8特征圖 ===> 卷積 + 2倍上采樣 ===> [1, 64, 160, 160]
- 1/4特征圖:[1, N, 160, 160] ===> 加新1/8特征圖的兩倍上采樣 ===> 新1/4特征圖 ===> 卷積 ===> [1, 64, 160, 160]
- 融合特征圖:[1, 256, 160, 160] # 將1/4,1/8, 1/16, 1/32特征圖按通道層合并在一起
3.Head網絡
負責計算文本區(qū)域概率圖的 Head 網絡,基于 FPN 特征進行上采樣,將 FPN 特征從原來的 1/4 尺寸映射回原圖尺寸。最終,Head 網絡會生成文本區(qū)域概率圖、文本區(qū)域閾值圖,并將這些圖合并,得到一個輸出大小為 [1, 3, 640, 640] 的結果。
標簽生成
DB算法在進行模型訓練的時,需要根據(jù)標注框生成兩幅圖像:概率圖和閾值圖。生成過程如下圖所示:
image圖像中的紅線是文本的標注框,文本標注框的點集合用如下形式表示:
其中,n表示頂點的數(shù)量。
1.概率圖標簽Gs
在polygon圖像中,將紅色的標注框外擴distance得到綠色的polygon框,內縮distance得到藍色的polygon框。論文中標注框內縮和外擴使用相同的distance,其計算公式為:
L代表周長,A代表面積,r代表縮放比例,通常r=0.4
多邊形輪廓的周長L和面積A通過Polygon庫計算獲得。
得到上述效果簡單示例:
code ref:https://blog.csdn.net/yewumeng123/article/details/127503815
import cv2
import pyclipper
import numpy as np
from shapely.geometry import Polygon
def draw_img(subject, canvas, color=(255,0,0)):
"""作圖函數(shù)"""
for i in range(len(subject)):
j = (i+1)%len(subject)
cv2.line(canvas, subject[i], subject[j], color)
# 論文默認shrink值
r=0.4
# 假定標注框
subject = ((100, 100), (250, 100), (250, 200), (100, 200))
# 創(chuàng)建Polygon對象
polygon = Polygon(subject)
# 計算偏置distance
distance = polygon.area*(1-np.power(r, 2))/polygon.length
print(distance)
# 25.2
# 創(chuàng)建PyclipperOffset對象
padding = pyclipper.PyclipperOffset()
# 向ClipperOffset對象添加一個路徑用來準備偏置
# padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
# adding.AddPath(subject, pyclipper.JT_SQUARE, pyclipper.ET_CLOSEDPOLYGON)
padding.AddPath(subject, pyclipper.JT_MITER, pyclipper.ET_CLOSEDPOLYGON)
# polygon外擴
polygon_expand = padding.Execute(distance)[0]
polygon_expand = [tuple(l) for l in polygon_expand]
print(polygon_expand)
# [(75, 75), (275, 75), (275, 225), (75, 225)]
# polygon內縮
polygon_shrink = padding.Execute(-distance)[0]
polygon_shrink = [tuple(l) for l in polygon_shrink]
print(polygon_shrink)
# [(125, 125), (225, 125), (225, 175), (125, 175)]
# 作圖
canvas = np.zeros((300,350,3), dtype=np.uint8)
# 原輪廓用紅色線條展示
draw_img(subject, canvas, color=(0,0,255))
# 外擴輪廓用綠色線條展示
draw_img(polygon_expand, canvas, color=(0,255,0))
# 內縮輪廓用藍色線條展示
draw_img(polygon_shrink, canvas, color=(255,0,0))
cv2.imshow("Canvas", canvas)
cv2.waitKey(0)
2.閾值圖標簽Gd
通過類似的過程,為閾值圖生成標簽。在 DBNet 中,閾值圖標簽的生成過程可以概述如下:
- 多邊形文字區(qū)域擴張與收縮:首先,將原始的多邊形文字區(qū)域G 進行擴張,得到擴張后的區(qū)域Gd(綠線表示)。同時,將文字區(qū)域以偏移量D為標準進行收縮,得到收縮后的區(qū)域Gs(藍線表示)。其中,偏移量D與概率圖中的偏移量相同。
- 邊界區(qū)域定義:收縮區(qū)域Gs和擴張區(qū)域Gd之間的間隙被視為文本區(qū)域的邊界。在這個邊界區(qū)域內,計算每個像素點到原始文字區(qū)域邊界G(紅線表示)的歸一化距離。這個距離是通過找到該像素點到最近的G邊界線段的距離來計算的。
- 歸一化處理:
計算得到的歸一化距離顯示出:擴張框Gd和收縮框Gs 上的像素點距離值最大,而原始邊界G上的距離值最?。?0)。
隨著從G向Gs和Gd移動,距離值逐漸變大。
- 距離轉換與二次歸一化:
- 計算完距離后,對這些距離值進行歸一化,即除以偏移量D,此時,Gs和Gd上的距離值均變?yōu)?1,而G 上的值為 0,表現(xiàn)為以G為中心向Gs和Gd 兩側的距離逐漸減小。
- 接著,對這些歸一化值再進行一次變換,使用 減去這些距離值,最終得到:G上的值為 1,Gs和Gd上的值為 0。這樣,Gs和 Gd區(qū)域內的像素值范圍變?yōu)?[0, 1]。
- 最終標簽生成:在完成上述操作后,我們再次對這些值進行歸一化,得到像 [0.2, 0.8] 等不同范圍的標簽值。最終,這些值將作為閾值圖的標簽,幫助模型更好地學習文本區(qū)域與邊界之間的精確關系。
通過該過程生成的閾值圖標簽,結合 DBNet 的差分二值化策略,能夠有效提升文本檢測的邊界處理能力。
損失函數(shù)
損失函數(shù)為概率圖的損失 Ls、二值化圖的損失Lb和閾值圖的損失的和Lt:
采用BCE Loss,為平衡正負樣本的比例,使用OHEM進行困難樣本挖掘,正樣本:負樣本=1:3:
損失函數(shù)代碼ref:https://github.com/MhLiao/DB/blob/master/decoders/seg_detector_loss.py
文本識別算法-CRNN
模型包括三個部分,分別稱作卷積層、循環(huán)層以及轉錄層。
CRNN網絡結構
卷積層由CNN構成,它的作用是從輸入的圖像中提取特征。提取的特征圖將會輸入到接下來的循環(huán)層中,循環(huán)層由RNN構成,它將輸出對特征序列每一幀的預測。最后轉錄層將得到的預測概率分布轉換成標記序列,得到最終的識別結果,它實際上就是模型中的損失函數(shù)。通過最小化損失函數(shù),訓練由CNN和RNN組成的網絡。
1.卷積層 CRNN模型中的卷積層由一系列的卷積層、池化層、BN層構造而成。就像其他的CNN模型一樣,它將輸入的圖片轉化為具有特征信息的特征圖,作為后面循環(huán)層的輸入。當然,為了使提取的特征圖尺寸相同,輸入的圖像事先要縮放到固定的大小。
由于卷積神經網絡中卷積層和最大池化層的存在,使其具有平移不變性的特點。卷積神經網絡中的感受野指的是經過卷積層輸出的特征圖中每個像素對應的原輸入圖像區(qū)域的大小,它與特征圖上的像素從左到右,從上到下是一一對應的,如下圖所示。因此,可以將特征圖作為圖像特征的表示。
2.循環(huán)層
在CRNN(Convolutional Recurrent Neural Network)模型中,循環(huán)層通常位于卷積層(Convolutional Layers)之后,用于處理卷積層提取的特征序列。
作用:
- 捕捉序列中的上下文信息:循環(huán)層能夠記住之前處理過的信息,并利用這些信息來輔助當前的預測任務。
- 處理任意長度的序列:與卷積神經網絡(CNN)不同,循環(huán)層可以處理變長的序列數(shù)據(jù),因為它們可以逐個時間點地處理輸入。
3.轉錄層
轉錄層的作用是將前面通過CNN層和RNN層得到的預測序列轉換成標記序列,得到最終的識別結果。簡單來說,就是選取預測序列中每個分量中概率最大的索引對應的符號作為識別結果,最終組成序列作為最終的識別序列。CRNN轉錄算法使用的是CTC算法,涉及原理可以自行查閱。
端到端的文本識別模型-TrOCR
TrOCR使用Transformer架構構建,包括用于提取視覺特征的圖像Transformer和用于語言建模的文本Transformer。在TrOCR中采用了普通的Transformer編碼器-解碼器結構。編碼器被設計為獲得圖像塊的表示,而解碼器被設計為在視覺特征和先前預測的指導下生成單詞片段序列。
TrOCR的架構,其中編碼器-解碼器模型設計為預訓練圖像Transformer作為編碼器,預訓練文本Transformer作為解碼器。
1.編碼器
編碼器接收輸入圖像,并將其大小調整為固定大小(H,W)。由于Transformer編碼 器無法處理原始圖像,除非它們是一系列輸入tokens,因此編碼器將輸入圖像分解為一批 個固定大小為 (P,P) 的正方形patches,同時保證調整大小的圖像的寬度W和 高度H可被patch大小P整除,將patches展平為向量并線性投影到D維向量,即patch embeddings。D是Transformer在其所有層中的隱藏尺寸。
與ViT 和DeiT類似,論文保留了通常用于圖像分類任務的特殊標記"[CLS]"。"[CLS]"標記將來自所有patch embeddings的所有信息匯集在一起,并表示整個圖像。同時,當使用DeiT預訓練模型進行編碼器初始化時,論文還將蒸餾token保持在輸入序列中,這允許模型從教師模型學習。patch embeddings和兩個特殊tokens根據(jù)其絕對位置被賦予可學習的1D位置嵌入。
2.解碼器
使用TrOCR的原始Transformer解碼器。標準Transformer解碼器也有一個相同的層堆棧,其結構與編碼器中的層相似,只是解碼器在多頭自注意力和前饋網絡之間插入“encoder-decoder attention”,以在編碼器的輸出上分配不同的關注。在編碼器-解碼器注意模塊中,鍵(K)和值(V)來自編碼器輸出,而查詢(Q)來自解碼器輸入。此外,解碼器利用自注意力中的注意力掩碼來防止自己在訓練期間獲得未來的信息?;诮獯a器的輸出將從解碼器的輸入向右移位一個位置,注意力掩碼需要確保位置i的輸出只和先前的輸出相關,即小于i的位置上的輸入:
來自解碼器的隱藏狀態(tài)通過線性層從模型維度投影到詞匯大小V的維度,而詞匯上的概率通過softmax函數(shù)計算。使用beam search來獲得最終輸出。
性能效果
微調
- 數(shù)據(jù)處理,helper.py
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset
from PIL import Image
from transformers import TrOCRProcessor
df = pd.read_csv('dataset/tex.csv')
train_df, test_df = train_test_split(df, test_size=0.1, shuffle=True, random_state=42)
# we reset the indices to start from zero
train_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)
model_path = 'microsoft/trocr-base-stage1'
class IAMDataset(Dataset):
def __init__(self, root_dir, df, processor, max_target_length=512):
self.root_dir = root_dir
self.df = df
self.processor = processor
self.max_target_length = max_target_length
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# get file name + text
file_name = self.df['file_name'][idx]
text = self.df['text'][idx]
# prepare image (i.e. resize + normalize)
image = Image.open(self.root_dir + file_name).convert("RGB")
pixel_values = self.processor(image, return_tensors="pt").pixel_values
# add labels (input_ids) by encoding the text
labels = self.processor.tokenizer(
text,
padding="max_length",
max_length=self.max_target_length,
truncatinotallow=True
).input_ids
# important: make sure that PAD tokens are ignored by the loss function
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
return encoding
processor = TrOCRProcessor.from_pretrained(model_path)
train_dataset = IAMDataset(root_dir='./dataset/trainning_set/',
df=train_df,
processor=processor,
max_target_length=512
)
eval_dataset = IAMDataset(root_dir='./dataset/trainning_set/',
df=test_df,
processor=processor,
max_target_length=512
)
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(eval_dataset))
- train.py
from helper import *
from transformers import VisionEncoderDecoderModel
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1")
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size
# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
predict_with_generate=True,
evaluation_strategy="steps",
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
fp16=True,
output_dir="./",
logging_steps=2,
save_steps=1000,
eval_steps=200,
)
from datasets import load_metric
cer_metric = load_metric("cer")
def compute_metrics(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
cer = cer_metric.compute(predictinotallow=pred_str, references=label_str)
return {"cer": cer}
from transformers import default_data_collator
# instantiate trainer
trainer = Seq2SeqTrainer(
model=model,
tokenizer=processor.feature_extractor,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=default_data_collator,
)
trainer.train()
應用思路
1.文字識別在表格識別中的角色
表格識別主要包含三個模型
- 單行文本檢測-DB
- 單行文本識別-CRNN
- 表格結構和cell坐標預測-SLANet
流程說明:
- 圖片由單行文字檢測模型檢測到單行文字的坐標,然后送入識別模型拿到識別結果。
- 圖片由SLANet模型拿到表格的結構信息和單元格的坐標信息。
- 由單行文字的坐標、識別結果和單元格的坐標一起組合出單元格的識別結果。
- 單元格的識別結果和表格結構一起構造表格的html字符串。
2.TrOCR公式識別微調后效果 通過微調TrOCR模型,端到端的得到公式latex格式字符串。
參考文獻
- Real-time Scene Text Detection with Differentiable Binarization,https://arxiv.org/pdf/1911.08947
- An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition,https://arxiv.org/pdf/1507.05717
- TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models,https://arxiv.org/pdf/2109.10282
- ??https://github.com/PaddlePaddle/PaddleOCR??
本文轉載自公眾號大模型自然語言處理 作者:余俊暉
原文鏈接:??https://mp.weixin.qq.com/s/F67wKhbYbPNVkc8Bg-ZsZA??
