自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

超越邊界:ControlNets改變醫(yī)學(xué)圖像生成的游戲規(guī)則

原創(chuàng) 精選
人工智能 機(jī)器視覺
文章介紹了如何使用ControlNets來控制Latent Diffusion Models生成醫(yī)學(xué)圖像的過程。首先,討論了擴(kuò)散模型的發(fā)展和其在生成過程中的控制挑戰(zhàn),然后介紹了ControlNets的概念和優(yōu)勢(shì)。

作者 | 崔皓

審校 | 重樓

摘要

文章介紹了如何使用ControlNets來控制Latent Diffusion Models生成醫(yī)學(xué)圖像的過程。首先討論了擴(kuò)散模型的發(fā)展和其在生成過程中的控制挑戰(zhàn),然后介紹了ControlNets的概念和優(yōu)勢(shì)。接著,文章詳細(xì)解釋了如何訓(xùn)練Latent Diffusion Model和ControlNet,以及如何使用ControlNet進(jìn)行采樣和評(píng)估。最后,文章展示了使用ControlNets生成的醫(yī)學(xué)圖像,并提供了性能評(píng)估結(jié)果。文章的目標(biāo)是展示這些模型在將腦圖像轉(zhuǎn)換為各種對(duì)比度的能力,并鼓勵(lì)讀者在自己的項(xiàng)目中使用這些工具。

開篇

本文將介紹如何訓(xùn)練一個(gè)ControlNet,使用戶能夠精確地控制Latent Diffusion Model(如Stable Diffusion!)的生成過程。我們的目標(biāo)是展示模型將腦圖像轉(zhuǎn)換為各種對(duì)比度的能力。為了實(shí)現(xiàn)這一目標(biāo),我們將利用最近推出的MONAI開源擴(kuò)展,即MONAI Generative Models!

項(xiàng)目代碼可以在這個(gè)公共倉庫中找到:https://github.com/Warvito/generative_brain_controlnet

引言

近年來,文本到圖像的擴(kuò)散模型取得了顯著的進(jìn)步,人們可以根據(jù)開放領(lǐng)域的文本描述生成高逼真的圖像。這些生成的圖像具有豐富的細(xì)節(jié),清晰的輪廓,連貫的結(jié)構(gòu)和有意義的場(chǎng)景。然而,盡管擴(kuò)散模型取得了重大的成就,但在生成過程的精確控制方面仍存在挑戰(zhàn)。即使是內(nèi)容豐富的文本描述,也可能無法準(zhǔn)確地捕捉到用戶的想法。

正如Lvmin Zhang和Maneesh Agrawala在他們的開創(chuàng)性論文“Adding Conditional Control to Text-to-Image Diffusion Models”(2023)中所提出的那樣,ControlNets的引入顯著提高了擴(kuò)散模型的可控性和定制性。這些神經(jīng)網(wǎng)絡(luò)充當(dāng)輕量級(jí)的適配器,可以精確地控制和定制,同時(shí)保留擴(kuò)散模型的原始生成能力。通過對(duì)這些適配器進(jìn)行微調(diào),同時(shí)保持原始擴(kuò)散模型不被修改,可以有效地增強(qiáng)文本到圖像模型的多樣性。

【編者:擴(kuò)展模型允許我們?cè)诒3衷寄P筒蛔兊耐瑫r(shí),對(duì)模型的行為進(jìn)行微調(diào)。這意味著我們可以利用原始模型的生成能力,同時(shí)通過ControlNets來調(diào)整生成過程,以滿足特定的需求或改進(jìn)性能?!?/span>

ControlNet的獨(dú)特之處在于它解決了空間一致性的問題。與以往的方法不同,ControlNet允許對(duì)結(jié)構(gòu)空間、結(jié)構(gòu)和幾何方面進(jìn)行明確的控制,同時(shí)保留從文本標(biāo)題中獲得的語義控制。原始研究引入了各種模型,使得可以基于邊緣、姿態(tài)、語義掩碼和深度圖進(jìn)行條件生成,為計(jì)算機(jī)視覺領(lǐng)域激動(dòng)人心的進(jìn)步鋪平了道路。

