自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

Meta公司開源大數(shù)據(jù)模型SAM實(shí)戰(zhàn)演練 原創(chuàng)

發(fā)布于 2024-7-15 09:08
瀏覽
0收藏

本文首先介紹Meta公司開發(fā)的開源圖像分割模型SAM的架構(gòu),然后通過一個河流像素分割遙感任務(wù)的實(shí)戰(zhàn)案例展示SAM模型應(yīng)用開發(fā)涉及的關(guān)鍵技術(shù)與模型優(yōu)勢。

引言

當(dāng)前,許多強(qiáng)大的開源基礎(chǔ)模型的發(fā)布,加上微調(diào)技術(shù)的不斷進(jìn)步,已經(jīng)帶來了機(jī)器學(xué)習(xí)和人工智能的新范式。整體來看,這場革命的核心在于轉(zhuǎn)換器模型(https://arxiv.org/pdf/1706.03762)。

雖然除了資金充足的公司之外,所有公司都曾經(jīng)無法獲得高精度的特定領(lǐng)域模型;但是如今,基礎(chǔ)模型范式甚至能夠允許學(xué)生或獨(dú)立研究人員獲得適度的資源,以實(shí)現(xiàn)與最先進(jìn)的專有模型相抗衡的結(jié)果。

Meta公司開源大數(shù)據(jù)模型SAM實(shí)戰(zhàn)演練-AI.x社區(qū)

微調(diào)技術(shù)可以極大地提高模型任務(wù)的性能

本文旨在探討Meta公司的分割一切模型(SAM:Segment Anything Model)在河流像素分割遙感任務(wù)中的應(yīng)用。如果您想直接跳到有關(guān)項(xiàng)目代碼,那么這個項(xiàng)目的源文件可以在GitHub(https://github.com/geo-smart/water-surf/blob/main/book/chapters/masking_distributed.ipynb)上獲得,數(shù)據(jù)也可以在HuggingFace(https://huggingface.co/datasets/stodoran/elwha-segmentation-v1)上找到。當(dāng)然,我還是建議您先閱讀一下本文。

項(xiàng)目需求

第一個任務(wù)是找到或創(chuàng)建一個合適的數(shù)據(jù)集。根據(jù)現(xiàn)有文獻(xiàn),一個良好的SAM微調(diào)數(shù)據(jù)集將至少包含200–800張圖像。過去十年深度學(xué)習(xí)進(jìn)步的一個關(guān)鍵教訓(xùn)是,數(shù)據(jù)越多越好,這樣一來,更大規(guī)模的微調(diào)數(shù)據(jù)集就不會出錯。然而,基礎(chǔ)模型研究的一個重要目標(biāo)是,允許即使是相對較小的數(shù)據(jù)集也足以實(shí)現(xiàn)強(qiáng)大的性能。

此外,我們還需要有一個HuggingFace帳戶,這可以在鏈接https://huggingface.co/join處創(chuàng)建。使用HuggingFace,我們可以隨時從任何設(shè)備輕松存儲和獲取數(shù)據(jù)集,這使得協(xié)作和再現(xiàn)性更加容易。

最后一個需求是,具備一臺帶有GPU的設(shè)備,我們可以在其上運(yùn)行訓(xùn)練工作流程。通過Google Colab(https://colab.research.google.com/)免費(fèi)提供的Nvidia T4 GPU足夠強(qiáng)大,可以在12小時內(nèi)對1000張圖像進(jìn)行50個時期的最大SAM模型檢查點(diǎn)(sam-vit-huge)訓(xùn)練。

為了避免在托管運(yùn)行時因使用限制而影響進(jìn)度,您可以安裝Google Drive并將每個模型檢查點(diǎn)保存在那里?;蛘?,部署并連接到GCP虛擬機(jī)(https://console.cloud.google.com/marketplace/product/colab-marketplace-image-public/colab)以完全繞過限制。如果您以前從未使用過GCP,那么您就有資格獲得300美元的免費(fèi)信貸,這足以支持對模型進(jìn)行至少十幾次的訓(xùn)練了。

理解SAM架構(gòu)

在開始訓(xùn)練之前,我們需要先來了解一下SAM模型的架構(gòu)。該模型包含三個組件:一個是從稍經(jīng)修改的掩碼自動編碼器(https://arxiv.org/pdf/2111.06377)得到的圖像編碼器,一個能夠處理各種提示類型的相當(dāng)靈活的提示編碼器,還有一個快速輕量級的掩碼解碼器。這種設(shè)計(jì)架構(gòu)背后的一個重要動機(jī)是允許在邊緣設(shè)備上(例如在瀏覽器中)進(jìn)行快速、實(shí)時的分割,因?yàn)閳D像嵌入只需要計(jì)算一次,并且掩碼解碼器可以在CPU上運(yùn)行約50ms。

Meta公司開源大數(shù)據(jù)模型SAM實(shí)戰(zhàn)演練-AI.x社區(qū)

SAM的模型架構(gòu)向我們展示了模型接受哪些輸入以及需要訓(xùn)練模型的哪些部分(圖片來源于SAM GitHub:https://github.com/facebookresearch/segment-anything)。

理論上,圖像編碼器已經(jīng)學(xué)會了嵌入圖像的最佳方式,包括識別形狀、邊緣和其他一般視覺特征等。類似地,在理論上,提示編碼器已經(jīng)能夠以最優(yōu)方式對提示進(jìn)行編碼。掩碼解碼器是模型架構(gòu)的一部分,它采用這些圖像和提示嵌入,并通過對圖像和提示嵌入式進(jìn)行操作來實(shí)際創(chuàng)建掩碼。

因此,一種方法是在訓(xùn)練期間凍結(jié)與圖像和提示編碼器相關(guān)聯(lián)的模型參數(shù),并且僅更新掩碼解碼器權(quán)重。這種方法的優(yōu)點(diǎn)是允許有監(jiān)督和無監(jiān)督的下游任務(wù),因?yàn)榭刂泣c(diǎn)和邊界框提示都是自動的,并且可供人工使用。

Meta公司開源大數(shù)據(jù)模型SAM實(shí)戰(zhàn)演練-AI.x社區(qū)

圖中顯示了AutoSAM體系架構(gòu)中使用的凍結(jié)SAM圖像編碼器和掩碼解碼器,以及過載提示編碼器(來源于AutoSAM論文:https://arxiv.org/pdf/2306.06370)。

另一種方法是使提示編碼器過載,凍結(jié)圖像編碼器和掩碼解碼器,并且只是簡單地不使用原始SAM掩碼編碼器。例如,AutoSAM體系架構(gòu)使用基于Harmonic Dense Net的網(wǎng)絡(luò)來基于圖像本身生成提示嵌入。在本教程中,我們將介紹第一種方法,即凍結(jié)圖像和提示編碼器,只訓(xùn)練掩碼解碼器,但這種替代方法的代碼可以在AutoSAM GitHub(https://github.com/talshaharabany/AutoSAM/blob/main/inference.py)和論文(https://arxiv.org/pdf/2306.06370)中找到。

配置提示

接下來的一步是確定模型在推理過程中會收到什么類型的提示,以便我們可以在訓(xùn)練時提供這種類型的提示。就我個人而言,考慮到自然語言處理的不可預(yù)測/不一致性,我不建議在任何嚴(yán)肅的計(jì)算機(jī)視覺項(xiàng)目架構(gòu)中使用文本提示。剩下的解決方案就需要依賴控制點(diǎn)和邊界框技術(shù)了;但是,最終的選擇還要取決于特定數(shù)據(jù)集的特定性質(zhì),盡管有關(guān)文獻(xiàn)中已經(jīng)指出邊界框方案的表現(xiàn)相當(dāng)一致地優(yōu)于控制點(diǎn)方案。

造成這種情況的原因尚不完全清楚,但可能是以下任何因素之一,或者是這些因素的組合:

  • 在推理時(當(dāng)真實(shí)值掩碼未知時),好的控制點(diǎn)比邊界框更難選擇。
  • 可能的點(diǎn)提示的空間比可能的邊界框提示的空間大幾個數(shù)量級,因此它沒有經(jīng)過徹底的訓(xùn)練。
  • 最初的SAM模型作者主要專注于模型的零樣本和少樣本(根據(jù)人工提示交互計(jì)算)功能,因此預(yù)訓(xùn)練可能更多地關(guān)注邊界框。

無論如何,河流分割實(shí)際上是一種罕見的情況;在這種情況下,點(diǎn)提示方案實(shí)際上優(yōu)于邊界框(盡管只是輕微的,即使是在非常有利的域中)。假設(shè)在河流的任何圖像中,水體將從圖像的一端延伸到另一端,任何包含的邊界框幾乎總是覆蓋圖像的大部分。因此,河流非常不同部分的邊界框提示看起來非常相似。理論上,這意味著邊界框?yàn)槟P吞峁┑男畔⒈瓤刂泣c(diǎn)少得多;因此,導(dǎo)致性能較差。

Meta公司開源大數(shù)據(jù)模型SAM實(shí)戰(zhàn)演練-AI.x社區(qū)

控制點(diǎn)、邊界框提示和疊加在兩個樣本訓(xùn)練圖像上的真實(shí)分割

請注意,在上圖中,盡管兩條河流部分的真實(shí)分割掩碼完全不同,但它們各自的邊界框幾乎相同,而它們的點(diǎn)提示(相對而言)差異更大。

另一個需要考慮的重要因素是在推理時生成輸入提示的容易程度。如果您希望在循環(huán)執(zhí)行階段有人工介入,那么請注意邊界框和控制點(diǎn)在推理階段都是相當(dāng)瑣碎的。然而,如果您打算使用一個完全自動化的架構(gòu)方案,那么回答這些問題將變得更加復(fù)雜。

無論是使用控制點(diǎn)還是邊界框,生成提示通常首先包括估計(jì)感興趣對象的粗略掩碼。邊界框可以只是包裹粗略掩碼的最小框,而控制點(diǎn)需要從粗略掩碼中采樣。這意味著,當(dāng)真實(shí)值掩碼未知時,邊界框更容易獲得,因?yàn)楦信d趣對象的估計(jì)掩碼只需要大致匹配真實(shí)對象的相同大小和位置;而對于控制點(diǎn),估計(jì)掩碼將需要更緊密地匹配對象的輪廓。

Meta公司開源大數(shù)據(jù)模型SAM實(shí)戰(zhàn)演練-AI.x社區(qū)

當(dāng)使用估計(jì)的掩碼而不是真實(shí)值時,控制點(diǎn)的放置可能包括錯誤標(biāo)注的點(diǎn),而邊界框通常位于正確的位置

對于河流分割,如果我們可以同時使用RGB和NIR,那么我們可以使用光譜指數(shù)閾值方法來獲得我們的粗略掩模。如果我們只能使用RGB模式,我們可以將圖像轉(zhuǎn)換為HSV模式,并對特定色調(diào)、飽和度和值范圍內(nèi)的所有像素設(shè)置閾值。然后,我們可以移除低于特定大小閾值的連接內(nèi)容,并使用skimage.morphology子模塊中的erosion函數(shù)來確保我們的掩模中只有1個像素是朝向藍(lán)色大斑點(diǎn)中心的像素。

模型訓(xùn)練

為了訓(xùn)練我們的模型,我們需要一個包含所有訓(xùn)練數(shù)據(jù)的數(shù)據(jù)加載器,我們可以在每個訓(xùn)練時期對這些數(shù)據(jù)進(jìn)行迭代。當(dāng)我們從HuggingFace加載數(shù)據(jù)集時,它采用datasets.Dataset類的形式。如果數(shù)據(jù)集是私有的,請確保首先安裝HuggingFace CLI并使用“!huggingface-cli login”方式登錄。

from datasets import load_dataset, load_from_disk, Dataset
hf_dataset_name = "stodoran/elwha-segmentation-v1"
training_data = load_dataset(hf_dataset_name, split="train")
validation_data = load_dataset(hf_dataset_name, split="validation")

然后,我們需要編寫自己的自定義數(shù)據(jù)集類,該類不僅返回任何索引的圖像和標(biāo)簽,還返回提示詞信息。下面是一個可以同時處理控制點(diǎn)和邊界框提示的實(shí)現(xiàn)。要完成初始化工作,需要一個HuggingFace datasets.Dataset實(shí)例和SAM模型的處理器實(shí)例。

from torch.utils.data import Dataset
class PromptType:
CONTROL_POINTS = "pts"
BOUNDING_BOX = "bbox"
class SAMDataset(Dataset):
def __init__(
self, 
dataset, 
processor, 
prompt_type = PromptType.CONTROL_POINTS,
num_positive = 3,
num_negative = 0,
erode = True,
multi_mask = "mean",
perturbation = 10,
image_size = (1024, 1024),
mask_size = (256, 256),
):
#將所有值賦給self
...

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
datapoint = self.dataset[idx]
input_image = cv2.resize(np.array(datapoint["image"]), self.image_size)
ground_truth_mask = cv2.resize(np.array(datapoint["label"]), self.mask_size)

if self.prompt_type == PromptType.CONTROL_POINTS:
inputs = self._getitem_ctrlpts(input_image, ground_truth_mask)
elif self.prompt_type == PromptType.BOUNDING_BOX:
inputs = self._getitem_bbox(input_image, ground_truth_mask)

inputs["ground_truth_mask"] = ground_truth_mask
return inputs

我們還必須定義SAMDataset_getitem_ctrlpts和SAMDataset_getitem_box函數(shù),盡管如果您只計(jì)劃使用一種提示類型,那么您可以重構(gòu)代碼以直接處理SAMDataset.__getitem__中的該類型,并刪除幫助類工具函數(shù)。

class SAMDataset(Dataset):
...
def _getitem_ctrlpts(self, input_image, ground_truth_mask):
# 獲取控制點(diǎn)提示。請參閱GitHub獲取該函數(shù)的源代碼,或?qū)⑵涮鎿Q為您自己的點(diǎn)選擇算法。
input_points, input_labels = generate_input_points(
num_positive=self.num_positive,
num_negative=self.num_negative,
mask=ground_truth_mask,
dynamic_distance=True,
erode=self.erode,
)
input_points = input_points.astype(float).tolist()
input_labels = input_labels.tolist()
input_labels = [[x] for x in input_labels]

# 為模型準(zhǔn)備圖像和提示。
inputs = self.processor(
input_image,
input_points=input_points,
input_labels=input_labels,
return_tensors="pt"
)

#刪除處理器默認(rèn)添加的批次維度。
inputs = {k: v.squeeze(0) for k, v in inputs.items()}
inputs["input_labels"] = inputs["input_labels"].squeeze(1)

return inputs

def _getitem_bbox(self, input_image, ground_truth_mask):
#獲取邊界框提示。
bbox = get_input_bbox(ground_truth_mask, perturbation=self.perturbation)

#為模型準(zhǔn)備圖像和提示。
inputs = self.processor(input_image, input_boxes=[[bbox]], return_tensors="pt")
inputs = {k: v.squeeze(0) for k, v in inputs.items()} # 刪除處理器默認(rèn)添加的批次維度。

return inputs

將所有這些功能組合到一起,我們可以創(chuàng)建一個函數(shù),該函數(shù)在給定HuggingFace數(shù)據(jù)集的任一部分的情況下創(chuàng)建并返回PyTorch數(shù)據(jù)加載器。編寫返回?cái)?shù)據(jù)加載器的函數(shù),而不僅僅是用相同的代碼執(zhí)行單元,這不僅是編寫靈活和可維護(hù)代碼的好方法,而且如果您計(jì)劃使用HuggingFace Accelerate(https://huggingface.co/docs/accelerate/index)來運(yùn)行分布式訓(xùn)練的話,這也是必要的。

from transformers import SamProcessor
from torch.utils.data import DataLoader

def get_dataloader(
hf_dataset,
model_size = "base",  # One of "base", "large", or "huge" 
batch_size = 8, 
prompt_type = PromptType.CONTROL_POINTS,
num_positive = 3,
num_negative = 0,
erode = True,
multi_mask = "mean",
perturbation = 10,
image_size = (256, 256),
mask_size = (256, 256),
):
processor = SamProcessor.from_pretrained(f"facebook/sam-vit-{model_size}")

sam_dataset = SAMDataset(
dataset=hf_dataset, 
processor=processor, 
prompt_type=prompt_type,
num_positive=num_positive,
num_negative=num_negative,
erode=erode,
multi_mask=multi_mask,
perturbation=perturbation,
image_size=image_size,
mask_size=mask_size,
)
dataloader = DataLoader(sam_dataset, batch_size=batch_size, shuffle=True)

return dataloader

在此之后,訓(xùn)練只需加載模型、凍結(jié)圖像和提示編碼器,并進(jìn)行所需次數(shù)的迭代訓(xùn)練。

model = SamModel.from_pretrained(f"facebook/sam-vit-{model_size}")
optimizer = AdamW(model.mask_decoder.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Train only the decoder.
for name, param in model.named_parameters():
if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
param.requires_grad_(False)

以下列出的是訓(xùn)練過程循環(huán)部分的基本框架代碼。請注意,為了簡潔起見,forward_pass、calculate loss、evaluate_mode和save_model_checkpoint函數(shù)被省略了,但GitHub上提供了實(shí)現(xiàn)。正向傳遞碼根據(jù)提示類型略有不同,損失計(jì)算也需要基于提示類型的特殊情況;當(dāng)使用點(diǎn)提示時,SAM模型為每個單個輸入點(diǎn)返回一個預(yù)測掩碼,因此為了獲得可以與真實(shí)數(shù)據(jù)進(jìn)行比較的單個掩碼,需要對預(yù)測掩碼進(jìn)行平均,或者需要選擇最佳預(yù)測掩碼(基于SAM的預(yù)測IoU分?jǐn)?shù)來識別)。

train_losses = []
validation_losses = []
epoch_loop = tqdm(total=num_epochs, position=epoch, leave=False)
batch_loop = tqdm(total=len(train_dataloader), position=0, leave=True)

while epoch < num_epochs:
epoch_losses = []

batch_loop.n = 0  #循環(huán)重置
for idx, batch in enumerate(train_dataloader):
# 正向傳遞
batch = {k: v.to(accelerator.device) for k, v in batch.items()}
outputs = forward_pass(model, batch, prompt_type)

#計(jì)算損失值
ground_truth_masks = batch["ground_truth_mask"].float()
train_loss = calculate_loss(outputs, ground_truth_masks, prompt_type, loss_fn, multi_mask="best")
epoch_losses.append(train_loss)

# 反向傳遞與優(yōu)化環(huán)節(jié)
optimizer.zero_grad()
accelerator.backward(train_loss)
optimizer.step()
lr_scheduler.step()

batch_loop.set_description(f"Train Loss: {train_loss.item():.4f}")
batch_loop.update(1)

validation_loss = evaluate_model(model, validation_dataloader, accelerator.device, loss_fn)
train_losses.append(torch.mean(torch.Tensor(epoch_losses)))
validation_losses.append(validation_loss)

if validation_loss < best_loss:
save_model_checkpoint(
accelerator,
best_checkpoint_path,
model,
optimizer,
lr_scheduler,
epoch,
train_history,
validation_loss,
train_losses,
validation_losses,
loss_config,
model_descriptor=model_descriptor,
)
best_loss = validation_loss

epoch_loop.set_description(f"Best Loss: {best_loss:.4f}")
epoch_loop.update(1)
epoch += 1

微調(diào)結(jié)果分析

對于艾爾瓦河項(xiàng)目,最佳設(shè)置是在不到12小時的時間內(nèi)使用GCP虛擬機(jī)實(shí)例,使用超過1k個分割掩碼的數(shù)據(jù)集訓(xùn)練成功“sam-vit-base”模型。

與基準(zhǔn)型SAM相比,微調(diào)顯著提高了性能,中值掩碼從不可用變?yōu)楦叨葴?zhǔn)確。

Meta公司開源大數(shù)據(jù)模型SAM實(shí)戰(zhàn)演練-AI.x社區(qū)

相對于基于默認(rèn)提示詞的基準(zhǔn)型SAM模型,微調(diào)后的SAM模型極大地提高了分割性能

需要注意的一個重要事實(shí)是,1k河流圖像的訓(xùn)練數(shù)據(jù)集是不完美的,分割標(biāo)簽在正確分類的像素?cái)?shù)量上變化很大。因此,上述指標(biāo)是在225幅河流圖像的像素完美數(shù)據(jù)集上計(jì)算出來的。

