基于 Faster ViT 進(jìn)行圖像分類
Faster Vision Transformer(FVT)是Vision Transformer(ViT)架構(gòu)的一個變體,這是一種為計算機(jī)視覺任務(wù)設(shè)計的神經(jīng)網(wǎng)絡(luò)。FVT 是原始 ViT 模型的更快、更高效版本,原始模型由 Dosovitskiy 等人在 2020 年的論文 “一幅圖像值 16x16 個詞:用于大規(guī)模圖像識別的轉(zhuǎn)換器” 中引入。
FVT 的關(guān)鍵特性
- 高效架構(gòu):FVT 旨在比原始 ViT 模型更快、更高效。它通過減少參數(shù)數(shù)量和計算復(fù)雜性,同時保持類似的性能來實現(xiàn)這一點(diǎn)。
- 多尺度視覺轉(zhuǎn)換器:FVT 使用多尺度視覺轉(zhuǎn)換器架構(gòu),允許它以多種尺度和分辨率處理圖像。這是通過使用層次結(jié)構(gòu)實現(xiàn)的,其中較小的轉(zhuǎn)換器用于處理圖像的較小區(qū)域。
- 自注意力機(jī)制:FVT 使用自注意力機(jī)制,允許它對圖像的不同部分之間的復(fù)雜關(guān)系進(jìn)行建模。這是通過使用在訓(xùn)練過程中學(xué)習(xí)到的注意力權(quán)重來實現(xiàn)的。
- 位置編碼:FVT 使用位置編碼來保留圖像的空間信息。這是通過使用學(xué)習(xí)到的位置嵌入來實現(xiàn)的,它們被添加到輸入令牌中。
首先,讓我們開始實現(xiàn)在自定義數(shù)據(jù)集上訓(xùn)練視覺轉(zhuǎn)換器。為此,我們需要通過 pip 安裝 fastervit。
pip install fastervit
讓我們導(dǎo)入我們剛剛通過 pip 安裝的 pytorch 庫以及更快視覺轉(zhuǎn)換器庫。
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
在這個實現(xiàn)中,我從 Kaggle 下載了損壞道路數(shù)據(jù)集。在這里檢查。然后將它們分割為訓(xùn)練和驗證數(shù)據(jù)集。之后加載數(shù)據(jù)集并應(yīng)用數(shù)據(jù)轉(zhuǎn)換。
data_dir = 'sih_road_dataset'
# Define data transformations
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
# Load datasets
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4) for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
接下來我們將加載更快視覺轉(zhuǎn)換器模型。
# Load the FasterViT model and modify it for your number of classes.
from fastervit import create_model
# Load FasterViT model
model = create_model('faster_vit_0_224',
pretrained=True,
model_path="faster_vit_0.pth.tar")
# Print the model architecture
print(model)
接下來我們將加載更快視覺轉(zhuǎn)換器模型。
# Load the FasterViT model and modify it for your number of classes.
from fastervit import create_model
# Load FasterViT model
model = create_model('faster_vit_0_224',
pretrained=True,
model_path="faster_vit_0.pth.tar")
# Print the model architecture
print(model)
當(dāng)我們打印模型時,我們可以看到末尾的頭部層,這是需要修改以進(jìn)行微調(diào)的部分。
為了針對您的自定義分類任務(wù)修改這一層,您應(yīng)該用一個具有適當(dāng)數(shù)量輸出類別的新線性層替換頭部層。
# Modify the final layer for custom classification
num_ftrs = model.head.in_features
model.head = torch.nn.Linear(num_ftrs, len(class_names))
# Move the model to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
接下來指定優(yōu)化器和學(xué)習(xí)率,
import torch.optim as optim
from torch.optim import lr_scheduler
# Define loss function
criterion = torch.nn.CrossEntropyLoss()
# Define optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Learning rate scheduler
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
好的,現(xiàn)在一切都已定義,我們現(xiàn)在指定將用于訓(xùn)練我們模型的自定義數(shù)據(jù)集的訓(xùn)練函數(shù)。
import time
import copy
def train_model(model, criterion, optimizer, scheduler, num_epochs=5):
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
for epoch in range(num_epochs):
print(f'Epoch {epoch}/{num_epochs - 1}')
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
# Iterate over data.
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
# Zero the parameter gradients
optimizer.zero_grad()
# Forward
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# Backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# Statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
if phase == 'train':
scheduler.step()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
# Deep copy the model
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
print()
time_elapsed = time.time() - since
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Best val Acc: {best_acc:.4f}')
# Load best model weights
model.load_state_dict(best_model_wts)
return model
下一步是啟動訓(xùn)練過程!
# Train the model
model = train_model(model, criterion, optimizer, exp_lr_scheduler, num_epochs=5)
# Save the model
torch.save(model.state_dict(), 'faster_vit_custom_model.pth')
請注意,這不是最好的模型,因為我們可以看到模型在訓(xùn)練數(shù)據(jù)集上過擬合了。本文的主要目的是演示如何實現(xiàn) Faster Vision Transformer 并在自定義數(shù)據(jù)集上訓(xùn)練它們。還有其他方法可以解決過擬合問題。
讓我們對下面的圖像進(jìn)行訓(xùn)練過的模型的快速測試:
import torch
from torchvision import transforms
from PIL import Image
from fastervit import create_model
# Define the number of classes in your custom dataset
num_classes = 4 # Replace with your actual number of classes
# Create the model architecture
model = create_model('faster_vit_0_224', pretrained=False)
# Modify the final classification layer to match the number of classes in your custom dataset
model.head = torch.nn.Linear(model.head.in_features, num_classes)
# Move the model to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Load the trained model weights
model.load_state_dict(torch.load('faster_vit_custom_model.pth'))
model.eval() # Set the model to evaluation mode
# Define data transformations for the input image
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Function to load and preprocess the image
def load_image(image_path):
image = Image.open(image_path).convert('RGB')
image = preprocess(image)
image = image.unsqueeze(0) # Add batch dimension
return image.to(device)
# Function to make predictions
def predict(image_path, model, class_names):
image = load_image(image_path)
with torch.no_grad():
outputs = model(image)
_, preds = torch.max(outputs, 1)
predicted_class = class_names[preds.item()]
return predicted_class
# List of class names (ensure this matches your custom dataset's classes)
class_names = ['good', 'poor', 'satisfactory', 'very_poor'] # Replace with your actual class names
# Example usage
image_path = 'test_img.jpg'
predicted_class = predict(image_path, model, class_names)
print(predicted_class)
預(yù)測的類別是,