醫(yī)學(xué)圖像的深度學(xué)習(xí)的完整代碼示例:使用Pytorch對(duì)MRI腦掃描的圖像進(jìn)行分割
圖像分割是醫(yī)學(xué)圖像分析中最重要的任務(wù)之一,在許多臨床應(yīng)用中往往是第一步也是最關(guān)鍵的一步。在腦MRI分析中,圖像分割通常用于測(cè)量和可視化解剖結(jié)構(gòu),分析大腦變化,描繪病理區(qū)域以及手術(shù)計(jì)劃和圖像引導(dǎo)干預(yù),分割是大多數(shù)形態(tài)學(xué)分析的先決條件。
本文我們將介紹如何使用QuickNAT對(duì)人腦的圖像進(jìn)行分割。使用MONAI, PyTorch和用于數(shù)據(jù)可視化和計(jì)算的常見Python庫(kù),如NumPy, TorchIO和matplotlib。
本文將主要設(shè)計(jì)以下幾個(gè)方面:
- 設(shè)置數(shù)據(jù)集和探索數(shù)據(jù)
- 處理和準(zhǔn)備數(shù)據(jù)集適當(dāng)?shù)哪P陀?xùn)練
- 創(chuàng)建一個(gè)訓(xùn)練循環(huán)
- 評(píng)估模型并分析結(jié)果
完整的代碼會(huì)在本文最后提供。
設(shè)置數(shù)據(jù)目錄
使用MONAI的第一步是設(shè)置MONAI_DATA_DIRECTORY環(huán)境變量指定目錄,如果未指定將使用臨時(shí)目錄。
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)
設(shè)置數(shù)據(jù)集
將CNN模型擴(kuò)展到大腦分割的主要挑戰(zhàn)之一是人工注釋的訓(xùn)練數(shù)據(jù)的有限性。作者引入了一種新的訓(xùn)練策略,利用沒有手動(dòng)標(biāo)簽的大型數(shù)據(jù)集和有手動(dòng)標(biāo)簽的小型數(shù)據(jù)集。
首先,使用現(xiàn)有的軟件工具(例如FreeSurfer)從大型未標(biāo)記數(shù)據(jù)集中獲得自動(dòng)生成的分割,然后使用這些工具對(duì)網(wǎng)絡(luò)進(jìn)行預(yù)訓(xùn)練。在第二步中,使用更小的手動(dòng)注釋數(shù)據(jù)[2]對(duì)網(wǎng)絡(luò)進(jìn)行微調(diào)。
IXI數(shù)據(jù)集由581個(gè)健康受試者的未標(biāo)記MRI T1掃描組成。這些數(shù)據(jù)是從倫敦3家不同的醫(yī)院收集來(lái)的。使用該數(shù)據(jù)集的主要缺點(diǎn)是標(biāo)簽不是公開可用的,因此為了遵循與研究論文中相同的方法,本文將使用FreeSurfer為這些MRI T1掃描生成分割。
FreeSurfer是一個(gè)用于分析和可視化結(jié)構(gòu)的軟件包。下載和安裝說(shuō)明可以在這里找到??梢灾苯邮褂昧恕皉econ-all”命令來(lái)執(zhí)行所有皮層重建過(guò)程。
盡管FreeSurfer是一個(gè)非常有用的工具,可以利用大量未標(biāo)記的數(shù)據(jù),并以監(jiān)督的方式訓(xùn)練網(wǎng)絡(luò),但是掃描生成這些標(biāo)簽需要長(zhǎng)達(dá)5個(gè)小時(shí),所以我們這里直接使用OASIS數(shù)據(jù)集來(lái)訓(xùn)練模型,OASIS數(shù)據(jù)集是一個(gè)較小的數(shù)據(jù)集,具有公開可用的手動(dòng)注釋。
OASIS是一個(gè)向科學(xué)界免費(fèi)提供大腦神經(jīng)成像數(shù)據(jù)集的項(xiàng)目。OASIS-1是由39個(gè)受試者的橫斷面組成的數(shù)據(jù)集,獲取方式如下:
resource = "https://download.nrg.wustl.edu/data/oasis_cross-sectional_disc1.tar.gz"
md5 = "c83e216ef8654a7cc9e2a30a4cdbe0cc"
compressed_file = os.path.join(root_dir, "oasis_cross-sectional_disc1.tar.gz")
data_dir = os.path.join(root_dir, "Oasis_Data")
if not os.path.exists(data_dir):
download_and_extract(resource, compressed_file, data_dir, md5)
數(shù)據(jù)探索
如果你打開' oasis_crosssectional_disc1 .tar.gz ',你會(huì)發(fā)現(xiàn)每個(gè)主題都有不同的文件夾。例如,對(duì)于主題OAS1_0001_MR1,是這樣的:
鏡像數(shù)據(jù)文件路徑:disc1\OAS1_0001_MR1\PROCESSED\MPRAGE\T88_111\ oas1_0001_mr1_mpr_n4_anon_111_t88_masked_ggc .img
標(biāo)簽文件:disc1\OAS1_0001_MR1\FSL_SEG\OAS1_0001_MR1_mpr_n4_anon_111_t88_masked_gfc_fseg.img
數(shù)據(jù)加載和預(yù)處理
下載數(shù)據(jù)集并將其提取到臨時(shí)目錄后,需要對(duì)其進(jìn)行重構(gòu),我們希望我們的目錄看起來(lái)像這樣:
所以需要按照下面的步驟加載數(shù)據(jù):
將。img文件轉(zhuǎn)換為。nii文件并保存到新文件夾中:創(chuàng)建兩個(gè)新文件夾。Oasis_Data_Processed包括每個(gè)受試者的處理過(guò)的MRI T1掃描,Oasis_Labels_Processed包括相應(yīng)的標(biāo)簽。
new_path_data= root_dir + '/Oasis_Data_Processed/'
if not os.path.exists(new_path_data):
os.makedirs(new_path_data)
new_path_labels= root_dir + '/Oasis_Labels_Processed/'
if not os.path.exists(new_path_labels):
os.makedirs(new_path_labels)
然后就是對(duì)其進(jìn)行操作:
for i in [x for x in range(1, 43) if x != 8 and x != 24 and x != 36]:
if i < 7 or i == 9:
filename = root_dir + '/Oasis_Data/disc1/OAS1_000'+ str(i) + '_MR1/PROCESSED/MPRAGE/T88_111/OAS1_000' + str(i) + '_MR1_mpr_n4_anon_111_t88_masked_gfc.img'
elif i == 7:
filename = root_dir + '/Oasis_Data/disc1/OAS1_000'+ str(i) + '_MR1/PROCESSED/MPRAGE/T88_111/OAS1_000' + str(i) + '_MR1_mpr_n3_anon_111_t88_masked_gfc.img'
elif i==15 or i==16 or i==20 or i==24 or i==26 or i==34 or i==38 or i==39:
filename = root_dir + '/Oasis_Data/disc1/OAS1_00'+ str(i) + '_MR1/PROCESSED/MPRAGE/T88_111/OAS1_00' + str(i) + '_MR1_mpr_n3_anon_111_t88_masked_gfc.img'
else:
filename = root_dir + '/Oasis_Data/disc1/OAS1_00'+ str(i) + '_MR1/PROCESSED/MPRAGE/T88_111/OAS1_00' + str(i) + '_MR1_mpr_n4_anon_111_t88_masked_gfc.img'
img = nib.load(filename)
nib.save(img, filename.replace('.img', '.nii'))
i = i+1
具體代碼就不再粘貼了,有興趣的看看最后的完整代碼。下一步就是讀取圖像和標(biāo)簽文件名
image_files = sorted(glob(os.path.join(root_dir + '/Oasis_Data_Processed', '*.nii')))
label_files = sorted(glob(os.path.join(root_dir + '/Oasis_Labels_Processed', '*.nii')))
files = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(image_files, label_files)]
為了可視化帶有相應(yīng)標(biāo)簽的圖像,可以使用TorchIO,這是一個(gè)Python庫(kù),用于深度學(xué)習(xí)中多維醫(yī)學(xué)圖像的加載、預(yù)處理、增強(qiáng)和采樣。
image_filename = root_dir + '/Oasis_Data_Processed/OAS1_0001_MR1_mpr_n4_anon_111_t88_masked_gfc.nii'
label_filename = root_dir + '/Oasis_Labels_Processed/OAS1_0001_MR1_mpr_n4_anon_111_t88_masked_gfc_fseg.nii'
subject = torchio.Subject(image=torchio.ScalarImage(image_filename), label=torchio.LabelMap(label_filename))
subject.plot()
下面就是將數(shù)據(jù)分成3部分——訓(xùn)練、驗(yàn)證和測(cè)試。將數(shù)據(jù)分成三個(gè)不同的類別的目的是建立一個(gè)可靠的機(jī)器學(xué)習(xí)模型,避免過(guò)擬合。
我們將整個(gè)數(shù)據(jù)集分成三個(gè)部分:
Train: 80%,Validation: 10%,Test: 10%
train_inds, val_inds, test_inds = partition_dataset(data = np.arange(len(files)), ratios = [8, 1, 1], shuffle = True)
train = [files[i] for i in sorted(train_inds)]
val = [files[i] for i in sorted(val_inds)]
test = [files[i] for i in sorted(test_inds)]
print(f"Training count: {len(train)}, Validation count: {len(val)}, Test count: {len(test)}")
因?yàn)槟P托枰氖嵌S切片,所以將每個(gè)切片保存在不同的文件夾中,如下圖所示。這兩個(gè)代碼單元將訓(xùn)練集的每個(gè)MRI體積的切片保存為“.png”格式。
Save coronal slices for training images
dir = root_dir + '/TrainData'
os.makedirs(os.path.join(dir, "Coronal"))
path = root_dir + '/TrainData/Coronal/'
for file in sorted(glob(os.path.join(root_dir + '/TrainData', '*.nii'))):
image=torchio.ScalarImage(file)
data = image.data
filename = os.path.basename(file)
filename = os.path.splitext(filename)
for i in range(0, 208):
slice = data[0, :, i]
array = slice.numpy()
data_dir = root_dir + '/TrainData/Coronal/' + filename[0] + '_slice' + str(i) + '.png'
plt.imsave(fname = data_dir, arr = array, format = 'png', cmap = plt.cm.gray)
同理,下面是保存標(biāo)簽:
dir = root_dir + '/TrainLabels'
os.makedirs(os.path.join(dir, "Coronal"))
path = root_dir + '/TrainLabels/Coronal/'
for file in sorted(glob(os.path.join(root_dir + '/TrainLabels', '*.nii'))):
label = torchio.LabelMap(file)
data = label.data
filename = os.path.basename(file)
filename = os.path.splitext(filename)
for i in range(0, 208):
slice = data[0, :, i]
array = slice.numpy()
data_dir = root_dir + '/TrainLabels/Coronal/' + filename[0] + '_slice' + str(i) + '.png'
plt.imsave(fname = data_dir, arr = array, format = 'png')
為訓(xùn)練和驗(yàn)證定義圖像的變換處理
在本例中,我們將使用Dictionary Transforms,其中數(shù)據(jù)是Python字典。
train_images_coronal = []
for file in sorted(glob(os.path.join(root_dir + '/TrainData/Coronal', '*.png'))):
train_images_coronal.append(file)
train_images_coronal = natsort.natsorted(train_images_coronal)
train_labels_coronal = []
for file in sorted(glob(os.path.join(root_dir + '/TrainLabels/Coronal', '*.png'))):
train_labels_coronal.append(file)
train_labels_coronal= natsort.natsorted(train_labels_coronal)
val_images_coronal = []
for file in sorted(glob(os.path.join(root_dir + '/ValData/Coronal', '*.png'))):
val_images_coronal.append(file)
val_images_coronal = natsort.natsorted(val_images_coronal)
val_labels_coronal = []
for file in sorted(glob(os.path.join(root_dir + '/ValLabels/Coronal', '*.png'))):
val_labels_coronal.append(file)
val_labels_coronal = natsort.natsorted(val_labels_coronal)
train_files_coronal = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(train_images_coronal, train_labels_coronal)]
val_files_coronal = [{'image': image_name, 'label': label_name} for image_name, label_name in zip(val_images_coronal, val_labels_coronal)]
現(xiàn)在我們將應(yīng)用以下變換:
LoadImaged:加載圖像數(shù)據(jù)和元數(shù)據(jù)。我們使用' PILReader '來(lái)加載圖像和標(biāo)簽文件。ensure_channel_first設(shè)置為True,將圖像數(shù)組形狀轉(zhuǎn)換為通道優(yōu)先。
Rotate90d:我們將圖像和標(biāo)簽旋轉(zhuǎn)90度,因?yàn)楫?dāng)我們下載它們時(shí),它們方向是不正確的。
ToTensord:將輸入的圖像和標(biāo)簽轉(zhuǎn)換為張量。
NormalizeIntensityd:對(duì)輸入進(jìn)行規(guī)范化。
train_transforms = Compose(
[
LoadImaged(keys = ['image', 'label'], reader=PILReader(converter=lambda image: image.convert("L")), ensure_channel_first = True),
Rotate90d(keys = ['image', 'label'], k = 2),
ToTensord(keys = ['image', 'label']),
NormalizeIntensityd(keys = ['image'])
]
)
val_transforms = Compose(
[
LoadImaged(keys = ['image', 'label'], reader=PILReader(converter=lambda image: image.convert("L")), ensure_channel_first = True),
Rotate90d(keys = ['image', 'label'], k = 2),
ToTensord(keys = ['image', 'label']),
NormalizeIntensityd(keys = ['image'])
]
)
MaskColorMap將我們定義了一個(gè)新的轉(zhuǎn)換,將相應(yīng)的像素值以一種格式映射為多個(gè)標(biāo)簽。這種轉(zhuǎn)換在語(yǔ)義分割中是必不可少的,因?yàn)槲覀儽仨殲槊總€(gè)可能的類別提供二元特征。One-Hot Encoding將對(duì)應(yīng)于原始類別的每個(gè)樣本的特征賦值為1。
因?yàn)镺ASIS-1數(shù)據(jù)集只有3個(gè)大腦結(jié)構(gòu)標(biāo)簽,對(duì)于更詳細(xì)的分割,理想的情況是像他們?cè)谘芯空撐闹心菢訉?duì)28個(gè)皮質(zhì)結(jié)構(gòu)進(jìn)行注釋。在OASIS-1下載說(shuō)明中,可以找到使用FreeSurfer獲得的更多大腦結(jié)構(gòu)的標(biāo)簽。
所以本文將分割更多的神經(jīng)解剖結(jié)構(gòu)。我們要將模型的參數(shù)num_classes修改為相應(yīng)的標(biāo)簽數(shù)量,以便模型的輸出是具有N個(gè)通道的特征映射,等于num_classes。
為了簡(jiǎn)化本教程,我們將使用以下標(biāo)簽,比OASIS-1但是要比FreeSurfer的少:
- Label 0: Background
- Label 1: LeftCerebralExterior
- Label 2: LeftWhiteMatter
- Label 3: LeftCerebralCortex
所以MaskColorMap的代碼如下:
class MaskColorMap(Enum):
Background = (30)
LeftCerebralExterior = (91)
LeftWhiteMatter = (137)
LeftCerebralCortex = (215)
數(shù)據(jù)集和數(shù)據(jù)加載
數(shù)據(jù)集和數(shù)據(jù)加載器從存儲(chǔ)中提取數(shù)據(jù),并將其分批發(fā)送給訓(xùn)練循環(huán)。這里我們使用monai.data.Dataset加載之前定義的訓(xùn)練和驗(yàn)證字典,并對(duì)輸入數(shù)據(jù)應(yīng)用相應(yīng)的轉(zhuǎn)換。dataloader用于將數(shù)據(jù)集加載到內(nèi)存中。我們將為訓(xùn)練和驗(yàn)證以及每個(gè)視圖定義一個(gè)數(shù)據(jù)集和數(shù)據(jù)加載器。
為了方便演示,我們使用通過(guò)使用torch.utils.data.Subset,在指定的索引處創(chuàng)建一個(gè)子集,只是用部分?jǐn)?shù)據(jù)訓(xùn)練加快演示速度。
train_dataset_coronal = Dataset(data=train_files_coronal, transform = train_transforms)
train_loader_coronal = DataLoader(train_dataset_coronal, batch_size = 1, shuffle = True)
val_dataset_coronal = Dataset(data = val_files_coronal, transform = val_transforms)
val_loader_coronal = DataLoader(val_dataset_coronal, batch_size = 1, shuffle = False)
# We will use a subset of the dataset
subset_train = list(range(90, len(train_dataset_coronal), 120))
train_dataset_coronal_subset = torch.utils.data.Subset(train_dataset_coronal, subset_train)
train_loader_coronal_subset = DataLoader(train_dataset_coronal_subset, batch_size = 1, shuffle = True)
subset_val = list(range(90, len(val_dataset_coronal), 50))
val_dataset_coronal_subset = torch.utils.data.Subset(val_dataset_coronal, subset_val)
val_loader_coronal_subset = DataLoader(val_dataset_coronal_subset, batch_size = 1, shuffle = False)
定義模型
給定一組MRI腦掃描I = {I1,…In}及其對(duì)應(yīng)的分割S = {S1,…Sn},我們想要學(xué)習(xí)一個(gè)函數(shù)fseg: I -> S。我們將這個(gè)函數(shù)表示為F-CNN模型,稱為QuickNAT:
QuickNAT由三個(gè)二維f - cnn組成,分別在coronal, axial, sagittal視圖上操作,然后通過(guò)聚合步驟推斷最終的分割結(jié)果,該分割結(jié)果由三個(gè)網(wǎng)絡(luò)的概率圖組合而成。每個(gè)F-CNN都有一個(gè)編碼器/解碼器架構(gòu),其中有4個(gè)編碼器和4個(gè)解碼器,并由瓶頸層分隔。最后一層是帶有softmax的分類器塊。該架構(gòu)還包括每個(gè)編碼器/解碼器塊內(nèi)的殘差鏈接。
class QuickNat(nn.Module):
"""
A PyTorch implementation of QuickNAT
"""
def __init__(self, params):
"""
:param params: {'num_channels':1,
'num_filters':64,
'kernel_h':5,
'kernel_w':5,
'stride_conv':1,
'pool':2,
'stride_pool':2,
'num_classes':28
'se_block': False,
'drop_out':0.2}
"""
super(QuickNat, self).__init__()
# from monai.networks.blocks import squeeze_and_excitation as se
# self.cSE = ChannelSELayer(num_channels, reduction_ratio)
# self.encode1 = sm.EncoderBlock(params, se_block_type=se.SELayer.CSSE)
# params["num_channels"] = params["num_filters"]
# self.encode2 = sm.EncoderBlock(params, se_block_type=se.SELayer.CSSE)
# self.encode3 = sm.EncoderBlock(params, se_block_type=se.SELayer.CSSE)
# self.encode4 = sm.EncoderBlock(params, se_block_type=se.SELayer.CSSE)
# self.bottleneck = sm.DenseBlock(params, se_block_type=se.SELayer.CSSE)
# params["num_channels"] = params["num_filters"] * 2
# self.decode1 = sm.DecoderBlock(params, se_block_type=se.SELayer.CSSE)
# self.decode2 = sm.DecoderBlock(params, se_block_type=se.SELayer.CSSE)
# self.decode3 = sm.DecoderBlock(params, se_block_type=se.SELayer.CSSE)
# self.decode4 = sm.DecoderBlock(params, se_block_type=se.SELayer.CSSE)
# self.encode1 = EncoderBlock(params, se_block_type=se.ChannelSELayer)
self.encode1 = EncoderBlock(params, se_block_type=se.SELayer.CSSE)
params["num_channels"] = params["num_filters"]
self.encode2 = EncoderBlock(params, se_block_type=se.SELayer.CSSE)
self.encode3 = EncoderBlock(params, se_block_type=se.SELayer.CSSE)
self.encode4 = EncoderBlock(params, se_block_type=se.SELayer.CSSE)
self.bottleneck = DenseBlock(params, se_block_type=se.SELayer.CSSE)
params["num_channels"] = params["num_filters"] * 2
self.decode1 = DecoderBlock(params, se_block_type=se.SELayer.CSSE)
self.decode2 = DecoderBlock(params, se_block_type=se.SELayer.CSSE)
self.decode3 = DecoderBlock(params, se_block_type=se.SELayer.CSSE)
self.decode4 = DecoderBlock(params, se_block_type=se.SELayer.CSSE)
params["num_channels"] = params["num_filters"]
self.classifier = ClassifierBlock(params)
def forward(self, input):
"""
:param input: X
:return: probabiliy map
"""
e1, out1, ind1 = self.encode1.forward(input)
e2, out2, ind2 = self.encode2.forward(e1)
e3, out3, ind3 = self.encode3.forward(e2)
e4, out4, ind4 = self.encode4.forward(e3)
bn = self.bottleneck.forward(e4)
d4 = self.decode4.forward(bn, out4, ind4)
d3 = self.decode1.forward(d4, out3, ind3)
d2 = self.decode2.forward(d3, out2, ind2)
d1 = self.decode3.forward(d2, out1, ind1)
prob = self.classifier.forward(d1)
return prob
def enable_test_dropout(self):
"""
Enables test time drop out for uncertainity
:return:
"""
attr_dict = self.__dict__["_modules"]
for i in range(1, 5):
encode_block, decode_block = (
attr_dict["encode" + str(i)],
attr_dict["decode" + str(i)],
)
encode_block.drop_out = encode_block.drop_out.apply(nn.Module.train)
decode_block.drop_out = decode_block.drop_out.apply(nn.Module.train)
@property
def is_cuda(self):
"""
Check if model parameters are allocated on the GPU.
"""
return next(self.parameters()).is_cuda
def save(self, path):
"""
Save model with its parameters to the given path. Conventionally the
path should end with '*.model'.
Inputs:
- path: path string
"""
print("Saving model... %s" % path)
torch.save(self.state_dict(), path)
def predict(self, X, device=0, enable_dropout=False):
"""
Predicts the output after the model is trained.
Inputs:
- X: Volume to be predicted
"""
self.eval()
print("tensor size before transformation", X.shape)
if type(X) is np.ndarray:
# X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor)
X = (
torch.tensor(X, requires_grad=False)
.type(torch.FloatTensor)
.cuda(device, non_blocking=True)
)
elif type(X) is torch.Tensor and not X.is_cuda:
X = X.type(torch.FloatTensor).cuda(device, non_blocking=True)
print("tensor size ", X.shape)
if enable_dropout:
self.enable_test_dropout()
with torch.no_grad():
out = self.forward(X)
max_val, idx = torch.max(out, 1)
idx = idx.data.cpu().numpy()
prediction = np.squeeze(idx)
print("prediction shape", prediction.shape)
del X, out, idx, max_val
return prediction
損失函數(shù)
神經(jīng)網(wǎng)絡(luò)的訓(xùn)練需要一個(gè)損失函數(shù)來(lái)計(jì)算模型誤差。訓(xùn)練的目標(biāo)是最小化預(yù)測(cè)輸出和目標(biāo)輸出之間的損失。我們的模型使用Dice Loss 和Weighted Logistic Loss的聯(lián)合損失函數(shù)進(jìn)行優(yōu)化,其中權(quán)重補(bǔ)償數(shù)據(jù)中的高類不平衡,并鼓勵(lì)正確分割解剖邊界。
優(yōu)化器
優(yōu)化算法允許我們繼續(xù)更新模型的參數(shù)并最小化損失函數(shù)的值,我們?cè)O(shè)置了以下的超參數(shù):
學(xué)習(xí)率:初始設(shè)置為0.1,10次后降低1階。這可以通過(guò)學(xué)習(xí)率調(diào)度器來(lái)實(shí)現(xiàn)。
權(quán)重衰減:0.0001。
批量大?。?。
動(dòng)量:設(shè)置為0.95的高值,以補(bǔ)償由于小批量大小而產(chǎn)生的噪聲梯度。
訓(xùn)練網(wǎng)絡(luò)
現(xiàn)在可以訓(xùn)練模型了。對(duì)于QuickNAT需要在3個(gè)(coronal, axial, sagittal)2d切片上訓(xùn)練3個(gè)模型。然后再聚合步驟中組合三個(gè)模型的概率生成最終結(jié)果,但是本文中只演示在coronal視圖的2D切片上訓(xùn)練一個(gè)F-CNN模型,因?yàn)槠渌麅蓚€(gè)與之類似。
num_epochs = 20
start_epoch = 1
val_interval = 1
train_loss_epoch_values = []
val_loss_epoch_values = []
best_ds_mean = -1
best_ds_mean_epoch = -1
ds_mean_train_values = []
ds_mean_val_values = []
# ds_LCE_values = []
# ds_LWM_values = []
# ds_LCC_values = []
print("START TRAINING. : model name = ", "quicknat")
for epoch in range(start_epoch, num_epochs):
print("==== Epoch ["+ str(epoch) + " / "+ str(num_epochs)+ "] DONE ====")
checkpoint_name = CHECKPOINT_DIR + "/checkpoint_epoch_" + str(epoch) + "." + CHECKPOINT_EXTENSION
print(checkpoint_name)
state = {
"epoch": epoch,
"arch": "quicknat",
"state_dict": model_coronal.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
}
save_checkpoint(state = state, filename = checkpoint_name)
print("\n==== Epoch [ %d / %d ] START ====" % (epoch, num_epochs))
steps_per_epoch = len(train_dataset_coronal_subset) / train_loader_coronal_subset.batch_size
model_coronal.train()
train_loss_epoch = 0
val_loss_epoch = 0
step = 0
predictions_train = []
labels_train = []
predictions_val = []
labels_val = []
for i_batch, sample_batched in enumerate(train_loader_coronal_subset):
inputs = sample_batched['image'].type(torch.FloatTensor)
labels = sample_batched['label'].type(torch.LongTensor)
# print(f"Train Input Shape: {inputs.shape}")
labels = labels.squeeze(1)
_img_channels, _img_height, _img_width = labels.shape
encoded_label= np.zeros((_img_height, _img_width, 1)).astype(int)
for j, cls in enumerate(MaskColorMap):
encoded_label[np.all(labels == cls.value, axis = 0)] = j
labels = encoded_label
labels = torch.from_numpy(labels)
labels = torch.permute(labels, (2, 1, 0))
# print(f"Train Label Shape: {labels.shape}")
# plt.title("Train Label")
# plt.imshow(labels[0, :, :])
# plt.show()
optimizer.zero_grad()
outputs = model_coronal(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
scheduler.step()
with torch.no_grad():
_, batch_output = torch.max(outputs, dim = 1)
# print(f"Train Prediction Shape: {batch_output.shape}")
# plt.title("Train Prediction")
# plt.imshow(batch_output[0, :, :])
# plt.show()
predictions_train.append(batch_output.cpu())
labels_train.append(labels.cpu())
train_loss_epoch += loss.item()
print(f"{step}/{len(train_dataset_coronal_subset) // train_loader_coronal_subset.batch_size}, Training_loss: {loss.item():.4f}")
step += 1
predictions_train_arr, labels_train_arr = torch.cat(predictions_train), torch.cat(labels_train)
# print(predictions_train_arr.shape)
dice_metric(predictions_train_arr, labels_train_arr)
ds_mean_train = dice_metric.aggregate().item()
ds_mean_train_values.append(ds_mean_train)
dice_metric.reset()
train_loss_epoch /= step
train_loss_epoch_values.append(train_loss_epoch)
print(f"Epoch {epoch + 1} Train Average Loss: {train_loss_epoch:.4f}")
if (epoch + 1) % val_interval == 0:
model_coronal.eval()
step = 0
with torch.no_grad():
for i_batch, sample_batched in enumerate(val_loader_coronal_subset):
inputs = sample_batched['image'].type(torch.FloatTensor)
labels = sample_batched['label'].type(torch.LongTensor)
# print(f"Val Input Shape: {inputs.shape}")
labels = labels.squeeze(1)
integer_encoded_labels = []
_img_channels, _img_height, _img_width = labels.shape
encoded_label= np.zeros((_img_height, _img_width, 1)).astype(int)
for j, cls in enumerate(MaskColorMap):
encoded_label[np.all(labels == cls.value, axis = 0)] = j
labels = encoded_label
labels = torch.from_numpy(labels)
labels = torch.permute(labels, (2, 1, 0))
# print(f"Val Label Shape: {labels.shape}")
# plt.title("Val Label")
# plt.imshow(labels[0, :, :])
# plt.show()
val_outputs = model_coronal(inputs)
val_loss = loss_function(val_outputs, labels)
predicted = torch.argmax(val_outputs, dim = 1)
# print(f"Val Prediction Shape: {predicted.shape}")
# plt.title("Val Prediction")
# plt.imshow(predicted[0, :, :])
# plt.show()
predictions_val.append(predicted)
labels_val.append(labels)
val_loss_epoch += val_loss.item()
print(f"{step}/{len(val_dataset_coronal_subset) // val_loader_coronal_subset.batch_size}, Validation_loss: {val_loss.item():.4f}")
step += 1
predictions_val_arr, labels_val_arr = torch.cat(predictions_val), torch.cat(labels_val)
dice_metric(predictions_val_arr, labels_val_arr)
# dice_metric_batch(predictions_val_arr, labels_val_arr)
ds_mean_val = dice_metric.aggregate().item()
ds_mean_val_values.append(ds_mean_val)
# ds_mean_val_batch = dice_metric_batch.aggregate()
# ds_LCE = ds_mean_val_batch[0].item()
# ds_LCE_values.append(ds_LCE)
# ds_LWM = ds_mean_val_batch[1].item()
# ds_LWM_values.append(ds_LWM)
# ds_LCC = ds_mean_val_batch[2].item()
# ds_LCC_values.append(ds_LCC)
dice_metric.reset()
# dice_metric_batch.reset()
if ds_mean_val > best_ds_mean:
best_ds_mean = ds_mean_val
best_ds_mean_epoch = epoch + 1
torch.save(model_coronal.state_dict(), os.path.join(BESTMODEL_DIR, "best_metric_model_coronal.pth"))
print("Saved new best metric model coronal")
print(
f"Current Epoch: {epoch + 1} Current Mean Dice score is: {ds_mean_val:.4f}"
f"\nBest Mean Dice score: {best_ds_mean:.4f} "
# f"\nMean Dice score Left Cerebral Exterior: {ds_LCE:.4f} Mean Dice score Left White Matter: {ds_LWM:.4f} Mean Dice score Left Cerebral Cortex: {ds_LCC:.4f} "
f"at Epoch: {best_ds_mean_epoch}"
)
val_loss_epoch /= step
val_loss_epoch_values.append(val_loss_epoch)
print(f"Epoch {epoch + 1} Average Validation Loss: {val_loss_epoch:.4f}")
print("FINISH.")
代碼也是傳統(tǒng)的Pytorch的訓(xùn)練步驟,就不詳細(xì)解釋了
繪制損失和精度曲線
訓(xùn)練曲線表示模型的學(xué)習(xí)情況,驗(yàn)證曲線表示模型泛化到未見實(shí)例的情況。我們使用matplotlib來(lái)繪制圖形。還可以使用TensorBoard,它使理解和調(diào)試深度學(xué)習(xí)程序變得更容易,并且是實(shí)時(shí)的。
epoch = range(1, num_epochs + 1)
# Plot Loss Curves
plt.figure(figsize=(18, 6))
plt.subplot(1, 3, 1)
plt.plot(epoch, train_loss_epoch_values, label='Training Loss')
plt.plot(epoch, val_loss_epoch_values, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.legend()
plt.figure()
plt.show()
# Plot Train Dice Coefficient Curve
plt.figure(figsize=(18, 6))
plt.subplot(1, 3, 2)
x = [(i + 1) for i in range(len(ds_mean_train_values))]
plt.plot(x, ds_mean_train_values, 'blue', label = 'Train Mean Dice Score')
plt.title("Training Mean Dice Coefficient")
plt.xlabel('Epoch')
plt.ylabel('Mean Dice Score')
plt.show()
# Plot Validation Dice Coefficient Curve
plt.figure(figsize=(18, 6))
plt.subplot(1, 3, 3)
x = [(i + 1) for i in range(len(ds_mean_val_values))]
plt.plot(x, ds_mean_val_values, 'orange', label = 'Validation Mean Dice Score')
plt.title("Validation Mean Dice Coefficient")
plt.xlabel('Epoch')
plt.ylabel('Mean Dice Score')
plt.show()
在曲線中,我們可以看到模型是過(guò)擬合的,因?yàn)轵?yàn)證損失上升而訓(xùn)練損失下降。這是深度學(xué)習(xí)算法中一個(gè)常見的陷阱,其中模型最終會(huì)記住訓(xùn)練數(shù)據(jù),而無(wú)法對(duì)未見過(guò)的數(shù)據(jù)進(jìn)行泛化。
避免過(guò)度擬合的技巧:
- 用更多的數(shù)據(jù)進(jìn)行訓(xùn)練:更大的數(shù)據(jù)集可以減少過(guò)擬合。
- 數(shù)據(jù)增強(qiáng):如果我們不能收集更多的數(shù)據(jù),我們可以應(yīng)用數(shù)據(jù)增強(qiáng)來(lái)人為地增加數(shù)據(jù)集的大小。
- 添加正則化:正則化是一種限制我們的網(wǎng)絡(luò)學(xué)習(xí)過(guò)于復(fù)雜的模型的技術(shù),因此可能會(huì)過(guò)度擬合。
評(píng)估網(wǎng)絡(luò)
我們?nèi)绾味攘磕P偷男阅?一個(gè)成功的預(yù)測(cè)是一個(gè)最大限度地?cái)U(kuò)大預(yù)測(cè)和真實(shí)之間的重疊。
這一目標(biāo)的兩個(gè)相關(guān)但不同的指標(biāo)是Dice和Intersection / Union (IoU)系數(shù),后者也被稱為Jaccard系數(shù)。兩個(gè)指標(biāo)都在0(無(wú)重疊)和1(完全重疊)之間。
這兩種指標(biāo)都可以用于類似的情況,但是區(qū)別在于Dice Score傾向于平均表現(xiàn),而IoU則幫助你理解最壞情況下的表現(xiàn)。
我們可以逐個(gè)類地檢查度量標(biāo)準(zhǔn),或者取所有類的平均值。這里將使用monai.metrics.DiceMetric來(lái)計(jì)算分?jǐn)?shù)。一個(gè)更通用的方法是使用torchmetrics,但是因?yàn)檫@里使用了monai框架,所以就直接使用它內(nèi)置的函數(shù)了。
我們可以看到Dice得分曲線的行為相當(dāng)不尋常。主要是因?yàn)轵?yàn)證平均Dice得分高于1,這是不可能的,因?yàn)檫@個(gè)度量是在0和1之間。我們無(wú)法確定這種行為的主要原因,但我們建議在多類問(wèn)題中為每個(gè)類單獨(dú)提供度量計(jì)算,并始終提供可視化示例以進(jìn)行可視化評(píng)估。
結(jié)果分析
最后我們要看看模型是如何推廣到未知數(shù)據(jù)的這個(gè)模型預(yù)測(cè)的幾乎所有東西都是左腦白質(zhì),一些像素是左腦皮層。盡管它的預(yù)測(cè)似乎是正確的,但仍有很大的改進(jìn)空間,因?yàn)槲覀兊哪P吞×耍梢赃x擇更深的模型獲得更好的效果。
總結(jié)
在本文中,我們介紹了如何訓(xùn)練QuickNAT來(lái)完成具有挑戰(zhàn)性的大腦分割任務(wù)。我們盡可能遵循作者在他們的研究論文中解釋的學(xué)習(xí)策略,這是本教程為了方便演示只在最簡(jiǎn)單的步驟上進(jìn)行了演示,文本的完整代碼:
https://github.com/inesdv26/Brain-Segmentation