我們一起快速學(xué)會(huì)一個(gè)算法-UNet
大家好,我是小寒。
今天給大家分享一個(gè)超強(qiáng)的算法模型,UNet
UNet 是一種專門(mén)用于圖像分割任務(wù)的卷積神經(jīng)網(wǎng)絡(luò)(CNN)架構(gòu),最早由 Olaf Ronneberger 等人在 2015 年提出。
UNet 的名字來(lái)源于其結(jié)構(gòu)的對(duì)稱性,類似于字母“U”。UNet 模型由于其優(yōu)越的分割性能,被廣泛應(yīng)用于各種圖像分割任務(wù),如醫(yī)學(xué)圖像分割等。
圖片
Unet 模型架構(gòu)
UNet 模型由兩部分組成:編碼器(Contracting Path)和解碼器(Expanding Path),中間通過(guò)跳躍連接(Skip Connections)相連。
編碼器(收縮路徑)
編碼器部分主要用于提取輸入圖像的特征。
它由一系列的卷積層、ReLU激活函數(shù)、最大池化層(Max Pooling)組成。
- 每個(gè)卷積層通常包含兩次卷積操作(使用 3x3 卷積核),每次卷積操作后接一個(gè) ReLU 激活函數(shù)。
- 之后,采用一個(gè) 2x2 的最大池化層(Max Pooling)進(jìn)行下采樣,以減少特征圖的空間維度。
- 每次下采樣后,特征圖的空間尺寸減小,而通道數(shù)增加,以提取更高層次的特征。
解碼器(擴(kuò)展路徑)
解碼器部分用于恢復(fù)圖像的空間信息,最終輸出與輸入圖像相同大小的分割結(jié)果。
它由上采樣(up-sampling)操作和卷積層組成。
- 上采樣(Upsampling),通常通過(guò)反卷積將特征圖的空間分辨率逐步恢復(fù)。
- 上采樣后,通過(guò)跳躍連接(Skip Connection)將對(duì)應(yīng)層的編碼器特征與解碼器特征拼接在一起,這樣可以保留輸入圖像的細(xì)節(jié)。
- 拼接后的特征圖經(jīng)過(guò)兩次卷積操作(同樣使用 3x3 卷積核)和 ReLU 激活函數(shù)進(jìn)行處理。
- 最終,經(jīng)過(guò)逐步上采樣和卷積,恢復(fù)到與輸入圖像相同的分辨率。
跳躍連接 (Skip Connections)
在UNet中,跳躍連接將編碼器中每一層的輸出與解碼器中相應(yīng)層的輸入相連,確保模型在還原圖像分辨率時(shí)保留更多的細(xì)節(jié)信息。
這種連接允許網(wǎng)絡(luò)在進(jìn)行上采樣時(shí)參考編碼器部分的特征,從而更好地復(fù)原高分辨率特征。
UNet模型的優(yōu)點(diǎn)
- 高效處理小樣本數(shù)據(jù)集
UNet 最初設(shè)計(jì)用于生物醫(yī)學(xué)圖像分割,具有高效利用小樣本數(shù)據(jù)集的能力。 - 精細(xì)的分割結(jié)果
通過(guò)跳躍連接,UNet 能夠很好地保留高分辨率的細(xì)節(jié),使得分割結(jié)果更為精確。 - 靈活性強(qiáng)
UNet 結(jié)構(gòu)簡(jiǎn)單且有效,容易擴(kuò)展和調(diào)整,適應(yīng)不同類型的分割任務(wù)。
案例分享
下面是一個(gè)使用 PyTorch 實(shí)現(xiàn) UNet 模型的代碼示例。這個(gè)示例展示了一個(gè)簡(jiǎn)化版的UNet模型,并應(yīng)用于圖像分割任務(wù)。
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(UNet, self).__init__()
# 編碼器部分
self.encoder1 = self.double_conv(in_channels, 64)
self.encoder2 = self.double_conv(64, 128)
self.encoder3 = self.double_conv(128, 256)
self.encoder4 = self.double_conv(256, 512)
# 最底部的卷積
self.bottleneck = self.double_conv(512, 1024)
# 解碼器部分
self.upconv4 = self.upconv(1024, 512)
self.decoder4 = self.double_conv(1024, 512)
self.upconv3 = self.upconv(512, 256)
self.decoder3 = self.double_conv(512, 256)
self.upconv2 = self.upconv(256, 128)
self.decoder2 = self.double_conv(256, 128)
self.upconv1 = self.upconv(128, 64)
self.decoder1 = self.double_conv(128, 64)
# 最終的1x1卷積,用于生成分割圖
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
def double_conv(self, in_channels, out_channels):
"""兩次卷積操作"""
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def upconv(self, in_channels, out_channels):
"""上采樣操作"""
return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
def forward(self, x):
# 編碼器部分
enc1 = self.encoder1(x)
enc2 = self.encoder2(F.max_pool2d(enc1, kernel_size=2))
enc3 = self.encoder3(F.max_pool2d(enc2, kernel_size=2))
enc4 = self.encoder4(F.max_pool2d(enc3, kernel_size=2))
# Bottleneck
bottleneck = self.bottleneck(F.max_pool2d(enc4, kernel_size=2))
# 解碼器部分
dec4 = self.upconv4(bottleneck)
dec4 = torch.cat((dec4, self.crop_tensor(enc4, dec4)), dim=1)
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, self.crop_tensor(enc3, dec3)), dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, self.crop_tensor(enc2, dec2)), dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, self.crop_tensor(enc1, dec1)), dim=1)
dec1 = self.decoder1(dec1)
# 最后的1x1卷積生成輸出
return self.final_conv(dec1)
def crop_tensor(self, encoder_tensor, decoder_tensor):
"""裁剪編碼器張量,使其與解碼器張量大小匹配"""
_, _, H, W = decoder_tensor.size()
encoder_tensor = self.center_crop(encoder_tensor, H, W)
return encoder_tensor
def center_crop(self, tensor, target_height, target_width):
"""中心裁剪函數(shù)"""
_, _, h, w = tensor.size()
crop_y = (h - target_height) // 2
crop_x = (w - target_width) // 2
return tensor[:, :, crop_y:crop_y + target_height, crop_x:crop_x + target_width]
# 使用示例
model = UNet(in_channels=1, out_channels=1) # 輸入和輸出均為1通道(例如用于灰度圖像)
input_image = torch.randn(1, 1, 572, 572) # 隨機(jī)生成一個(gè)輸入圖像
output = model(input_image)
print(output.shape)