【編者:空間一致性是指在圖像生成或處理過程中,生成的圖像應(yīng)該在空間上保持一致性,即相鄰的像素或區(qū)域之間應(yīng)該有合理的關(guān)系和連續(xù)性。例如,在生成一個(gè)人臉的圖像時(shí),眼睛、鼻子和嘴巴的相對(duì)位置應(yīng)該是一致的,不能隨機(jī)分布。在傳統(tǒng)的生成模型中,保持空間一致性可能是一個(gè)挑戰(zhàn),因?yàn)槟P涂赡軙?huì)在嘗試生成復(fù)雜的圖像時(shí)產(chǎn)生不一致的結(jié)果。

例如,你正在使用一個(gè)文本到圖像的模型來生成一張貓的圖像,來看看如何使用邊緣、姿態(tài)、語義掩碼和深度圖這四個(gè)條件。

1. 邊緣:創(chuàng)建一個(gè)邊緣圖來表示貓的輪廓。包括貓的頭部、身體、尾巴等主要部分的邊緣信息。

2.姿態(tài):創(chuàng)建姿態(tài)圖來表示貓的姿勢(shì)。例如,一只正在跳躍的貓,可以在姿態(tài)圖中表示出這個(gè)跳躍的動(dòng)作。

3.語義掩碼:創(chuàng)建一個(gè)語義掩碼來表示貓的各個(gè)部分。例如,在語義掩碼中標(biāo)出貓的眼睛、耳朵、鼻子等部分。

4.深度圖:創(chuàng)建深度圖來表示貓的三維形狀。例如,表示出貓的頭部比尾巴更接近觀察者。

通過這四個(gè)步驟,就可以指導(dǎo)模型生成一只符合我們需求的貓的圖像?!?/span>

在醫(yī)學(xué)成像領(lǐng)域,經(jīng)常會(huì)遇到圖像轉(zhuǎn)換的應(yīng)用場(chǎng)景,因此ControlNet的使用就非常有價(jià)值。在這些應(yīng)用場(chǎng)景中,有一個(gè)場(chǎng)景需要將圖像在不同的領(lǐng)域之間進(jìn)行轉(zhuǎn)換,例如將計(jì)算機(jī)斷層掃描(CT)轉(zhuǎn)換為磁共振成像(MRI),或者將圖像在不同的對(duì)比度之間進(jìn)行轉(zhuǎn)換,例如從T1加權(quán)到T2加權(quán)的MRI圖像。在這篇文章中,我們將關(guān)注一個(gè)特定的案例:使用從FLAIR圖像獲取的腦圖像的2D切片來生成相應(yīng)的T1加權(quán)圖像。我們的目標(biāo)是展示MONAI擴(kuò)展(MONAI Generative Models)以及ControlNets如何進(jìn)行醫(yī)學(xué)數(shù)據(jù)訓(xùn)練,并生成評(píng)估模型。通過深入研究這個(gè)例子,我們能提供關(guān)于這些技術(shù)在醫(yī)學(xué)成像領(lǐng)域的最佳實(shí)踐。

FLAIR到T1w轉(zhuǎn)換FLAIR到T1w轉(zhuǎn)換

Latent Diffusion Model訓(xùn)練

Latent Diffusion Model架構(gòu)Latent Diffusion Model架構(gòu)

