終于把 Unet 算法搞懂了!!
今天給大家分享一個超強(qiáng)的算法模型,Unet
UNet 是一種經(jīng)典的卷積神經(jīng)網(wǎng)絡(luò)(CNN)架構(gòu),最初由 Olaf Ronneberger 等人在 2015 年提出,專為生物醫(yī)學(xué)圖像分割設(shè)計。
它的獨特之處在于其編碼器-解碼器對稱結(jié)構(gòu),能夠有效地在多尺度上提取特征并生成精確的像素級分割結(jié)果。
UNet 算法在圖像分割任務(wù)中表現(xiàn)優(yōu)異,尤其是在需要精細(xì)邊界的場景中廣泛應(yīng)用,如醫(yī)學(xué)影像分割、衛(wèi)星圖像分割等。
圖片
UNet 架構(gòu)
UNet 模型由兩部分組成:編碼器和解碼器,中間通過跳躍連接(Skip Connections)相連。
UNet 的設(shè)計理念是將輸入圖像經(jīng)過一系列卷積和下采樣操作逐漸提取高層次特征(編碼路徑),然后通過上采樣逐步恢復(fù)原始的分辨率(解碼路徑),并將編碼路徑中對應(yīng)的特征與解碼路徑進(jìn)行跳躍連接(skip connection)。這種跳躍連接能夠幫助網(wǎng)絡(luò)結(jié)合低層次細(xì)節(jié)信息和高層次語義信息,實現(xiàn)精確的像素級分割。
編碼器
類似傳統(tǒng)的卷積神經(jīng)網(wǎng)絡(luò),編碼器的主要任務(wù)是逐漸壓縮輸入圖像的空間分辨率,提取更高層次的特征。
這個部分包含一系列卷積層和最大池化層(max pooling),每次池化操作都會將圖像的空間維度減少一半。
圖片
解碼器
解碼器的任務(wù)是通過逐漸恢復(fù)圖像的空間分辨率,將編碼器部分提取到的高層次特征映射回原始的圖像分辨率。
解碼器包含反卷積(上采樣)操作,并結(jié)合來自編碼器的相應(yīng)特征層,以實現(xiàn)精細(xì)的邊界恢復(fù)。
圖片
跳躍連接
跳躍連接是 UNet 的一個關(guān)鍵創(chuàng)新點。
每個編碼器層的輸出特征圖與解碼器中對應(yīng)層的特征圖進(jìn)行拼接,形成跳躍連接。
這樣可以將編碼器中的局部信息和解碼器中的全局信息進(jìn)行融合,從而提高分割結(jié)果的精度。
圖片
UNet 算法工作流程
- 輸入圖像
- 編碼階段
每個編碼塊包含兩個 3x3 卷積層(帶有 ReLU 激活函數(shù))和一個 2x2 最大池化層,池化層用于下采樣。
經(jīng)過每個編碼塊后,特征圖的空間尺寸減少一半,但通道數(shù)量翻倍。
- 瓶頸層
在網(wǎng)絡(luò)的最底部,這部分用來提取最深層次的特征。
- 解碼階段
每個解碼塊包含一個 2x2 轉(zhuǎn)置卷積(或上采樣操作)和兩個 3x3 卷積層(帶有 ReLU 激活函數(shù))。
- 與編碼路徑不同的是,解碼過程中每次上采樣時,還將相應(yīng)的編碼層的特征拼接(跳躍連接)到解碼層。
輸出層
最后一層通過 1x1 卷積將輸出通道數(shù)映射為類別數(shù),用于生成分割掩碼。
最終輸出的是一個大小與輸入圖像相同的分割圖。
代碼示例
下面是一個使用 UNet 進(jìn)行圖像分割的簡單示例代碼。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision import datasets
import matplotlib.pyplot as plt
# UNet 模型定義
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
def conv_block(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.encoder1 = conv_block(1, 64)
self.encoder2 = conv_block(64, 128)
self.encoder3 = conv_block(128, 256)
self.encoder4 = conv_block(256, 512)
self.pool = nn.MaxPool2d(2)
self.bottleneck = conv_block(512, 1024)
self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.decoder4 = conv_block(1024, 512)
self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.decoder3 = conv_block(512, 256)
self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.decoder2 = conv_block(256, 128)
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.decoder1 = conv_block(128, 64)
self.conv_last = nn.Conv2d(64, 1, kernel_size=1)
def forward(self, x):
# Encoder
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool(enc1))
enc3 = self.encoder3(self.pool(enc2))
enc4 = self.encoder4(self.pool(enc3))
# Bottleneck
bottleneck = self.bottleneck(self.pool(enc4))
# Decoder
dec4 = self.upconv4(bottleneck)
dec4 = torch.cat((dec4, enc4), dim=1)
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, enc3), dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.decoder1(dec1)
return torch.sigmoid(self.conv_last(dec1))
# 創(chuàng)建數(shù)據(jù)集
class RandomDataset(Dataset):
def __init__(self, num_samples, image_size):
self.num_samples = num_samples
self.image_size = image_size
self.transform = transforms.Compose([transforms.ToTensor()])
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
image = torch.randn(1, self.image_size, self.image_size) # 隨機(jī)生成圖像
mask = (image > 0).float() # 隨機(jī)生成掩碼
return image, mask
# 訓(xùn)練模型
def train_model():
image_size = 128
batch_size = 8
num_epochs = 10
learning_rate = 1e-3
# 實例化模型、損失函數(shù)和優(yōu)化器
model = UNet()
criterion = nn.BCELoss() # 使用二元交叉熵?fù)p失
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
dataset = RandomDataset(num_samples=100, image_size=image_size)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
for epoch in range(num_epochs):
for images, masks in dataloader:
outputs = model(images)
loss = criterion(outputs, masks)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
# 測試一個隨機(jī)樣本
test_image, test_mask = dataset[0]
model.eval()
with torch.no_grad():
prediction = model(test_image.unsqueeze(0))
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.title('Input Image')
plt.imshow(test_image.squeeze().numpy(), cmap='gray')
plt.subplot(1, 3, 2)
plt.title('Ground Truth Mask')
plt.imshow(test_mask.squeeze().numpy(), cmap='gray')
plt.subplot(1, 3, 3)
plt.title('Predicted Mask')
plt.imshow(prediction.squeeze().numpy(), cmap='gray')
plt.show()
train_model()
UNet 的成功源于其有效的特征提取與恢復(fù)機(jī)制,特別是跳躍連接的設(shè)計,使得編碼過程中丟失的細(xì)節(jié)能夠通過解碼階段恢復(fù)。
UNet 在醫(yī)學(xué)圖像分割等任務(wù)上有著廣泛的應(yīng)用,能夠生成高精度的像素級分割結(jié)果。