實(shí)驗(yàn)過程中,我們觀察到的一個有趣的行為是,模型學(xué)會了從不完美的訓(xùn)練數(shù)據(jù)中進(jìn)行歸納。當(dāng)在訓(xùn)練樣本包含明顯錯誤分類的數(shù)據(jù)點(diǎn)上進(jìn)行評估時,我們可以觀察到模型預(yù)測避免了誤差。請注意,顯示訓(xùn)練樣本的頂行中的圖像包含的掩碼不會一直填充到河岸,而顯示模型預(yù)測的底行則更緊密地分割河流邊界。

Meta公司開源大數(shù)據(jù)模型SAM實(shí)戰(zhàn)演練-AI.x社區(qū)

即使訓(xùn)練數(shù)據(jù)不完美,經(jīng)微調(diào)的SAM模型也能帶來令人印象深刻的泛化效果。請注意,與訓(xùn)練數(shù)據(jù)(頂行)相比,預(yù)測(底行)的錯誤分類更少,并且河流的填充程度更高。

結(jié)論

如果您已經(jīng)順利完成本文中的實(shí)例內(nèi)容,那么祝賀您!您已經(jīng)學(xué)會了為任何下游愿景任務(wù)完全微調(diào)Meta的分割一切模型SAM所需的一切!