為了從FLAIR圖像生成T1加權(quán)(T1w)圖像,首先需要訓(xùn)練一個(gè)能夠生成T1w圖像的擴(kuò)散模型。在我們的例子中,我們使用從英國生物銀行數(shù)據(jù)集(根據(jù)這個(gè)數(shù)據(jù)協(xié)議可用)中提取的腦MRI圖像的2D切片。然后,使用你最喜歡的方法(例如,ANTs或UniRes)將原始的3D腦部圖像注冊(cè)到MNI空間后,我們從腦部的中心部分提取五個(gè)2D切片。之所以選擇這個(gè)區(qū)域,因?yàn)樗鞣N組織,使得我們更容易評(píng)估圖像轉(zhuǎn)換。使用這個(gè)腳本,我們最終得到了大約190,000個(gè)切片,空間尺寸為224 × 160像素。接下來,使用這個(gè)腳本將我們的圖像劃分為訓(xùn)練集(約180,000個(gè)切片)、驗(yàn)證集(約5,000個(gè)切片)和測(cè)試集(約5,000個(gè)切片)。準(zhǔn)備好數(shù)據(jù)集后,我們可以開始訓(xùn)練我們的Latent Diffusion Model了!

為了優(yōu)化計(jì)算資源,潛在擴(kuò)散模型使用一個(gè)編碼器將輸入圖像x轉(zhuǎn)換為一個(gè)低維的潛在空間z,然后可以通過一個(gè)解碼器進(jìn)行重構(gòu)。這種方法使得即使在計(jì)算能力有限的情況下也能訓(xùn)練擴(kuò)散模型,同時(shí)保持了它們的原始質(zhì)量和靈活性。與我們?cè)谥暗奈恼轮兴龅念愃疲ㄊ褂肕ONAI生成醫(yī)學(xué)圖像),我們使用MONAI Generative models中的KL-regularization模型來創(chuàng)建壓縮模型。通過使用這個(gè)配置加上L1損失、KL-regularisation,感知損失以及對(duì)抗性損失,我們創(chuàng)建了一個(gè)能夠以高保真度編碼和解碼腦圖像的自編碼器(使用這個(gè)腳本)。自編碼器的重構(gòu)質(zhì)量對(duì)于Latent Diffusion Model的性能至關(guān)重要,因?yàn)樗x了生成圖像的質(zhì)量上限。如果自編碼器的解碼器產(chǎn)生模糊或低質(zhì)量的圖像,我們的生成模型將無法生成更高質(zhì)量的圖像。

【編者:KL-regularization,或稱為Kullback-Leibler正則化,是一種在機(jī)器學(xué)習(xí)和統(tǒng)計(jì)中常用的技術(shù),用于在模型復(fù)雜性和模型擬合數(shù)據(jù)的好壞之間找到一個(gè)平衡。這種正則化方法的名字來源于它使用的Kullback-Leibler散度,這是一種衡量?jī)蓚€(gè)概率分布之間差異的度量

上面這段話用一個(gè)例子來解釋一下,或許會(huì)更加清楚一些。

例如,假設(shè)你是一個(gè)藝術(shù)家,你的任務(wù)是畫出一系列的貓的畫像。你可以自由地畫任何貓,但是你的老板希望你畫的貓看起來都是"普通的"貓,而不是太奇特的貓。

這就是你的任務(wù):你需要?jiǎng)?chuàng)造新的貓的畫像,同時(shí)還要確保這些畫像都符合"普通的"貓的特征。這就像是變分自編碼器(VAE)的任務(wù):它需要生成新的數(shù)據(jù),同時(shí)還要確保這些數(shù)據(jù)符合某種預(yù)設(shè)的分布(也就是"普通的"貓的分布)。

現(xiàn)在,假設(shè)你開始畫貓。你可能會(huì)發(fā)現(xiàn),有些時(shí)候你畫的貓看起來太奇特了,比如它可能有六只眼睛,或者它的尾巴比普通的貓長(zhǎng)很多。這時(shí),你的老板可能會(huì)提醒你,讓你畫的貓更接近"普通的"貓。

