如何為您的機(jī)器學(xué)習(xí)問題選擇正確的預(yù)訓(xùn)練模型
在這篇文章中,我們將簡(jiǎn)要介紹一下遷移學(xué)習(xí)是什么,以及如何使用它。
什么是遷移學(xué)習(xí)?
遷移學(xué)習(xí)是使用預(yù)訓(xùn)練模型解決深度學(xué)習(xí)問題的藝術(shù)。
遷移學(xué)習(xí)是一種機(jī)器學(xué)習(xí)技術(shù),你可以使用一個(gè)預(yù)訓(xùn)練好的神經(jīng)網(wǎng)絡(luò)來解決一個(gè)問題,這個(gè)問題類似于網(wǎng)絡(luò)最初訓(xùn)練用來解決的問題。例如,您可以利用構(gòu)建好的用于識(shí)別狗的品種的深度學(xué)習(xí)模型來對(duì)狗和貓進(jìn)行分類,而不是構(gòu)建您自己的模型。這可以為您省去尋找有效的神經(jīng)網(wǎng)絡(luò)體系結(jié)構(gòu)的痛苦,可以為你節(jié)省花在訓(xùn)練上的時(shí)間,并可以保證有良好的結(jié)果。也就是說,你可以花很長(zhǎng)時(shí)間來制作一個(gè)50層的CNN來***地區(qū)分你的貓和狗,或者你可以簡(jiǎn)單地使用許多預(yù)訓(xùn)練好的圖像分類模型。
使用預(yù)訓(xùn)練模型的三種不同方式
主要有三種不同的方式可以重新定位預(yù)訓(xùn)練模型。他們是,
- 特征提取 。
- 復(fù)制預(yù)訓(xùn)練的網(wǎng)絡(luò)的體系結(jié)構(gòu)。
- 凍結(jié)一些層并訓(xùn)練其他層。
特征提取:這里我們所需要做的就是改變輸出層,以給出cat和dog的概率(或者您的模型試圖將內(nèi)容分類到的類的數(shù)量),而不是最初訓(xùn)練它將內(nèi)容分類到的數(shù)千個(gè)類。當(dāng)我們?cè)噲D訓(xùn)練模型所使用的數(shù)據(jù)與預(yù)訓(xùn)練的模型最初所訓(xùn)練的數(shù)據(jù)非常相似且數(shù)據(jù)集的大小很小時(shí),這是理想的。這種機(jī)制稱為固定特征提取。我們只對(duì)添加的新輸出層進(jìn)行重新訓(xùn)練,并保留每一層的權(quán)重。
復(fù)制預(yù)訓(xùn)練網(wǎng)絡(luò)的架構(gòu) :在這里,我們定義了一個(gè)與預(yù)訓(xùn)練模型具有相同體系結(jié)構(gòu)的機(jī)器學(xué)習(xí)模型,該模型在執(zhí)行與我們?cè)噲D實(shí)現(xiàn)的任務(wù)類似的任務(wù)時(shí)顯示了出色的結(jié)果,并從頭開始訓(xùn)練它。我們從預(yù)訓(xùn)練的模型中丟棄每一層的權(quán)重,然后根據(jù)我們的數(shù)據(jù)重新訓(xùn)練整個(gè)模型。當(dāng)我們有大量的數(shù)據(jù)要訓(xùn)練時(shí),我們會(huì)采用這種方法,但它與訓(xùn)練前的模型所訓(xùn)練的數(shù)據(jù)并不十分相似。
凍結(jié)一些層并訓(xùn)練其他層:我們可以選擇凍結(jié)一個(gè)預(yù)訓(xùn)練模型的初始k層,只訓(xùn)練最頂層的n-k層。我們保持初始值的權(quán)重與預(yù)訓(xùn)練模型的權(quán)重相同且不變,并對(duì)數(shù)據(jù)的高層進(jìn)行再訓(xùn)練。當(dāng)數(shù)據(jù)集較小且數(shù)據(jù)相似度較低時(shí),采用該方法。較低的層主要關(guān)注可以從數(shù)據(jù)中提取的最基本的信息,因此可以將其用于其他問題,因?yàn)榛炯?jí)別的信息通常是相同的。
另一種常見情況是數(shù)據(jù)相似性高且數(shù)據(jù)集也很大。在這種情況下,我們保留模型的體系結(jié)構(gòu)和模型的初始權(quán)重。然后,我們對(duì)整個(gè)模型進(jìn)行再訓(xùn)練,以更新預(yù)訓(xùn)練模型的權(quán)重,以更好地適應(yīng)我們的特定問題。這是使用遷移學(xué)習(xí)的理想情況。
下圖顯示了隨著數(shù)據(jù)集大小和數(shù)據(jù)相似性的變化而采用的方法。

