譯者 | 朱先忠
審校 | 重樓
當(dāng)前,許多強(qiáng)大的開(kāi)源基礎(chǔ)模型的發(fā)布,加上微調(diào)技術(shù)的不斷進(jìn)步,已經(jīng)帶來(lái)了機(jī)器學(xué)習(xí)和人工智能的新范式。整體來(lái)看,這場(chǎng)革命的核心在于轉(zhuǎn)換器模型(https://arxiv.org/pdf/1706.03762)。
雖然除了資金充足的公司之外,所有公司都曾經(jīng)無(wú)法獲得高精度的特定領(lǐng)域模型;但是如今,基礎(chǔ)模型范式甚至能夠允許學(xué)生或獨(dú)立研究人員獲得適度的資源,以實(shí)現(xiàn)與最先進(jìn)的專(zhuān)有模型相抗衡的結(jié)果。
微調(diào)技術(shù)可以極大地提高模型任務(wù)的性能
本文旨在探討Meta公司的分割一切模型(SAM:Segment Anything Model)在河流像素分割遙感任務(wù)中的應(yīng)用。如果您想直接跳到有關(guān)項(xiàng)目代碼,那么這個(gè)項(xiàng)目的源文件可以在GitHub(https://github.com/geo-smart/water-surf/blob/main/book/chapters/masking_distributed.ipynb)上獲得,數(shù)據(jù)也可以在HuggingFace(https://huggingface.co/datasets/stodoran/elwha-segmentation-v1)上找到。當(dāng)然,我還是建議您先閱讀一下本文。
項(xiàng)目需求
第一個(gè)任務(wù)是找到或創(chuàng)建一個(gè)合適的數(shù)據(jù)集。根據(jù)現(xiàn)有文獻(xiàn),一個(gè)良好的SAM微調(diào)數(shù)據(jù)集將至少包含200–800張圖像。過(guò)去十年深度學(xué)習(xí)進(jìn)步的一個(gè)關(guān)鍵教訓(xùn)是,數(shù)據(jù)越多越好,這樣一來(lái),更大規(guī)模的微調(diào)數(shù)據(jù)集就不會(huì)出錯(cuò)。然而,基礎(chǔ)模型研究的一個(gè)重要目標(biāo)是,允許即使是相對(duì)較小的數(shù)據(jù)集也足以實(shí)現(xiàn)強(qiáng)大的性能。
此外,我們還需要有一個(gè)HuggingFace帳戶,這可以在鏈接https://huggingface.co/join處創(chuàng)建。使用HuggingFace,我們可以隨時(shí)從任何設(shè)備輕松存儲(chǔ)和獲取數(shù)據(jù)集,這使得協(xié)作和再現(xiàn)性更加容易。
最后一個(gè)需求是,具備一臺(tái)帶有GPU的設(shè)備,我們可以在其上運(yùn)行訓(xùn)練工作流程。通過(guò)Google Colab(https://colab.research.google.com/)免費(fèi)提供的Nvidia T4 GPU足夠強(qiáng)大,可以在12小時(shí)內(nèi)對(duì)1000張圖像進(jìn)行50個(gè)時(shí)期的最大SAM模型檢查點(diǎn)(sam-vit-huge)訓(xùn)練。
為了避免在托管運(yùn)行時(shí)因使用限制而影響進(jìn)度,您可以安裝Google Drive并將每個(gè)模型檢查點(diǎn)保存在那里。或者,部署并連接到GCP虛擬機(jī)(https://console.cloud.google.com/marketplace/product/colab-marketplace-image-public/colab)以完全繞過(guò)限制。如果您以前從未使用過(guò)GCP,那么您就有資格獲得300美元的免費(fèi)信貸,這足以支持對(duì)模型進(jìn)行至少十幾次的訓(xùn)練了。
理解SAM架構(gòu)
在開(kāi)始訓(xùn)練之前,我們需要先來(lái)了解一下SAM模型的架構(gòu)。該模型包含三個(gè)組件:一個(gè)是從稍經(jīng)修改的掩碼自動(dòng)編碼器(https://arxiv.org/pdf/2111.06377)得到的圖像編碼器,一個(gè)能夠處理各種提示類(lèi)型的相當(dāng)靈活的提示編碼器,還有一個(gè)快速輕量級(jí)的掩碼解碼器。這種設(shè)計(jì)架構(gòu)背后的一個(gè)重要?jiǎng)訖C(jī)是允許在邊緣設(shè)備上(例如在瀏覽器中)進(jìn)行快速、實(shí)時(shí)的分割,因?yàn)閳D像嵌入只需要計(jì)算一次,并且掩碼解碼器可以在CPU上運(yùn)行約50ms。
SAM的模型架構(gòu)向我們展示了模型接受哪些輸入以及需要訓(xùn)練模型的哪些部分(圖片來(lái)源于SAM GitHub:https://github.com/facebookresearch/segment-anything)。
理論上,圖像編碼器已經(jīng)學(xué)會(huì)了嵌入圖像的最佳方式,包括識(shí)別形狀、邊緣和其他一般視覺(jué)特征等。類(lèi)似地,在理論上,提示編碼器已經(jīng)能夠以最優(yōu)方式對(duì)提示進(jìn)行編碼。掩碼解碼器是模型架構(gòu)的一部分,它采用這些圖像和提示嵌入,并通過(guò)對(duì)圖像和提示嵌入式進(jìn)行操作來(lái)實(shí)際創(chuàng)建掩碼。
因此,一種方法是在訓(xùn)練期間凍結(jié)與圖像和提示編碼器相關(guān)聯(lián)的模型參數(shù),并且僅更新掩碼解碼器權(quán)重。這種方法的優(yōu)點(diǎn)是允許有監(jiān)督和無(wú)監(jiān)督的下游任務(wù),因?yàn)榭刂泣c(diǎn)和邊界框提示都是自動(dòng)的,并且可供人工使用。
圖中顯示了AutoSAM體系架構(gòu)中使用的凍結(jié)SAM圖像編碼器和掩碼解碼器,以及過(guò)載提示編碼器(來(lái)源于AutoSAM論文:https://arxiv.org/pdf/2306.06370)。
另一種方法是使提示編碼器過(guò)載,凍結(jié)圖像編碼器和掩碼解碼器,并且只是簡(jiǎn)單地不使用原始SAM掩碼編碼器。例如,AutoSAM體系架構(gòu)使用基于Harmonic Dense Net的網(wǎng)絡(luò)來(lái)基于圖像本身生成提示嵌入。在本教程中,我們將介紹第一種方法,即凍結(jié)圖像和提示編碼器,只訓(xùn)練掩碼解碼器,但這種替代方法的代碼可以在AutoSAM GitHub(https://github.com/talshaharabany/AutoSAM/blob/main/inference.py)和論文(https://arxiv.org/pdf/2306.06370)中找到。
配置提示
接下來(lái)的一步是確定模型在推理過(guò)程中會(huì)收到什么類(lèi)型的提示,以便我們可以在訓(xùn)練時(shí)提供這種類(lèi)型的提示。就我個(gè)人而言,考慮到自然語(yǔ)言處理的不可預(yù)測(cè)/不一致性,我不建議在任何嚴(yán)肅的計(jì)算機(jī)視覺(jué)項(xiàng)目架構(gòu)中使用文本提示。剩下的解決方案就需要依賴(lài)控制點(diǎn)和邊界框技術(shù)了;但是,最終的選擇還要取決于特定數(shù)據(jù)集的特定性質(zhì),盡管有關(guān)文獻(xiàn)中已經(jīng)指出邊界框方案的表現(xiàn)相當(dāng)一致地優(yōu)于控制點(diǎn)方案。
造成這種情況的原因尚不完全清楚,但可能是以下任何因素之一,或者是這些因素的組合:
- 在推理時(shí)(當(dāng)真實(shí)值掩碼未知時(shí)),好的控制點(diǎn)比邊界框更難選擇。
- 可能的點(diǎn)提示的空間比可能的邊界框提示的空間大幾個(gè)數(shù)量級(jí),因此它沒(méi)有經(jīng)過(guò)徹底的訓(xùn)練。
- 最初的SAM模型作者主要專(zhuān)注于模型的零樣本和少樣本(根據(jù)人工提示交互計(jì)算)功能,因此預(yù)訓(xùn)練可能更多地關(guān)注邊界框。
無(wú)論如何,河流分割實(shí)際上是一種罕見(jiàn)的情況;在這種情況下,點(diǎn)提示方案實(shí)際上優(yōu)于邊界框(盡管只是輕微的,即使是在非常有利的域中)。假設(shè)在河流的任何圖像中,水體將從圖像的一端延伸到另一端,任何包含的邊界框幾乎總是覆蓋圖像的大部分。因此,河流非常不同部分的邊界框提示看起來(lái)非常相似。理論上,這意味著邊界框?yàn)槟P吞峁┑男畔⒈瓤刂泣c(diǎn)少得多;因此,導(dǎo)致性能較差。
控制點(diǎn)、邊界框提示和疊加在兩個(gè)樣本訓(xùn)練圖像上的真實(shí)分割
請(qǐng)注意,在上圖中,盡管兩條河流部分的真實(shí)分割掩碼完全不同,但它們各自的邊界框幾乎相同,而它們的點(diǎn)提示(相對(duì)而言)差異更大。
另一個(gè)需要考慮的重要因素是在推理時(shí)生成輸入提示的容易程度。如果您希望在循環(huán)執(zhí)行階段有人工介入,那么請(qǐng)注意邊界框和控制點(diǎn)在推理階段都是相當(dāng)瑣碎的。然而,如果您打算使用一個(gè)完全自動(dòng)化的架構(gòu)方案,那么回答這些問(wèn)題將變得更加復(fù)雜。
無(wú)論是使用控制點(diǎn)還是邊界框,生成提示通常首先包括估計(jì)感興趣對(duì)象的粗略掩碼。邊界框可以只是包裹粗略掩碼的最小框,而控制點(diǎn)需要從粗略掩碼中采樣。這意味著,當(dāng)真實(shí)值掩碼未知時(shí),邊界框更容易獲得,因?yàn)楦信d趣對(duì)象的估計(jì)掩碼只需要大致匹配真實(shí)對(duì)象的相同大小和位置;而對(duì)于控制點(diǎn),估計(jì)掩碼將需要更緊密地匹配對(duì)象的輪廓。
當(dāng)使用估計(jì)的掩碼而不是真實(shí)值時(shí),控制點(diǎn)的放置可能包括錯(cuò)誤標(biāo)注的點(diǎn),而邊界框通常位于正確的位置
對(duì)于河流分割,如果我們可以同時(shí)使用RGB和NIR,那么我們可以使用光譜指數(shù)閾值方法來(lái)獲得我們的粗略掩模。如果我們只能使用RGB模式,我們可以將圖像轉(zhuǎn)換為HSV模式,并對(duì)特定色調(diào)、飽和度和值范圍內(nèi)的所有像素設(shè)置閾值。然后,我們可以移除低于特定大小閾值的連接內(nèi)容,并使用skimage.morphology子模塊中的erosion函數(shù)來(lái)確保我們的掩模中只有1個(gè)像素是朝向藍(lán)色大斑點(diǎn)中心的像素。
模型訓(xùn)練
為了訓(xùn)練我們的模型,我們需要一個(gè)包含所有訓(xùn)練數(shù)據(jù)的數(shù)據(jù)加載器,我們可以在每個(gè)訓(xùn)練時(shí)期對(duì)這些數(shù)據(jù)進(jìn)行迭代。當(dāng)我們從HuggingFace加載數(shù)據(jù)集時(shí),它采用datasets.Dataset類(lèi)的形式。如果數(shù)據(jù)集是私有的,請(qǐng)確保首先安裝HuggingFace CLI并使用“!huggingface-cli login”方式登錄。
from datasets import load_dataset, load_from_disk, Dataset
hf_dataset_name = "stodoran/elwha-segmentation-v1"
training_data = load_dataset(hf_dataset_name, split="train")
validation_data = load_dataset(hf_dataset_name, split="validation")
然后,我們需要編寫(xiě)自己的自定義數(shù)據(jù)集類(lèi),該類(lèi)不僅返回任何索引的圖像和標(biāo)簽,還返回提示詞信息。下面是一個(gè)可以同時(shí)處理控制點(diǎn)和邊界框提示的實(shí)現(xiàn)。要完成初始化工作,需要一個(gè)HuggingFace datasets.Dataset實(shí)例和SAM模型的處理器實(shí)例。
from torch.utils.data import Dataset
class PromptType:
CONTROL_POINTS = "pts"
BOUNDING_BOX = "bbox"
class SAMDataset(Dataset):
def __init__(
self,
dataset,
processor,
prompt_type = PromptType.CONTROL_POINTS,
num_positive = 3,
num_negative = 0,
erode = True,
multi_mask = "mean",
perturbation = 10,
image_size = (1024, 1024),
mask_size = (256, 256),
):
#將所有值賦給self
...
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
datapoint = self.dataset[idx]
input_image = cv2.resize(np.array(datapoint["image"]), self.image_size)
ground_truth_mask = cv2.resize(np.array(datapoint["label"]), self.mask_size)
if self.prompt_type == PromptType.CONTROL_POINTS:
inputs = self._getitem_ctrlpts(input_image, ground_truth_mask)
elif self.prompt_type == PromptType.BOUNDING_BOX:
inputs = self._getitem_bbox(input_image, ground_truth_mask)
inputs["ground_truth_mask"] = ground_truth_mask
return inputs
我們還必須定義SAMDataset_getitem_ctrlpts和SAMDataset_getitem_box函數(shù),盡管如果您只計(jì)劃使用一種提示類(lèi)型,那么您可以重構(gòu)代碼以直接處理SAMDataset.__getitem__中的該類(lèi)型,并刪除幫助類(lèi)工具函數(shù)。
class SAMDataset(Dataset):
...
def _getitem_ctrlpts(self, input_image, ground_truth_mask):
# 獲取控制點(diǎn)提示。請(qǐng)參閱GitHub獲取該函數(shù)的源代碼,或?qū)⑵涮鎿Q為您自己的點(diǎn)選擇算法。
input_points, input_labels = generate_input_points(
num_positive=self.num_positive,
num_negative=self.num_negative,
mask=ground_truth_mask,
dynamic_distance=True,
erode=self.erode,
)
input_points = input_points.astype(float).tolist()
input_labels = input_labels.tolist()
input_labels = [[x] for x in input_labels]
# 為模型準(zhǔn)備圖像和提示。
inputs = self.processor(
input_image,
input_points=input_points,
input_labels=input_labels,
return_tensors="pt"
)
#刪除處理器默認(rèn)添加的批次維度。
inputs = {k: v.squeeze(0) for k, v in inputs.items()}
inputs["input_labels"] = inputs["input_labels"].squeeze(1)
return inputs
def _getitem_bbox(self, input_image, ground_truth_mask):
#獲取邊界框提示。
bbox = get_input_bbox(ground_truth_mask, perturbation=self.perturbation)
#為模型準(zhǔn)備圖像和提示。
inputs = self.processor(input_image, input_boxes=[[bbox]], return_tensors="pt")
inputs = {k: v.squeeze(0) for k, v in inputs.items()} # 刪除處理器默認(rèn)添加的批次維度。
return inputs
將所有這些功能組合到一起,我們可以創(chuàng)建一個(gè)函數(shù),該函數(shù)在給定HuggingFace數(shù)據(jù)集的任一部分的情況下創(chuàng)建并返回PyTorch數(shù)據(jù)加載器。編寫(xiě)返回?cái)?shù)據(jù)加載器的函數(shù),而不僅僅是用相同的代碼執(zhí)行單元,這不僅是編寫(xiě)靈活和可維護(hù)代碼的好方法,而且如果您計(jì)劃使用HuggingFace Accelerate(https://huggingface.co/docs/accelerate/index)來(lái)運(yùn)行分布式訓(xùn)練的話,這也是必要的。
from transformers import SamProcessor
from torch.utils.data import DataLoader
def get_dataloader(
hf_dataset,
model_size = "base", # One of "base", "large", or "huge"
batch_size = 8,
prompt_type = PromptType.CONTROL_POINTS,
num_positive = 3,
num_negative = 0,
erode = True,
multi_mask = "mean",
perturbation = 10,
image_size = (256, 256),
mask_size = (256, 256),
):
processor = SamProcessor.from_pretrained(f"facebook/sam-vit-{model_size}")
sam_dataset = SAMDataset(
dataset=hf_dataset,
processor=processor,
prompt_type=prompt_type,
num_positive=num_positive,
num_negative=num_negative,
erode=erode,
multi_mask=multi_mask,
perturbation=perturbation,
image_size=image_size,
mask_size=mask_size,
)
dataloader = DataLoader(sam_dataset, batch_size=batch_size, shuffle=True)
return dataloader
在此之后,訓(xùn)練只需加載模型、凍結(jié)圖像和提示編碼器,并進(jìn)行所需次數(shù)的迭代訓(xùn)練。
model = SamModel.from_pretrained(f"facebook/sam-vit-{model_size}")
optimizer = AdamW(model.mask_decoder.parameters(), lr=learning_rate, weight_decay=weight_decay)
# Train only the decoder.
for name, param in model.named_parameters():
if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
param.requires_grad_(False)
以下列出的是訓(xùn)練過(guò)程循環(huán)部分的基本框架代碼。請(qǐng)注意,為了簡(jiǎn)潔起見(jiàn),forward_pass、calculate loss、evaluate_mode和save_model_checkpoint函數(shù)被省略了,但GitHub上提供了實(shí)現(xiàn)。正向傳遞碼根據(jù)提示類(lèi)型略有不同,損失計(jì)算也需要基于提示類(lèi)型的特殊情況;當(dāng)使用點(diǎn)提示時(shí),SAM模型為每個(gè)單個(gè)輸入點(diǎn)返回一個(gè)預(yù)測(cè)掩碼,因此為了獲得可以與真實(shí)數(shù)據(jù)進(jìn)行比較的單個(gè)掩碼,需要對(duì)預(yù)測(cè)掩碼進(jìn)行平均,或者需要選擇最佳預(yù)測(cè)掩碼(基于SAM的預(yù)測(cè)IoU分?jǐn)?shù)來(lái)識(shí)別)。
train_losses = []
validation_losses = []
epoch_loop = tqdm(total=num_epochs, position=epoch, leave=False)
batch_loop = tqdm(total=len(train_dataloader), position=0, leave=True)
while epoch < num_epochs:
epoch_losses = []
batch_loop.n = 0 #循環(huán)重置
for idx, batch in enumerate(train_dataloader):
# 正向傳遞
batch = {k: v.to(accelerator.device) for k, v in batch.items()}
outputs = forward_pass(model, batch, prompt_type)
#計(jì)算損失值
ground_truth_masks = batch["ground_truth_mask"].float()
train_loss = calculate_loss(outputs, ground_truth_masks, prompt_type, loss_fn, multi_mask="best")
epoch_losses.append(train_loss)
# 反向傳遞與優(yōu)化環(huán)節(jié)
optimizer.zero_grad()
accelerator.backward(train_loss)
optimizer.step()
lr_scheduler.step()
batch_loop.set_description(f"Train Loss: {train_loss.item():.4f}")
batch_loop.update(1)
validation_loss = evaluate_model(model, validation_dataloader, accelerator.device, loss_fn)
train_losses.append(torch.mean(torch.Tensor(epoch_losses)))
validation_losses.append(validation_loss)
if validation_loss < best_loss:
save_model_checkpoint(
accelerator,
best_checkpoint_path,
model,
optimizer,
lr_scheduler,
epoch,
train_history,
validation_loss,
train_losses,
validation_losses,
loss_config,
model_descriptor=model_descriptor,
)
best_loss = validation_loss
epoch_loop.set_description(f"Best Loss: {best_loss:.4f}")
epoch_loop.update(1)
epoch += 1
微調(diào)結(jié)果分析
對(duì)于艾爾瓦河項(xiàng)目,最佳設(shè)置是在不到12小時(shí)的時(shí)間內(nèi)使用GCP虛擬機(jī)實(shí)例,使用超過(guò)1k個(gè)分割掩碼的數(shù)據(jù)集訓(xùn)練成功“sam-vit-base”模型。
與基準(zhǔn)型SAM相比,微調(diào)顯著提高了性能,中值掩碼從不可用變?yōu)楦叨葴?zhǔn)確。
相對(duì)于基于默認(rèn)提示詞的基準(zhǔn)型SAM模型,微調(diào)后的SAM模型極大地提高了分割性能
需要注意的一個(gè)重要事實(shí)是,1k河流圖像的訓(xùn)練數(shù)據(jù)集是不完美的,分割標(biāo)簽在正確分類(lèi)的像素?cái)?shù)量上變化很大。因此,上述指標(biāo)是在225幅河流圖像的像素完美數(shù)據(jù)集上計(jì)算出來(lái)的。
實(shí)驗(yàn)過(guò)程中,我們觀察到的一個(gè)有趣的行為是,模型學(xué)會(huì)了從不完美的訓(xùn)練數(shù)據(jù)中進(jìn)行歸納。當(dāng)在訓(xùn)練樣本包含明顯錯(cuò)誤分類(lèi)的數(shù)據(jù)點(diǎn)上進(jìn)行評(píng)估時(shí),我們可以觀察到模型預(yù)測(cè)避免了誤差。請(qǐng)注意,顯示訓(xùn)練樣本的頂行中的圖像包含的掩碼不會(huì)一直填充到河岸,而顯示模型預(yù)測(cè)的底行則更緊密地分割河流邊界。
即使訓(xùn)練數(shù)據(jù)不完美,經(jīng)微調(diào)的SAM模型也能帶來(lái)令人印象深刻的泛化效果。請(qǐng)注意,與訓(xùn)練數(shù)據(jù)(頂行)相比,預(yù)測(cè)(底行)的錯(cuò)誤分類(lèi)更少,并且河流的填充程度更高。
結(jié)論
如果您已經(jīng)順利完成本文中的實(shí)例內(nèi)容,那么祝賀您!您已經(jīng)學(xué)會(huì)了為任何下游愿景任務(wù)完全微調(diào)Meta的分割一切模型SAM所需的一切!
雖然您的微調(diào)工作流程無(wú)疑與本教程中介紹的實(shí)施方式不同,但從閱讀本教程中獲得的知識(shí)不僅會(huì)影響到您的細(xì)分項(xiàng)目,還會(huì)影響到未來(lái)的深度學(xué)習(xí)項(xiàng)目及其他項(xiàng)目。
最后,希望您繼續(xù)探索機(jī)器學(xué)習(xí)的世界,保持好奇心,并一如既往地快樂(lè)編程!
附錄
本文實(shí)例中使用的數(shù)據(jù)集是Elwha V1數(shù)據(jù)集(https://huggingface.co/datasets/stodoran/elwha-segmentation-v1),該數(shù)據(jù)集由華盛頓大學(xué)的GeoSMART研究實(shí)驗(yàn)室(https://geo-smart.github.io/)創(chuàng)建,用于將微調(diào)的大型視覺(jué)變換器應(yīng)用于地理空間分割任務(wù)的研究項(xiàng)目。本文描述的內(nèi)容代表了即將發(fā)表的論文的精簡(jiǎn)版和一個(gè)更易于實(shí)現(xiàn)的版本。在高水平上,Elwha V1數(shù)據(jù)集由SAM檢查點(diǎn)的后處理模型預(yù)測(cè)組成,該檢查點(diǎn)使用Buscombe等人(https://zenodo.org/records/10155783)發(fā)布并在多學(xué)科研究數(shù)據(jù)知識(shí)庫(kù)和文獻(xiàn)資源網(wǎng)站Zenodo上發(fā)布的標(biāo)注正射影像的子集進(jìn)行了微調(diào)。
譯者介紹
朱先忠,51CTO社區(qū)編輯,51CTO專(zhuān)家博客、講師,濰坊一所高校計(jì)算機(jī)教師,自由編程界老兵一枚。
原文標(biāo)題:Learn Transformer Fine-Tuning and Segment Anything,作者:Stefan Todoran
鏈接:https://towardsdatascience.com/learn-transformer-fine-tuning-and-segment-anything-481c6c4ac802