使用 OCR 識別手寫文本
本文實現了基于微調TrOCR模型進行手寫文本識別。
1.GNHK手寫筆記數據集
GNHK(GoodNotes Handwriting Kollection)手寫筆記數據集由GoodNotes提供,包含來自世界各地學生的數百份英文手寫筆記。
下載數據集
訪問GNHK數據集官方網站:
(https://www.goodnotes.com/gnhk),滾動到底部,同意使用條款和條件;點擊第二個鏈接下載數據集。
下載后會得到兩個文件:train_data.zip 和 test_data.zip。解壓這兩個文件后,數據集的目錄結構如下:
├── test_data
│ └── test
│ ├── eng_AF_004.jpg
│ ├── eng_AF_004.json
│ ├── eng_AF_007.jpg
│ ├── eng_AF_007.json
│ ...
│ ├── eng_NA_142.jpg
│ └── eng_NA_142.json
├── train_data
└── train
├── eng_AF_001.jpg
├── eng_AF_001.json
├── eng_AF_002.jpg
├── eng_AF_002.json
...
├── eng_NA_146.jpg
└── eng_NA_146.json
4 directories, 1375 files
- 訓練集:包含515個樣本
- 測試集:包含172個樣本
- 圖像文件:從1080p到4K的高分辨率圖像
- 標注文件:每個圖像文件對應一個JSON文件,包含圖像中每個單詞的標注信息
以下是數據集中的一些手寫筆記圖像樣本。
每個圖像文件對應一個JSON文件,文件內容格式如下:
[
{
"text": "%math%",
"polygon": {
"x0": 112, "y0": 556,
"x1": 285, "y1": 563,
"x2": 245, "y2": 776,
"x3": 112, "y3": 783
},
"line_idx": 1,
"type": "H"
},
{
"text": "%math%",
"polygon": {
"x0": 2365, "y0": 202,
"x1": 2350, "y1": 509,
"x2": 2588, "y2": 527,
"x3": 2632, "y3": 195
},
"line_idx": 0,
"type": "H"
},
...
{
"text": "ownership",
"polygon": {
"x0": 1347, "y0": 1606,
"x1": 2238, "y1": 1574,
"x2": 2170, "y2": 1884,
"x3": 1300, "y3": 1747
},
"line_idx": 4,
"type": "H"
}
]
其中:
- text:表示單詞的內容。如果單詞是數學符號、特殊字符或不可理解的內容(例如劃線),則用%%符號包裹的特殊詞表示。否則,text鍵包含實際的單詞。
- polygon:表示單詞的多邊形坐標,用于精確標注單詞的位置。
- line_idx:表示單詞所在的行索引。
- type:表示單詞的類型,通常為"H"(手寫)。
2.項目目錄結構
├── input
│ └── gnhk_dataset
│ ├── test_data
│ ├── test_processed
│ ├── train_data
│ ├── train_processed
│ ├── test_processed.csv
│ └── train_processed.csv
├── pretrained_model_inference [10066 entries exceeds filelimit, not opening dir]
├── trocr_handwritten
│ ├── checkpoint-6093
│ │ ├── config.json
│ │ ├── generation_config.json
│ │ ├── model.safetensors
│ │ ├── optimizer.pt
│ │ ├── preprocessor_config.json
│ │ ├── rng_state.pth
│ │ ├── scheduler.pt
│ │ ├── trainer_state.json
│ │ └── training_args.bin
│ ├── checkpoint-6770
│ │ ├── config.json
│ │ ├── generation_config.json
│ │ ├── model.safetensors
│ │ ├── optimizer.pt
│ │ ├── preprocessor_config.json
│ │ ├── rng_state.pth
│ │ ├── scheduler.pt
│ │ ├── trainer_state.json
│ │ └── training_args.bin
│ └── runs
│ └── Aug27_11-30-05_f57a2dab37c7
├── Fine_Tune_TrOCR_Handwritten.ipynb
├── preprocess_gnhk_dataset.py
└── Pretrained_Model_Inference.ipynb
目錄說明:
- input/gnhk_dataset:包含下載并解壓的數據集
- pretrained_model_inference:包含使用預訓練的TrOCR手寫模型對驗證數據集進行推理的結果。
- trocr_handwritten:包含微調TrOCR模型后的結果。
- Fine_Tune_TrOCR_Handwritten.ipynb:用于微調TrOCR模型的Jupyter Notebook
- preprocess_gnhk_dataset.py:包含預處理GNHK數據集的Python腳本
- Pretrained_Model_Inference.ipynb:用于使用預訓練模型進行推理的Jupyter Notebook
3.安裝依賴項
在繼續(xù)進行數據預處理、推理和訓練之前,我們需要安裝以下依賴項。
pip install transformers
pip install sentencepiece
pip install jiwer
pip install datasets
pip install evaluate
pip install -U accelerate
pip install matplotlib
pip install protobuf==3.20.1
pip install tensorboard
4.GNHK數據集預處理
預訓練的TrOCR模型只能識別單個單詞或單行句子,而GNHK數據集中的圖像是整個文檔的圖像。因此需要對數據集進行預處理,以便模型能夠更好地處理這些圖像。
數據集預處理的關鍵步驟如下:
- 轉換多邊形坐標為四點邊界框坐標。
- 裁剪每個單詞并存儲在單獨的目錄中。
- 創(chuàng)建兩個 CSV 文件,一個用于訓練集,一個用于測試集。這些文件將包含裁剪后的圖像名稱和標簽文本。
代碼實現:
import os
import json
import csv
import cv2
import numpy as np
from tqdm import tqdm
def create_directories():
"""
創(chuàng)建必要的目錄
"""
dirs = [
'input/gnhk_dataset/train_processed/images',
'input/gnhk_dataset/test_processed/images',
]
for dir_path in dirs:
os.makedirs(dir_path, exist_ok=True)
def polygon_to_bbox(polygon):
"""
將多邊形坐標轉換為四點邊界框坐標
"""
points = np.array([(polygon[f'x{i}'], polygon[f'y{i}']) for i in range(4)], dtype=np.int32)
x, y, w, h = cv2.boundingRect(points)
return x, y, w, h
def process_dataset(input_folder, output_folder, csv_path):
"""
處理數據集,裁剪圖像并生成 CSV 文件
"""
with open(csv_path, 'w', newline='') as csvfile:
csv_writer = csv.writer(csvfile)
csv_writer.writerow(['image_filename', 'text'])
for filename in tqdm(os.listdir(input_folder), desc=f"Processing {os.path.basename(input_folder)}"):
if filename.endswith('.json'):
json_path = os.path.join(input_folder, filename)
img_path = os.path.join(input_folder, filename.replace('.json', '.jpg'))
with open(json_path, 'r') as f:
data = json.load(f)
img = cv2.imread(img_path)
for idx, item in enumerate(data):
text = item['text']
if text.startswith('%') and text.endswith('%'):
text = 'SPECIAL_CHARACTER'
x, y, w, h = polygon_to_bbox(item['polygon'])
cropped_img = img[y:y+h, x:x+w]
output_filename = f"{filename.replace('.json', '')}_{idx}.jpg"
output_path = os.path.join(output_folder, output_filename)
cv2.imwrite(output_path, cropped_img)
csv_writer.writerow([output_filename, text])
def main():
"""
主函數,創(chuàng)建目錄并處理數據集
"""
create_directories()
process_dataset(
'input/gnhk_dataset/train_data/train',
'input/gnhk_dataset/train_processed/images',
'input/gnhk_dataset/train_processed.csv'
)
process_dataset(
'input/gnhk_dataset/test_data/test',
'input/gnhk_dataset/test_processed/images',
'input/gnhk_dataset/test_processed.csv'
)
if __name__ == '__main__':
main()
將上述代碼保存為preprocess_gnhk_dataset.py文件。在終端中運行腳本。
python preprocess_gnhk_dataset.py
運行腳本后,將會在 input/gnhk_dataset 目錄下創(chuàng)建以下子目錄和文件:
子目錄:
- train_processed/images:存儲訓練集的裁剪圖像。
- test_processed/images:存儲測試集的裁剪圖像。
CSV 文件:
- train_processed.csv:包含訓練集的圖像文件名和對應的標簽文本。
- test_processed.csv:包含測試集的圖像文件名和對應的標簽文本。
以下是一些經過處理后的裁剪圖像示例:
csv文件示例如下圖所示:
每個csv文件包括裁剪后的圖像文件名和對應圖像的標簽文本。每一行表示一個裁剪后的圖像及其對應的標簽文本。
處理后的數據集包括:
- 訓練集:32495張裁剪圖像
- 測試集:10066張裁剪圖像
5.微調TrOCR模型
首先,導入必要的庫,并定義一些全局設置。
import os
import torch
import evaluate
import numpy as np
import pandas as pd
import glob as glob
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from PIL import Image
from tqdm.notebook import tqdm
from dataclasses import dataclass
from torch.utils.data import Dataset
from transformers import (
VisionEncoderDecoderModel,
TrOCRProcessor,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
default_data_collator
)
block_plot = False
plt.rcParams['figure.figsize'] = (12, 9)
os.environ["TOKENIZERS_PARALLELISM"] = 'false'
接著,為確保實驗的可重復性,設置隨機種子,并初始化計算設備。
def seed_everything(seed_value):
np.random.seed(seed_value)
torch.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
seed_everything(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
定義一些重要的配置項,包括訓練和數據集的路徑。這里設置批次大小batch size為48,訓練輪數10,基礎學習率0.00005。
@dataclass(frozen=True)
class TrainingConfig:
BATCH_SIZE: int = 48
EPOCHS: int = 10
LEARNING_RATE: float = 0.00005
@dataclass(frozen=True)
class DatasetConfig:
DATA_ROOT: str = 'input/gnhk_dataset'
@dataclass(frozen=True)
class ModelConfig:
MODEL_NAME: str = 'microsoft/trocr-small-handwritten'
可視化訓練樣本,以幫助我們驗證路徑、CSV文件準備和標簽是否正確。
def visualize(dataset_path, df):
all_images = df.image_filename
all_labels = df.text
plt.figure(figsize=(15, 3))
for i in range(15):
plt.subplot(3, 5, i+1)
image = plt.imread(f"{dataset_path}/test_processed/images/{all_images[i]}")
label = all_labels[i]
plt.imshow(image)
plt.axis('off')
plt.title(label)
plt.show()
sample_df = pd.read_csv(
os.path.join(DatasetConfig.DATA_ROOT, 'test_processed.csv'),
header=None,
skiprows=1,
names=['image_filename', 'text'],
nrows=50
)
visualize(DatasetConfig.DATA_ROOT, sample_df)
GNHK手寫文本識別數據集具有自定義的目錄結構和CSV文件,我們需要編寫自定義的數據集準備代碼。
- 讀取csv文件
train_df = pd.read_csv(
os.path.join(DatasetConfig.DATA_ROOT, 'train_processed.csv'),
header=None,
skiprows=1,
names=['image_filename', 'text']
)
test_df = pd.read_csv(
os.path.join(DatasetConfig.DATA_ROOT, 'test_processed.csv'),
header=None,
skiprows=1,
names=['image_filename', 'text']
)
- 為了減少過擬合,應用一些輕微的數據增強,主要包括顏色抖動和高斯模糊。
# 定義數據增強
train_transforms = transforms.Compose([
transforms.ColorJitter(brightness=0.5, hue=0.3),
transforms.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 5)),
])
- 需要創(chuàng)建一個自定義的PyTorch數據集類。
class CustomOCRDataset(Dataset):
def __init__(self, root_dir, df, processor, max_target_length=128):
self.root_dir = root_dir
self.df = df
self.processor = processor
self.max_target_length = max_target_length
# 填充空值
self.df['text'] = self.df['text'].fillna('')
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# 圖像文件名
file_name = self.df['image_filename'][idx]
# 文本(標簽)
text = self.df['text'][idx]
# 讀取圖像,應用數據增強,并獲取轉換后的像素值
image = Image.open(os.path.join(self.root_dir, file_name)).convert('RGB')
image = train_transforms(image)
pixel_values = self.processor(image, return_tensors='pt').pixel_values
# 通過分詞器對文本進行分詞,并獲取標簽
labels = self.processor.tokenizer(
text,
padding='max_length',
max_length=self.max_target_length
).input_ids
# 使用 -100 作為填充標記
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
encoding = {
"pixel_values": pixel_values.squeeze(),
"labels": torch.tensor(labels)
}
return encoding
- 初始化TrOCR處理器,并準備訓練和驗證數據集。
# 初始化處理器
processor = TrOCRProcessor.from_pretrained(ModelConfig['MODEL_NAME'])
# 準備訓練數據集
train_dataset = CustomOCRDataset(
root_dir=os.path.join(DatasetConfig['DATA_ROOT'], 'train_processed/images/'),
df=train_df,
processor=processor
)
# 準備驗證數據集
valid_dataset = CustomOCRDataset(
root_dir=os.path.join(DatasetConfig['DATA_ROOT'], 'test_processed/images/'),
df=test_df,
processor=processor
)
初始化和配置模型,并統計模型的參數數量。
- 加載模型
# 初始化模型
model = VisionEncoderDecoderModel.from_pretrained(ModelConfig['MODEL_NAME'])
model.to(device)
print(model)
# 統計總參數和可訓練參數
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
print(f"{total_trainable_params:,} training parameters.")
- 手動設置一些配置。
# 設置特殊 token 用于從標簽創(chuàng)建 decoder_input_ids
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# 設置正確的詞匯表大小
model.config.vocab_size = model.config.decoder.vocab_size
model.config.eos_token_id = processor.tokenizer.sep_token_id
# 設置最大輸出長度
model.config.max_length = 64
# 啟用提前停止
model.config.early_stopping = True
# 設置不重復 n-gram 的大小
model.config.no_repeat_ngram_size = 3
# 設置長度懲罰
model.config.length_penalty = 2.0
# 設置 beam search 的束寬
model.config.num_beams = 4
# 打印模型配置
print(model.config)
- 定義AdamW優(yōu)化器,并配置學習率和權重衰減。
# 定義 AdamW 優(yōu)化器
optimizer = optim.AdamW(
model.parameters(), lr=TrainingConfig['LEARNING_RATE'], weight_decay=0.0005
)
- 使用字符錯誤率CER對模型進行評估。
cer_metric = evaluate.load('cer')
def compute_cer(pred):
# 提取標簽的 ID
labels_ids = pred.label_ids
# 提取預測的 ID
pred_ids = pred.predictions
# 將預測的 ID 解碼為字符串,跳過特殊 token
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
# 將標簽中的 -100 轉換為 pad_token_id,以避免影響評估結果
labels_ids[labels_ids == -100] = processor.tokenizer.
# 將標簽的 ID 解碼為字符串,跳過特殊 token
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
# 使用 cer_metric 計算 CER
cer = cer_metric.compute(predictions=pred_str, references=label_str)
return {"cer": cer}
訓練和驗證模型。在開始訓練之前,需要初始化訓練參數和 Trainer API。
- 定義 Seq2SeqTrainingArguments 對象,設置訓練和驗證的相關參數。
# 初始化訓練參數
training_args = Seq2SeqTrainingArguments(
predict_with_generate=True,
evaluation_strategy='epoch',
per_device_train_batch_size=TrainingConfig['BATCH_SIZE'],
per_device_eval_batch_size=TrainingConfig['BATCH_SIZE'],
fp16=True,
output_dir='trocr_handwritten/',
logging_strategy='epoch',
save_strategy='epoch',
save_total_limit=2,
report_to='tensorboard',
num_train_epochs=TrainingConfig['EPOCHS'],
dataloader_num_workers=8
)
- 使用 Seq2SeqTrainer API 初始化訓練器。Seq2SeqTrainer 接受模型、處理器、訓練參數、數據集和數據收集器作為參數。
# 初始化訓練器
trainer = Seq2SeqTrainer(
model=model,
tokenizer=processor.feature_extractor,
args=training_args,
compute_metrics=compute_cer,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
data_collator=default_data_collator
)
- 開始微調模型。
# 開始訓練
trainer.train()
以下是訓練10個epoch后的日志示例:
在訓練完成后,我們得到了最佳的驗證 CER 值。接下來,我們將使用最后一個epoch的檢查點對驗證集進行推理。
如圖所示,驗證CER圖表在整個訓練過程中持續(xù)下降,直到最后一個 epoch。這表明模型仍在學習,并且可能通過適當的學習率調度進一步訓練幾個 epoch 以獲得更好的性能。
6.使用訓練好的TrOCR模型推理
接下來,將使用訓練好的trOCR模型對一組圖像進行推理。
- 加載處理器和訓練好的模型檢查點。
# 定義模型和處理器
processor = TrOCRProcessor.from_pretrained(ModelConfig.MODEL_NAME)
trained_model = VisionEncoderDecoderModel.from_pretrained('trocr_handwritten/checkpoint-'+str(res.global_step)).to(device)
- 定義一些輔助函數,用于讀取圖像、通過模型進行前向傳播以及繪制結果。
def read_and_show(image_path):
"""
:param image_path: String, path to the input image.
Returns:
image: PIL Image.
"""
image = Image.open(image_path).convert('RGB')
return image
def ocr(image, processor, model):
"""
:param image: PIL Image.
:param processor: Huggingface OCR processor.
:param model: Huggingface OCR model.
Returns:
generated_text: the OCR'd text string.
"""
pixel_values = processor(image, return_tensors='pt').pixel_values.to(device)
generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
def eval_new_data(data_path=None, num_samples=50, df=None):
all_images = df.image_filename
all_labels = df.text
plt.figure(figsize=(15, 3))
for i in range(num_samples):
plt.subplot(3, 5, i+1)
image = read_and_show(os.path.join(data_path, all_images[i]))
text = ocr(image, processor, trained_model)
plt.imshow(image)
plt.title(text)
plt.axis('off')
plt.show()
- 運行推理并可視化結果
# 運行推理并可視化結果
eval_new_data(
data_path=data_path,
num_samples=num_samples,
df=sample_df
)
推理結果如下圖所示。
由此可以看出,模型成功地正確預測了所有單詞。這表明經過微調后,模型在驗證集上的表現非常出色。
附錄(完整代碼)
鏈接:https://pan.baidu.com/s/1R5-JB7zKTeb1pJ0kS2Tmnw
提取碼:d388