這就像是KL-regularization的作用:它是一種"懲罰",當(dāng)你生成的數(shù)據(jù)偏離預(yù)設(shè)的分布時(shí),它就會(huì)提醒你。如果你畫的貓?zhí)嫣?,你的老板就?huì)提醒你。在VAE中,如果生成的數(shù)據(jù)偏離預(yù)設(shè)的分布,KL-regularization就會(huì)通過增加損失函數(shù)的值來提醒模型。

通過這種方式,你可以在創(chuàng)造新的貓的畫像的同時(shí),還能確保這些畫像都符合"普通的"貓的特征。同樣,VAE也可以在生成新的數(shù)據(jù)的同時(shí),確保這些數(shù)據(jù)符合預(yù)設(shè)的分布。這就是KL-regularization的主要作用?!?/span>

【編者:對(duì)三種損失函數(shù)進(jìn)行簡(jiǎn)要說明:

1. L1損失:L1損失,也稱為絕對(duì)值損失,是預(yù)測(cè)值和真實(shí)值之間差異的絕對(duì)值的平均。它的公式為 L1 = 1/n Σ|yi - xi|,其中yi是真實(shí)值,xi是預(yù)測(cè)值,n是樣本數(shù)量。L1損失對(duì)異常值不敏感,因?yàn)樗粫?huì)過度懲罰預(yù)測(cè)錯(cuò)誤的樣本。例如,如果我們預(yù)測(cè)一個(gè)房?jī)r(jià)為100萬,但實(shí)際價(jià)格為110萬,L1損失就是10萬。

2. 感知損失:感知損失是一種在圖像生成任務(wù)中常用的損失函數(shù),它衡量的是生成圖像和真實(shí)圖像在感知層面的差異。感知損失通常通過比較圖像在某個(gè)預(yù)訓(xùn)練模型(如VGG網(wǎng)絡(luò))的某一層的特征表示來計(jì)算。這種方法可以捕捉到圖像的高級(jí)特性,如紋理和形狀,而不僅僅是像素級(jí)的差異。例如,如果我們生成的貓的圖像和真實(shí)的貓的圖像在顏色上有細(xì)微的差異,但在形狀和紋理上是一致的,那么感知損失可能就會(huì)很小。

3. 對(duì)抗性損失:對(duì)抗性損失是在生成對(duì)抗網(wǎng)絡(luò)(GAN)中使用的一種損失函數(shù)。在GAN中,生成器的任務(wù)是生成看起來像真實(shí)數(shù)據(jù)的假數(shù)據(jù),而判別器的任務(wù)是區(qū)分真實(shí)數(shù)據(jù)和假數(shù)據(jù)。對(duì)抗性損失就是用來衡量生成器生成的假數(shù)據(jù)能否欺騙判別器。例如,如果我們的生成器生成了一張貓的圖像,而判別器幾乎無法區(qū)分這張圖像和真實(shí)的貓的圖像,那么對(duì)抗性損失就會(huì)很小。】

使用這個(gè)腳本,我們可以通過使用原始圖像和它們的重構(gòu)之間的多尺度結(jié)構(gòu)相似性指數(shù)測(cè)量(MS-SSIM)來量化自編碼器的保真度。在這個(gè)例子中,我們得到了一個(gè)高性能的MS-SSIM指標(biāo),等于0.9876。

【編者:多尺度結(jié)構(gòu)相似性指數(shù)測(cè)量(Multi-Scale Structural Similarity Index, MS-SSIM)是一種用于衡量?jī)煞鶊D像相似度的指標(biāo)。它是結(jié)構(gòu)相似性指數(shù)(Structural Similarity Index, SSIM)的擴(kuò)展,考慮了圖像的多尺度信息。

SSIM是一種比傳統(tǒng)的均方誤差(Mean Squared Error, MSE)或峰值信噪比(Peak Signal-to-Noise Ratio, PSNR)更符合人眼視覺感知的圖像質(zhì)量評(píng)價(jià)指標(biāo)。它考慮了圖像的亮度、對(duì)比度和結(jié)構(gòu)三個(gè)方面的信息,而不僅僅是像素級(jí)的差異。

