從幾個(gè)“補(bǔ)丁”中重建完整圖像 | 構(gòu)建可擴(kuò)展學(xué)習(xí)器的掩模自編碼器
到目前為止,我們已經(jīng)詳細(xì)轉(zhuǎn)換了各種重要的ViT架構(gòu)。在這個(gè)視覺(jué)transformer系列的這一部分,我將使用PyTorch從零開(kāi)始構(gòu)建掩模自編碼器視覺(jué)transformer。不再拖延,讓我們直接進(jìn)入主題!
掩模自編碼器
Mae是一種自監(jiān)督學(xué)習(xí)方法,這意味著它沒(méi)有預(yù)先標(biāo)記的目標(biāo)數(shù)據(jù),而是在訓(xùn)練時(shí)利用輸入數(shù)據(jù)。這種方法主要涉及遮蔽圖像的75%的補(bǔ)丁。因此,在創(chuàng)建補(bǔ)?。℉/補(bǔ)丁大小,W/補(bǔ)丁大小)之后,其中H和W是圖像的高度和寬度,我們遮蔽75%的補(bǔ)丁,只使用其余的補(bǔ)丁并將其輸入到標(biāo)準(zhǔn)的ViT中。這里的主要目標(biāo)是僅使用圖像中已知的補(bǔ)丁重建缺失的補(bǔ)丁。
輸入(75%的補(bǔ)丁被遮蔽) | 目標(biāo)(重建缺失的像素)
MAE主要包含這三個(gè)組件:
- 隨機(jī)遮蔽
- 編碼器
- 解碼器
1.隨機(jī)掩蓋
這就像選擇圖像的隨機(jī)補(bǔ)丁,然后掩蓋其中的3/4一樣簡(jiǎn)單。然而,官方實(shí)現(xiàn)使用了不同但更有效的技術(shù)。
def random_masking(x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
B, T, D = x.shape
len_keep = int(T * (1 - mask_ratio))
# creating noise of shape (B, T) to latter generate random indices
noise = torch.rand(B, T, device=x.device)
# sorting the noise, and then ids_shuffle to keep the original indexe format
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
# gathering the first few samples
ids_keep = ids_shuffle[:, :len_keep]
x = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([B, T], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x, mask, ids_restore
- 假設(shè)輸入形狀是(B, T, C)。這里我們首先創(chuàng)建一個(gè)形狀為(B, T)的隨機(jī)張量,然后將其傳遞給argsort,這將為我們提供一個(gè)排序的索引張量——例如,torch.argsort([0.3, 0.4, 0.2]) = [2, 0, 1]。
- 我們還將ids_shuffle傳遞給另一個(gè)argsort以獲取ids_restore。這只是一個(gè)具有原始索引格式的張量。
- 接下來(lái),我們收集我們想要保留的標(biāo)記。
- 生成二進(jìn)制掩模,并將要保留的標(biāo)記標(biāo)記為0,其余標(biāo)記為1。
- 最后,對(duì)掩模進(jìn)行解洗牌,這里我們創(chuàng)建的ids_restore將有助于生成表示,掩模應(yīng)該具有的。即哪些索引的標(biāo)記被遮蔽為0或1,與原始輸入有關(guān)?
注意:與在隨機(jī)位置創(chuàng)建隨機(jī)補(bǔ)丁不同,官方實(shí)現(xiàn)使用了不同的技術(shù)。
為圖像生成隨機(jī)索引。就像我們?cè)趇ds_shuffle中所做的那樣。然后獲取前25%的索引(int(T*(1–3/4))或int(T/4)。我們只使用前25%的隨機(jī)索引并遮蔽其余部分。
然后我們用ids_restore中原始索引的順序幫助對(duì)掩模進(jìn)行重新排序(解洗牌)。因此,在收集之前,掩模的前25%為0。但記住這些是隨機(jī)索引,這就是為什么我們重新排序以獲得掩模應(yīng)該在的確切索引。
2.編碼器
class MaskedAutoEncoder(nn.Module):
def __init__(self, emb_size=1024, decoder_emb_size=512, patch_size=16, num_head=16, encoder_num_layers=24, decoder_num_layers=8, in_channels=3, img_size=224):
super().__init__()
self.patch_embed = PatchEmbedding(emb_size = emb_size)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_emb_size))
self.encoder_transformer = nn.Sequential(*[Block(emb_size, num_head) for _ in range(encoder_num_layers)])
def encoder(self, x, mask_ratio):
x = self.patch_embed(x)
cls_token = x[:, :1, :]
x = x[:, 1:, :]
x, mask, restore_id = random_masking(x, mask_ratio)
x = torch.cat((cls_token, x), dim=1)
x = self.encoder_transformer(x)
return x, mask, restore_id
PatchEmbedding和Block是ViT模型中的標(biāo)準(zhǔn)實(shí)現(xiàn)。
我們首先獲取圖像的補(bǔ)丁嵌入(B, C, H, W)→(B, T, C),這里的PatchEmbedding實(shí)現(xiàn)還返回連接在嵌入張量x中的cls_token。如果你想使用timm庫(kù)獲取標(biāo)準(zhǔn)的PatchEmbed和Block,也可以這樣做,但這個(gè)實(shí)現(xiàn)效果相同。即from timm.models.vision_transformer import PatchEmbed, Block
由于我們已經(jīng)有了cls_token,我們首先想要移除它,然后將其傳遞以生成遮蔽。x:(B K C),掩模:(B T)restore_id(B T),其中K是我們保留的標(biāo)記的長(zhǎng)度,即T/4。
然后我們將cls_token連接起來(lái)并傳遞給標(biāo)準(zhǔn)的編碼器_transformer。
3.器
解碼階段涉及將輸入嵌入維度更改為decoder_embedding_size?;叵胍幌?,輸入維度是(B, K, C),其中K是T/4。因此我們將未遮蔽的補(bǔ)丁與遮蔽的補(bǔ)丁連接起來(lái),然后將它們輸入到另一個(gè)視覺(jué)transformer模型(解碼器)中,如圖1所示。
class MaskedAutoEncoder(nn.Module):
def __init__(self, emb_size=1024, decoder_emb_size=512, patch_size=16, num_head=16, encoder_num_layers=24, decoder_num_layers=8, in_channels=3, img_size=224):
super().__init__()
self.patch_embed = PatchEmbedding(emb_size = emb_size)
self.decoder_embed = nn.Linear(emb_size, decoder_emb_size)
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, (img_size//patch_size)**2 + 1, decoder_emb_size), requires_grad=False)
self.decoder_pred = nn.Linear(decoder_emb_size, patch_size**2 * in_channels, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_emb_size))
self.encoder_transformer = nn.Sequential(*[Block(emb_size, num_head) for _ in range(encoder_num_layers)])
self.decoder_transformer = nn.Sequential(*[Block(decoder_emb_size, num_head) for _ in range(decoder_num_layers)])
self.project = self.projection = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=patch_size**2 * in_channels, kernel_size=patch_size, stride=patch_size),
Rearrange('b e (h) (w) -> b (h w) e'),
)
def encoder(self, x, mask_ratio):
x = self.patch_embed(x)
cls_token = x[:, :1, :]
x = x[:, 1:, :]
x, mask, restore_id = random_masking(x, mask_ratio)
x = torch.cat((cls_token, x), dim=1)
x = self.encoder_transformer(x)
return x, mask, restore_id
def decoder(self, x, restore_id):
x = self.decoder_embed(x)
mask_tokens = self.mask_token.repeat(x.shape[0], restore_id.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
x_ = torch.gather(x_, dim=1, index=restore_id.unsqueeze(-1).repeat(1, 1, x.shape[2]))
x = torch.cat([x[:, :1, :], x_], dim=1)
# add pos embed
x = x + self.decoder_pos_embed
x = self.decoder_transformer(x)
# predictor projection
x = self.decoder_pred(x)
# remove cls token
x = x[:, 1:, :]
return x
我們將輸入傳遞給decoder_embed。然后我們?yōu)樗形覀冋诒蔚臉?biāo)記創(chuàng)建mask_tokens,并將其與原始輸入x連接起來(lái),不包括其cls_token。
現(xiàn)在張量具有前K個(gè)未遮蔽的標(biāo)記,其余為遮蔽的標(biāo)記,但現(xiàn)在我們想要按照索引的確切順序重新排序它們。我們可以借助ids_restore來(lái)實(shí)現(xiàn)。
現(xiàn)在ids_restore具有索引,當(dāng)傳遞給torch.gather時(shí),將對(duì)輸入進(jìn)行解洗牌。因此,我們?cè)陔S機(jī)遮蔽中選擇的未遮蔽標(biāo)記(ids_shuffle中的前幾個(gè)隨機(jī)索引)現(xiàn)在被重新排列在它們應(yīng)該在的確切順序中。稍后我們?cè)俅螌ls_token與重新排序的補(bǔ)丁連接起來(lái)。
現(xiàn)在我們將整個(gè)輸入傳遞給標(biāo)準(zhǔn)的視覺(jué)transformer,并移除cls_token并返回張量x以計(jì)算損失。
損失函數(shù)
掩模自編碼器在遮蔽和未遮蔽的補(bǔ)丁上進(jìn)行訓(xùn)練,并學(xué)習(xí)重建圖像中的遮蔽補(bǔ)丁。掩模自編碼器視覺(jué)transformer中使用的損失函數(shù)是均方誤差。
class MaskedAutoEncoder(nn.Module):
def __init__(self, emb_size=1024, decoder_emb_size=512, patch_size=16, num_head=16, encoder_num_layers=24, decoder_num_layers=8, in_channels=3, img_size=224):
super().__init__()
self.patch_embed = PatchEmbedding(emb_size = emb_size)
self.decoder_embed = nn.Linear(emb_size, decoder_emb_size)
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, (img_size//patch_size)**2 + 1, decoder_emb_size), requires_grad=False)
self.decoder_pred = nn.Linear(decoder_emb_size, patch_size**2 * in_channels, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_emb_size))
self.encoder_transformer = nn.Sequential(*[Block(emb_size, num_head) for _ in range(encoder_num_layers)])
self.decoder_transformer = nn.Sequential(*[Block(decoder_emb_size, num_head) for _ in range(decoder_num_layers)])
self.project = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=patch_size**2 * in_channels, kernel_size=patch_size, stride=patch_size),
Rearrange('b e (h) (w) -> b (h w) e'),
)
def random_masking(x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
B, T, D = x.shape
len_keep = int(T * (1 - mask_ratio))
# creating noise of shape (B, T) to latter generate random indices
noise = torch.rand(B, T, device=x.device)
# sorting the noise, and then ids_shuffle to keep the original indexe format
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
# gathering the first few samples
ids_keep = ids_shuffle[:, :len_keep]
x = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([B, T], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x, mask, ids_restore
def encoder(self, x, mask_ratio):
x = self.patch_embed(x)
cls_token = x[:, :1, :]
x = x[:, 1:, :]
x, mask, restore_id = self.random_masking(x, mask_ratio)
x = torch.cat((cls_token, x), dim=1)
x = self.encoder_transformer(x)
return x, mask, restore_id
def decoder(self, x, restore_id):
x = self.decoder_embed(x)
mask_tokens = self.mask_token.repeat(x.shape[0], restore_id.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
x_ = torch.gather(x_, dim=1, index=restore_id.unsqueeze(-1).repeat(1, 1, x.shape[2]))
x = torch.cat([x[:, :1, :], x_], dim=1)
# add pos embed
x = x + self.decoder_pos_embed
x = self.decoder_transformer(x)
# predictor projection
x = self.decoder_pred(x)
# remove cls token
x = x[:, 1:, :]
return x
def loss(self, imgs, pred, mask):
"""
imgs: [N, 3, H, W]
pred: [N, L, patch*patch*3]
mask: [N, L], 0 is keep, 1 is remove,
"""
target = self.project(imgs)
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
return loss
def forward(self, img):
mask_ratio = 0.75
x, mask, restore_ids = self.encoder(img, mask_ratio)
pred = self.decoder(x, restore_ids)
loss = self.loss(img, pred, mask)
return loss, pred, mask
在未遮蔽的補(bǔ)丁上訓(xùn)練視覺(jué)transformer模型,將未遮蔽補(bǔ)丁的輸出與遮蔽補(bǔ)丁重新排序。
在遮蔽和未遮蔽的補(bǔ)丁結(jié)合在一起的原始形式上訓(xùn)練視覺(jué)transformer模型。
計(jì)算解碼器預(yù)測(cè)輸出的最后一個(gè)維度(B, T, decoder embed)和圖像的原始補(bǔ)丁嵌入(B, T, patch embedding)之間的均方誤差損失。
源碼:https://github.com/mishra-18/ML-Models/blob/main/Vission Transformers/mae.py