導(dǎo)讀
在不同的任務(wù)上對(duì)比了UNet和UNet++以及使用不同的預(yù)訓(xùn)練編碼器的效果。
介紹
語(yǔ)義分割是計(jì)算機(jī)視覺(jué)的一個(gè)問(wèn)題,我們的任務(wù)是使用圖像作為輸入,為圖像中的每個(gè)像素分配一個(gè)類(lèi)。在語(yǔ)義分割的情況下,我們不關(guān)心是否有同一個(gè)類(lèi)的多個(gè)實(shí)例(對(duì)象),我們只是用它們的類(lèi)別來(lái)標(biāo)記它們。有多種關(guān)于不同計(jì)算機(jī)視覺(jué)問(wèn)題的介紹課程,但用一張圖片可以總結(jié)不同的計(jì)算機(jī)視覺(jué)問(wèn)題:
語(yǔ)義分割在生物醫(yī)學(xué)圖像分析中有著廣泛的應(yīng)用:x射線(xiàn)、MRI掃描、數(shù)字病理、顯微鏡、內(nèi)窺鏡等。https://grand-challenge.org/challenges上有許多不同的有趣和重要的問(wèn)題有待探索。
從技術(shù)角度來(lái)看,如果我們考慮語(yǔ)義分割問(wèn)題,對(duì)于N×M×3(假設(shè)我們有一個(gè)RGB圖像)的圖像,我們希望生成對(duì)應(yīng)的映射N(xiāo)×M×k(其中k是類(lèi)的數(shù)量)。有很多架構(gòu)可以解決這個(gè)問(wèn)題,但在這里我想談?wù)剝蓚€(gè)特定的架構(gòu),Unet和Unet++。
有許多關(guān)于Unet的評(píng)論,它如何永遠(yuǎn)地改變了這個(gè)領(lǐng)域。它是一個(gè)統(tǒng)一的非常清晰的架構(gòu),由一個(gè)編碼器和一個(gè)解碼器組成,前者生成圖像的表示,后者使用該表示來(lái)構(gòu)建分割。每個(gè)空間分辨率的兩個(gè)映射連接在一起(灰色箭頭),因此可以將圖像的兩種不同表示組合在一起。并且它成功了!
接下來(lái)是使用一個(gè)訓(xùn)練好的編碼器??紤]圖像分類(lèi)的問(wèn)題,我們?cè)噲D建立一個(gè)圖像的特征表示,這樣不同的類(lèi)在該特征空間可以被分開(kāi)。我們可以(幾乎)使用任何CNN,并將其作為一個(gè)編碼器,從編碼器中獲取特征,并將其提供給我們的解碼器。據(jù)我所知,Iglovikov & Shvets 使用了VGG11和resnet34分別為Unet解碼器以生成更好的特征和提高其性能。
Unet++是最近對(duì)Unet體系結(jié)構(gòu)的改進(jìn),它有多個(gè)跳躍連接。
根據(jù)論文, Unet++的表現(xiàn)似乎優(yōu)于原來(lái)的Unet。就像在Unet中一樣,這里可以使用多個(gè)編碼器(骨干)來(lái)為輸入圖像生成強(qiáng)特征。
我應(yīng)該使用哪個(gè)編碼器?
這里我想重點(diǎn)介紹Unet和Unet++,并比較它們使用不同的預(yù)訓(xùn)練編碼器的性能。為此,我選擇使用胸部x光數(shù)據(jù)集來(lái)分割肺部。這是一個(gè)二值分割,所以我們應(yīng)該給每個(gè)像素分配一個(gè)類(lèi)為“1”的概率,然后我們可以二值化來(lái)制作一個(gè)掩碼。首先,讓我們看看數(shù)據(jù)。
這些是非常大的圖像,通常是2000×2000像素,有很大的mask,從視覺(jué)上看,找到肺不是問(wèn)題。使用segmentation_models_pytorch庫(kù),我們?yōu)閁net和Unet++使用100+個(gè)不同的預(yù)訓(xùn)練編碼器。我們做了一個(gè)快速的pipeline來(lái)訓(xùn)練模型,使用Catalyst (pytorch的另一個(gè)庫(kù),這可以幫助你訓(xùn)練模型,而不必編寫(xiě)很多無(wú)聊的代碼)和Albumentations(幫助你應(yīng)用不同的圖像轉(zhuǎn)換)。
- 定義數(shù)據(jù)集和增強(qiáng)。我們將調(diào)整圖像大小為256×256,并對(duì)訓(xùn)練數(shù)據(jù)集應(yīng)用一些大的增強(qiáng)。
importalbumentationsasA
fromtorch.utils.dataimportDataset,DataLoader
fromcollectionsimportOrderedDict
classChestXRayDataset(Dataset):
def__init__(
self,
images,
masks,
transforms):
self.images=images
self.masks=masks
self.transforms=transforms
def__len__(self):
return(len(self.images))
def__getitem__(self,idx):
"""Willloadthemask,getrandomcoordinatesaround/withthemask,
loadtheimagebycoordinates
"""
sample_image=imread(self.images[idx])
iflen(sample_image.shape)==3:
sample_image=sample_image[...,0]
sample_image=np.expand_dims(sample_image,2)/255
sample_mask=imread(self.masks[idx])/255
iflen(sample_mask.shape)==3:
sample_mask=sample_mask[...,0]
augmented=self.transforms(image=sample_image,mask=sample_mask)
sample_image=augmented['image']
sample_mask=augmented['mask']
sample_image=sample_image.transpose(2,0,1)#channelsfirst
sample_mask=np.expand_dims(sample_mask,0)
data={'features':torch.from_numpy(sample_image.copy()).float(),
'mask':torch.from_numpy(sample_mask.copy()).float()}
return(data)
defget_valid_transforms(crop_size=256):
returnA.Compose(
[
A.Resize(crop_size,crop_size),
],
p=1.0)
deflight_training_transforms(crop_size=256):
returnA.Compose([
A.RandomResizedCrop(height=crop_size,width=crop_size),
A.OneOf(
[
A.Transpose(),
A.VerticalFlip(),
A.HorizontalFlip(),
A.RandomRotate90(),
A.NoOp()
],p=1.0),
])
defmedium_training_transforms(crop_size=256):
returnA.Compose([
A.RandomResizedCrop(height=crop_size,width=crop_size),
A.OneOf(
[
A.Transpose(),
A.VerticalFlip(),
A.HorizontalFlip(),
A.RandomRotate90(),
A.NoOp()
],p=1.0),
A.OneOf(
[
A.CoarseDropout(max_holes=16,max_height=16,max_width=16),
A.NoOp()
],p=1.0),
])
defheavy_training_transforms(crop_size=256):
returnA.Compose([
A.RandomResizedCrop(height=crop_size,width=crop_size),
A.OneOf(
[
A.Transpose(),
A.VerticalFlip(),
A.HorizontalFlip(),
A.RandomRotate90(),
A.NoOp()
],p=1.0),
A.ShiftScaleRotate(p=0.75),
A.OneOf(
[
A.CoarseDropout(max_holes=16,max_height=16,max_width=16),
A.NoOp()
],p=1.0),
])
defget_training_trasnforms(transforms_type):
iftransforms_type=='light':
return(light_training_transforms())
eliftransforms_type=='medium':
return(medium_training_transforms())
eliftransforms_type=='heavy':
return(heavy_training_transforms())
else:
raiseNotImplementedError("Notimplementedtransformationconfiguration")
- 定義模型和損失函數(shù)。這里我們使用帶有regnety_004編碼器的Unet++,并使用RAdam + Lookahed優(yōu)化器使用DICE + BCE損失之和進(jìn)行訓(xùn)練。
importtorch
importsegmentation_models_pytorchassmp
importnumpyasnp
importmatplotlib.pyplotasplt
fromcatalystimportdl,metrics,core,contrib,utils
importtorch.nnasnn
fromskimage.ioimportimread
importos
fromsklearn.model_selectionimporttrain_test_split
fromcatalyst.dlimportCriterionCallback,MetricAggregationCallback
encoder='timm-regnety_004'
model=smp.UnetPlusPlus(encoder,classes=1,in_channels=1)
#model.cuda()
learning_rate=5e-3
encoder_learning_rate=5e-3/10
layerwise_params={"encoder*":dict(lr=encoder_learning_rate,weight_decay=0.00003)}
model_params=utils.process_model_params(model,layerwise_params=layerwise_params)
base_optimizer=contrib.nn.RAdam(model_params,lr=learning_rate,weight_decay=0.0003)
optimizer=contrib.nn.Lookahead(base_optimizer)
scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,factor=0.25,patience=10)
criterion={
"dice":DiceLoss(mode='binary'),
"bce":nn.BCEWithLogitsLoss()
}
- 定義回調(diào)函數(shù)并訓(xùn)練!
callbacks=[
#Eachcriterioniscalculatedseparately.
CriterionCallback(
input_key="mask",
prefix="loss_dice",
criterion_key="dice"
),
CriterionCallback(
input_key="mask",
prefix="loss_bce",
criterion_key="bce"
),
#Andonlythenweaggregateeverythingintooneloss.
MetricAggregationCallback(
prefix="loss",
mode="weighted_sum",
metrics={
"loss_dice":1.0,
"loss_bce":0.8
},
),
#metrics
IoUMetricsCallback(
mode='binary',
input_key='mask',
)
]
runner=dl.SupervisedRunner(input_key="features",input_target_key="mask")
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
loaders=loaders,
callbacks=callbacks,
logdir='../logs/xray_test_log',
num_epochs=100,
main_metric="loss",
minimize_metric=True,
verbose=True,
)
如果我們用不同的編碼器對(duì)Unet和Unet++進(jìn)行驗(yàn)證,我們可以看到每個(gè)訓(xùn)練模型的驗(yàn)證質(zhì)量,并總結(jié)如下:
我們注意到的第一件事是,在所有編碼器中,Unet++的性能似乎都比Unet好。當(dāng)然,有時(shí)這種差異并不是很大,我們不能說(shuō)它們?cè)诮y(tǒng)計(jì)上是否完全不同 —— 我們需要在多個(gè)folds上訓(xùn)練,看看分?jǐn)?shù)分布,單點(diǎn)不能證明任何事情。第二,resnest200e顯示了最高的質(zhì)量,同時(shí)仍然有合理的參數(shù)數(shù)量。有趣的是,如果我們看看https://paperswithcode.com/task/semantic-segmentation,我們會(huì)發(fā)現(xiàn)resnest200在一些基準(zhǔn)測(cè)試中也是SOTA。
好的,但是讓我們用Unet++和Unet使用resnest200e編碼器來(lái)比較不同的預(yù)測(cè)。
在某些個(gè)別情況下,Unet++實(shí)際上比Unet更糟糕。但總的來(lái)說(shuō)似乎更好一些。
一般來(lái)說(shuō),對(duì)于分割網(wǎng)絡(luò)來(lái)說(shuō),這個(gè)數(shù)據(jù)集看起來(lái)是一個(gè)容易的任務(wù)。讓我們?cè)谝粋€(gè)更難的任務(wù)上測(cè)試Unet++。為此,我使用PanNuke數(shù)據(jù)集,這是一個(gè)帶標(biāo)注的組織學(xué)數(shù)據(jù)集(205,343個(gè)標(biāo)記核,19種不同的組織類(lèi)型,5個(gè)核類(lèi))。數(shù)據(jù)已經(jīng)被分割成3個(gè)folds。
我們可以使用類(lèi)似的代碼在這個(gè)數(shù)據(jù)集上訓(xùn)練Unet++模型,如下所示:
我們?cè)谶@里看到了相同的模式 - resnest200e編碼器似乎比其他的性能更好。我們可以用兩個(gè)不同的模型(最好的是resnest200e編碼器,最差的是regnety_002)來(lái)可視化一些例子。
我們可以肯定地說(shuō),這個(gè)數(shù)據(jù)集是一項(xiàng)更難的任務(wù) —— 不僅mask不夠精確,而且個(gè)別的核被分配到錯(cuò)誤的類(lèi)別。然而,使用resnest200e編碼器的Unet++仍然表現(xiàn)很好。
總結(jié)
這不是一個(gè)全面語(yǔ)義分割的指導(dǎo),這更多的是一個(gè)想法,使用什么來(lái)獲得一個(gè)堅(jiān)實(shí)的基線(xiàn)。有很多模型、FPN,DeepLabV3, Linknet與Unet有很大的不同,有許多Unet-like架構(gòu),例如,使用雙編碼器的Unet,MAnet,PraNet,U2-net — 有很多的型號(hào)供你選擇,其中一些可能在你的任務(wù)上表現(xiàn)的比較好,但是,一個(gè)堅(jiān)實(shí)的基線(xiàn)可以幫助你從正確的方向上開(kāi)始。
-
編碼器
+關(guān)注
關(guān)注
45文章
3875瀏覽量
140528 -
醫(yī)學(xué)影像
+關(guān)注
關(guān)注
1文章
112瀏覽量
17707
原文標(biāo)題:UNet 和 UNet++:醫(yī)學(xué)影像經(jīng)典分割網(wǎng)絡(luò)對(duì)比
文章出處:【微信號(hào):CVSCHOOL,微信公眾號(hào):OpenCV學(xué)堂】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
RK3576 yolo11-seg訓(xùn)練部署教程

【正點(diǎn)原子STM32MP257開(kāi)發(fā)板試用】基于 DeepLab 模型的圖像分割
東軟集團(tuán)入選國(guó)家數(shù)據(jù)局?jǐn)?shù)據(jù)標(biāo)注優(yōu)秀案例
RK3576 yolov11-seg訓(xùn)練部署教程

醫(yī)療設(shè)備工業(yè)成像采集卡:提升醫(yī)療影像診斷水平的關(guān)鍵組件

無(wú)法在在DL Workbench中導(dǎo)入unet-camvid-onnx-0001模型之前下載CamVid數(shù)據(jù)集?
中信建投報(bào)告泄密,AI硬件正在重塑醫(yī)療影像與IVD領(lǐng)域的未來(lái)

三維測(cè)量在醫(yī)療領(lǐng)域的應(yīng)用
英特爾助力東軟PACS&RIS賦能三維可視化與AI輔助診斷

NVIDIA助力西門(mén)子醫(yī)療加速醫(yī)學(xué)影像AI部署
經(jīng)典圖神經(jīng)網(wǎng)絡(luò)(GNNs)的基準(zhǔn)分析研究

評(píng)論