MS-SSIM則進(jìn)一步考慮了圖像的多尺度信息。它通過在不同的尺度(例如,不同的分辨率)上計(jì)算SSIM,然后將這些SSIM值進(jìn)行加權(quán)平均,得到最終的MS-SSIM值。這樣可以更好地捕捉到圖像的細(xì)節(jié)和結(jié)構(gòu)信息。

例如,假設(shè)我們有兩幅貓的圖像,一幅是原始圖像,另一幅是我們的模型生成的圖像。我們可以在不同的尺度(例如,原始尺度、一半尺度、四分之一尺度等)上計(jì)算這兩幅圖像的SSIM值,然后將這些SSIM值進(jìn)行加權(quán)平均,得到MS-SSIM值。如果MS-SSIM值接近1,那么說明生成的圖像與原始圖像非常相似;如果MS-SSIM值遠(yuǎn)離1,那么說明生成的圖像與原始圖像有較大的差異。

MS-SSIM常用于圖像處理和計(jì)算機(jī)視覺的任務(wù)中,例如圖像壓縮、圖像增強(qiáng)、圖像生成等,用于評(píng)價(jià)處理或生成的圖像的質(zhì)量。文中MS-SSIM被用來量化自編碼器的保真度,即自編碼器重構(gòu)的圖像與原始圖像的相似度。得到的MS-SSIM指標(biāo)為0.9876,接近1,說明自編碼器的重構(gòu)質(zhì)量非常高,重構(gòu)的圖像與原始圖像非常相似?!?/span>

在我們訓(xùn)練了自編碼器之后,我們將在潛在空間z上訓(xùn)練diffusion model(擴(kuò)散模型)。擴(kuò)散模型是一個(gè)能夠通過在一系列時(shí)間步上迭代地去噪來從純?cè)肼晥D像生成圖像的模型。它通常使用一個(gè)U-Net架構(gòu)(具有編碼器-解碼器格式),其中我們有編碼器的層跳過連接到解碼器部分的層(通過長(zhǎng)跳躍連接),使得特征可重用并穩(wěn)定訓(xùn)練和收斂。

【編者:這段話描述的是訓(xùn)練擴(kuò)散模型的過程,以及擴(kuò)散模型的基本工作原理。

首先,作者提到在訓(xùn)練了自編碼器之后,他們將在潛在空間z上訓(xùn)練擴(kuò)散模型。這意味著他們首先使用自編碼器將輸入圖像編碼為一個(gè)低維的潛在空間z,然后在這個(gè)潛在空間上訓(xùn)練擴(kuò)散模型。

擴(kuò)散模型是一種生成模型,它的工作原理是從純?cè)肼晥D像開始,然后通過在一系列時(shí)間步上迭代地去噪,最終生成目標(biāo)圖像。這個(gè)過程就像是將一個(gè)模糊的圖像逐漸清晰起來,直到生成一個(gè)清晰的、與目標(biāo)圖像相似的圖像。

擴(kuò)散模型通常使用一個(gè)U-Net架構(gòu)。U-Net是一種特殊的卷積神經(jīng)網(wǎng)絡(luò),它有一個(gè)編碼器部分和一個(gè)解碼器部分,編碼器部分用于將輸入圖像編碼為一個(gè)潛在空間,解碼器部分用于將潛在空間解碼為一個(gè)輸出圖像。U-Net的特點(diǎn)是它有一些跳躍連接,這些連接將編碼器部分的某些層直接連接到解碼器部分的對(duì)應(yīng)層。這些跳躍連接可以使得編碼器部分的特征被重用在解碼器部分,這有助于穩(wěn)定訓(xùn)練過程并加速模型的收斂?!?/span>

在訓(xùn)練過程中,Latent Diffusion Model學(xué)習(xí)了給定這些提示的條件噪聲預(yù)測(cè)。再次,我們使用MONAI來創(chuàng)建和訓(xùn)練這個(gè)網(wǎng)絡(luò)。在這個(gè)腳本中,我們使用這個(gè)配置來實(shí)例化模型,其中訓(xùn)練和評(píng)估在代碼的這個(gè)部分進(jìn)行。由于我們?cè)谶@個(gè)教程中對(duì)文本提示不太感興趣,所以我們對(duì)所有的圖像使用了相同的提示( “腦部的T1加權(quán)圖像”)。

我們的Latent Diffusion Model生成的合成腦圖像我們的Latent Diffusion Model生成的合成腦圖像

再次,我們可以量化生成模型從而提升其性能,這次我們?cè)u(píng)估樣本的質(zhì)量(使用Fréchet inception distance (FID))和模型的多樣性(計(jì)算一組1,000個(gè)樣本的所有樣本對(duì)之間的MS-SSIM)。使用這對(duì)腳本(12),我們得到了FID = 2.1986和MS-SSIM Diversity = 0.5368。

如你在前面的圖像和結(jié)果中所看到的,我們現(xiàn)在有一個(gè)高分辨率圖像的模型,質(zhì)量非常好。然而,我們對(duì)圖像的外觀沒有任何空間控制。為此,我們將使用一個(gè)ControlNet來引導(dǎo)我們的Latent Diffusion Model的生成。

ControlNet訓(xùn)練

ControlNet架構(gòu)ControlNet架構(gòu)

ControlNet架構(gòu)包括兩個(gè)主要組成部分:一個(gè)是U-Net模型的編碼器(可訓(xùn)練版本),包括中間塊,以及一個(gè)預(yù)訓(xùn)練的“鎖定”版本的擴(kuò)散模型。在這里,鎖定的副本保留了生成能力,而可訓(xùn)練的副本在特定的圖像到圖像數(shù)據(jù)集上進(jìn)行訓(xùn)練,以學(xué)習(xí)條件控制。這兩個(gè)組件通過一個(gè)“零卷積”層相互連接——一個(gè)1×1的卷積層,其初始化的權(quán)重和偏置設(shè)為零。卷積權(quán)重逐漸從零過渡到優(yōu)化的參數(shù),確保在初始訓(xùn)練步驟中,可訓(xùn)練和鎖定副本的輸出與在沒有ControlNet的情況下保持一致。換句話說,當(dāng)ControlNet在任何優(yōu)化之前應(yīng)用到某些神經(jīng)網(wǎng)絡(luò)塊時(shí),它不會(huì)對(duì)深度神經(jīng)特征引入任何額外的影響或噪聲。

【編者:描述ControlNet架構(gòu)的設(shè)計(jì)和工作原理。ControlNet架構(gòu)包括兩個(gè)主要部分:一個(gè)可訓(xùn)練的U-Net模型的編碼器,以及一個(gè)預(yù)訓(xùn)練的、被"鎖定"的擴(kuò)散模型。

用一個(gè)例子來幫助理解。假設(shè)你正在學(xué)習(xí)畫畫,你有一個(gè)老師(預(yù)訓(xùn)練的"鎖定"的擴(kuò)散模型)和一個(gè)可以修改的畫布(可訓(xùn)練的U-Net模型的編碼器)。你的老師已經(jīng)是一個(gè)經(jīng)驗(yàn)豐富的藝術(shù)家,他的技能(生成能力)是固定的,不會(huì)改變。而你的畫布是可以修改的,你可以在上面嘗試不同的畫法,學(xué)習(xí)如何畫畫(學(xué)習(xí)條件控制)。

這兩個(gè)部分通過一個(gè)"零卷積"層相互連接。這個(gè)"零卷積"層就像是一個(gè)透明的過濾器,它最初不會(huì)改變?nèi)魏螙|西(因?yàn)樗臋?quán)重和偏置都初始化為零),所以你可以看到你的老師的原始畫作。然后,隨著你的學(xué)習(xí)進(jìn)步,這個(gè)過濾器會(huì)逐漸改變(卷積權(quán)重從零過渡到優(yōu)化的參數(shù)),開始對(duì)你的老師的畫作進(jìn)行修改,使其更符合你的風(fēng)格。

這個(gè)設(shè)計(jì)確保了在初始訓(xùn)練步驟中,可訓(xùn)練部分和"鎖定"部分的輸出是一致的,即你的畫作最初會(huì)和你的老師的畫作一樣。然后,隨著訓(xùn)練的進(jìn)行,你的畫作會(huì)逐漸展現(xiàn)出你自己的風(fēng)格,但仍然保持著你老師的基本技巧?!?/span>

通過整合這兩個(gè)組件,ControlNet使我們能夠控制Diffusion Model(擴(kuò)散模型)的U-Net中每個(gè)級(jí)別的行為。

在我們的例子中,我們?cè)谶@個(gè)腳本中實(shí)例化了ControlNet,使用了以下等效的代碼片段。

import torch
from generative.networks.nets import ControlNet, DiffusionModelUNet

# Load pre-trained diffusion model
diffusion_model = DiffusionModelUNet(
    spatial_dims=2,
    in_channels=3,
    out_channels=3,
    num_res_blocks=2,
    num_channels=[256, 512, 768],
    attention_levels=[False, True, True],
    with_conditioning=True,
    cross_attention_dim=1024,
    num_head_channels=[0, 512, 768],
)
diffusion_model.load_state_dict(torch.load("diffusion_model.pt"))

# Create ControlNet
controlnet = ControlNet(
    spatial_dims=2,
    in_channels=3,
    num_res_blocks=2,
    num_channels=[256, 512, 768],
    attention_levels=[False, True, True],
    with_conditioning=True,
    cross_attention_dim=1024,
    num_head_channels=[0, 512, 768],
    conditioning_embedding_in_channels=1,
    conditioning_embedding_num_channels=[64, 128, 128, 256],
)

# Create trainable copy of the diffusion model
controlnet.load_state_dict(diffusion_model.state_dict(), strict=False)

# Lock the weighht of the diffusion model
for p in diffusion_model.parameters():
    p.requires_grad = False

由于我們使用的是Latent Diffusion Model,這需要ControlNets將基于圖像的條件轉(zhuǎn)換為相同的潛在空間,以匹配卷積的大小。為此,我們使用一個(gè)與完整模型一起訓(xùn)練的卷積網(wǎng)絡(luò)。在我們的案例中,我們有三個(gè)下采樣級(jí)別(類似于自動(dòng)編碼器KL),定義在“conditioning_embedding_num_channels=[64, 128, 128, 256]”。由于我們的條件圖像是一個(gè)FLAIR圖像,只有一個(gè)通道,我們也需要在“conditioning_embedding_in_channels=1”中指定其輸入通道的數(shù)量。

初始化我們的網(wǎng)絡(luò)后,我們像訓(xùn)練擴(kuò)散模型一樣訓(xùn)練它。在以下的代碼片段(以及代碼的這部分)中,可以看到首先將條件FLAIR圖像傳遞給可訓(xùn)練的網(wǎng)絡(luò),并從其跳過連接中獲取輸出。然后,當(dāng)計(jì)算預(yù)測(cè)的噪聲時(shí),這些值被輸入到擴(kuò)散模型中。在內(nèi)部,擴(kuò)散模型將ControlNets的跳過連接與自己的連接相加,然后在饋送解碼器部分之前(代碼)。以下是訓(xùn)練循環(huán)的一部分:

