譯者 | 朱先忠
審校 | 重樓
大家好!在這篇文章中,我們將使用MONAI最新的開源擴(kuò)展——MONAI生成式模型來創(chuàng)建一個潛在擴(kuò)散模型(Latent Diffusion Model),并通過此模型生成胸部X射線圖像!
簡介
生成式人工智能在醫(yī)療保健方面展現(xiàn)出巨大的潛力,因?yàn)樗试S我們創(chuàng)建模型來學(xué)習(xí)訓(xùn)練數(shù)據(jù)集的基本模式和結(jié)構(gòu)。通過這種方式,我們可以使用這些生成式模型來創(chuàng)建無數(shù)的合成數(shù)據(jù),這些數(shù)據(jù)具有與真實(shí)數(shù)據(jù)相同的細(xì)節(jié)和特征,并且不受到這些內(nèi)容的具體限制。鑒于其重要性,我們創(chuàng)建了MONAI生成模型,這是包含時下眾多最新模型(如擴(kuò)散模型、自回歸轉(zhuǎn)換器和生成對抗性網(wǎng)絡(luò))和有助于訓(xùn)練和評估生成模型相應(yīng)組件的MONAI平臺的一個開源擴(kuò)展。
MONAI生成模型
在這篇文章中,我們將通過一個完整的項(xiàng)目來創(chuàng)建一個潛在擴(kuò)散模型(與Stable Diffusion相同類型的模型),該模型能夠從放射學(xué)報(bào)告中生成胸部X射線(CXR)圖像。在本文中,我們試圖使代碼易于理解并適應(yīng)不同的開發(fā)環(huán)境;所以,盡管它可能不是最高效的,但我希望您喜歡它!
您可以在GitHub存儲庫中找到文中完整的開源項(xiàng)目。另外,請注意:在本文中,我們引用的是其版本0.2。
下載數(shù)據(jù)集
首先,我們從數(shù)據(jù)集開始討論。在這個項(xiàng)目中,我們使用的是MIMIC數(shù)據(jù)集。要訪問此數(shù)據(jù)集,需要先在Physionet門戶網(wǎng)站上創(chuàng)建一個帳戶。我們將使用MIMIC-CXR-JPG(包含JPG文件)和MIMIC-CXR(包含放射學(xué)報(bào)告)。這兩個數(shù)據(jù)集都是在PhysioNet認(rèn)證健康數(shù)據(jù)許可證1.5.0下發(fā)行。值得注意的是,在您完成免費(fèi)的培訓(xùn)課程后,您即可以根據(jù)數(shù)據(jù)集頁面底部的說明自由下載數(shù)據(jù)集。最初,CXR圖像的像素約為+1000x1000。因此,下載過程可能需要一段時間。
胸部X光圖像是一種重要的工具,可以提供有關(guān)胸腔內(nèi)包括肺、心臟和血管等結(jié)構(gòu)和器官的極為寶貴的信息,下載后,我們應(yīng)該有超過35萬張圖像!這些圖像對應(yīng)于三種不同角度的投影之一:后-前(PA)、前-后(AP)和側(cè)面(LAT)。
對于本文中這個試驗(yàn)項(xiàng)目來說,我們只對PA投影感興趣,這是最常見角度的投影,我們可以在其中可視化放射學(xué)報(bào)告中提到的大多數(shù)特征(包括96162張圖像)。關(guān)于這些報(bào)告,我們共有85882個文件,每個文件都包含幾個文本部分。在這里,我們將使用調(diào)查結(jié)果(主要解釋圖像中的內(nèi)容)和印象(總結(jié)報(bào)告的內(nèi)容,就像下結(jié)論一樣)數(shù)據(jù)。
為了使我們的模型和訓(xùn)練過程更易于管理,我們將重新調(diào)整一下圖像的大小,使其在最小軸上具有512個像素。自動執(zhí)行這些初始步驟的腳本列表可以在鏈接https://github.com/Warvito/generative_chestxray#preprocessing處找到。
開發(fā)模型
潛在擴(kuò)散模型(Latent Diffusion Model)架構(gòu):自動編碼器將輸入圖像x壓縮為潛在表示z,然后通過擴(kuò)散模型來估計(jì)z的概率分布。
我們所要使用的潛在擴(kuò)散模型由如下幾個部分組成:
- 自動編碼器:用于將輸入的圖像壓縮為較小的潛在表示;
- 擴(kuò)散模型:其將學(xué)習(xí)CXR的潛在表示的概率數(shù)據(jù)分布;
- 一個文本編碼器:它創(chuàng)建一個嵌入向量,用于調(diào)節(jié)采樣過程。在本文這個例子中,我們使用的是一個經(jīng)過預(yù)訓(xùn)練的例子。
借助于MONAI生成模型,我們可以輕松創(chuàng)建和訓(xùn)練這些模型。所以,讓我們從自動編碼器(Autoencoder)開始介紹!
使用具有KL正則化的自動編碼器
具有KL正則化的自動編碼器(AE-KL,在一些項(xiàng)目中,簡稱為VAE)的主要目標(biāo)是能夠創(chuàng)建小的潛在表示,并以高保真度重建圖像(保留盡可能多的細(xì)節(jié))。在這個項(xiàng)目中,我們正在創(chuàng)建一個具有四個級別的自動編碼器,包括64128128128個通道,在每個級別之間應(yīng)用下采樣塊,使特征圖在進(jìn)入最深層時更小。
盡管我們的自動編碼器可能有自我注意功能的塊,但在這個例子中,我們采用了與我們之前對大腦圖像的研究類似的結(jié)構(gòu),并且不使用注意力機(jī)制來節(jié)省內(nèi)存使用。最后,我們的潛在表示有三個通道。
from generative.networks.nets import AutoencoderKL
...
model = AutoencoderKL(
spatial_dims=2,
in_channels=1,
out_channels=1,
num_channels=[64, 128, 128, 128],
latent_channels=3,
num_res_blocks=2,
attention_levels=[False, False, False, False],
with_encoder_nonlocal_attn=False,
with_decoder_nonlocal_attn=False,
)
注意:在我們的腳本中,我們使用OmegaConf包來存儲我們模型的超參數(shù)。您可以在此文件中看到以前的配置??傊?,OmegaConf是一個強(qiáng)大的工具,用于管理Python項(xiàng)目中的配置,特別是那些涉及深度學(xué)習(xí)或其他復(fù)雜軟件系統(tǒng)的項(xiàng)目。OmegaConf允許我們方便地組織.yaml文件中的超參數(shù),并在腳本中讀取它們。
訓(xùn)練AE-KL
接下來,我們定義了訓(xùn)練過程的幾個組成部分。首先,我們使用KL規(guī)則化。該部分負(fù)責(zé)評估擴(kuò)散模型的潛在空間分布與高斯分布之間的距離。正如Rombach等人所提出的,這將用于限制潛在空間的方差,當(dāng)我們在其上訓(xùn)練擴(kuò)散模型時,這是非常有用的(稍后將詳細(xì)介紹)。我們模型中的forward方法返回重構(gòu),以及我們用于計(jì)算KL散度的潛在表示的μ和σ向量。
#位于訓(xùn)練循環(huán)內(nèi)
reconstruction, z_mu, z_sigma = model(x=images)
…
kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3])
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
其次,我們的模型中使用了像素級損失算法。在這個項(xiàng)目中,我們采用L1距離來評估我們的AE-kl重構(gòu)與原始圖像的差異。
l1_loss = F.l1_loss(reconstruction.float(), images.float())
接下來,我們的模型中還使用了感知級別層(Perceptual-level)的損失計(jì)算。感知損失的想法是,我們不是在像素級別評估輸入圖像和重構(gòu)之間的差異,而是通過預(yù)先訓(xùn)練的模型來傳遞這兩個圖像。然后,我們測量內(nèi)部激活和特征圖的距離。在MONAI生成式模型中,我們可以輕松使用基于醫(yī)學(xué)圖像預(yù)訓(xùn)練網(wǎng)絡(luò)的感知網(wǎng)絡(luò)(可從鏈接https://github.com/Project-MONAI/GenerativeModels/blob/main/generative/losses/perceptual.py處找到)。我們可以訪問RadImageNet研究中的2D網(wǎng)絡(luò)(來自Mei等人),該研究對130多萬張醫(yī)學(xué)圖像進(jìn)行了訓(xùn)練!我們實(shí)現(xiàn)了2.5D方法,使用2D預(yù)訓(xùn)練網(wǎng)絡(luò)通過評估切片來評估3D圖像。最后,我們可以訪問MedicalNet,以3D純方法評估我們的3D圖像。在這個項(xiàng)目中,我們使用了與Pinaya等人類似的方法,并使用了學(xué)習(xí)感知圖像塊相似性(LPIPS:Learned Perceptual Image Patch Similarity)度量(也可在MONAI生成式模型上獲得)。
# 實(shí)例化感知損失類
perceptual_loss = PerceptualLoss(
spatial_dims=2,
network_type="squeeze",
)
...
#在訓(xùn)練循環(huán)內(nèi)部
...
p_loss = perceptual_loss(reconstruction.float(), images.float())
最后,我們使用對抗性損失計(jì)算來處理重構(gòu)的細(xì)節(jié)。對抗性網(wǎng)絡(luò)是一個補(bǔ)丁鑒別器(Patch-Discriminator,最初由Pix2Pix研究提出),我們對圖像中的幾個補(bǔ)丁進(jìn)行預(yù)測,而不是對整個圖像是真是假只有一個預(yù)測。
與原始的潛在擴(kuò)散模型和穩(wěn)定擴(kuò)散不同,我們使用了來自最小二乘GAN的鑒別器損失。盡管這不是更高級的對抗性損失,但在3D醫(yī)學(xué)圖像上訓(xùn)練時,它也顯示出了有效性和穩(wěn)定性(但仍有改進(jìn)的空間)。盡管對抗性損失可能非常不穩(wěn)定,但它們與感知損失的結(jié)合也有助于穩(wěn)定鑒別器和生成器的損失。
我們的訓(xùn)練循環(huán)和評估步驟可以從鏈接https://github.com/Warvito/generative_chestxray/blob/83f6c0892c63a1cdbf308ff654d40afc0af51bab/src/python/training/training_functions.py#L129和鏈接https://github.com/Warvito/generative_chestxray/blob/83f6c0892c63a1cdbf308ff654d40afc0af51bab/src/python/training/training_functions.py#L236處找到。在訓(xùn)練了共75次之后,我們用MLflow包保存我們的模型。我們使用MLflow包來更好地監(jiān)控我們的實(shí)驗(yàn),因?yàn)樗M織了git hash和相應(yīng)參數(shù)等信息,并可以將具有唯一ID的不同運(yùn)行存儲在實(shí)驗(yàn)組中,并更容易比較不同的結(jié)果(類似于其他工具,如權(quán)重和偏差)。AE-KL的日志文件可以從鏈接https://drive.google.com/drive/folders/1Ots9ujg4dSaTABKkyaUnFLpCKQ7kCgdK?usp=sharing處找到。
采用Diffusion模型
接下來,我們需要訓(xùn)練我們的擴(kuò)散模型了。
擴(kuò)散模型是一個類似U-Net的網(wǎng)絡(luò)。傳統(tǒng)上,這種模型接收有噪聲的圖像(或潛在表示)作為輸入,并預(yù)測其噪聲分量。這些模型使用迭代去噪機(jī)制,通過幾個步驟就能夠從馬爾可夫鏈上的噪聲中生成圖像。出于這個原因,該模型還以定義該模型處于采樣過程的哪個階段的時間步長為條件。
借助于DiffusionModelUNet類,我們可以為我們的擴(kuò)散模型創(chuàng)建類似U-Net的網(wǎng)絡(luò)。我們的項(xiàng)目使用了配置文件中所定義的配置參數(shù),其中它定義了具有3個通道的輸入和輸出(因?yàn)槲覀兊腁E-kl具有3個通道的潛在空間),以及具有256512768個通道的3個不同級別。每個級別有2個殘差塊。
如前所述,這里非常重要的一點(diǎn)是傳遞所使用的模型的時間步長值,從而使模型便于調(diào)節(jié)這些殘差塊的行為。最后,我們定義了網(wǎng)絡(luò)內(nèi)部的注意力機(jī)制。在我們的案例中,我們在第二級和第三級有注意力塊(由attention_levels參數(shù)表示),每個注意力頭有512和768個通道(換句話說,我們在每個級別都有一個注意力頭)。這些注意力機(jī)制很重要,因?yàn)樗鼈冊试S我們通過交叉注意力方法將外部條件(放射性報(bào)告)應(yīng)用于網(wǎng)絡(luò)。
外部條件(或“上下文”)應(yīng)用于U-Net的注意力塊
在我們的項(xiàng)目中,我們使用了一個已經(jīng)經(jīng)過訓(xùn)練的文本編碼器。為了簡單起見,我們使用了Stable Diffusion v2.1模型中的相同模型(“stabilityai/Stable-Diffusion-2–1-base”)將我們的文本標(biāo)記轉(zhuǎn)換為文本嵌入,該文本嵌入將用作DiffusionModel UNet交叉關(guān)注層中的Key和Value向量。我們的文本嵌入的每個令牌都有1024個維度,我們在“with_conditing”和“cross_attention_dim”參數(shù)中定義它。
from generative.networks.nets import DiffusionModelUNet
...
diffusion = DiffusionModelUNet(
spatial_dims=2,
in_channels=3,
out_channels=3,
num_res_blocks=2,
num_channels=[256, 512, 768],
attention_levels=[False, True, True],
with_cnotallow=True,
cross_attention_dim=1024,
num_head_channels=[0, 512, 768],
)
除了我們的模型定義之外,定義擴(kuò)散模型的噪聲將如何在訓(xùn)練期間添加到輸入圖像并在采樣期間去除也是很重要的。為此,我們在MONAI生成模型中實(shí)現(xiàn)了Schedulers類,以定義噪聲調(diào)度器。在本例中,我們將使用DDPMScheduler,具有1000個時間步長和以下超參數(shù)。
from generative.networks.schedulers import DDPMScheduler
...
scheduler = DDPMScheduler(
beta_schedule="scaled_linear",
num_train_timesteps=1000,
beta_start=0.0015,
beta_end=0.0205,
prediction_type="v_prediction",
)
在這里,我們選擇了“v預(yù)測”方法,其中我們的U-Net將嘗試預(yù)測速度分量(原始圖像和添加的噪聲的組合),而不僅僅是添加的噪聲。這種方法已被證明具有更穩(wěn)定的訓(xùn)練和更快的收斂(也用于鏈接https://arxiv.org/abs/2210.02303處相應(yīng)的算法模型)。
訓(xùn)練Diffusion模型
在訓(xùn)練擴(kuò)散模型之前,我們需要找到一個合適的比例因子。如Rombach等人所述,如果潛在空間分布的標(biāo)準(zhǔn)偏差過高,信噪比可能會影響LDM獲得的結(jié)果。如果潛在表示的值太高,我們添加到其中的最大高斯噪聲量可能不足以破壞所有信息。這樣,在訓(xùn)練過程中,原始潛在表示的信息可能在不應(yīng)該出現(xiàn)的時候出現(xiàn),這使得以后不可能從純噪聲中對圖像進(jìn)行采樣。KL正則化可以在這方面有所幫助,但最好使用比例因子來調(diào)整潛在的表示值。在這個腳本中,我們驗(yàn)證了一批訓(xùn)練集中潛在空間分量的標(biāo)準(zhǔn)偏差的大小。我們發(fā)現(xiàn)我們的比例因子應(yīng)該至少為0.8221。在我們的案例中,我們使用了一個更保守的值0.3(類似于穩(wěn)定擴(kuò)散的值)。
通過定義比例因子,我們可以訓(xùn)練我們的模型。在下面代碼中,我們來看一下對應(yīng)的訓(xùn)練循環(huán)。
#位于訓(xùn)練循環(huán)內(nèi)部
...
with torch.no_grad():
e = stage1(images) * scale_factor
prompt_embeds = text_encoder(reports.squeeze(1))[0]
timesteps = torch.randint(0, scheduler.num_train_timesteps, (images.shape[0],), device=device).long()
noise = torch.randn_like(e).to(device)
noisy_e = scheduler.add_noise(original_samples=e, noise=noise, timesteps=timesteps)
noise_pred = model(x=noisy_e, timesteps=timesteps, cnotallow=prompt_embeds)
if scheduler.prediction_type == "v_prediction":
# 使用v-預(yù)測參數(shù)法
target = scheduler.get_velocity(e, noise, timesteps)
elif scheduler.prediction_type == "epsilon":
target = noise
loss = F.mse_loss(noise_pred.float(), target.float())
正如您所看到的,我們首先從數(shù)據(jù)加載器中獲取圖像和報(bào)告。為了處理我們的圖像,我們使用了MONAI的轉(zhuǎn)換,并添加了一些自定義轉(zhuǎn)換,從放射學(xué)報(bào)告中提取隨機(jī)句子,并對輸入的文本進(jìn)行標(biāo)記。在大約10%的情況下,我們使用空字符串(“”——這是一個帶有句子開始標(biāo)記(值=49406)和填充標(biāo)記(值=49407)的向量),以便能夠在采樣期間使用無分類器引導(dǎo)。
接下來,我們獲得了潛在表示和提示嵌入。我們創(chuàng)建要添加的噪聲、要在該迭代中使用的隨機(jī)時間步長以及所需的目標(biāo)(速度分量)。最后,我們使用均方誤差來計(jì)算我們的損失。
這種訓(xùn)練共持續(xù)了500次,有關(guān)的訓(xùn)練日志數(shù)據(jù)您可以在鏈接處https://drive.google.com/drive/folders/1Ots9ujg4dSaTABKkyaUnFLpCKQ7kCgdK?usp=sharing找到。
采樣圖像
在我們訓(xùn)練了兩個模型之后,我們可以對合成圖像進(jìn)行采樣。本實(shí)驗(yàn)中,我們使用了鏈接https://github.com/Warvito/generative_chestxray/blob/main/src/python/testing/sample_images.py處所對應(yīng)的腳本。
該腳本使用Ho等人提出的無分類器引導(dǎo)方法,以強(qiáng)制執(zhí)行圖像生成中使用的文本提示。在這種方法中,我們使用了一個指導(dǎo)量表,可以用來犧牲生成數(shù)據(jù)的多樣性,以獲得對文本提示具有更高保真度的樣本。7.0是默認(rèn)值。
在下圖中,我們可以看到經(jīng)過訓(xùn)練的模型是如何了解臨床特征的,以及它們的位置和嚴(yán)重程度。
評估
在本節(jié)中,我們將展示如何使用MONAI的指標(biāo)來評估我們的生成模型在幾個方面的性能。
基于MS-SSIM方法的自動編碼器重構(gòu)質(zhì)量評估
首先,我們來驗(yàn)證我們的自動編碼器kl重建輸入圖像的效果。這是我們在開發(fā)模型時的一個重要考慮點(diǎn),因?yàn)閴嚎s和重建數(shù)據(jù)的質(zhì)量將定義我們樣本質(zhì)量的上限。如果模型沒有很好地學(xué)習(xí)如何從潛在表示中解碼圖像,或者如果它沒有很好的對我們的潛在空間建模,那么就不可能以現(xiàn)實(shí)的方式解碼合成表示。在這個腳本中,我們使用測試集中的共5000張圖像來評估我們的模型。我們可以使用多尺度結(jié)構(gòu)相似性指數(shù)測量(MS-SSIM:Multiscale Structural Similarity Index Measure)來驗(yàn)證我們的重建效果。MS-SSIM是一種廣泛使用的圖像質(zhì)量評估方法,用于測量兩幅圖像之間的相似性。與傳統(tǒng)的圖像質(zhì)量評估方法(如PSNR和SSIM)不同,MS-SSIM能夠捕捉不同尺度的圖像的結(jié)構(gòu)信息。
在我們的模型情況下,值越高,模型就越好。對于我們當(dāng)前發(fā)布的版本(版本0.2),我們觀察到我們的模型的平均MS-SSIM重建指標(biāo)值為0.9789。
MS-SSIM樣本的多樣性
我們將首先評估由我們的模型生成的樣本的多樣性。為此,我們計(jì)算了不同生成圖像之間的多尺度結(jié)構(gòu)相似性指數(shù)測度。在我們的實(shí)驗(yàn)項(xiàng)目中,我們假設(shè),如果我們的生成模型能夠生成不同的圖像,那么在比較合成圖像對時,它將呈現(xiàn)較低的平均MS-SSIM值。例如,如果我們遇到模式崩潰之類的問題,我們生成的圖像看起來會相似,并且MS-SSIM值會比我們在真實(shí)數(shù)據(jù)集中觀察到的值低得多。
在我們的項(xiàng)目中,我們使用非條件樣本(以“”(空字符串)作為文本提示生成的樣本)來保持原始數(shù)據(jù)集的自然比例。如本腳本所示,我們選擇了1000個模型的合成樣本,并使用MONAI的數(shù)據(jù)加載器來幫助加載所有可能的圖像對。我們使用嵌套循環(huán)來遍歷所有可能的圖像對,并忽略在兩個數(shù)據(jù)加載器中選擇的圖像相同的情況。在這里,我們可以觀察到MS-SSIM值為0.4083。我們可以在來自測試集的真實(shí)圖像中執(zhí)行相同的評估作為參考值。使用這個腳本,我們獲得了測試集的MS-SSIM值為0.4046,這表明我們的模型生成的圖像具有與在真實(shí)數(shù)據(jù)中觀察到的圖像相似的多樣性。
然而,多樣性并不意味著圖像看起來好看或逼真。所以,我們將在下一步檢查圖像質(zhì)量!
使用FID指標(biāo)評估合成圖像的質(zhì)量
最后,我們來測量項(xiàng)目所生成的樣本的FID指標(biāo)值。FID是一種評估兩組之間分布的指標(biāo),顯示它們的相似程度。為此,我們需要一個預(yù)先訓(xùn)練的神經(jīng)網(wǎng)絡(luò),從中我們可以提取用于計(jì)算距離的特征(類似于感知損失)。在這個例子中,我們選擇使用torchxrayvision軟件包中所提供的神經(jīng)網(wǎng)絡(luò)。
我們使用了Dense121網(wǎng)絡(luò)(“densenet121-res224-all”),我們選擇這個網(wǎng)絡(luò)是為了接近文獻(xiàn)中用于CXR合成圖像的網(wǎng)絡(luò)。從這個網(wǎng)絡(luò)中,我們獲得了一個具有1024個維度的特征向量。正如FID原始文件中建議的那樣,與特征數(shù)量相比,使用類似數(shù)量的樣本是很重要的。出于這個原因,我們使用了1000張無條件圖像,并將它們與測試集中的1000張圖像進(jìn)行比較。對于FID指標(biāo)來說,值越低越好,這里我們獲得了合理的FID=9.0237。
結(jié)論
在這篇文章中,我們介紹了一種使用MONAI生成模型開發(fā)項(xiàng)目的方法,內(nèi)容包括從下載數(shù)據(jù)到評估生成模型和最后的合成數(shù)據(jù)等。盡管這個項(xiàng)目版本可能更高效,也有更好的超參數(shù),但是其實(shí)我們希望它能很好地說明我們開發(fā)的擴(kuò)展所提供的不同功能。如果您對如何改進(jìn)我們的CXR模型有任何想法,或者如果您想為我們的軟件包做出貢獻(xiàn),請?jiān)阪溄觝ttps://github.com/Warvito/generative_chestxray/issues或鏈接https://github.com/Project-MONAI/GenerativeModels/issues處添加對我們問題部分的評論。
最后說明一點(diǎn),本文中我們經(jīng)過訓(xùn)練的模型可以在MONAI模型動物園網(wǎng)站找到,還有我們的3D大腦生成器和其他模型等內(nèi)容均可找到。我們所提供的模型動物園使下載模型的權(quán)重和執(zhí)行推理的代碼變得更加容易。
最后,如果您想要了解更多教程資源和關(guān)于我們工作的更多信息,請?jiān)L問鏈接https://github.com/Project-MONAI/GenerativeModels/tree/main/tutorials查看我們的教程頁面,并請關(guān)注我本人,以獲取最新的更新和更多類似內(nèi)容!
[注]除非另有說明;否則,本文中所有圖片均由作者提供。
譯者介紹
朱先忠,51CTO社區(qū)編輯,51CTO專家博客、講師,濰坊一所高校計(jì)算機(jī)教師,自由編程界老兵一枚。
原文標(biāo)題:Generating Medical Images with MONAI,作者:Walter Hugo Lopez Pinaya