PyTorch中的遷移學(xué)習(xí)
在torchvision.models模塊下,PyTorch中有八種不同的預(yù)訓(xùn)練模型。他們是 :
- AlexNet
- VGG
- RESNET
- SqueezeNet
- DenseNet
- Inception v3
- GoogLeNet
- ShuffleNet v2
這些都是為圖像分類而構(gòu)建的卷積神經(jīng)網(wǎng)絡(luò),在ImageNet數(shù)據(jù)集上進(jìn)行訓(xùn)練。ImageNet是根據(jù)WordNet層次結(jié)構(gòu)組織的圖像數(shù)據(jù)庫,包含14,197,122張屬于21841類的圖像。

由于PyTorch中的所有預(yù)訓(xùn)練模型都針對(duì)相同的任務(wù)在相同的數(shù)據(jù)集上進(jìn)行訓(xùn)練,所以我們選擇哪一個(gè)并不重要。讓我們選擇ResNet網(wǎng)絡(luò),看看如何在前面討論的不同場(chǎng)景中使用它。
用于圖像識(shí)別的ResNet或深度殘差學(xué)習(xí)在pytorch、ResNet -18、ResNet -34、ResNet -50、ResNet -101和ResNet -152上有五個(gè)版本。
讓我們從torchvision下載ResNet-18。
- import torchvision.models as models
- model = models.resnet18(pretrained=True)

以下是我們剛剛下載的模型。


