基于視覺 Transformer(ViT)進(jìn)行圖像分類
近年來,Transformer 架構(gòu)徹底改變了自然語言處理(NLP)任務(wù)。視覺Transformer(ViT)將這一創(chuàng)新更進(jìn)一步,將變換器架構(gòu)適應(yīng)于圖像分類任務(wù)。本教程將指導(dǎo)您使用ViT對花卉圖像進(jìn)行分類。
一、先決條件
要跟隨本教程,您應(yīng)該具備以下基礎(chǔ)知識:
- Python編程
- 深度學(xué)習(xí)概念
- TensorFlow和Keras
二、數(shù)據(jù)集概覽
在本教程中,我們將使用一個包含3670張圖像的花卉數(shù)據(jù)集,這些圖像被歸類為五個類別:雛菊、蒲公英、玫瑰、向日葵和郁金香。數(shù)據(jù)集已預(yù)先分割為訓(xùn)練和測試集,以方便使用。
第1步:理解視覺Transformer架構(gòu)
視覺Transformer(ViT)是由谷歌研究引入的一種新穎架構(gòu),它將最初為自然語言處理(NLP)開發(fā)的Transformer架構(gòu)應(yīng)用于計算機(jī)視覺任務(wù)。與傳統(tǒng)的卷積神經(jīng)網(wǎng)絡(luò)(CNN)不同,ViT將圖像分割成塊,并處理這些塊作為標(biāo)記序列,類似于NLP任務(wù)中處理單詞的方式。
ViT的關(guān)鍵優(yōu)勢:
- 能夠有效處理大規(guī)模數(shù)據(jù)集。
- 在圖像分類任務(wù)中實(shí)現(xiàn)最先進(jìn)的性能。
- 具有高效的遷移學(xué)習(xí)能力。
讓我們深入了解ViT架構(gòu)的關(guān)鍵組成部分:
(1) 將輸入圖像分割成塊
與傳統(tǒng)的卷積神經(jīng)網(wǎng)絡(luò)(CNN)不同,ViT將輸入圖像分割成固定大小的塊。然后每個塊被展平成一維向量。例如,一個來自3通道圖像(RGB)的16x16塊將產(chǎn)生一個768維向量(16 * 16 * 3)。
圖表:圖像到塊
+-----------------+
| Image |
| (224 x 224) |
+-----------------+
|
V
+---------------------+
| Patch 1 |
| (16 x 16) |
+---------------------+
|
V
+---------------------+
| Patch 2 |
| (16 x 16) |
+---------------------+
|
...
|
V
+---------------------+
| Patch n |
| (16 x 16) |
+---------------------+
(2) 塊的線性embedding
每個展平的塊被線性embedding到固定大小的向量中。這一步類似于NLP中使用的詞embedding,將塊轉(zhuǎn)換為適合Transformer處理的格式。
圖表:塊embedding
+---------------------+
| Flattened Patch 1 |
| [p1, p2, ..., pn] |
+---------------------+
|
V
+---------------------------+
| Linear Embedding |
| [e1, e2, ..., em] |
+---------------------------+
(3) 添加位置embedding
為了保留空間信息,將位置embedding添加到每個塊embedding中。這有助于模型理解每個塊在原始圖像中的相對位置。
圖表:位置embedding
+---------------------------+ +-----------------------+
| Linear Embedded Patches | + | Positional Embeddings |
| [e1, e2, ..., em] | | [pe1, pe2, ..., pem] |
+---------------------------+ +-----------------------+
| |
V V
+---------------------------+
| Embedded Patches + |
| Positional Embeddings |
| [e1+pe1, e2+pe2, ..., |
| em+pem] |
+---------------------------+
(4) 類別標(biāo)記
在embedding塊的序列前添加一個可學(xué)習(xí)的分類標(biāo)記([CLS])。這個標(biāo)記用于聚合所有塊的信息,并最終用于分類。
圖表:添加類別標(biāo)記
+---------------------------+
| Class Token |
| [cls] |
+---------------------------+
|
V
+---------------------------+ +---------------------------+
| Embedded Patches + | | Class Token + |
| Positional Embeddings | --> | Embedded Patches + |
| [e1+pe1, e2+pe2, ..., | | Positional Embeddings |
| em+pem] | | [cls, e1+pe1, e2+pe2, ... |
+---------------------------+ | em+pem] |
+---------------------------+
(5) Transformer編碼器
將向量序列(類別標(biāo)記+embedding塊)傳遞過一系列變換器編碼器層。每一層由多頭自注意力和MLP塊組成。
圖表:Transformer編碼器
+------------------------------------+
| Transformer Encoder Layer |
| |
| +------------------------------+ |
| | Multi-Headed Self-Attention | |
| +------------------------------+ |
| |
| +------------------------------+ |
| | MLP Block | |
| +------------------------------+ |
| |
+------------------------------------+
|
V
+------------------------------------+
| Output Sequence |
| [cls, e1', e2', ..., em'] |
+------------------------------------+
每個編碼器層處理輸入序列并產(chǎn)生相同長度和維度的輸出序列。自注意力機(jī)制允許每個塊關(guān)注所有其他塊,使模型能夠捕捉塊之間的長期依賴性和交互。
(6) 分類頭
用于分類的[CLS]標(biāo)記的最終隱藏狀態(tài)。將全連接層應(yīng)用于[CLS]標(biāo)記的輸出以預(yù)測類別概率。
圖表:分類頭
+------------------------------------+
| Output Sequence |
| [cls, e1', e2', ..., em'] |
+------------------------------------+
|
V
+---------------------------+
| Fully Connected Layer |
| [class probabilities] |
+---------------------------+
第2步:實(shí)現(xiàn)視覺Transformer
讓我們逐一了解vit.py文件中ViT實(shí)現(xiàn)的主要組成部分:
(1) 類別標(biāo)記
這個類創(chuàng)建一個可學(xué)習(xí)的分類標(biāo)記,該標(biāo)記被添加到塊嵌入序列的前面。
class ClassToken(Layer):
def __init__(self):
super().__init__()
def build(self, input_shape):
w_init = tf.random_normal_initializer()
self.w = tf.Variable(
initial_value=w_init(shape=(1, 1, input_shape[-1]), dtype=tf.float32),
trainable=True
)
def call(self, inputs):
batch_size = tf.shape(inputs)[0]
hidden_dim = self.w.shape[-1]
cls = tf.broadcast_to(self.w, [batch_size, 1, hidden_dim])
cls = tf.cast(cls, dtype=inputs.dtype)
return cls
(2) MLP塊
這個函數(shù)實(shí)現(xiàn)了Transformer編碼器中使用的MLP塊。
def mlp(x, cf):
x = Dense(cf["mlp_dim"], activation="gelu")(x)
x = Dropout(cf["dropout_rate"])(x)
x = Dense(cf["hidden_dim"])(x)
x = Dropout(cf["dropout_rate"])(x)
return x
(3) Transformer編碼器
這個函數(shù)實(shí)現(xiàn)了一個Transformer編碼器層,包括自注意力和MLP塊。
def transformer_encoder(x, cf):
skip_1 = x
x = LayerNormalization()(x)
x = MultiHeadAttention(
num_heads=cf["num_heads"], key_dim=cf["hidden_dim"]
)(x, x)
x = Add()([x, skip_1])
skip_2 = x
x = LayerNormalization()(x)
x = mlp(x, cf)
x = Add()([x, skip_2])
return x
(4) 視覺Transformer模型
這個函數(shù)組裝完整的視覺Transformer模型。
def ViT(cf):
inputs = Input(shape=cf["input_shape"])
patches = Patches(cf["patch_size"])(inputs)
x = PatchEncoder(num_patches=cf["num_patches"], projection_dim=cf["projection_dim"])(patches)
cls_token = ClassToken()(x)
x = Concatenate(axis=1)([cls_token, x])
for _ in range(cf["num_layers"]):
x = transformer_encoder(x, cf)
x = LayerNormalization()(x)
x = x[:, 0]
x = Dense(cf["num_classes"], activation="softmax")(x)
model = Model(inputs, x)
return model
第3步:數(shù)據(jù)準(zhǔn)備和加載
在train.py文件中,我們處理數(shù)據(jù)準(zhǔn)備和加載:
(1) 加載和分割數(shù)據(jù)集
這個函數(shù)加載數(shù)據(jù)集并將其分割為訓(xùn)練、驗證和測試集。
from glob import glob
import os
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
def load_data(path, split=0.1):
images = shuffle(glob(os.path.join(path, "*", "*.jpg")))
split_size = int(len(images) * split)
train_x, valid_x = train_test_split(images, test_size=split_size, random_state=42)
train_x, test_x = train_test_split(train_x, test_size=split_size, random_state=42)
return train_x, valid_x, test_x
(2) 處理圖像和創(chuàng)建塊
這個函數(shù)處理圖像、調(diào)整大小并創(chuàng)建塊。
import cv2
import numpy as np
from patchify import patchify
def process_image_label(path):
image = cv2.imread(path)
image = cv2.resize(image, (hp["image_size"], hp["image_size"]))
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = image / 255.0
patch_shape = (hp["patch_size"], hp["patch_size"], hp["num_channels"])
patches = patchify(image, patch_shape, hp["patch_size"])
patches = np.reshape(patches, hp["flat_patches_shape"])
patches = patches.astype(np.float32)
label = os.path.basename(os.path.dirname(path))
class_idx = hp["class_names"].index(label)
class_idx = np.array(class_idx, dtype=np.float32)
return patches, class_idx
(3) 創(chuàng)建TensorFlow數(shù)據(jù)集
這個函數(shù)從處理過的圖像創(chuàng)建TensorFlow數(shù)據(jù)集。
import tensorflow as tf
def tf_dataset(images, batch=32):
ds = tf.data.Dataset.from_tensor_slices((images))
ds = ds.map(parse).batch(batch).prefetch(8)
return ds
第4步:模型訓(xùn)練
在train.py文件中,我們設(shè)置訓(xùn)練過程:
(1) 編譯模型
這個函數(shù)使用指定的優(yōu)化器和損失函數(shù)編譯ViT模型。
model.compile(
loss="categorical_crossentropy",
optimizer=tf.keras.optimizers.Adam(hp["lr"], clipvalue=1.0),
metrics=["acc"]
)
(2) 設(shè)置回調(diào)
這個函數(shù)設(shè)置各種回調(diào),用于在訓(xùn)練期間監(jiān)控和保存模型。
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, CSVLogger, EarlyStopping
callbacks = [
ModelCheckpoint(model_path, monitor='val_loss', verbose=1, save_best_only=True),
ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, min_lr=1e-10, verbose=1),
CSVLogger(csv_path),
EarlyStopping(monitor='val_loss', patience=50, restore_best_weights=False),
]
(3) 訓(xùn)練模型
這個函數(shù)使用訓(xùn)練和驗證數(shù)據(jù)集訓(xùn)練ViT模型。
model.fit(
train_ds,
epochs=hp["num_epochs"],
validation_data=valid_ds,
callbacks=callbacks
)
第5步:模型評估
在test.py文件中,我們加載訓(xùn)練好的模型并在測試集上評估它:
model = ViT(hp)
model.load_weights(model_path)
model.compile(
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
optimizer=tf.keras.optimizers.Adam(hp["lr"]),
metrics=["acc"]
)
model. Evaluate(test_ds)
結(jié)論
在本教程中,我們實(shí)現(xiàn)了一個用于花卉圖像分類的視覺Transformer(ViT)。我們涵蓋了以下關(guān)鍵點(diǎn):
- 視覺Transformer的架構(gòu)
- 使用TensorFlow和Keras實(shí)現(xiàn)ViT模型
- 為花卉數(shù)據(jù)集準(zhǔn)備和加載數(shù)據(jù)
- 模型訓(xùn)練過程
- 在測試集上評估模型
視覺Transformer展示了注意力機(jī)制在計算機(jī)視覺任務(wù)中的強(qiáng)大能力,可能取代或補(bǔ)充傳統(tǒng)的CNN架構(gòu)。通過遵循本教程,您將獲得有關(guān)圖像分類的尖端深度學(xué)習(xí)模型的實(shí)踐經(jīng)驗。
進(jìn)一步探索
為了進(jìn)一步提高您的理解和結(jié)果,您可以嘗試:
- 嘗試不同的超參數(shù)
- 嘗試數(shù)據(jù)增強(qiáng)技術(shù)
- 比較ViT與基于CNN的模型的性能
- 可視化注意力圖以了解模型關(guān)注的內(nèi)容
請記住,視覺Transformer通常在大型數(shù)據(jù)集上預(yù)訓(xùn)練并在較小的特定任務(wù)數(shù)據(jù)集上微調(diào)時表現(xiàn)最佳。在本教程中,我們在相對較小的數(shù)據(jù)集上從頭開始訓(xùn)練,但原理保持不變。通過遵循這些步驟,您將能夠?qū)崿F(xiàn)并訓(xùn)練一個用于花卉圖像分類的視覺Transformer模型,深入了解現(xiàn)代深度學(xué)習(xí)技術(shù)在計算機(jī)視覺中的應(yīng)用。
論文鏈接:https://arxiv.org/pdf/2010.11929.pdf
GitHub鏈接:https://github.com/sanjay-dutta/Computer-Vision-Practice/tree/main/Vit_flower
數(shù)據(jù)集鏈接:https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz