萬字長文從零解讀SAM(Segment Anything Model)大模型!萬物皆可分割!(含源碼解析)
SAM(Segment Anything Model),顧名思義,即為分割一切!該模型由Facebook的Meta AI實(shí)驗(yàn)室,能夠根據(jù)文本指令或圖像識別,實(shí)現(xiàn)對任意物體的識別與分割。它的誕生,無疑是CV領(lǐng)域的一次重要里程碑。
論文地址:https://arxiv.org/abs/2304.02643
項(xiàng)目地址:https://github.com/facebookresearch/segment-anything
一、SAM Task
SAM借鑒了NLP領(lǐng)域的Prompt策略,通過給圖像分割任務(wù)提供Prompt提示來完成任意目標(biāo)的快速分割。Prompt類型可以是「前景/背景點(diǎn)集、粗略的框或遮罩、任意形式的文本或者任何指示圖像中需要進(jìn)行分割」的信息。如下圖(a)所示,模型的輸入是原始的圖像和一些prompt,目標(biāo)是輸出"valid"的分割,所謂valid,就是當(dāng)prompt的指向是模糊時(shí),模型能夠輸出至少其中一個(gè)mask。
這樣,可以是的SAM能夠適配各種下游任務(wù)。例如,給定一個(gè)貓的邊界框,SAM能夠輸出其mask,從而和實(shí)例分割任務(wù)搭配起來。
二、SAM Model
如下圖所示,SAM模型包含三個(gè)核心組件,Image Encoder、Prompt Encoder和Mask Decoder。圖像經(jīng)過Image Encoder編碼,Prompt提示經(jīng)過Prompt Encoder編碼,兩部分Embedding再經(jīng)過一個(gè)輕量化的Mask Decoder得到融合后的特征。其中,Encoder部分使用的是已有模型,Decoder部分使用Transformer。
1. Image Encoder
Image Encoder的作用是把圖像映射到特征空間,整體過程如下圖所示。
正如論文中所講,本質(zhì)上這個(gè)Encoder可以是任何網(wǎng)絡(luò)結(jié)構(gòu),在這里使用的是微調(diào)的Detectron的ViT,當(dāng)然它也可以被改成傳統(tǒng)的卷積結(jié)構(gòu),非常合理。
輸入圖像經(jīng)過ViT結(jié)構(gòu)的過程如下:
(1) Patch Embedding
輸入圖像通過一個(gè)卷積base,將圖像劃分為16x16的patches,步長也為16,這樣feature map的尺寸就縮小了16倍,同時(shí)channel從3映射到768。Patch Embedding示意圖如下所示。
圖像大小決定了patch的數(shù)量。
代碼實(shí)現(xiàn):
'''
將輸入的圖像轉(zhuǎn)換為序列化的特征向量
'''
class PatchEmbed(nn.Module):
def __init__(
self,
# 卷積核大小
# 這里是 (16, 16),意味著圖像將被劃分為16x16的patches
kernel_size: Tuple[int, int] = (16, 16),
# 卷積的步長,與kernel_size相同,即(16, 16),
# 意味著每一步移動16個(gè)像素,這樣圖像的尺寸就會減少到原來的1/16
stride: Tuple[int, int] = (16, 16),
# 控制邊緣填充,這里設(shè)置為 (0, 0),意味著沒有額外的填充
padding: Tuple[int, int] = (0, 0),
# 輸入圖像的通道數(shù),通常為3(RGB圖像)
in_chans: int = 3,
# 輸出的特征維度,也就是每個(gè)patch被編碼為的向量的長度,這里設(shè)置為768
embed_dim: int = 768,
) -> None:
'''
初始化這個(gè)子類實(shí)例的屬性
'''
# PatchEmbed的子類,繼承自nn.Module,用于構(gòu)建神經(jīng)網(wǎng)絡(luò)模塊
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
'''前向傳播:
接收輸入張量 x,形狀 (B, C, H, W),其中,
- B表示批次大小
- C 是輸入通道數(shù)
- H 和 W 是圖像的高度和寬度
'''
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 卷積,將輸入的通道數(shù)從 in_chans 轉(zhuǎn)換為 embed_dim
x = self.proj(x)
# 將張量的維度順序從 (B, C, H, W) 調(diào)整為 (B, H, W, C)
x = x.permute(0, 2, 3, 1)
return x
Patch Embedding過程在Vision Transformer結(jié)構(gòu)圖中對應(yīng)下圖所示。
(2) Positional Embedding
經(jīng)過Patch Embedding后輸出tokens需要加入位置編碼,以保留圖像的空間信息。位置編碼可以理解為一張map,map的行數(shù)與輸入序列個(gè)數(shù)相同,每一行代表一個(gè)向量,向量的維度和輸入序列tokens的維度相同,位置編碼的操作是sum,所以維度依舊保持不變。
圖像尺寸是1024,因此patch的數(shù)量是1024/16=64。
代碼實(shí)現(xiàn):
# 在ImageEncoderViT的__init__定義
if use_abs_pos:
# 使用預(yù)訓(xùn)練圖像大小初始化絕對位置嵌入
self.pos_embed = nn.Parameter(
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
)
# 在ImageEncoderViT的forward添加位置編碼
if self.pos_embed is not None:
x = x + self.pos_embed
Positiona Embedding過程在結(jié)構(gòu)圖中對應(yīng)的部分:
(3) Transformer Encoder
feature map通過16個(gè)Transformer Block,其中12個(gè)Block使用了基于Window Partition(就是把特征圖分成14*14的windows做局部的Attention)的注意力機(jī)制,以處理局部信息。另外4個(gè)Block是全局注意力模塊,它們穿插在Window Partition模塊之間,以捕捉圖像的全局上下文。
# 在ImageEncoderViT的__init__定義
# -----Transformer Encoder-----
# 初始化一個(gè)ModuleList,用于存儲Block實(shí)例
self.blocks = nn.ModuleList()
# 循環(huán)創(chuàng)建Block,depth是Transformer Encoder層數(shù)
for i in range(depth):
# 創(chuàng)建單個(gè)Block
block = Block(
# 輸入的通道數(shù),即每個(gè)patch編碼后的向量維度
dim=embed_dim,
# 自注意力機(jī)制中的注意力頭數(shù)
num_heads=num_heads,
# MLP層的通道數(shù)相對于輸入通道數(shù)的比例
mlp_ratio=mlp_ratio,
# 是否在QKV全連接層中使用偏置
qkv_bias=qkv_bias,
# 歸一化層
norm_layer=norm_layer,
# 激活函數(shù)
act_layer=act_layer,
# 是否使用相對位置編碼
use_rel_pos=use_rel_pos,
# 相對位置編碼的初始化設(shè)置
rel_pos_zero_init=rel_pos_zero_init,
# 如果當(dāng)前Block不是全局注意力層,則使用窗口大小,否則使用0
window_size=window_size if i not in global_attn_indexes else 0,
# 輸入特征的尺寸,基于原始圖像大小和patch大小計(jì)算得出
input_size=(img_size // patch_size, img_size // patch_size),
)
# 將創(chuàng)建的Block對象添加到self.blocks列表中
self.blocks.append(block)
# -----Transformer Encoder-----
Transformer Encoder過程在結(jié)構(gòu)圖中對應(yīng)的部分:
① Encoder Block
如上圖右所示,Encoder Block從低到高主要由LayerNorm 、Multi-Head Attention和MLP構(gòu)成。
class Block(nn.Module):
def __init__(
self,
dim: int, # 輸入通道數(shù)
num_heads: int, # attention中head的個(gè)數(shù)
mlp_ratio: float = 4.0, # MLP層的通道數(shù)相對于輸入通道數(shù)的比例。
qkv_bias: bool = True, # 如果為True,QKV全連接層包含偏置。
norm_layer: Type[nn.Module] = nn.LayerNorm, # 歸一化層
act_layer: Type[nn.Module] = nn.GELU, # 激活層
use_rel_pos: bool = False, # 是否使用相對位置編碼
rel_pos_zero_init: bool = True, # 相對位置編碼的初始化設(shè)置
window_size: int = 0, # 注意力層的窗口大小
input_size: Optional[Tuple[int, int]] = None, # 輸入特征的尺寸
) -> None:
super().__init__()
self.norm1 = norm_layer(dim) # 第一個(gè)歸一化層,用于注意力層
self.attn = Attention( # Multi-Head Attention
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.norm2 = norm_layer(dim) #第二個(gè)歸一化層,用于MLP之前
# MLP
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
self.window_size = window_size
# 前向傳播
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 保存輸入張量的副本
shortcut = x
# 對輸入張量應(yīng)用第一個(gè)歸一化層
x = self.norm1(x)
# Window partition 對X進(jìn)行padding
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
# Multi-Head Attention
x = self.attn(x)
# 如果 window_size > 0,使用window_unpartition去除窗口分區(qū)的padding,恢復(fù)原始尺寸
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
# 將注意力層的輸出與輸入張量相加,實(shí)現(xiàn)殘差連接
x = shortcut + x
# 對經(jīng)過第二個(gè)歸一化層的張量應(yīng)用MLP層,再次使用殘差連接
x = x + self.mlp(self.norm2(x))
# 返回最終的張量 x
return x
② Partition操作
在非全局注意力的Block中,為了適應(yīng)14x14的窗口大小,輸入特征圖需要進(jìn)行補(bǔ)邊(padding)和拆分操作。具體流程如下:
- 輸入特征圖:輸入特征圖的初始尺寸為 1x64x64x768。
- 確定最小可整除尺寸:窗口大小為14*14,要找到能夠被14整除的最小特征圖尺寸。對于寬度和高度,我們需要找到大于等于64且能被14整除的最小數(shù)。這兩個(gè)數(shù)分別是70(64+6)和70(64+6),所以最小可整除特征圖的尺寸是 1x70x70x768。
- padding:為了將特征圖尺寸從 64x64 擴(kuò)展到 70x70,我們需要在右下角填充 6x6 的區(qū)域,因?yàn)?0-64=6。這種padding方式確保了窗口可以在特征圖的邊緣正確地劃分。
- 拆分特征圖:將padding后的特征圖1x70x70x768按照窗口大小14x14進(jìn)行拆分。因?yàn)?0/14=5,所以特征圖可以被拆分為 5x5個(gè)14x14的窗口,總共5x5=25個(gè)窗口。每個(gè)窗口的尺寸為14x14x768。
如下圖所示:
# 將輸入張量x分割成指定大小的窗口
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
# 獲取輸入張量形狀
# B表示批次大小,H和W表示高和寬,C表示通道數(shù)
B, H, W, C = x.shape
# 計(jì)算填充高度和寬度 pad_h 和 pad_w,以使得輸入尺寸能被window_size整除
# 避免在分割時(shí)產(chǎn)生非完整的窗口
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
# 如果需要填充,使用F.pad函數(shù)在寬度和高度方向上進(jìn)行填充
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
# 更新填充后張量的高度和寬度 Hp 和 Wp
Hp, Wp = H + pad_h, W + pad_w
# 張量重塑為:B,Hp/S,S,Wp/S,S,C,這樣可以將輸入張量分割成多個(gè)窗口
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
# 調(diào)整張量的形狀,使其由B,Hp/S,Wp/S,S,S,C-->B*Hp*Wp/(S*S),S,S,C
# 這樣每個(gè)窗口都在張量的連續(xù)部分
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
# 返回一個(gè)包含所有窗口的張量和原始張量的填充后尺寸 (Hp, Wp)
return windows, (Hp, Wp)
③ Unpartition操作
在非全局注意力的Block中,將attention層輸出的特征圖1x70x70x768轉(zhuǎn)化為1x64x64x768的特征圖,實(shí)際上是通過切片操作x = x[:1, :64, :64, :],從1x70x70x768的特征圖中取出左上角的1x64x64x768部分。
# 用于將window_partition函數(shù)分割的窗口重新組合回原始尺寸的張量
def window_unpartition(
# 獲取輸入張量 windows 的形狀,以及窗口大小 window_size
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
# 原始尺寸的填充高度和寬度
Hp, Wp = pad_hw
# 原始尺寸的無填充高度和寬度
H, W = hw
# 從窗口張量的總大小中計(jì)算出原始批量大小 B
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
# 重塑窗口張量:B*Hp*Wp/(S*S),S,S,C-->B,Hp/S,Wp/S,S,S,C
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
# 再次重塑張量:B,Hp/S,Wp/S,S,S,C-->B,Hp,Wp,C
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
# 如果原始尺寸小于填充后的尺寸
if Hp > H or Wp > W:
# 通過切片 x[:, :H, :W, :] 去除填充部分,只保留原始大小的區(qū)域
x = x[:, :H, :W, :].contiguous()
# B,H,W,C
# 返回合并后的張量,其形狀為 (B,H,W,C),即原始的批量大小、高度、寬度和通道數(shù)
return x
Encoder Block過程如下圖所示:
window_partition將輸入特征的尺寸從(H, W)調(diào)整為(S, S)的窗口,其中S是窗口大小。這種調(diào)整是為了在多頭注意力(Multi-Head Attention)中將相對位置嵌入添加到注意力圖(attn)。然而,并非所有Transformer Block都需要在注意力圖中嵌入相對位置信息。 window_unpartition 函數(shù)的作用是將經(jīng)過注意力計(jì)算的窗口特征重新組合回原始尺寸(S×S–>H×W)。 Hp和Wp是S的整數(shù)倍。
④ Multi-Head Attention
先來看Attention,結(jié)構(gòu)如下圖所示。
Attention中q、k和v的作用:
代碼實(shí)現(xiàn)如下:
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim: int, # 輸入通道數(shù)
num_heads: int = 8, # head數(shù)目
qkv_bias: bool = True, # 是否在QKV線性變換中使用偏置項(xiàng),默認(rèn)為True
use_rel_pos: bool = False, #是否使用相對位置編碼,默認(rèn)為False
rel_pos_zero_init: bool = True, #如果使用相對位置編碼,是否以零初始化,默認(rèn)為True
input_size: Optional[Tuple[int, int]] = None, # 可選參數(shù),用于指定相對位置編碼的尺寸,只有在使用相對位置編碼時(shí)才需要
) -> None:
super().__init__()
self.num_heads = num_heads #輸入head數(shù)目
head_dim = dim // num_heads #每個(gè)head維度
self.scale = head_dim**-0.5 #用于縮放注意力得分的因子,以避免數(shù)值溢出,取值為head_dim的平方根的倒數(shù)
#一個(gè)全連接層(nn.Linear),將輸入映射到Q、K、V的組合
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
# 一個(gè)全連接層,用于將注意力機(jī)制的輸出投影回原始維度
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos: # 使用相對位置編碼
assert (
input_size is not None
), "Input size must be provided if using relative positional encoding."
# 初始化水平方向(rel_pos_h)和垂直方向(rel_pos_w)的相對位置嵌入
# 2S-1,Epos
# 輸入尺寸為(H, W),則水平方向的位置嵌入長度為2*H-1,垂直方向的位置嵌入長度為2*W-1
# 每個(gè)位置嵌入的維度為head_dim
# 這些位置嵌入以模型參數(shù)的形式定義(nn.Parameter),意味著它們會在訓(xùn)練過程中被學(xué)習(xí)和更新
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 輸入張量x的形狀為(B, H, W, C),其中B是批次大小,H和W是高度和寬度,C是通道數(shù)(即dim)
B, H, W, _ = x.shape
# 使用qkv層將x轉(zhuǎn)換為Q、K、V的組合,然后通過重塑和重新排列來準(zhǔn)備多頭注意力計(jì)算
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
# attn with shape (B * nHead, H * W, H * W)
# 計(jì)算注意力分?jǐn)?shù)
# q * self.scale: q是查詢向量(query vectors),形狀為(B * nHead, H * W, C),其中B是批次大小,nHead是注意力頭的數(shù)量,H * W是序列的長度,C是每個(gè)位置的特征維度
# self.scale是用于縮放注意力分?jǐn)?shù)的因子,通常取head_dim的平方根的倒數(shù),以防止數(shù)值過大
# 乘以self.scale是為了穩(wěn)定計(jì)算并防止梯度消失
# k.transpose(-2, -1): k是鍵向量(key vectors),形狀與q相同。transpose(-2, -1)是對k進(jìn)行轉(zhuǎn)置操作,即將最后一個(gè)和倒數(shù)第二個(gè)維度互換,目的是讓q和k在計(jì)算點(diǎn)積時(shí)的維度匹配。轉(zhuǎn)置后的k形狀變?yōu)?B * nHead, C, H * W)
# 將q和轉(zhuǎn)置后的k進(jìn)行矩陣乘法。計(jì)算每個(gè)查詢位置q與所有鍵位置k的點(diǎn)積,生成一個(gè)形狀為(B * nHead, H * W, H * W)的注意力分?jǐn)?shù)矩陣attn。每個(gè)位置i和j的注意力分?jǐn)?shù)表示q_i與k_j的相似度
attn = (q * self.scale) @ k.transpose(-2, -1)
# 如果啟用了相對位置編碼
if self.use_rel_pos:
# (H, W)代表輸入序列的尺寸,這里假設(shè)H和W是相等的(S×S),即輸入是一個(gè)正方形網(wǎng)格(例如,圖像的像素網(wǎng)格)
# attn: 上述計(jì)算得到的注意力分?jǐn)?shù)矩陣,形狀為(B * nHead, H * W, H * W)
# q: 查詢向量,形狀為(B * nHead, H * W, C)
# self.rel_pos_h和self.rel_pos_w: 分別表示水平和垂直方向上的相對位置嵌入,形狀分別為(2 * S - 1, head_dim)
# (H, W): 輸入序列的尺寸,用于指導(dǎo)相對位置嵌入的計(jì)算
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
# 生成的注意力分?jǐn)?shù)矩陣attn隨后會經(jīng)過Softmax函數(shù),將每個(gè)位置的分?jǐn)?shù)歸一化到[0, 1]區(qū)間,形成一個(gè)概率分布
attn = attn.softmax(dim=-1)
# 加權(quán)求和:
# 使用attn @ v計(jì)算加權(quán)和,其中@表示矩陣乘法,v是值向量(value vectors),形狀為(B * nHead, H * W, C)
# 注意力權(quán)重矩陣attn(形狀為(B * nHead, H * W, H * W))與v按元素相乘后,再進(jìn)行矩陣乘法,得到加權(quán)后的值向量,形狀為(B * nHead, H * W, C)
# 使用.view()將加權(quán)后的值向量重塑為(B, self.num_heads, H, W, -1),然后使用.permute(0, 2, 3, 1, 4)進(jìn)行重排,將self.num_heads移動到第四個(gè)維度。最后,使用.reshape(B, H, W, -1)將結(jié)果進(jìn)一步重塑為(B, H, W, -1),與輸入張量的形狀一致,但保留了多頭注意力的輸出
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
# 使用self.proj(一個(gè)全連接層,形狀為(dim, dim))對上述處理后的張量進(jìn)行線性投影,以將其投影回原始的特征維度
x = self.proj(x)
# 最終,返回經(jīng)過線性投影的張量x作為注意力模塊的輸出
return x
在多頭注意力(Multi-Head Attention)模塊中,輸入特征F(N×E)表示一個(gè)序列,其中N是序列中的元素?cái)?shù)量,E是每個(gè)元素的特征維度。具體流程如下。
首先將每個(gè)token的qkv特征維度embed_dim均拆分到每個(gè)head上。
每個(gè)head分別通過q和k計(jì)算得到權(quán)重w,權(quán)重w和v得到輸出output,合并所有head的output得到最終的output。
get_rel_pos用于計(jì)算查詢(query)和鍵(key)之間在二維空間中的相對位置編碼,如下圖所示。
實(shí)現(xiàn)代碼:
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
# 表示查詢(query)和鍵(key)在二維空間中的最大相對距離
# max(q_size, k_size):取查詢的寬度q_size和鍵的寬度k_size中的較大值
# 如果q_size和k_size都為S,則最大的正向距離是S-1,最大的負(fù)向距離也是S-1,所以總的最大距離是2 * S
# - 1:減去1是因?yàn)樵谟?jì)算相對位置時(shí),0被包含在內(nèi),所以最大距離是2 * S - 1
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# 如果rel_pos的形狀的第0個(gè)維度(即長度)不等于max_rel_dist,說明需要進(jìn)行插值
if rel_pos.shape[0] != max_rel_dist:
# 使用F.interpolate進(jìn)行線性插值
rel_pos_resized = F.interpolate(
# 1,N,Ep --> 1,Ep,N --> 1,Ep,2S-1
# 將rel_pos重塑為(1, N, Ep),其中N是原始的長度,Ep是每個(gè)位置編碼的特征維度
# 通過permute(0, 2, 1)進(jìn)行轉(zhuǎn)置,使其形狀變?yōu)?1, Ep, N)
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
# 設(shè)置插值的目標(biāo)長度為max_rel_dist
size=max_rel_dist,
# 指定插值方法為線性插值
mode="linear",
)
# Ep,2S-1 --> 2S-1,Ep
# 插值后的rel_pos形狀為(1, Ep, max_rel_dist),通過reshape(-1, max_rel_dist)將其重塑為(Ep, max_rel_dist)
# 再通過permute(1, 0)轉(zhuǎn)置為(max_rel_dist, Ep)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
# 如果rel_pos的長度與max_rel_dist相等,說明已經(jīng)足夠覆蓋所有可能的相對位置,因此直接使用rel_pos,不進(jìn)行任何處理
rel_pos_resized = rel_pos
# 如果q和k長度值不同,則用短邊長度縮放坐標(biāo)
# 創(chuàng)建查詢坐標(biāo)q_coords
# torch.arange(q_size)生成一個(gè)從0到q_size - 1的整數(shù)序列,表示q_size個(gè)位置
# [:, None]在序列末尾添加一個(gè)維度,使其形狀為(q_size, 1),這樣可以方便與一個(gè)標(biāo)量進(jìn)行逐元素乘法
# max(k_size / q_size, 1.0)計(jì)算比例因子,如果k_size大于q_size,則使用k_size / q_size,否則使用1.0
# 這確保了在q_size小于k_size的情況下,q_coords的坐標(biāo)會被適當(dāng)放大,以匹配k_coords的尺度
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
# 創(chuàng)建鍵坐標(biāo)k_coords
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
# S,S
# 計(jì)算了查詢(query)和鍵(key)在二維空間中的相對坐標(biāo)relative_coords
# (q_coords - k_coords):每個(gè)查詢位置相對于每個(gè)鍵位置的水平距離
# (k_size - 1) * max(q_size / k_size, 1.0):計(jì)算了一個(gè)偏移量,用于確保相對坐標(biāo)在正確的范圍內(nèi)
# (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0):將計(jì)算出的差值和偏移量相加,得到最終的相對坐標(biāo)relative_coords
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
# tensor索引是tensor時(shí),即tensor1[tensor2]
# 假設(shè)tensor2某個(gè)具體位置值是2,則tensor1[2]位置的tensor1切片替換tensor2中的2
# tensor1->shape 5,5,3 tensor2->shape 2,2,3 tensor1切片->shape 5,3 tensor1[tensor2]->shape 2,2,3,5,3
# tensor1->shape 5,5 tensor2->shape 3,2,3 tensor1切片->shape 5 tensor1[tensor2]->shape 3,2,3,5
# 2S-1,Ep-->S,S,Ep
return rel_pos_resized[relative_coords.long()]
add_decomposed_rel_pos為atten注意力特征添加相對位置的嵌入特征,如下圖所示。
def add_decomposed_rel_pos(
# 注意力分?jǐn)?shù)矩陣
attn: torch.Tensor,
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
# S,S
q_h, q_w = q_size
k_h, k_w = k_size
# rel_pos_h -> 2S-1×Epos
# 查詢(query)和鍵(key)在高度方向上的相對位置編碼
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
# 查詢(query)和鍵(key)在寬度方向上的相對位置編碼
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
# 重塑q為(B, q_h, q_w, dim)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
# 計(jì)算相對位置加權(quán)
# 計(jì)算rel_h和rel_w,這兩個(gè)張量表示在每個(gè)位置上,查詢與相對位置編碼的加權(quán)和
# B,q_h,q_w,k_h
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
# B,q_h, q_w, k_w
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
# 合并注意力分?jǐn)?shù)和相對位置編碼
# 將attn重塑為(B, q_h, q_w, k_h, k_w),然后與rel_h和rel_w按元素相加
# 將attn重塑為(B, q_h, q_w, k_h, k_w),然后與rel_h和rel_w按元素相加
attn = (
# B,q_h, q_w, k_h, k_w
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
).view(B, q_h * q_w, k_h * k_w)
return attn
Multi-Head Attention模塊為注意力特征嵌入了相對位置特征(add_decomposed_rel_pos):
⑤ Neck Convolution
最后,通過兩層卷積(Neck)將通道數(shù)降低至256,生成最終的Image Embedding。其結(jié)構(gòu)圖如下所示。
代碼實(shí)現(xiàn)如下:
# neck: nn.Sequential,它包含兩個(gè)卷積層和兩個(gè)LayerNorm2d)
self.neck = nn.Sequential(
# 1x1的卷積層,用于將輸入通道數(shù)從embed_dim減小到out_chans
# 1x1卷積主要用于通道間的信息融合,而不改變特征圖的空間尺寸
nn.Conv2d(
embed_dim,
out_chans,
kernel_size=1,
# 不使用偏置項(xiàng)
bias=False,
),
# 歸一化層,用于規(guī)范化輸出通道的均值和方差,提高模型的穩(wěn)定性和收斂速度
# out_chans:歸一化層的通道數(shù)
LayerNorm2d(out_chans),
# 3x3的卷積層
nn.Conv2d(
# 使用out_chans作為輸入和輸出通道數(shù)
out_chans,
out_chans,
kernel_size=3,
# 輸入和輸出的特征圖尺寸保持不變,避免尺寸收縮
padding=1,
# 不使用偏置
bias=False,
),
# 第二個(gè)歸一化層,再次對輸出進(jìn)行規(guī)范化
LayerNorm2d(out_chans),
)
# 歸一化
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
# 創(chuàng)建了兩個(gè)可學(xué)習(xí)的參數(shù):weight和bias
# weight初始化為全1,bias初始化為全0
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 沿著通道維度求均值,keepdim=True保留維度,使得u的形狀與x相同,除了通道維度的大小為1
u = x.mean(1, keepdim=True) # dim=1維度求均值并保留通道
# 計(jì)算標(biāo)準(zhǔn)化因子 s,即減去均值后的平方差的平均值,也保留通道維度
s = (x - u).pow(2).mean(1, keepdim=True)
# 歸一化,將每個(gè)像素的值減去均值 u,然后除以標(biāo)準(zhǔn)差的平方根加上一個(gè)小的常數(shù) eps 以保證數(shù)值穩(wěn)定性
x = (x - u) / torch.sqrt(s + self.eps)
# 應(yīng)用可學(xué)習(xí)的權(quán)重和偏置
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
2. Prompt Encoder
SAM模型中Prompt Encoder網(wǎng)絡(luò)結(jié)構(gòu)如下圖所示。主要包括三步驟:
- Embed_Points:標(biāo)記點(diǎn)編碼(標(biāo)記點(diǎn)由點(diǎn)轉(zhuǎn)變?yōu)橄蛄?
- Embed_Boxes:標(biāo)記框編碼(標(biāo)記框由點(diǎn)轉(zhuǎn)變?yōu)橄蛄?
- Embed_Masks:mask編碼(mask下采樣保證與Image Encoder輸出一致)
(1) Embed_Points
Embed_Points結(jié)構(gòu)如下圖所示:
標(biāo)記點(diǎn)預(yù)處理,將channel由2變?yōu)閑mbed_dim(MatMul:forward_with_coords),然后再加上位置編碼權(quán)重。其中:
- 2:坐標(biāo)(h,w)
- embed_dim:提示編碼的channel
代碼實(shí)現(xiàn):
# 將輸入的點(diǎn)坐標(biāo)和對應(yīng)的標(biāo)簽轉(zhuǎn)化為高維的嵌入表示,以便于后續(xù)的模型處理
def _embed_points(
self,
points: torch.Tensor,
labels: torch.Tensor,
pad: bool,
) -> torch.Tensor:
# 將輸入的點(diǎn)坐標(biāo)points的每個(gè)坐標(biāo)值增加0.5,以將坐標(biāo)從像素的左上角移動到像素中心
points = points + 0.5
# points和boxes聯(lián)合則不需要pad
if pad:
# 在點(diǎn)坐標(biāo) points 和標(biāo)簽 labels 中添加一個(gè)填充項(xiàng)
# 以保持批次處理的一致性,即使某些樣本的點(diǎn)數(shù)量少于最大數(shù)量。
# 填充的點(diǎn)坐標(biāo)為(0,0),標(biāo)簽為-1
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) # B,1,2
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) # B,1
points = torch.cat([points, padding_point], dim=1) # B,N+1,2
labels = torch.cat([labels, padding_label], dim=1) # B,N+1
# 根據(jù)調(diào)整后的點(diǎn)坐標(biāo)和輸入圖像的尺寸生成位置編碼
# 生成的嵌入維度:B,N+1,2f
# 2f 表示每個(gè)點(diǎn)位置編碼的維度,是通過某種函數(shù)(如正弦或余弦函數(shù))從原始的2D坐標(biāo)擴(kuò)展而來
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
# 根據(jù)標(biāo)簽 labels 的值,對每個(gè)點(diǎn)的嵌入進(jìn)行調(diào)整。
# labels為-1是非標(biāo)記點(diǎn),設(shè)為非標(biāo)記點(diǎn)權(quán)重
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight
# labels為0是背景點(diǎn),加上背景點(diǎn)權(quán)重
point_embedding[labels == 0] += self.point_embeddings[0].weight
# labels為1是目標(biāo)點(diǎn),加上目標(biāo)點(diǎn)權(quán)重
point_embedding[labels == 1] += self.point_embeddings[1].weight
return point_embedding
(2) Embed_Boxes
Embed_Boxes結(jié)構(gòu)如下圖所示。
標(biāo)記框(Bounding Box)一般有兩個(gè)點(diǎn),編碼步驟如下:
- 將輸入的邊界框坐標(biāo)張量boxes從BxNx4轉(zhuǎn)換為BxNx2x2;
- 再使用point embedding編碼的方式,得到corner_embedding;
- 加上之前生成的可學(xué)習(xí)的embeding向量。
- 最后輸出的corner_embedding大小為Nx2x256。
代碼實(shí)現(xiàn):
# 將輸入的邊界框(boxes)轉(zhuǎn)換為高維的嵌入表示
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
# 將坐標(biāo)從像素的左上角移動到像素中心
boxes = boxes + 0.5
# 將輸入的邊界框坐標(biāo)張量boxes從BxN*4轉(zhuǎn)換為B*Nx2x2
# 其中B是批次大小,N是每個(gè)樣本中的邊界框數(shù)量
coords = boxes.reshape(-1, 2, 2)
# 對每個(gè)邊界框的角點(diǎn)坐標(biāo)進(jìn)行位置編碼
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) #
# 分別對每個(gè)邊界框的起始點(diǎn)和末尾點(diǎn)的嵌入向量加上特定的權(quán)重
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
# 返回加權(quán)后嵌入向量,形狀為 B*Nx2xembed_dim,其中 embed_dim 是位置編碼的維度
return corner_embedding
(3) Embed_Mask
mask提示允許我們直接在原圖上指示感興趣區(qū)域來引導(dǎo)模型。這些mask通過卷積操作被轉(zhuǎn)換為與圖像嵌入空間相匹配的特征,然后與圖像嵌入相加結(jié)合,為模型提供分割的精確位置信息。
如果沒有使用mask提示,則將一組可學(xué)習(xí)向量(no_mask_embed,1*256)expand為1x256×64×64后替代,使得在處理序列數(shù)據(jù)時(shí),即使沒有具體的mask信息,也能有一個(gè)統(tǒng)一的處理方式。
# 在PromptEncoder的forward定義
'''
首先獲取no_mask_embed權(quán)重矩陣,并將其重塑成一個(gè)形狀為(1, num_embeddings, 1, 1)的四維張量。
再利用.expand方法將這個(gè)張量擴(kuò)展到與圖像編碼相同的尺寸。bs是batch大小,-1是一個(gè)占位符,它會自動計(jì)算出
num_embeddings的值以保持張量的元素總數(shù)不變。self.image_embedding_size[0]和self.image_embedding_size[1]分別表示圖像編碼的寬度和高度。
'''
self.no_mask_embed = nn.Embedding(1, embed_dim) # embed_dim=256
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1])
)
如果有配置mask,Embed_Masks結(jié)構(gòu)如下圖所示。
已知輸入mask是Nx1x256x256,經(jīng)過3層卷積,最后得到與Image Embedding一樣的size:
首先,mask進(jìn)入一個(gè)1x2x2x4的卷積,stride=2;LN;再進(jìn)入一個(gè)4x2x2x16的卷積,stride=2;LN;最后再進(jìn)入一個(gè)16x1x1x256的卷積;得到最后的mask_embedding的size為Nx256x64x64,最終mask_embedding作為dense_embedding輸出,大小為Nx256x64x64。
mask的輸出尺寸是Image Encoder模塊輸出的圖像編碼尺寸的4倍,因此為了保持一致,需要4倍下采樣。
代碼實(shí)現(xiàn):
# 將輸入的掩模(mask)張量轉(zhuǎn)換為一個(gè)低分辨率的嵌入表示
# 掩模 masks 是一個(gè)形狀為 BxCxHxW 的張量
# 其中 B 是批次大小,C 是通道數(shù)(通常為1,因?yàn)檠谀MǔV挥幸煌ǖ溃?,H 和 W 分別是高度和寬度。
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
# mask下采樣4倍
mask_embedding = self.mask_downscaling(masks)
# 返回下采樣并轉(zhuǎn)換后的掩模嵌入,其形狀為 B*embed_dim*H'*W',其中 H' 和 W' 是下采樣后的高度和寬度
return mask_embedding
# mask_downscaling包括多個(gè)卷積層、層歸一化(LayerNorm2d)和激活函數(shù),目的是減少掩模的空間維度,同時(shí)增加通道維度
self.mask_downscaling = nn.Sequential(
# 將通道數(shù)從1減少到mask_in_chans//4,同時(shí)使用2x2的卷積核和步長2進(jìn)行下采樣,降低了空間分辨率
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
# 規(guī)范化通道維度上的特征
LayerNorm2d(mask_in_chans // 4),
# 激活函數(shù),引入非線性
activation(),
# 將通道數(shù)恢復(fù)到 mask_in_chans,再次使用2x2的卷積核和步長2進(jìn)行下采樣,進(jìn)一步降低空間分辨率
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
# LayerNorm2d 層和激活函數(shù)
LayerNorm2d(mask_in_chans),
activation(),
# 將通道數(shù)增加到 embed_dim,通常是為了與模型的其他部分保持一致
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
)
(4) PositionEmbeddingRandom
用于將標(biāo)記點(diǎn)和標(biāo)記框的坐標(biāo)進(jìn)行提示編碼預(yù)處理。就是將64x64個(gè)坐標(biāo)點(diǎn)歸一化后,與隨機(jī)高斯矩陣相乘(2x128),再將結(jié)果分別進(jìn)行sin和cos,最后再拼到一起,輸出的大小為256x64x64,與image_embedding大小基本一致了。
class PositionEmbeddingRandom(nn.Module):
"""
Positional encoding using random spatial frequencies.
"""
def init(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
super().init()
if scale is None or scale <= 0.0:
scale = 1.0
# 構(gòu)建一個(gè)2x128的隨機(jī)矩陣作為位置編碼高斯矩陣
self.register_buffer(
"positional_encoding_gaussian_matrix",
scale * torch.randn((2, num_pos_feats)),
)
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points that are normalized to [0,1]."""
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
# 矩陣乘法:64x64xx2 @ 2x128 ---> 64x64x128
coords = coords @ self.positional_encoding_gaussian_matrix
coords = 2 * np.pi * coords
# outputs d_1 x ... x d_n x C shape
# cat, 最后一個(gè)維度上拼接:64x64x256
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
"""Generate positional encoding for a grid of the specified size."""
h, w = size
device: Any = self.positional_encoding_gaussian_matrix.device
# 構(gòu)造一個(gè)64x64的全1矩陣
grid = torch.ones((h, w), device=device, dtype=torch.float32)
# 行、列累加
y_embed = grid.cumsum(dim=0) - 0.5
x_embed = grid.cumsum(dim=1) - 0.5
# 行列累加結(jié)果歸一化
y_embed = y_embed / h
x_embed = x_embed / w
# 行列拼接:64x64x2,編碼后的結(jié)果是64x64x256
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
# 最后輸出256x64x64
return pe.permute(2, 0, 1) # C x H x W
3. Mask Decoder
Mask Decoder網(wǎng)絡(luò)結(jié)構(gòu)參數(shù)配置如下:
def __init__(
self,
*,
# transformer通道數(shù)
transformer_dim: int,
# 用于預(yù)測mask的Transformer網(wǎng)絡(luò)模塊
transformer: nn.Module,
# 消除掩碼歧義預(yù)測的掩碼數(shù)量,默認(rèn)為3
num_multimask_outputs: int = 3,
# 激活函數(shù),默認(rèn)為GELU
activation: Type[nn.Module] = nn.GELU,
# MLP用于預(yù)測掩模質(zhì)量的深度
iou_head_depth: int = 3,
# MLP的隱藏層通道數(shù)
iou_head_hidden_dim: int = 256,
) -> None:
super().__init__()
self.transformer_dim = transformer_dim #存儲傳入的transformer_dim
# 存儲傳入的transformer模塊
self.transformer = transformer
# 存儲掩碼預(yù)測的輸出數(shù)量
self.num_multimask_outputs = num_multimask_outputs
# 用于表示IoU(Intersection over Union)的嵌入層,大小為1×transformer_dim
# 可學(xué)習(xí)的iou tokens:1x256
self.iou_token = nn.Embedding(1, transformer_dim)
# 包含IoU token在內(nèi)的總mask token數(shù)量
# # num_mask_tokens = 3 + 1 = 4, transformer_dim = 256
# 輸出一個(gè)4x256的矩陣
self.num_mask_tokens = num_multimask_outputs + 1
# 存儲所有mask token的嵌入層,大小為num_mask_tokens×transformer_dim
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
#----- upscaled -----
# 用于4倍上采樣的序列,包含兩個(gè)轉(zhuǎn)置卷積層,每個(gè)上采樣2倍,中間夾著LayerNorm和激活函數(shù)
self.output_upscaling = nn.Sequential(
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), #轉(zhuǎn)置卷積 上采樣2倍
LayerNorm2d(transformer_dim // 4),
activation(),
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
activation(),
)
# ----- upscaled -----
# 多層感知機(jī)(MLP)模塊
# 一個(gè)模塊列表,包含了num_mask_tokens個(gè)MLP,每個(gè)MLP用于處理不同mask的輸出
self.output_hypernetworks_mlps = nn.ModuleList(
[
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
for i in range(self.num_mask_tokens)
]
)
# ----- MLP -----
# ----- MLP -----
# 一個(gè)MLP,用于預(yù)測IoU,輸入是transformer_dim,經(jīng)過iou_head_hidden_dim的隱藏層,輸出是num_mask_tokens
self.iou_prediction_head = MLP(
transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
)
# ----- MLP -----
SAM模型Mask Decoder網(wǎng)絡(luò)結(jié)構(gòu)如下圖所示:
spa_pro_emb(sparse embedding)、iou_token、mask_token合并成一個(gè)tokens,作為point_embeddings。
- spa_pro_emb: point、bbox prompt合并后的產(chǎn)物,一般為NxXx256。
- iou_token:可學(xué)習(xí)參數(shù),大小為1x256。
- mask_token:可學(xué)習(xí)參數(shù),大小為4x256。
原論文中Mask Decoder模塊各部分結(jié)構(gòu)示意圖如下:
Mask Decoder網(wǎng)絡(luò)在特征提取中的基本步驟如下:
- transformer:將來自編碼器的圖像特征與額外的提示信息(如掩碼提示或查詢向量)融合,以捕捉目標(biāo)區(qū)域的上下文信息。
- upscaled:對粗略mask src進(jìn)行上采樣,使其與原始圖像尺寸相匹配,以便進(jìn)行更精細(xì)的mask預(yù)測。
- mask_MLP:通過一系列全連接層,對上采樣后的特征進(jìn)行變換,計(jì)算出針對每個(gè)像素的mask概率。這些層可以設(shè)計(jì)為學(xué)習(xí)如何為每個(gè)mask通道分配權(quán)重,從而生成最終的mask輸出。
- iou_MLP:評估生成的mask與真實(shí)mask之間的重疊程度,即預(yù)測mask的質(zhì)量。
def forward(
self,
# image encoder 圖像特征
image_embeddings: torch.Tensor,
# 位置編碼
# 256x64x64
image_pe: torch.Tensor,
# 標(biāo)記點(diǎn)和標(biāo)記框的嵌入編碼
sparse_prompt_embeddings: torch.Tensor,
# 輸入mask的嵌入編碼
dense_prompt_embeddings: torch.Tensor,
# 是否輸出多個(gè)mask
multimask_output: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 將這些特征融合,通過Transformer和后續(xù)的上采樣及MLP層,生成掩膜預(yù)測和IoU分?jǐn)?shù)
masks, iou_pred = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
)
# 如果multimask_output為True,表示需要輸出多個(gè)掩模,選取索引為1到num_multimask_outputs的所有掩模
if multimask_output:
mask_slice = slice(1, None)
# 否則,如果multimask_output為False,僅輸出第一個(gè)掩模(通常是最高得分的掩模)
else:
mask_slice = slice(0, 1)
# 根據(jù)multimask_output選擇后的掩模,維度調(diào)整為(batch_size, num_selected_masks, height, width)
masks = masks[:, mask_slice, :, :]
# 根據(jù)multimask_output選擇后的IoU預(yù)測,維度調(diào)整為(batch_size, num_selected_masks)
iou_pred = iou_pred[:, mask_slice]
return masks, iou_pred
def predict_masks(
self,
# image embedding: 是image encoder的輸出,大小為為1x256x64x64
image_embeddings: torch.Tensor,
# image_pe位置編碼也拓展成Nx256x64x64的矩陣
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 首先將iou token和mask token 拼接得到一個(gè)5x256的矩陣,再將其拓展到與sparse embedding一個(gè)維度Nx5x256
# 1,E and 4,E --> 5,E
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
# 再將拓展后的矩陣與sparse embedding拼接得到tokens,其大小Nx(5+X)x256
# 5,E --> B,5,E
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
# 再與稀疏矩陣拼接,假設(shè)稀疏矩陣只有point為Nx2x256,拼接之后則為Nx(5+2)x256
# B,5,E and B,N,E -->B,5+N,E N是點(diǎn)的個(gè)數(shù)(標(biāo)記點(diǎn)和標(biāo)記框的點(diǎn))
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# 將image embedding(1x256x64x64)拓展成稠密prompt的維度:Nx256x64x64
# B,C,H,W
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
#將拓展后的image embedding直接與稠密prompt相加:Nx256x64x64
# B,C,H,W + 1,C,H,W ---> B,C,H,W
src = src + dense_prompt_embeddings
# # 將256x64x64的位置編碼,拓展成Nx256x64x64
# 1,C,H,W---> B,C,H,W
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape
# ----- transformer -----
# Run the transformer:這里使用的TwoWayTransformer,有必要對輸入再說明一下
# src:image_bedding + dense_prompt(mask),Nx256x64x64
# pos_src: 位置編碼,Nx256x64x64
# tokens: iou_tokens + mask_tokens + sparse_prompt(point/bbox),Nx(5+x)x256
# B,N,C
hs, src = self.transformer(src, pos_src, tokens)
# ----- transformer -----
# # 后處理
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1: (1 + self.num_mask_tokens), :]
# 通過上采樣層將Transformer輸出的掩模部分恢復(fù)到(batch_size, channels, height, width)的形狀
# B,N,C-->B,C,H,W
src = src.transpose(1, 2).view(b, c, h, w)
# ----- upscaled -----
# 4倍上采樣
upscaled_embedding = self.output_upscaling(src)
# ----- upscaled -----
# 對每個(gè)mask token,通過其對應(yīng)的MLP得到一個(gè)權(quán)重張量,使用這些權(quán)重與上采樣后的特征張量進(jìn)行點(diǎn)乘,得到掩模預(yù)測(batch_size, num_mask_tokens, height, width)
hyper_in_list: List[torch.Tensor] = []
# ----- mlp -----
for i in range(self.num_mask_tokens):
# mask_tokens_out[:, i, :]: B,1,C
# output_hypernetworks_mlps: B,1,c
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
# B,n,c
hyper_in = torch.stack(hyper_in_list, dim=1)
# ----- mlp -----
b, c, h, w = upscaled_embedding.shape
# B,n,c × B,c,N-->B,n,h,w
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
# ----- mlp -----
# 通過IoU預(yù)測頭(MLP)對IoU token的輸出進(jìn)行處理,得到(batch_size, num_mask_tokens)的IoU分?jǐn)?shù)
# iou_token_out: B,1,n
iou_pred = self.iou_prediction_head(iou_token_out)
# ----- mlp -----
# 返回預(yù)測的掩模和IoU分?jǐn)?shù)
# masks: B,n,h,w
# iou_pred: B,1,n
return masks, iou_pred
(1) transformer
Mask Decoder由多個(gè)重復(fù)堆疊TwoWayAttention Block和1個(gè)Multi-Head Attention組成。
(2) TwoWayAttention Block
TwoWayAttention Block由LayerNorm 、Multi-Head Attention和MLP構(gòu)成。所謂的TwoWay:即是兩輪次循環(huán),第一次point_embedding自注意,第二次則加上上一輪輸出的queries進(jìn)行attention。
原論文中TwoWayAttention部分示意圖。
class TwoWayAttentionBlock(nn.Module):
def __init__(
self,
embedding_dim: int, # 輸入特征維度
num_heads: int, # 注意力頭的數(shù)量,決定了注意力機(jī)制的并行度
mlp_dim: int = 2048, # MLP(多層感知機(jī))中間層的維度,用于特征變換和非線性增強(qiáng)
activation: Type[nn.Module] = nn.ReLU, # 激活函數(shù)類型,默認(rèn)為ReLU
attention_downsample_rate: int = 2, # 下采樣比率
# 是否在第一層自注意力中跳過位置編碼的殘差連接
skip_first_layer_pe: bool = False,
) -> None:
super().__init__()
# 自注意力模塊,用于增強(qiáng)queries內(nèi)部的信息交互
self.self_attn = Attention(embedding_dim, num_heads)
# norm1/2/3/4: LayerNorm層,用于穩(wěn)定訓(xùn)練和加速收斂
self.norm1 = nn.LayerNorm(embedding_dim)
# cross_attn_token_to_image和cross_attn_image_to_token: 交叉注意力模塊,分別讓標(biāo)記點(diǎn)特征關(guān)注圖像特征,以及圖像特征反過來關(guān)注標(biāo)記點(diǎn)特征
self.cross_attn_token_to_image = Attention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
)
self.norm2 = nn.LayerNorm(embedding_dim)
# mlp: 多層感知機(jī)模塊,增加模型的表達(dá)能力
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
self.norm3 = nn.LayerNorm(embedding_dim)
self.norm4 = nn.LayerNorm(embedding_dim)
self.cross_attn_image_to_token = Attention(
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
)
self.skip_first_layer_pe = skip_first_layer_pe
# 前向傳播
def forward(
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
) -> Tuple[Tensor, Tensor]:
# queries:標(biāo)記點(diǎn)編碼相關(guān)(原始標(biāo)記點(diǎn)編碼經(jīng)過一系列特征提取)
# keys:原始圖像編碼相關(guān)(原始圖像編碼經(jīng)過一系列特征提取)
# query_pe:原始標(biāo)記點(diǎn)編碼
# key_pe:原始圖像位置編碼
# 第一輪本身queries==query_pe沒比較再"殘差"
# 首先對queries應(yīng)用自注意力,若skip_first_layer_pe=True,直接使用queries進(jìn)行自注意力計(jì)算;否則,將queries與query_pe相加后進(jìn)行自注意力計(jì)算,并殘差連接回queries,之后進(jìn)行LayerNorm
if self.skip_first_layer_pe:
queries = self.self_attn(q=queries, k=queries, v=queries)
else:
q = queries + query_pe
attn_out = self.self_attn(q=q, k=q, v=queries)
queries = queries + attn_out
queries = self.norm1(queries)
# 調(diào)整queries和keys(圖像特征)加上各自的位置編碼,然后通過cross_attn_token_to_image交叉注意力層,使標(biāo)記點(diǎn)特征關(guān)注圖像特征,結(jié)果與原始queries殘差連接并進(jìn)行LayerNorm
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm2(queries)
# MLP block:將更新后的queries通過MLP模塊進(jìn)行非線性變換,結(jié)果與原queries殘差連接并進(jìn)行LayerNorm
mlp_out = self.mlp(queries)
queries = queries + mlp_out
queries = self.norm3(queries)
# 交叉注意力(圖像到標(biāo)記點(diǎn)):再次調(diào)整queries和keys加上位置編碼,但這次通過cross_attn_image_to_token讓圖像特征關(guān)注標(biāo)記點(diǎn)特征,更新后的keys與原始keys殘差連接并進(jìn)行LayerNorm
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
keys = keys + attn_out
keys = self.norm4(keys)
return queries, keys
(3) Attention
Mask Decoder的Attention與ViT的Attention有些細(xì)微的不同:
- Mask Decoder的Attention是3個(gè)FC層分別接受3個(gè)輸入獲得q、k和v。
- ViT的Attention是1個(gè)FC層接受1個(gè)輸入后將結(jié)果均拆分獲得q、k和v。
如下圖所示:
原論文中Attention部分示意圖:
class Attention(nn.Module):
def __init__(
self,
embedding_dim: int, # 輸入特征的維度
num_heads: int, # attention的head數(shù)
downsample_rate: int = 1, # 下采樣
) -> None:
super().__init__()
self.embedding_dim = embedding_dim
# 內(nèi)部維度
self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
# 四個(gè)線性層(全連接層):用于生成query向量、key向量、value向量
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
# 用于將注意力機(jī)制后的輸出投影回原始的特征維度
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
# 將輸入張量分解為多頭注意力所需的形狀
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
b, n, c = x.shape
x = x.reshape(b, n, num_heads, c // num_heads)
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
# 在注意力計(jì)算后重新組合這些頭部
def _recombine_heads(self, x: Tensor) -> Tensor:
b, n_heads, n_tokens, c_per_head = x.shape
x = x.transpose(1, 2)
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
# 輸入投影:分別使用q_proj、k_proj和v_proj對query、key和value進(jìn)行線性變換
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
# 分離頭部:將變換后的query、key和value張量按照num_heads進(jìn)行重塑,以便進(jìn)行多頭注意力計(jì)算
# B,N_heads,N_tokens,C_per_head
q = self._separate_heads(q, self.num_heads)
k = self._separate_heads(k, self.num_heads)
v = self._separate_heads(v, self.num_heads)
# 注意力計(jì)算:
# 計(jì)算query和key的點(diǎn)積,然后除以c_per_head的平方根進(jìn)行歸一化,以防止數(shù)值過大
_, _, _, c_per_head = q.shape
attn = q @ k.permute(0, 1, 3, 2) # B,N_heads,N_tokens,C_per_head
# 歸一化Scale
attn = attn / math.sqrt(c_per_head)
# 應(yīng)用softmax函數(shù)得到注意力權(quán)重
attn = torch.softmax(attn, dim=-1)
# 使用注意力權(quán)重對value進(jìn)行加權(quán)求和,得到注意力輸出
out = attn @ v
# # B,N_tokens,C
# 重新組合頭部:將多頭注意力輸出合并回原始的特征維度。
out = self._recombine_heads(out)
# 輸出投影:最后,通過out_proj將輸出投影回原始的embedding_dim
out = self.out_proj(out)
return out
(4) transformer_MLP
transformer中MLP的結(jié)構(gòu)如下圖所示。
# MLPBlock類是一個(gè)簡單的多層感知機(jī)(MLP)模塊,由兩個(gè)全連接層(Linear)和一個(gè)激活函數(shù)組成
class MLPBlock(nn.Module):
def __init__(
self,
# 輸入的維度,通常是特征向量的長度
embedding_dim: int,
# MLP中間層的寬度,可以設(shè)置為比輸入維度更大的值以增加模型的表達(dá)能力
mlp_dim: int,
# 激活函數(shù),這里默認(rèn)使用GELU
act: Type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
# 第一個(gè)全連接層,將輸入從embedding_dim維度變換到mlp_dim維度
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
# 第二個(gè)全連接層,將mlp_dim維度的結(jié)果變換回embedding_dim維度,以保持與輸入相同的維度
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
# 激活函數(shù)實(shí)例,用于在全連接層之間引入非線性
self.act = act()
# 接收輸入張量x,將其傳遞給lin1,然后應(yīng)用激活函數(shù)act。
# 將激活函數(shù)的輸出傳遞給lin2,得到最終的輸出張量
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))
(5) upscaled
這個(gè)上采樣過程將Transformer的輸出特征圖恢復(fù)到更接近輸入圖像的分辨率,以便于生成掩模預(yù)測。upscaled的結(jié)構(gòu)如下圖所示。
# 在MaskDecoder的__init__定義
# output_upscaling是一個(gè)序列模塊,用于上采樣Transformer輸出的特征圖
self.output_upscaling = nn.Sequential(
# 使用nn.ConvTranspose2d,輸入通道數(shù)為transformer_dim,輸出通道數(shù)為transformer_dim // 4,內(nèi)核大小為2,步長為2
# 將特征圖的尺寸放大兩倍,同時(shí)將通道數(shù)減半
# 內(nèi)核大小為2的轉(zhuǎn)置卷積相當(dāng)于上采樣2倍,步長為2確保輸出尺寸翻倍
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), #轉(zhuǎn)置卷積 上采樣2倍
# 層歸一化(LayerNorm2d)
LayerNorm2d(transformer_dim // 4),
# 激活函數(shù)
activation(),
# 再次使用nn.ConvTranspose2d,輸入通道數(shù)為transformer_dim // 4,輸出通道數(shù)為transformer_dim // 8,內(nèi)核大小為2,步長為2。這一步繼續(xù)將特征圖的尺寸放大兩倍,同時(shí)通道數(shù)再次減半
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
# 重復(fù)激活函數(shù)的過程,以進(jìn)一步增強(qiáng)非線性表達(dá)
activation(),
)
# 在MaskDecoder的predict_masks添加位置編碼
upscaled_embedding = self.output_upscaling(src)
(6) mask_MLP
此處的MLP基礎(chǔ)模塊不同于ViT的MLP(transformer_MLP)基礎(chǔ)模塊。
# 在MaskDecoder的__init__定義
# output_hypernetworks_mlps是一個(gè)nn.ModuleList,包含了多個(gè)多層感知機(jī)(MLP)。每個(gè)MLP的目的是根據(jù)輸入的mask_tokens_out生成特定掩模的超網(wǎng)絡(luò)權(quán)重
self.output_hypernetworks_mlps = nn.ModuleList(
[
# transformer_dim: Transformer的輸出維度,也是輸入到MLP的通道數(shù)
# transformer_dim // 8: MLP的輸出通道數(shù),用于生成超網(wǎng)絡(luò)的權(quán)重
# 3: MLP的中間層維度,用于增加模型的表達(dá)能力
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
for i in range(self.num_mask_tokens)
]
)
# 在MaskDecoder的predict_masks添加位置編碼
# 對于self.num_mask_tokens個(gè)掩模token,遍歷output_hypernetworks_mlps列表
for i in range(self.num_mask_tokens):
# mask_tokens_out[:, i, :]: B,1,C
# output_hypernetworks_mlps: B,1,c
# 對每個(gè)掩模token,應(yīng)用對應(yīng)的MLP,輸入是mask_tokens_out中對應(yīng)位置的特征,輸出為B, 1, c形狀的張量,其中c是超網(wǎng)絡(luò)的輸出通道數(shù)
# 將每個(gè)MLP的輸出收集到hyper_in_list列表中
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
# B,n,c
# 將hyper_in_list堆疊成一個(gè)B, n, c形狀的張量hyper_in,其中n是掩模token的數(shù)量
hyper_in = torch.stack(hyper_in_list, dim=1)
# 獲取upscaled_embedding的形狀b, c, h, w,其中b是批次大小,c是通道數(shù),h和w是高度和寬度
b, c, h, w = upscaled_embedding.shape
# B,n,c × B,c,N-->B,n,h,w
# 執(zhí)行矩陣乘法(@運(yùn)算符)將hyper_in(B, n, c)與upscaled_embedding(在通道維度上展平為B, c, h * w)相結(jié)合
# 計(jì)算每個(gè)掩模token的超網(wǎng)絡(luò)權(quán)重與上采樣特征圖的點(diǎn)積,得到B, n, h * w形狀的張量
# 通過view操作將結(jié)果轉(zhuǎn)換回B, n, h, w形狀,生成了masks張量,表示每個(gè)掩模token對應(yīng)的預(yù)測掩模
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
(7) iou_MLP
此處的MLP基礎(chǔ)模塊不同于ViT的MLP(transformer_MLP)基礎(chǔ)模塊。
# 在MaskDecoder的__init__定義
# 一個(gè)多層感知機(jī)(MLP)模塊,其目的是預(yù)測每個(gè)掩模token對應(yīng)的IoU(Intersection over Union,交并比)值,以評估預(yù)測掩模與真實(shí)掩模的重合程度
self.iou_prediction_head = MLP(
# transformer_dim: 輸入到MLP的特征維度,通常與Transformer的輸出維度相同
# iou_head_hidden_dim: MLP中間層的維度,用于增強(qiáng)模型的表達(dá)能力
# self.num_mask_tokens: 輸出維度,即預(yù)測的掩模令牌數(shù)量,每個(gè)令牌對應(yīng)一個(gè)IoU預(yù)測值
transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
)
# 在MaskDecoder的predict_masks添加位置編碼
iou_pred = self.iou_prediction_head(iou_token_out)
(8) MaskDeco_MLP
Mask Decoder中MLP的結(jié)構(gòu)如下圖所示。
'''
定義了一個(gè)多層感知機(jī),它包含一個(gè)可配置的隱藏層數(shù)目、輸入和輸出維度,并可以選擇是否在輸出層應(yīng)用Sigmoid激活函數(shù)
'''
class MLP(nn.Module):
def __init__(
self,
input_dim: int, # 輸入特征的維度,即輸入張量的通道數(shù)
hidden_dim: int, # 隱藏層的通道數(shù),中間層的寬度
output_dim: int, # 輸出特征的維度,即輸出張量的通道數(shù)
num_layers: int, # 多層感知機(jī)的層數(shù),包括輸入層和輸出層
sigmoid_output: bool = False, # 一個(gè)布爾值,表示是否在輸出層應(yīng)用Sigmoid激活函數(shù),默認(rèn)為False
) -> None:
'''
內(nèi)部組件
'''
super().__init__()
# 存儲輸入的層數(shù)
self.num_layers = num_layers
# 一個(gè)列表,包含num_layers - 1個(gè)hidden_dim,用于構(gòu)建中間層的線性變換
h = [hidden_dim] * (num_layers - 1)
# 一個(gè)nn.ModuleList,包含num_layers個(gè)線性層(全連接層),每個(gè)層的輸入和輸出通道數(shù)由h和input_dim、output_dim決定
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
self.sigmoid_output = sigmoid_output
def forward(self, x):
# 對輸入張量x,遍歷layers列表中的每個(gè)線性層
for i, layer in enumerate(self.layers):
# 如果當(dāng)前層不是最后一層,應(yīng)用ReLU激活函數(shù)(F.relu)
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
# 如果sigmoid_output為True,最后對輸出應(yīng)用Sigmoid激活函數(shù)
if self.sigmoid_output:
x = F.sigmoid(x)
return x