現(xiàn)在,讓我們看看嘗試,看看如何針對(duì)四個(gè)不同的問題訓(xùn)練這個(gè)模型。
數(shù)據(jù)集很小,數(shù)據(jù)相似性很高
考慮這個(gè)kaggle數(shù)據(jù)集(https://www.kaggle.com/mriganksingh/cat-images-dataset)。這包括貓的圖像和其他非貓的圖像。它有209個(gè)像素64*64*3的訓(xùn)練圖像和50個(gè)測(cè)試圖像。這顯然是一個(gè)非常小的數(shù)據(jù)集,但我們知道ResNet是在大量動(dòng)物和貓圖像上訓(xùn)練的,所以我們可以使用ResNet作為固定特征提取器來解決我們的貓與非貓的問題。
- num_ftrs = model.fc.in_features
- num_ftrs

Out: 512
- model.fc.out_features
Out: 1000
我們需要凍結(jié)除***一層之外的所有網(wǎng)絡(luò)。我們需要設(shè)置requires_grad = False來凍結(jié)參數(shù),這樣就不會(huì)在backward()中計(jì)算梯度。新構(gòu)造模塊的參數(shù)默認(rèn)為requires_grad=True。
- for param in model.parameters():
- param.requires_grad = False

由于我們只需要***一層提供兩個(gè)概率,即圖像的概率是否為cat,我們可以重新定義***一層中的輸出特征數(shù)。
- model.fc = nn.Linear(num_ftrs, 2)
這是我們模型的新架構(gòu)。


我們現(xiàn)在要做的就是訓(xùn)練模型的***一層,我們將能夠使用我們重新定位的vgg16來預(yù)測(cè)圖像是否是貓,而且數(shù)據(jù)和訓(xùn)練時(shí)間都非常少。
數(shù)據(jù)的大小很小,數(shù)據(jù)相似性也很低
考慮來自(https://www.kaggle.com/kvinicki/canine-coccidiosis),這個(gè)數(shù)據(jù)集包含了犬異孢球蟲和犬異孢球蟲卵囊的圖像和標(biāo)簽,異孢球蟲卵囊是一種球蟲寄生蟲,可感染狗的腸道。它是由薩格勒布獸醫(yī)學(xué)院創(chuàng)建的。它包含了兩種寄生蟲的341張圖片。

這個(gè)數(shù)據(jù)集很小,而且不是Imagenet中的一個(gè)類別。在這種情況下,我們保留預(yù)先訓(xùn)練好的模型架構(gòu),凍結(jié)較低的層并保留它們的權(quán)重,并訓(xùn)練較低的層更新它們的權(quán)重以適應(yīng)我們的問題。
- count = 0
- for child in model.children():
- count+=1
- print(count)

Out: 10
ResNet18共有10層。讓我們凍結(jié)前6層。
- count = 0
- for child in model.children():
- count+=1
- if count < 7:
- for param in child.parameters():
- param.requires_grad = False

現(xiàn)在我們已經(jīng)凍結(jié)了前6層,讓我們重新定義最終輸出層,只給出2個(gè)輸出,而不是1000。
- model.fc = nn.Linear(num_ftrs, 2)
這是更新的架構(gòu)。


現(xiàn)在,訓(xùn)練這個(gè)機(jī)器學(xué)習(xí)模型,更新***4層的權(quán)重。
數(shù)據(jù)集的大小很大,但數(shù)據(jù)相似性非常低
考慮這個(gè)來自kaggle,皮膚癌MNIST的數(shù)據(jù)集:HAM10000
其具有超過10015個(gè)皮膚鏡圖像,屬于7種不同類別。這不是我們?cè)贗magenet中可以找到的那種數(shù)據(jù)。
這就是我們只保留模型架構(gòu)而不保留來自預(yù)訓(xùn)練模型的任何權(quán)重的地方。讓我們重新定義輸出層,將項(xiàng)目分類為7個(gè)類別。
- model.fc = nn.Linear(num_ftrs, 7)
這個(gè)模型需要幾個(gè)小時(shí)才能在沒有GPU的機(jī)器上進(jìn)行訓(xùn)練,但是如果你運(yùn)行足夠的時(shí)代,你仍然會(huì)得到很好的結(jié)果,而不必定義你自己的模型架構(gòu)。
數(shù)據(jù)大小很大,數(shù)據(jù)相似性很高
考慮來自kaggle 的鮮花數(shù)據(jù)集(https://www.kaggle.com/alxmamaev/flowers-recognition)。它包含4242個(gè)花卉圖像。圖片分為五類:洋甘菊,郁金香,玫瑰,向日葵,蒲公英。每個(gè)類大約有800張照片。
這是應(yīng)用遷移學(xué)習(xí)的理想情況。我們保留了預(yù)訓(xùn)練模型的體系結(jié)構(gòu)和每一層的權(quán)重,并訓(xùn)練模型更新權(quán)重以匹配我們的特定問題。
- model.fc = nn.Linear(num_ftrs, 5)
- best_model_wts = copy.deepcopy(model.state_dict())

我們從預(yù)訓(xùn)練的模型中復(fù)制權(quán)重并初始化我們的模型。我們使用訓(xùn)練和測(cè)試階段來更新這些權(quán)重。
- for epoch in range(num_epochs):
- print(‘Epoch {}/{}’.format(epoch, num_epochs — 1))
- print(‘-’ * 10)
- for phase in [‘train’, ‘test’]:
- if phase == 'train':
- scheduler.step()
- model.train()
- else:
- model.eval()
- running_loss = 0.0
- running_corrects = 0
- for inputs, labels in dataloaders[phase]:
- inputs = inputs.to(device)
- labels = labels.to(device)
- optimizer.zero_grad()
- with torch.set_grad_enabled(phase == ‘train’):
- outputs = model(inputs)
- _, preds = torch.max(outputs, 1)
- loss = criterion(outputs, labels)
- if phase == ‘train’:
- loss.backward()
- optimizer.step()
- running_loss += loss.item() * inputs.size(0)
- running_corrects += torch.sum(preds == labels.data)
- epoch_loss = running_loss / dataset_sizes[phase]
- epoch_acc = running_corrects.double() / dataset_sizes[phase]
- print(‘{} Loss: {:.4f} Acc: {:.4f}’.format(
- phase, epoch_loss, epoch_acc))
- if phase == ‘test’ and epoch_acc > best_acc:
- best_acc = epoch_acc
- best_model_wts = copy.deepcopy(model.state_dict())
- print(‘Best val Acc: {:4f}’.format(best_acc))
- model.load_state_dict(best_model_wts)

這種機(jī)器學(xué)習(xí)模式也需要幾個(gè)小時(shí)的訓(xùn)練,但即使只有一個(gè)訓(xùn)練epoch ,也會(huì)產(chǎn)生出色的效果。
您可以按照相同的原則在任何其他平臺(tái)上使用任何其他預(yù)訓(xùn)練的網(wǎng)絡(luò)執(zhí)行遷移學(xué)習(xí)。本文隨機(jī)挑選了Resnet和pytorch。任何其他CNN都會(huì)給出類似的結(jié)果。希望這可以節(jié)省您使用計(jì)算機(jī)視覺解決現(xiàn)實(shí)世界問題的痛苦時(shí)間。