# Training Loop...
images = batch["t1w"].to(device)
cond = batch["flair"].to(device)
...
noise = torch.randn_like(latent_representation).to(device)
noisy_z = scheduler.add_noise(
    original_samples=latent_representation, noise=noise, timesteps=timesteps)

# Compute trainable part
down_block_res_samples, mid_block_res_sample = controlnet(
    x=noisy_z, timesteps=timesteps, context=prompt_embeds, controlnet_cond=cond)

# Using controlnet outputs to control diffusion model behaviour
noise_pred = diffusion_model(
    x=noisy_z,
    timesteps=timesteps,
    context=prompt_embeds,
    down_block_additional_residuals=down_block_res_samples,
    mid_block_additional_residual=mid_block_res_sample,)

# Then compute diffusion model loss as usual...

ControlNet采樣和評(píng)估

在訓(xùn)練模型之后,我們可以對(duì)它們進(jìn)行采樣和評(píng)估。在這里,使用測(cè)試集中的FLAIR圖像來生成條件化的T1w圖像。與我們的訓(xùn)練類似,采樣過程非常接近于擴(kuò)散模型使用的過程,唯一的區(qū)別是將條件圖像傳遞給訓(xùn)練過的ControlNet,并在每個(gè)采樣時(shí)間步中使用其輸出來饋送擴(kuò)散模型。如下圖所示,我們生成的圖像具有高空間保真度的原始條件,皮層回旋遵循類似的形狀,圖像保留了不同組織之間的邊界。

測(cè)試集中的原始FLAIR圖像作為輸入到ControlNet(左),生成的T1加權(quán)圖像(中),和原始的T1加權(quán)圖像,也就是預(yù)期的輸出(右)

在我們對(duì)模型的圖像進(jìn)行采樣之后,可以量化ControlNet在將圖像在不同對(duì)比度之間轉(zhuǎn)換時(shí)的性能。由于我們從測(cè)試集中得到了預(yù)期的T1w圖像,我們也可以檢查它們的差異,并使用平均絕對(duì)誤差(MAE)、峰值信噪比(PSNR)和MS-SSIM計(jì)算真實(shí)和合成圖像之間的距離。在我們的測(cè)試集中,當(dāng)執(zhí)行這個(gè)腳本時(shí),我們得到了PSNR= 26.2458+-1.0092,MAE=0.02632+-0.0036和MSSIM=0.9526+-0.0111。

ControlNet為我們的擴(kuò)散模型提供了令人難以置信的控制,最近的方法已經(jīng)擴(kuò)展了其方法,結(jié)合了不同的訓(xùn)練ControlNets(Multi-ControlNet),在同一模型中處理不同類型的條件(T2I適配器),甚至在模型上設(shè)置條件(使用像ControlNet 1.1這樣的方法)。如果這些方法聽起來很有趣,不要猶豫,嘗試一下!

總結(jié)

在這篇文章中,我們展示了如何使用ControlNets來控制Latent Diffusion Models生成醫(yī)學(xué)圖像的過程。我們的目標(biāo)是展示這些模型在將腦圖像轉(zhuǎn)換為各種對(duì)比度的能力。為了實(shí)現(xiàn)這一目標(biāo),我們利用了最近推出的MONAI的開源擴(kuò)展,即MONAI Generative Models!我們希望這篇文章能幫助你理解如何使用這些工具,并鼓勵(lì)你在你自己的項(xiàng)目中使用它們。

作者介紹

崔皓,51CTO社區(qū)編輯,資深架構(gòu)師,擁有18年的軟件開發(fā)和架構(gòu)經(jīng)驗(yàn),10年分布式架構(gòu)經(jīng)驗(yàn)。

原文標(biāo)題:Controllable Medical Image Generation with ControlNets,作者:Walter Hugo Lopez Pinaya

責(zé)任編輯:華軒 來源: 51CTO
點(diǎn)贊
收藏

51CTO技術(shù)棧公眾號(hào)