雖然您的微調(diào)工作流程無疑與本教程中介紹的實(shí)施方式不同,但從閱讀本教程中獲得的知識不僅會影響到您的細(xì)分項(xiàng)目,還會影響到未來的深度學(xué)習(xí)項(xiàng)目及其他項(xiàng)目。

最后,希望您繼續(xù)探索機(jī)器學(xué)習(xí)的世界,保持好奇心,并一如既往地快樂編程!

附錄

本文實(shí)例中使用的數(shù)據(jù)集是Elwha V1數(shù)據(jù)集(https://huggingface.co/datasets/stodoran/elwha-segmentation-v1),該數(shù)據(jù)集由華盛頓大學(xué)的GeoSMART研究實(shí)驗(yàn)室(https://geo-smart.github.io/)創(chuàng)建,用于將微調(diào)的大型視覺變換器應(yīng)用于地理空間分割任務(wù)的研究項(xiàng)目。本文描述的內(nèi)容代表了即將發(fā)表的論文的精簡版和一個更易于實(shí)現(xiàn)的版本。在高水平上,Elwha V1數(shù)據(jù)集由SAM檢查點(diǎn)的后處理模型預(yù)測組成,該檢查點(diǎn)使用Buscombe等人(https://zenodo.org/records/10155783)發(fā)布并在多學(xué)科研究數(shù)據(jù)知識庫和文獻(xiàn)資源網(wǎng)站Zenodo上發(fā)布的標(biāo)注正射影像的子集進(jìn)行了微調(diào)。

譯者介紹

朱先忠,51CTO社區(qū)編輯,51CTO專家博客、講師,濰坊一所高校計(jì)算機(jī)教師,自由編程界老兵一枚。

原文標(biāo)題:Learn Transformer Fine-Tuning and Segment Anything,作者:Stefan Todoran。

鏈接:??https://towardsdatascience.com/learn-transformer-fine-tuning-and-segment-anything-481c6c4ac802?。

?著作權(quán)歸作者所有,如需轉(zhuǎn)載,請注明出處,否則將追究法律責(zé)任
收藏
回復(fù)
舉報(bào)
回復(fù)
相關(guān)推薦