自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

從零解讀 SAM (Segment Anything Model) 大模型!萬物皆可分割!(含源碼解析)

人工智能
SAM借鑒了NLP領(lǐng)域的Prompt策略,通過給圖像分割任務(wù)提供Prompt提示來完成任意目標(biāo)的快速分割。

SAM(Segment Anything Model),顧名思義,即為分割一切!該模型由Facebook的Meta AI實驗室,能夠根據(jù)文本指令或圖像識別,實現(xiàn)對任意物體的識別與分割。它的誕生,無疑是CV領(lǐng)域的一次重要里程碑。

論文地址:https://arxiv.org/abs/2304.02643 項目地址:https://github.com/facebookresearch/segment-anything

一、SAM Task

SAM借鑒了NLP領(lǐng)域的Prompt策略,通過給圖像分割任務(wù)提供Prompt提示來完成任意目標(biāo)的快速分割。Prompt類型可以是「前景/背景點集、粗略的框或遮罩、任意形式的文本或者任何指示圖像中需要進行分割」的信息。如下圖(a)所示,模型的輸入是原始的圖像和一些prompt,目標(biāo)是輸出"valid"的分割,所謂valid,就是當(dāng)prompt的指向是模糊時,模型能夠輸出至少其中一個mask。

這樣,可以是的SAM能夠適配各種下游任務(wù)。例如,給定一個貓的邊界框,SAM能夠輸出其mask,從而和實例分割任務(wù)搭配起來。

二、SAM Model

如下圖所示,SAM模型包含三個核心組件,Image Encoder、Prompt Encoder和Mask Decoder。圖像經(jīng)過Image Encoder編碼,Prompt提示經(jīng)過Prompt Encoder編碼,兩部分Embedding再經(jīng)過一個輕量化的Mask Decoder得到融合后的特征。其中,Encoder部分使用的是已有模型,Decoder部分使用Transformer。

1.Image Encoder

Image Encoder的作用是把圖像映射到特征空間,整體過程如下圖所示。

正如論文中所講,本質(zhì)上這個Encoder可以是任何網(wǎng)絡(luò)結(jié)構(gòu),在這里使用的是微調(diào)的Detectron的ViT,當(dāng)然它也可以被改成傳統(tǒng)的卷積結(jié)構(gòu),非常合理。

輸入圖像經(jīng)過ViT結(jié)構(gòu)的過程如下:

(1) Patch Embedding

輸入圖像通過一個卷積base,將圖像劃分為16x16的patches,步長也為16,這樣feature map的尺寸就縮小了16倍,同時channel從3映射到768。Patch Embedding示意圖如下所示。

圖像大小決定了patch的數(shù)量。

「代碼實現(xiàn):」

'''
將輸入的圖像轉(zhuǎn)換為序列化的特征向量
'''
class PatchEmbed(nn.Module):
    def __init__(
        self,
        # 卷積核大小
        # 這里是 (16, 16),意味著圖像將被劃分為16x16的patches
        kernel_size: Tuple[int, int] = (16, 16),
        # 卷積的步長,與kernel_size相同,即(16, 16),
        # 意味著每一步移動16個像素,這樣圖像的尺寸就會減少到原來的1/16
        stride: Tuple[int, int] = (16, 16),
        # 控制邊緣填充,這里設(shè)置為 (0, 0),意味著沒有額外的填充
        padding: Tuple[int, int] = (0, 0),
        # 輸入圖像的通道數(shù),通常為3(RGB圖像)
        in_chans: int = 3,
        # 輸出的特征維度,也就是每個patch被編碼為的向量的長度,這里設(shè)置為768
        embed_dim: int = 768,
    ) -> None:
        '''
        初始化這個子類實例的屬性
        '''
        # PatchEmbed的子類,繼承自nn.Module,用于構(gòu)建神經(jīng)網(wǎng)絡(luò)模塊
        super().__init__()
        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
        )
    '''前向傳播:
       接收輸入張量 x,形狀 (B, C, H, W),其中,
       - B表示批次大小
       - C 是輸入通道數(shù)
       - H 和 W 是圖像的高度和寬度
    '''
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 卷積,將輸入的通道數(shù)從 in_chans 轉(zhuǎn)換為 embed_dim
        x = self.proj(x)
        # 將張量的維度順序從 (B, C, H, W) 調(diào)整為 (B, H, W, C)
        x = x.permute(0, 2, 3, 1)
        return x

Patch Embedding過程在Vision Transformer結(jié)構(gòu)圖中對應(yīng)下圖所示。

(2) Positiona Embedding

經(jīng)過Patch Embedding后輸出tokens需要加入位置編碼,以保留圖像的空間信息。位置編碼可以理解為一張map,map的行數(shù)與輸入序列個數(shù)相同,每一行代表一個向量,向量的維度和輸入序列tokens的維度相同,位置編碼的操作是sum,所以維度依舊保持不變。

圖像尺寸是1024,因此patch的數(shù)量是1024/16=64。

「代碼實現(xiàn):」

# 在ImageEncoderViT的__init__定義
if use_abs_pos:
    # 使用預(yù)訓(xùn)練圖像大小初始化絕對位置嵌入
    self.pos_embed = nn.Parameter(
        torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
    )
# 在ImageEncoderViT的forward添加位置編碼
if self.pos_embed is not None:
    x = x + self.pos_embed

Positiona Embedding過程在結(jié)構(gòu)圖中對應(yīng)的部分:

(3) Transformer Encoder

feature map通過16個Transformer Block,其中12個Block使用了基于Window Partition(就是把特征圖分成14*14的windows做局部的Attention)的注意力機制,以處理局部信息。另外4個Block是全局注意力模塊,它們穿插在Window Partition模塊之間,以捕捉圖像的全局上下文。

# 在ImageEncoderViT的__init__定義
# -----Transformer Encoder-----
# 初始化一個ModuleList,用于存儲Block實例
self.blocks = nn.ModuleList()
# 循環(huán)創(chuàng)建Block,depth是Transformer Encoder層數(shù)
for i in range(depth):
    # 創(chuàng)建單個Block
    block = Block(
        # 輸入的通道數(shù),即每個patch編碼后的向量維度
        dim=embed_dim,
        # 自注意力機制中的注意力頭數(shù)
        num_heads=num_heads,
        # MLP層的通道數(shù)相對于輸入通道數(shù)的比例
        mlp_ratio=mlp_ratio,
        # 是否在QKV全連接層中使用偏置
        qkv_bias=qkv_bias,
        # 歸一化層
        norm_layer=norm_layer,
        # 激活函數(shù)
        act_layer=act_layer,
        # 是否使用相對位置編碼
        use_rel_pos=use_rel_pos,
        # 相對位置編碼的初始化設(shè)置
        rel_pos_zero_init=rel_pos_zero_init,
        # 如果當(dāng)前Block不是全局注意力層,則使用窗口大小,否則使用0
        window_size=window_size if i not in global_attn_indexes else 0,
        # 輸入特征的尺寸,基于原始圖像大小和patch大小計算得出
        input_size=(img_size // patch_size, img_size // patch_size),
    )
    # 將創(chuàng)建的Block對象添加到self.blocks列表中
    self.blocks.append(block)
# -----Transformer Encoder-----

Transformer Encoder過程在結(jié)構(gòu)圖中對應(yīng)的部分:

① Encoder Block

如上圖右所示,Encoder Block從低到高主要由LayerNorm 、Multi-Head Attention和MLP構(gòu)成。

class Block(nn.Module):
    def __init__(
        self,
        dim: int,                           # 輸入通道數(shù)
        num_heads: int,                     # attention中head的個數(shù)
        mlp_ratio: float = 4.0,             # MLP層的通道數(shù)相對于輸入通道數(shù)的比例。
        qkv_bias: bool = True,              # 如果為True,QKV全連接層包含偏置。
        norm_layer: Type[nn.Module] = nn.LayerNorm,     # 歸一化層
        act_layer: Type[nn.Module] = nn.GELU,           # 激活層
        use_rel_pos: bool = False,                      # 是否使用相對位置編碼
        rel_pos_zero_init: bool = True,                 # 相對位置編碼的初始化設(shè)置
        window_size: int = 0,                           # 注意力層的窗口大小
        input_size: Optional[Tuple[int, int]] = None,   # 輸入特征的尺寸
    ) -> None:
        super().__init__()
        self.norm1 = norm_layer(dim)         # 第一個歸一化層,用于注意力層
        self.attn = Attention(               # Multi-Head Attention
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            use_rel_pos=use_rel_pos,
            rel_pos_zero_init=rel_pos_zero_init,
            input_size=input_size if window_size == 0 else (window_size, window_size),
        )
        self.norm2 = norm_layer(dim)      #第二個歸一化層,用于MLP之前
        # MLP
        self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
        self.window_size = window_size
    # 前向傳播
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 保存輸入張量的副本
        shortcut = x
        # 對輸入張量應(yīng)用第一個歸一化層
        x = self.norm1(x)
        # Window partition 對X進行padding
        if self.window_size > 0:
            H, W = x.shape[1], x.shape[2]
            x, pad_hw = window_partition(x, self.window_size)
        # Multi-Head Attention
        x = self.attn(x)
        # 如果 window_size > 0,使用window_unpartition去除窗口分區(qū)的padding,恢復(fù)原始尺寸
        if self.window_size > 0:
            x = window_unpartition(x, self.window_size, pad_hw, (H, W))
        # 將注意力層的輸出與輸入張量相加,實現(xiàn)殘差連接
        x = shortcut + x
        # 對經(jīng)過第二個歸一化層的張量應(yīng)用MLP層,再次使用殘差連接
        x = x + self.mlp(self.norm2(x))
        # 返回最終的張量 x
        return x

② Partition操作

在非全局注意力的Block中,為了適應(yīng)14x14的窗口大小,輸入特征圖需要進行補邊(padding)和拆分操作。具體流程如下:

  • 輸入特征圖:輸入特征圖的初始尺寸為 1x64x64x768。
  • 確定最小可整除尺寸:窗口大小為14*14,要找到能夠被14整除的最小特征圖尺寸。對于寬度和高度,我們需要找到大于等于64且能被14整除的最小數(shù)。這兩個數(shù)分別是70(64+6)和70(64+6),所以最小可整除特征圖的尺寸是 1x70x70x768。
  • padding:為了將特征圖尺寸從 64x64 擴展到 70x70,我們需要在右下角填充 6x6 的區(qū)域,因為70-64=6。這種padding方式確保了窗口可以在特征圖的邊緣正確地劃分。
  • 拆分特征圖:將padding后的特征圖1x70x70x768按照窗口大小14x14進行拆分。因為70/14=5,所以特征圖可以被拆分為 5x5個14x14的窗口,總共5x5=25個窗口。每個窗口的尺寸為14x14x768。

如下圖所示。

# 將輸入張量x分割成指定大小的窗口
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
    # 獲取輸入張量形狀
    # B表示批次大小,H和W表示高和寬,C表示通道數(shù)
    B, H, W, C = x.shape
    # 計算填充高度和寬度 pad_h 和 pad_w,以使得輸入尺寸能被window_size整除
    # 避免在分割時產(chǎn)生非完整的窗口
    pad_h = (window_size - H % window_size) % window_size
    pad_w = (window_size - W % window_size) % window_size
    # 如果需要填充,使用F.pad函數(shù)在寬度和高度方向上進行填充
    if pad_h > 0 or pad_w > 0:
        x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
    # 更新填充后張量的高度和寬度 Hp 和 Wp
    Hp, Wp = H + pad_h, W + pad_w
    # 張量重塑為:B,Hp/S,S,Wp/S,S,C,這樣可以將輸入張量分割成多個窗口
    x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
    # 調(diào)整張量的形狀,使其由B,Hp/S,Wp/S,S,S,C-->B*Hp*Wp/(S*S),S,S,C
    # 這樣每個窗口都在張量的連續(xù)部分
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    # 返回一個包含所有窗口的張量和原始張量的填充后尺寸 (Hp, Wp)
    return windows, (Hp, Wp)

③ Unpartition操作

在非全局注意力的Block中,將attention層輸出的特征圖1x70x70x768轉(zhuǎn)化為1x64x64x768的特征圖,實際上是通過切片操作x = x[:1, :64, :64, :],從1x70x70x768的特征圖中取出左上角的1x64x64x768部分。

# 用于將window_partition函數(shù)分割的窗口重新組合回原始尺寸的張量
def window_unpartition(
    # 獲取輸入張量 windows 的形狀,以及窗口大小 window_size
    windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
    # 原始尺寸的填充高度和寬度
    Hp, Wp = pad_hw
    # 原始尺寸的無填充高度和寬度
    H, W = hw
    # 從窗口張量的總大小中計算出原始批量大小 B
    B = windows.shape[0] // (Hp * Wp // window_size // window_size)
    # 重塑窗口張量:B*Hp*Wp/(S*S),S,S,C-->B,Hp/S,Wp/S,S,S,C
    x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
    # 再次重塑張量:B,Hp/S,Wp/S,S,S,C-->B,Hp,Wp,C
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
    # 如果原始尺寸小于填充后的尺寸
    if Hp > H or Wp > W:
        # 通過切片 x[:, :H, :W, :] 去除填充部分,只保留原始大小的區(qū)域
        x = x[:, :H, :W, :].contiguous()
    # B,H,W,C
    # 返回合并后的張量,其形狀為 (B,H,W,C),即原始的批量大小、高度、寬度和通道數(shù)
    return x

Encoder Block過程如下圖所示:

window_partition將輸入特征的尺寸從(H, W)調(diào)整為(S, S)的窗口,其中S是窗口大小。這種調(diào)整是為了在多頭注意力(Multi-Head Attention)中將相對位置嵌入添加到注意力圖(attn)。然而,并非所有Transformer Block都需要在注意力圖中嵌入相對位置信息。 window_unpartition 函數(shù)的作用是將經(jīng)過注意力計算的窗口特征重新組合回原始尺寸(S×S–>H×W)。


Hp和Wp是S的整數(shù)倍。

④ Multi-Head Attention

先來看Attention,結(jié)構(gòu)如下圖所示。

Attention中q、k和v的作用:

代碼實現(xiàn)如下:

class Attention(nn.Module):
    """Multi-head Attention block with relative position embeddings."""
    def __init__(
        self,
        dim: int,               # 輸入通道數(shù)
        num_heads: int = 8,     # head數(shù)目
        qkv_bias: bool = True,  # 是否在QKV線性變換中使用偏置項,默認(rèn)為True
        use_rel_pos: bool = False, #是否使用相對位置編碼,默認(rèn)為False
        rel_pos_zero_init: bool = True, #如果使用相對位置編碼,是否以零初始化,默認(rèn)為True
        input_size: Optional[Tuple[int, int]] = None,       # 可選參數(shù),用于指定相對位置編碼的尺寸,只有在使用相對位置編碼時才需要
    ) -> None:
        super().__init__()
        self.num_heads = num_heads #輸入head數(shù)目
        head_dim = dim // num_heads #每個head維度
        self.scale = head_dim**-0.5 #用于縮放注意力得分的因子,以避免數(shù)值溢出,取值為head_dim的平方根的倒數(shù)
        #一個全連接層(nn.Linear),將輸入映射到Q、K、V的組合
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        #  一個全連接層,用于將注意力機制的輸出投影回原始維度
        self.proj = nn.Linear(dim, dim)
        self.use_rel_pos = use_rel_pos
        if self.use_rel_pos:        # 使用相對位置編碼
            assert (
                input_size is not None
            ), "Input size must be provided if using relative positional encoding."
            # 初始化水平方向(rel_pos_h)和垂直方向(rel_pos_w)的相對位置嵌入
            # 2S-1,Epos
            # 輸入尺寸為(H, W),則水平方向的位置嵌入長度為2*H-1,垂直方向的位置嵌入長度為2*W-1
            # 每個位置嵌入的維度為head_dim
            # 這些位置嵌入以模型參數(shù)的形式定義(nn.Parameter),意味著它們會在訓(xùn)練過程中被學(xué)習(xí)和更新
            self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
            self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 輸入張量x的形狀為(B, H, W, C),其中B是批次大小,H和W是高度和寬度,C是通道數(shù)(即dim)
        B, H, W, _ = x.shape
        # 使用qkv層將x轉(zhuǎn)換為Q、K、V的組合,然后通過重塑和重新排列來準(zhǔn)備多頭注意力計算
        # qkv with shape (3, B, nHead, H * W, C)
        qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        # q, k, v with shape (B * nHead, H * W, C)
        q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
        # attn with shape (B * nHead, H * W,  H * W)
        # 計算注意力分?jǐn)?shù)
        # q * self.scale: q是查詢向量(query vectors),形狀為(B * nHead, H * W, C),其中B是批次大小,nHead是注意力頭的數(shù)量,H * W是序列的長度,C是每個位置的特征維度
        # self.scale是用于縮放注意力分?jǐn)?shù)的因子,通常取head_dim的平方根的倒數(shù),以防止數(shù)值過大
        # 乘以self.scale是為了穩(wěn)定計算并防止梯度消失
        # k.transpose(-2, -1): k是鍵向量(key vectors),形狀與q相同。transpose(-2, -1)是對k進行轉(zhuǎn)置操作,即將最后一個和倒數(shù)第二個維度互換,目的是讓q和k在計算點積時的維度匹配。轉(zhuǎn)置后的k形狀變?yōu)?B * nHead, C, H * W)
        # 將q和轉(zhuǎn)置后的k進行矩陣乘法。計算每個查詢位置q與所有鍵位置k的點積,生成一個形狀為(B * nHead, H * W, H * W)的注意力分?jǐn)?shù)矩陣attn。每個位置i和j的注意力分?jǐn)?shù)表示q_i與k_j的相似度
        attn = (q * self.scale) @ k.transpose(-2, -1)
        # 如果啟用了相對位置編碼
        if self.use_rel_pos:
            # (H, W)代表輸入序列的尺寸,這里假設(shè)H和W是相等的(S×S),即輸入是一個正方形網(wǎng)格(例如,圖像的像素網(wǎng)格)
            # attn: 上述計算得到的注意力分?jǐn)?shù)矩陣,形狀為(B * nHead, H * W, H * W)
            # q: 查詢向量,形狀為(B * nHead, H * W, C)
            # self.rel_pos_h和self.rel_pos_w: 分別表示水平和垂直方向上的相對位置嵌入,形狀分別為(2 * S - 1, head_dim)
            # (H, W): 輸入序列的尺寸,用于指導(dǎo)相對位置嵌入的計算
            attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
        # 生成的注意力分?jǐn)?shù)矩陣attn隨后會經(jīng)過Softmax函數(shù),將每個位置的分?jǐn)?shù)歸一化到[0, 1]區(qū)間,形成一個概率分布
        attn = attn.softmax(dim=-1)
        # 加權(quán)求和: 
        # 使用attn @ v計算加權(quán)和,其中@表示矩陣乘法,v是值向量(value vectors),形狀為(B * nHead, H * W, C)
        # 注意力權(quán)重矩陣attn(形狀為(B * nHead, H * W, H * W))與v按元素相乘后,再進行矩陣乘法,得到加權(quán)后的值向量,形狀為(B * nHead, H * W, C)
        # 使用.view()將加權(quán)后的值向量重塑為(B, self.num_heads, H, W, -1),然后使用.permute(0, 2, 3, 1, 4)進行重排,將self.num_heads移動到第四個維度。最后,使用.reshape(B, H, W, -1)將結(jié)果進一步重塑為(B, H, W, -1),與輸入張量的形狀一致,但保留了多頭注意力的輸出
        x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
        # 使用self.proj(一個全連接層,形狀為(dim, dim))對上述處理后的張量進行線性投影,以將其投影回原始的特征維度
        x = self.proj(x)
        # 最終,返回經(jīng)過線性投影的張量x作為注意力模塊的輸出
        return x

在多頭注意力(Multi-Head Attention)模塊中,輸入特征F(N×E)表示一個序列,其中N是序列中的元素數(shù)量,E是每個元素的特征維度。具體流程如下。

首先將每個token的qkv特征維度embed_dim均拆分到每個head上。

每個head分別通過q和k計算得到權(quán)重w,權(quán)重w和v得到輸出output,合并所有head的output得到最終的output。

get_rel_pos用于計算查詢(query)和鍵(key)之間在二維空間中的相對位置編碼,如下圖所示。

實現(xiàn)代碼:

def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
    # 表示查詢(query)和鍵(key)在二維空間中的最大相對距離
    # max(q_size, k_size):取查詢的寬度q_size和鍵的寬度k_size中的較大值
    # 如果q_size和k_size都為S,則最大的正向距離是S-1,最大的負(fù)向距離也是S-1,所以總的最大距離是2 * S
    # - 1:減去1是因為在計算相對位置時,0被包含在內(nèi),所以最大距離是2 * S - 1
    max_rel_dist = int(2 * max(q_size, k_size) - 1)
    # 如果rel_pos的形狀的第0個維度(即長度)不等于max_rel_dist,說明需要進行插值
    if rel_pos.shape[0] != max_rel_dist:
        # 使用F.interpolate進行線性插值
        rel_pos_resized = F.interpolate(
            # 1,N,Ep --> 1,Ep,N --> 1,Ep,2S-1
            # 將rel_pos重塑為(1, N, Ep),其中N是原始的長度,Ep是每個位置編碼的特征維度
            # 通過permute(0, 2, 1)進行轉(zhuǎn)置,使其形狀變?yōu)?1, Ep, N)
            rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
            # 設(shè)置插值的目標(biāo)長度為max_rel_dist
            size=max_rel_dist,
            # 指定插值方法為線性插值
            mode="linear",
        )
        # Ep,2S-1 --> 2S-1,Ep
        # 插值后的rel_pos形狀為(1, Ep, max_rel_dist),通過reshape(-1, max_rel_dist)將其重塑為(Ep, max_rel_dist)
        # 再通過permute(1, 0)轉(zhuǎn)置為(max_rel_dist, Ep)
        rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
    else:
        # 如果rel_pos的長度與max_rel_dist相等,說明已經(jīng)足夠覆蓋所有可能的相對位置,因此直接使用rel_pos,不進行任何處理
        rel_pos_resized = rel_pos

    # 如果q和k長度值不同,則用短邊長度縮放坐標(biāo)
    # 創(chuàng)建查詢坐標(biāo)q_coords
    # torch.arange(q_size)生成一個從0到q_size - 1的整數(shù)序列,表示q_size個位置
    # [:, None]在序列末尾添加一個維度,使其形狀為(q_size, 1),這樣可以方便與一個標(biāo)量進行逐元素乘法
    # max(k_size / q_size, 1.0)計算比例因子,如果k_size大于q_size,則使用k_size / q_size,否則使用1.0
    # 這確保了在q_size小于k_size的情況下,q_coords的坐標(biāo)會被適當(dāng)放大,以匹配k_coords的尺度
    q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
    # 創(chuàng)建鍵坐標(biāo)k_coords
    k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
    # S,S
    # 計算了查詢(query)和鍵(key)在二維空間中的相對坐標(biāo)relative_coords
    # (q_coords - k_coords):每個查詢位置相對于每個鍵位置的水平距離
    # (k_size - 1) * max(q_size / k_size, 1.0):計算了一個偏移量,用于確保相對坐標(biāo)在正確的范圍內(nèi)
    # (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0):將計算出的差值和偏移量相加,得到最終的相對坐標(biāo)relative_coords
    relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)

    # tensor索引是tensor時,即tensor1[tensor2]
    # 假設(shè)tensor2某個具體位置值是2,則tensor1[2]位置的tensor1切片替換tensor2中的2
    # tensor1->shape 5,5,3 tensor2->shape 2,2,3 tensor1切片->shape 5,3 tensor1[tensor2]->shape 2,2,3,5,3
    # tensor1->shape 5,5 tensor2->shape 3,2,3 tensor1切片->shape 5 tensor1[tensor2]->shape 3,2,3,5

    # 2S-1,Ep-->S,S,Ep
    return rel_pos_resized[relative_coords.long()]

add_decomposed_rel_pos為atten注意力特征添加相對位置的嵌入特征,如下圖所示。

def add_decomposed_rel_pos(
    # 注意力分?jǐn)?shù)矩陣
    attn: torch.Tensor,
    q: torch.Tensor,
    rel_pos_h: torch.Tensor,
    rel_pos_w: torch.Tensor,
    q_size: Tuple[int, int],
    k_size: Tuple[int, int],
) -> torch.Tensor:
    # S,S
    q_h, q_w = q_size
    k_h, k_w = k_size
    # rel_pos_h -> 2S-1×Epos
    # 查詢(query)和鍵(key)在高度方向上的相對位置編碼
    Rh = get_rel_pos(q_h, k_h, rel_pos_h)
    # 查詢(query)和鍵(key)在寬度方向上的相對位置編碼
    Rw = get_rel_pos(q_w, k_w, rel_pos_w)
    # 重塑q為(B, q_h, q_w, dim)
    B, _, dim = q.shape
    r_q = q.reshape(B, q_h, q_w, dim)
    # 計算相對位置加權(quán)
    # 計算rel_h和rel_w,這兩個張量表示在每個位置上,查詢與相對位置編碼的加權(quán)和
    # B,q_h,q_w,k_h
    rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
    # B,q_h, q_w, k_w
    rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
    # 合并注意力分?jǐn)?shù)和相對位置編碼
    # 將attn重塑為(B, q_h, q_w, k_h, k_w),然后與rel_h和rel_w按元素相加
    # 將attn重塑為(B, q_h, q_w, k_h, k_w),然后與rel_h和rel_w按元素相加
    attn = (
    # B,q_h, q_w, k_h, k_w
        attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
    ).view(B, q_h * q_w, k_h * k_w)
    return attn

Multi-Head Attention模塊為注意力特征嵌入了相對位置特征(add_decomposed_rel_pos):

(4) Neck Convolution

最后,通過兩層卷積(Neck)將通道數(shù)降低至256,生成最終的Image Embedding。其結(jié)構(gòu)圖如下所示。

代碼實現(xiàn)如下:

# neck: nn.Sequential,它包含兩個卷積層和兩個LayerNorm2d)
self.neck = nn.Sequential(
    # 1x1的卷積層,用于將輸入通道數(shù)從embed_dim減小到out_chans
    # 1x1卷積主要用于通道間的信息融合,而不改變特征圖的空間尺寸
    nn.Conv2d(
        embed_dim,
        out_chans,
        kernel_size=1,
        # 不使用偏置項
        bias=False,
    ),
    # 歸一化層,用于規(guī)范化輸出通道的均值和方差,提高模型的穩(wěn)定性和收斂速度
    # out_chans:歸一化層的通道數(shù)
    LayerNorm2d(out_chans),
    # 3x3的卷積層
    nn.Conv2d(
        # 使用out_chans作為輸入和輸出通道數(shù)
        out_chans,
        out_chans,
        kernel_size=3,
        # 輸入和輸出的特征圖尺寸保持不變,避免尺寸收縮
        padding=1,
        # 不使用偏置
        bias=False,
    ),
    # 第二個歸一化層,再次對輸出進行規(guī)范化
    LayerNorm2d(out_chans),
)
# 歸一化
class LayerNorm2d(nn.Module):
    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
        super().__init__()
        # 創(chuàng)建了兩個可學(xué)習(xí)的參數(shù):weight和bias
        # weight初始化為全1,bias初始化為全0
        self.weight = nn.Parameter(torch.ones(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 沿著通道維度求均值,keepdim=True保留維度,使得u的形狀與x相同,除了通道維度的大小為1
        u = x.mean(1, keepdim=True)                 # dim=1維度求均值并保留通道
        # 計算標(biāo)準(zhǔn)化因子 s,即減去均值后的平方差的平均值,也保留通道維度
        s = (x - u).pow(2).mean(1, keepdim=True)
        # 歸一化,將每個像素的值減去均值 u,然后除以標(biāo)準(zhǔn)差的平方根加上一個小的常數(shù) eps 以保證數(shù)值穩(wěn)定性
        x = (x - u) / torch.sqrt(s + self.eps)
        # 應(yīng)用可學(xué)習(xí)的權(quán)重和偏置
        x = self.weight[:, None, None] * x + self.bias[:, None, None]
        return x

2.Prompt Encoder

SAM模型中Prompt Encoder網(wǎng)絡(luò)結(jié)構(gòu)如下圖所示。主要包括三步驟:

  • Embed_Points:標(biāo)記點編碼(標(biāo)記點由點轉(zhuǎn)變?yōu)橄蛄?
  • Embed_Boxes:標(biāo)記框編碼(標(biāo)記框由點轉(zhuǎn)變?yōu)橄蛄?
  • Embed_Masks:mask編碼(mask下采樣保證與Image Encoder輸出一致)

(1) Embed_Points

Embed_Points結(jié)構(gòu)如下圖所示。

標(biāo)記點預(yù)處理,將channel由2變?yōu)閑mbed_dim(MatMul:forward_with_coords),然后再加上位置編碼權(quán)重。其中,

  • 2:坐標(biāo)(h,w)
  • embed_dim:提示編碼的channel

「代碼實現(xiàn):」

# 將輸入的點坐標(biāo)和對應(yīng)的標(biāo)簽轉(zhuǎn)化為高維的嵌入表示,以便于后續(xù)的模型處理
def _embed_points(
    self,
    points: torch.Tensor,
    labels: torch.Tensor,
    pad: bool,
) -> torch.Tensor:
    # 將輸入的點坐標(biāo)points的每個坐標(biāo)值增加0.5,以將坐標(biāo)從像素的左上角移動到像素中心
    points = points + 0.5
    # points和boxes聯(lián)合則不需要pad
    if pad:
        # 在點坐標(biāo) points 和標(biāo)簽 labels 中添加一個填充項
        # 以保持批次處理的一致性,即使某些樣本的點數(shù)量少于最大數(shù)量。
        # 填充的點坐標(biāo)為(0,0),標(biāo)簽為-1
        padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)  # B,1,2
        padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)     # B,1
        points = torch.cat([points, padding_point], dim=1)                          # B,N+1,2
        labels = torch.cat([labels, padding_label], dim=1)                          # B,N+1
    # 根據(jù)調(diào)整后的點坐標(biāo)和輸入圖像的尺寸生成位置編碼
    # 生成的嵌入維度:B,N+1,2f
    # 2f 表示每個點位置編碼的維度,是通過某種函數(shù)(如正弦或余弦函數(shù))從原始的2D坐標(biāo)擴展而來
    point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)  
    # 根據(jù)標(biāo)簽 labels 的值,對每個點的嵌入進行調(diào)整。

    # labels為-1是非標(biāo)記點,設(shè)為非標(biāo)記點權(quán)重
    point_embedding[labels == -1] = 0.0
    point_embedding[labels == -1] += self.not_a_point_embed.weight
    # labels為0是背景點,加上背景點權(quán)重
    point_embedding[labels == 0] += self.point_embeddings[0].weight
    # labels為1是目標(biāo)點,加上目標(biāo)點權(quán)重
    point_embedding[labels == 1] += self.point_embeddings[1].weight
    return point_embedding

(2) Embed_Boxes

Embed_Boxes結(jié)構(gòu)如下圖所示。

標(biāo)記框(Bounding Box)一般有兩個點,編碼步驟如下:

  • 將輸入的邊界框坐標(biāo)張量boxes從BxNx4轉(zhuǎn)換為BxNx2x2;
  • 再使用point embedding編碼的方式,得到corner_embedding;
  • 加上之前生成的可學(xué)習(xí)的embeding向量。

最后輸出的corner_embedding大小為Nx2x256。

「代碼實現(xiàn):」

# 將輸入的邊界框(boxes)轉(zhuǎn)換為高維的嵌入表示
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
    # 將坐標(biāo)從像素的左上角移動到像素中心
    boxes = boxes + 0.5
    # 將輸入的邊界框坐標(biāo)張量boxes從BxN*4轉(zhuǎn)換為B*Nx2x2
    # 其中B是批次大小,N是每個樣本中的邊界框數(shù)量
    coords = boxes.reshape(-1, 2, 2)
    # 對每個邊界框的角點坐標(biāo)進行位置編碼
    corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)    #
    # 分別對每個邊界框的起始點和末尾點的嵌入向量加上特定的權(quán)重
    corner_embedding[:, 0, :] += self.point_embeddings[2].weight
    corner_embedding[:, 1, :] += self.point_embeddings[3].weight
    # 返回加權(quán)后嵌入向量,形狀為 B*Nx2xembed_dim,其中 embed_dim 是位置編碼的維度
    return corner_embedding

(3) Embed_Mask

mask提示允許我們直接在原圖上指示感興趣區(qū)域來引導(dǎo)模型。這些mask通過卷積操作被轉(zhuǎn)換為與圖像嵌入空間相匹配的特征,然后與圖像嵌入相加結(jié)合,為模型提供分割的精確位置信息。

如果沒有使用mask提示,則將一組可學(xué)習(xí)向量(no_mask_embed,1*256)expand為1x256×64×64后替代,使得在處理序列數(shù)據(jù)時,即使沒有具體的mask信息,也能有一個統(tǒng)一的處理方式。

# 在PromptEncoder的forward定義
'''
首先獲取no_mask_embed權(quán)重矩陣,并將其重塑成一個形狀為(1, num_embeddings, 1, 1)的四維張量。

再利用.expand方法將這個張量擴展到與圖像編碼相同的尺寸。bs是batch大小,-1是一個占位符,它會自動計算出
num_embeddings的值以保持張量的元素總數(shù)不變。self.image_embedding_size[0]和self.image_embedding_size[1]分別表示圖像編碼的寬度和高度。
'''
self.no_mask_embed = nn.Embedding(1, embed_dim)      # embed_dim=256
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
                bs, -1, self.image_embedding_size[0], self.image_embedding_size[1])

)

如果有配置mask,Embed_Masks結(jié)構(gòu)如下圖所示。

已知輸入mask是Nx1x256x256,經(jīng)過3層卷積,最后得到與Image Embedding一樣的size:

首先,mask進入一個1x2x2x4的卷積,stride=2;LN;再進入一個4x2x2x16的卷積,stride=2;LN;最后再進入一個16x1x1x256的卷積;得到最后的mask_embedding的size為Nx256x64x64,最終mask_embedding作為dense_embedding輸出,大小為Nx256x64x64。

mask的輸出尺寸是Image Encoder模塊輸出的圖像編碼尺寸的4倍,因此為了保持一致,需要4倍下采樣。

「代碼實現(xiàn)」

# 將輸入的掩模(mask)張量轉(zhuǎn)換為一個低分辨率的嵌入表示
# 掩模 masks 是一個形狀為 BxCxHxW 的張量
# 其中 B 是批次大小,C 是通道數(shù)(通常為1,因為掩模通常只有一通道),H 和 W 分別是高度和寬度。
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
    # mask下采樣4倍
    mask_embedding = self.mask_downscaling(masks)
    # 返回下采樣并轉(zhuǎn)換后的掩模嵌入,其形狀為 B*embed_dim*H'*W',其中 H' 和 W' 是下采樣后的高度和寬度
    return mask_embedding

# mask_downscaling包括多個卷積層、層歸一化(LayerNorm2d)和激活函數(shù),目的是減少掩模的空間維度,同時增加通道維度
self.mask_downscaling = nn.Sequential(
    # 將通道數(shù)從1減少到mask_in_chans//4,同時使用2x2的卷積核和步長2進行下采樣,降低了空間分辨率
    nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
    # 規(guī)范化通道維度上的特征
    LayerNorm2d(mask_in_chans // 4),
    # 激活函數(shù),引入非線性
    activation(),
    # 將通道數(shù)恢復(fù)到 mask_in_chans,再次使用2x2的卷積核和步長2進行下采樣,進一步降低空間分辨率
    nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
    # LayerNorm2d 層和激活函數(shù)
    LayerNorm2d(mask_in_chans),
    activation(),
    # 將通道數(shù)增加到 embed_dim,通常是為了與模型的其他部分保持一致
    nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
        )

① PositionEmbeddingRandom

用于將標(biāo)記點和標(biāo)記框的坐標(biāo)進行提示編碼預(yù)處理。就是將64x64個坐標(biāo)點歸一化后,與隨機高斯矩陣相乘(2x128),再將結(jié)果分別進行sin和cos,最后再拼到一起,輸出的大小為256x64x64,與image_embedding大小基本一致了。

class PositionEmbeddingRandom(nn.Module):
    """
    Positional encoding using random spatial frequencies.
    """
    def init(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
        super().init()
        if scale is None or scale <= 0.0:
            scale = 1.0
        # 構(gòu)建一個2x128的隨機矩陣作為位置編碼高斯矩陣
        self.register_buffer(
            "positional_encoding_gaussian_matrix",
            scale * torch.randn((2, num_pos_feats)),
        )

    def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
        """Positionally encode points that are normalized to [0,1]."""
        # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
        coords = 2 * coords - 1

        # 矩陣乘法:64x64xx2 @ 2x128 ---> 64x64x128
        coords = coords @ self.positional_encoding_gaussian_matrix
        coords = 2 * np.pi * coords

        # outputs d_1 x ... x d_n x C shape
        # cat, 最后一個維度上拼接:64x64x256
        return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)

    def forward(self, size: Tuple[int, int]) -> torch.Tensor:
        """Generate positional encoding for a grid of the specified size."""
        h, w = size
        device: Any = self.positional_encoding_gaussian_matrix.device

        # 構(gòu)造一個64x64的全1矩陣
        grid = torch.ones((h, w), device=device, dtype=torch.float32)

        # 行、列累加
        y_embed = grid.cumsum(dim=0) - 0.5
        x_embed = grid.cumsum(dim=1) - 0.5

        # 行列累加結(jié)果歸一化
        y_embed = y_embed / h
        x_embed = x_embed / w

        # 行列拼接:64x64x2,編碼后的結(jié)果是64x64x256
        pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))

        # 最后輸出256x64x64
        return pe.permute(2, 0, 1)  # C x H x W

3.Mask Decoder

Mask Decoder網(wǎng)絡(luò)結(jié)構(gòu)參數(shù)配置如下。

def __init__(
    self,
    *,
    # transformer通道數(shù)
    transformer_dim: int,
    # 用于預(yù)測mask的Transformer網(wǎng)絡(luò)模塊
    transformer: nn.Module,
    # 消除掩碼歧義預(yù)測的掩碼數(shù)量,默認(rèn)為3
    num_multimask_outputs: int = 3,
    # 激活函數(shù),默認(rèn)為GELU
    activation: Type[nn.Module] = nn.GELU,
    # MLP用于預(yù)測掩模質(zhì)量的深度
    iou_head_depth: int = 3,
    # MLP的隱藏層通道數(shù)
    iou_head_hidden_dim: int = 256,
) -> None:
    super().__init__()
    self.transformer_dim = transformer_dim #存儲傳入的transformer_dim
    # 存儲傳入的transformer模塊
    self.transformer = transformer
    # 存儲掩碼預(yù)測的輸出數(shù)量
    self.num_multimask_outputs = num_multimask_outputs
    # 用于表示IoU(Intersection over Union)的嵌入層,大小為1×transformer_dim
    # 可學(xué)習(xí)的iou tokens:1x256
    self.iou_token = nn.Embedding(1, transformer_dim)
    # 包含IoU token在內(nèi)的總mask token數(shù)量
    # # num_mask_tokens = 3 + 1 = 4, transformer_dim = 256
    # 輸出一個4x256的矩陣
    self.num_mask_tokens = num_multimask_outputs + 1
    # 存儲所有mask token的嵌入層,大小為num_mask_tokens×transformer_dim
    self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)

    #----- upscaled -----
    # 用于4倍上采樣的序列,包含兩個轉(zhuǎn)置卷積層,每個上采樣2倍,中間夾著LayerNorm和激活函數(shù)
    self.output_upscaling = nn.Sequential(
        nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),     #轉(zhuǎn)置卷積 上采樣2倍
        LayerNorm2d(transformer_dim // 4),
        activation(),
        nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
        activation(),
    )
    # ----- upscaled -----

    # 多層感知機(MLP)模塊
    #  一個模塊列表,包含了num_mask_tokens個MLP,每個MLP用于處理不同mask的輸出
    self.output_hypernetworks_mlps = nn.ModuleList(
        [
            MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
            for i in range(self.num_mask_tokens)
        ]
    )
    # ----- MLP -----

    # ----- MLP -----
    # 一個MLP,用于預(yù)測IoU,輸入是transformer_dim,經(jīng)過iou_head_hidden_dim的隱藏層,輸出是num_mask_tokens
    self.iou_prediction_head = MLP(
        transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
    )
    # ----- MLP -----

SAM模型Mask Decoder網(wǎng)絡(luò)結(jié)構(gòu)如下圖所示。

  • spa_pro_emb(sparse embedding)、iou_token、mask_token合并成一個tokens,作為point_embeddings。
  • spa_pro_emb: point、bbox prompt合并后的產(chǎn)物,一般為NxXx256。
  • iou_token:可學(xué)習(xí)參數(shù),大小為1x256。
  • mask_token:可學(xué)習(xí)參數(shù),大小為4x256。

原論文中Mask Decoder模塊各部分結(jié)構(gòu)示意圖如下。

Mask Decoder網(wǎng)絡(luò)在特征提取中的基本步驟如下:

  • transformer:將來自編碼器的圖像特征與額外的提示信息(如掩碼提示或查詢向量)融合,以捕捉目標(biāo)區(qū)域的上下文信息。
  • upscaled:對粗略mask src進行上采樣,使其與原始圖像尺寸相匹配,以便進行更精細(xì)的mask預(yù)測。
  • mask_MLP:通過一系列全連接層,對上采樣后的特征進行變換,計算出針對每個像素的mask概率。這些層可以設(shè)計為學(xué)習(xí)如何為每個mask通道分配權(quán)重,從而生成最終的mask輸出。
  • iou_MLP:評估生成的mask與真實mask之間的重疊程度,即預(yù)測mask的質(zhì)量。
def forward(
    self,
    # image encoder 圖像特征
    image_embeddings: torch.Tensor,
    # 位置編碼
    # 256x64x64
    image_pe: torch.Tensor,
    # 標(biāo)記點和標(biāo)記框的嵌入編碼
    sparse_prompt_embeddings: torch.Tensor,
    # 輸入mask的嵌入編碼
    dense_prompt_embeddings: torch.Tensor,
    # 是否輸出多個mask
    multimask_output: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
    # 將這些特征融合,通過Transformer和后續(xù)的上采樣及MLP層,生成掩膜預(yù)測和IoU分?jǐn)?shù)
    masks, iou_pred = self.predict_masks(
        image_embeddings=image_embeddings,
        image_pe=image_pe,
        sparse_prompt_embeddings=sparse_prompt_embeddings,
        dense_prompt_embeddings=dense_prompt_embeddings,
    )
    # 如果multimask_output為True,表示需要輸出多個掩模,選取索引為1到num_multimask_outputs的所有掩模
    if multimask_output:
        mask_slice = slice(1, None)
    # 否則,如果multimask_output為False,僅輸出第一個掩模(通常是最高得分的掩模)
    else:
        mask_slice = slice(0, 1)
    # 根據(jù)multimask_output選擇后的掩模,維度調(diào)整為(batch_size, num_selected_masks, height, width)
    masks = masks[:, mask_slice, :, :]
    # 根據(jù)multimask_output選擇后的IoU預(yù)測,維度調(diào)整為(batch_size, num_selected_masks)
    iou_pred = iou_pred[:, mask_slice]
    return masks, iou_pred
def predict_masks(
    self,
    # image embedding: 是image encoder的輸出,大小為為1x256x64x64
    image_embeddings: torch.Tensor,
    # image_pe位置編碼也拓展成Nx256x64x64的矩陣
    image_pe: torch.Tensor,
    sparse_prompt_embeddings: torch.Tensor,
    dense_prompt_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    # 首先將iou token和mask token 拼接得到一個5x256的矩陣,再將其拓展到與sparse embedding一個維度Nx5x256
    # 1,E and 4,E --> 5,E
    output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
    # 再將拓展后的矩陣與sparse embedding拼接得到tokens,其大小Nx(5+X)x256
    # 5,E --> B,5,E
    output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
    # 再與稀疏矩陣拼接,假設(shè)稀疏矩陣只有point為Nx2x256,拼接之后則為Nx(5+2)x256
    # B,5,E and B,N,E -->B,5+N,E       N是點的個數(shù)(標(biāo)記點和標(biāo)記框的點)
    tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)

    # 將image embedding(1x256x64x64)拓展成稠密prompt的維度:Nx256x64x64
    # B,C,H,W
    src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
    #將拓展后的image embedding直接與稠密prompt相加:Nx256x64x64
    # B,C,H,W + 1,C,H,W ---> B,C,H,W
    src = src + dense_prompt_embeddings
    # # 將256x64x64的位置編碼,拓展成Nx256x64x64
    # 1,C,H,W---> B,C,H,W
    pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
    b, c, h, w = src.shape

    # ----- transformer -----
    # Run the transformer:這里使用的TwoWayTransformer,有必要對輸入再說明一下
    # src:image_bedding + dense_prompt(mask),Nx256x64x64
    # pos_src: 位置編碼,Nx256x64x64
    # tokens: iou_tokens + mask_tokens + sparse_prompt(point/bbox),Nx(5+x)x256
    # B,N,C
    hs, src = self.transformer(src, pos_src, tokens)
    # ----- transformer -----
    # # 后處理
    iou_token_out = hs[:, 0, :]
    mask_tokens_out = hs[:, 1: (1 + self.num_mask_tokens), :]

    # 通過上采樣層將Transformer輸出的掩模部分恢復(fù)到(batch_size, channels, height, width)的形狀
    # B,N,C-->B,C,H,W
    src = src.transpose(1, 2).view(b, c, h, w)
    # ----- upscaled -----
    # 4倍上采樣
    upscaled_embedding = self.output_upscaling(src)
    # ----- upscaled -----
    
    # 對每個mask token,通過其對應(yīng)的MLP得到一個權(quán)重張量,使用這些權(quán)重與上采樣后的特征張量進行點乘,得到掩模預(yù)測(batch_size, num_mask_tokens, height, width)
    hyper_in_list: List[torch.Tensor] = []
    
    # ----- mlp -----
    for i in range(self.num_mask_tokens):
        # mask_tokens_out[:, i, :]: B,1,C
        # output_hypernetworks_mlps: B,1,c
        hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
    # B,n,c
    hyper_in = torch.stack(hyper_in_list, dim=1)
    # ----- mlp -----
    
    b, c, h, w = upscaled_embedding.shape
    # B,n,c × B,c,N-->B,n,h,w
    masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
    
    # ----- mlp -----
    # 通過IoU預(yù)測頭(MLP)對IoU token的輸出進行處理,得到(batch_size, num_mask_tokens)的IoU分?jǐn)?shù)
    # iou_token_out: B,1,n
    iou_pred = self.iou_prediction_head(iou_token_out)
    # ----- mlp -----
    # 返回預(yù)測的掩模和IoU分?jǐn)?shù)
    # masks: B,n,h,w
    # iou_pred: B,1,n
    return masks, iou_pred

(1) transformer

Mask Decoder由多個重復(fù)堆疊TwoWayAttention Block和1個Multi-Head Attention組成。

① TwoWayAttention Block

TwoWayAttention Block由LayerNorm 、Multi-Head Attention和MLP構(gòu)成。所謂的TwoWay:即是兩輪次循環(huán),第一次point_embedding自注意,第二次則加上上一輪輸出的queries進行attention。

原論文中TwoWayAttention部分示意圖。

class TwoWayAttentionBlock(nn.Module):
    def __init__(
        self,
        embedding_dim: int,         # 輸入特征維度
        num_heads: int,             # 注意力頭的數(shù)量,決定了注意力機制的并行度
        mlp_dim: int = 2048,        # MLP(多層感知機)中間層的維度,用于特征變換和非線性增強
        activation: Type[nn.Module] = nn.ReLU,      # 激活函數(shù)類型,默認(rèn)為ReLU
        attention_downsample_rate: int = 2,         # 下采樣比率
        # 是否在第一層自注意力中跳過位置編碼的殘差連接
        skip_first_layer_pe: bool = False,
    ) -> None:
        super().__init__()
        # 自注意力模塊,用于增強queries內(nèi)部的信息交互
        self.self_attn = Attention(embedding_dim, num_heads)
        # norm1/2/3/4: LayerNorm層,用于穩(wěn)定訓(xùn)練和加速收斂
        self.norm1 = nn.LayerNorm(embedding_dim)
        # cross_attn_token_to_image和cross_attn_image_to_token: 交叉注意力模塊,分別讓標(biāo)記點特征關(guān)注圖像特征,以及圖像特征反過來關(guān)注標(biāo)記點特征
        self.cross_attn_token_to_image = Attention(
            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
        )
        self.norm2 = nn.LayerNorm(embedding_dim)
        # mlp: 多層感知機模塊,增加模型的表達能力
        self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
        self.norm3 = nn.LayerNorm(embedding_dim)

        self.norm4 = nn.LayerNorm(embedding_dim)
        self.cross_attn_image_to_token = Attention(
            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
        )
        self.skip_first_layer_pe = skip_first_layer_pe
    # 前向傳播
    def forward(
        self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
    ) -> Tuple[Tensor, Tensor]:

        # queries:標(biāo)記點編碼相關(guān)(原始標(biāo)記點編碼經(jīng)過一系列特征提取)
        # keys:原始圖像編碼相關(guān)(原始圖像編碼經(jīng)過一系列特征提取)
        # query_pe:原始標(biāo)記點編碼
        # key_pe:原始圖像位置編碼
        # 第一輪本身queries==query_pe沒比較再"殘差"

        # 首先對queries應(yīng)用自注意力,若skip_first_layer_pe=True,直接使用queries進行自注意力計算;否則,將queries與query_pe相加后進行自注意力計算,并殘差連接回queries,之后進行LayerNorm
        if self.skip_first_layer_pe:
            queries = self.self_attn(q=queries, k=queries, v=queries)
        else:
            q = queries + query_pe
            attn_out = self.self_attn(q=q, k=q, v=queries)
            queries = queries + attn_out
        queries = self.norm1(queries)

        # 調(diào)整queries和keys(圖像特征)加上各自的位置編碼,然后通過cross_attn_token_to_image交叉注意力層,使標(biāo)記點特征關(guān)注圖像特征,結(jié)果與原始queries殘差連接并進行LayerNorm
        q = queries + query_pe
        k = keys + key_pe
        attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
        queries = queries + attn_out
        queries = self.norm2(queries)

        # MLP block:將更新后的queries通過MLP模塊進行非線性變換,結(jié)果與原queries殘差連接并進行LayerNorm
        mlp_out = self.mlp(queries)
        queries = queries + mlp_out
        queries = self.norm3(queries)

        # 交叉注意力(圖像到標(biāo)記點):再次調(diào)整queries和keys加上位置編碼,但這次通過cross_attn_image_to_token讓圖像特征關(guān)注標(biāo)記點特征,更新后的keys與原始keys殘差連接并進行LayerNorm
        q = queries + query_pe
        k = keys + key_pe
        attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
        keys = keys + attn_out
        keys = self.norm4(keys)
        return queries, keys

② Attention

Mask Decoder的Attention與ViT的Attention有些細(xì)微的不同:

  • Mask Decoder的Attention是3個FC層分別接受3個輸入獲得q、k和v。
  • ViT的Attention是1個FC層接受1個輸入后將結(jié)果均拆分獲得q、k和v。

如下圖所示。

原論文中Attention部分示意圖。

class Attention(nn.Module):

    def __init__(
        self,
        embedding_dim: int,         # 輸入特征的維度
        num_heads: int,             # attention的head數(shù)
        downsample_rate: int = 1,   # 下采樣
    ) -> None:
        super().__init__()
        self.embedding_dim = embedding_dim
        # 內(nèi)部維度
        self.internal_dim = embedding_dim // downsample_rate
        self.num_heads = num_heads
        assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
        # 四個線性層(全連接層):用于生成query向量、key向量、value向量
        self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
        self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
        self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
        # 用于將注意力機制后的輸出投影回原始的特征維度
        self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
    # 將輸入張量分解為多頭注意力所需的形狀
    def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
        b, n, c = x.shape
        x = x.reshape(b, n, num_heads, c // num_heads)
        return x.transpose(1, 2)  # B x N_heads x N_tokens x C_per_head
    # 在注意力計算后重新組合這些頭部
    def _recombine_heads(self, x: Tensor) -> Tensor:
        b, n_heads, n_tokens, c_per_head = x.shape
        x = x.transpose(1, 2)
        return x.reshape(b, n_tokens, n_heads * c_per_head)  # B x N_tokens x C

    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
        # 輸入投影:分別使用q_proj、k_proj和v_proj對query、key和value進行線性變換
        q = self.q_proj(q)
        k = self.k_proj(k)
        v = self.v_proj(v)

        # 分離頭部:將變換后的query、key和value張量按照num_heads進行重塑,以便進行多頭注意力計算
        # B,N_heads,N_tokens,C_per_head
        q = self._separate_heads(q, self.num_heads)
        k = self._separate_heads(k, self.num_heads)
        v = self._separate_heads(v, self.num_heads)

        # 注意力計算:
        # 計算query和key的點積,然后除以c_per_head的平方根進行歸一化,以防止數(shù)值過大
        _, _, _, c_per_head = q.shape
        attn = q @ k.permute(0, 1, 3, 2)  # B,N_heads,N_tokens,C_per_head
        # 歸一化Scale
        attn = attn / math.sqrt(c_per_head)
        # 應(yīng)用softmax函數(shù)得到注意力權(quán)重
        attn = torch.softmax(attn, dim=-1)
        # 使用注意力權(quán)重對value進行加權(quán)求和,得到注意力輸出
        out = attn @ v
        # # B,N_tokens,C
        # 重新組合頭部:將多頭注意力輸出合并回原始的特征維度。
        out = self._recombine_heads(out)
        # 輸出投影:最后,通過out_proj將輸出投影回原始的embedding_dim
        out = self.out_proj(out)
        return out

③ transformer_MLP

transformer中MLP的結(jié)構(gòu)如下圖所示。

# MLPBlock類是一個簡單的多層感知機(MLP)模塊,由兩個全連接層(Linear)和一個激活函數(shù)組成
class MLPBlock(nn.Module):
    def __init__(
        self,
        # 輸入的維度,通常是特征向量的長度
        embedding_dim: int,
        # MLP中間層的寬度,可以設(shè)置為比輸入維度更大的值以增加模型的表達能力
        mlp_dim: int,
        # 激活函數(shù),這里默認(rèn)使用GELU
        act: Type[nn.Module] = nn.GELU,
    ) -> None:
        super().__init__()
        # 第一個全連接層,將輸入從embedding_dim維度變換到mlp_dim維度
        self.lin1 = nn.Linear(embedding_dim, mlp_dim)
        # 第二個全連接層,將mlp_dim維度的結(jié)果變換回embedding_dim維度,以保持與輸入相同的維度
        self.lin2 = nn.Linear(mlp_dim, embedding_dim)
        # 激活函數(shù)實例,用于在全連接層之間引入非線性
        self.act = act()
    # 接收輸入張量x,將其傳遞給lin1,然后應(yīng)用激活函數(shù)act。
    # 將激活函數(shù)的輸出傳遞給lin2,得到最終的輸出張量
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.lin2(self.act(self.lin1(x)))

④upscaled

這個上采樣過程將Transformer的輸出特征圖恢復(fù)到更接近輸入圖像的分辨率,以便于生成掩模預(yù)測。upscaled的結(jié)構(gòu)如下圖所示。

# 在MaskDecoder的__init__定義
# output_upscaling是一個序列模塊,用于上采樣Transformer輸出的特征圖
self.output_upscaling = nn.Sequential(
    # 使用nn.ConvTranspose2d,輸入通道數(shù)為transformer_dim,輸出通道數(shù)為transformer_dim // 4,內(nèi)核大小為2,步長為2
    # 將特征圖的尺寸放大兩倍,同時將通道數(shù)減半
    # 內(nèi)核大小為2的轉(zhuǎn)置卷積相當(dāng)于上采樣2倍,步長為2確保輸出尺寸翻倍
    nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),     #轉(zhuǎn)置卷積 上采樣2倍
    # 層歸一化(LayerNorm2d)
    LayerNorm2d(transformer_dim // 4),
    # 激活函數(shù)
    activation(),
    # 再次使用nn.ConvTranspose2d,輸入通道數(shù)為transformer_dim // 4,輸出通道數(shù)為transformer_dim // 8,內(nèi)核大小為2,步長為2。這一步繼續(xù)將特征圖的尺寸放大兩倍,同時通道數(shù)再次減半
    nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
    # 重復(fù)激活函數(shù)的過程,以進一步增強非線性表達
    activation(),
)
# 在MaskDecoder的predict_masks添加位置編碼
upscaled_embedding = self.output_upscaling(src)

⑤ mask_MLP

此處的MLP基礎(chǔ)模塊不同于ViT的MLP(transformer_MLP)基礎(chǔ)模塊。

# 在MaskDecoder的__init__定義
# output_hypernetworks_mlps是一個nn.ModuleList,包含了多個多層感知機(MLP)。每個MLP的目的是根據(jù)輸入的mask_tokens_out生成特定掩模的超網(wǎng)絡(luò)權(quán)重
self.output_hypernetworks_mlps = nn.ModuleList(
    [
        # transformer_dim: Transformer的輸出維度,也是輸入到MLP的通道數(shù)
        # transformer_dim // 8: MLP的輸出通道數(shù),用于生成超網(wǎng)絡(luò)的權(quán)重
        # 3: MLP的中間層維度,用于增加模型的表達能力
        MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
        for i in range(self.num_mask_tokens)
    ]
)
# 在MaskDecoder的predict_masks添加位置編碼
# 對于self.num_mask_tokens個掩模token,遍歷output_hypernetworks_mlps列表
for i in range(self.num_mask_tokens):
    # mask_tokens_out[:, i, :]: B,1,C
    # output_hypernetworks_mlps: B,1,c
    # 對每個掩模token,應(yīng)用對應(yīng)的MLP,輸入是mask_tokens_out中對應(yīng)位置的特征,輸出為B, 1, c形狀的張量,其中c是超網(wǎng)絡(luò)的輸出通道數(shù)
    # 將每個MLP的輸出收集到hyper_in_list列表中
    hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
# B,n,c
# 將hyper_in_list堆疊成一個B, n, c形狀的張量hyper_in,其中n是掩模token的數(shù)量
hyper_in = torch.stack(hyper_in_list, dim=1)
# 獲取upscaled_embedding的形狀b, c, h, w,其中b是批次大小,c是通道數(shù),h和w是高度和寬度
b, c, h, w = upscaled_embedding.shape
# B,n,c × B,c,N-->B,n,h,w
# 執(zhí)行矩陣乘法(@運算符)將hyper_in(B, n, c)與upscaled_embedding(在通道維度上展平為B, c, h * w)相結(jié)合
# 計算每個掩模token的超網(wǎng)絡(luò)權(quán)重與上采樣特征圖的點積,得到B, n, h * w形狀的張量
# 通過view操作將結(jié)果轉(zhuǎn)換回B, n, h, w形狀,生成了masks張量,表示每個掩模token對應(yīng)的預(yù)測掩模
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)

⑥ iou_MLP

此處的MLP基礎(chǔ)模塊不同于ViT的MLP(transformer_MLP)基礎(chǔ)模塊。

# 在MaskDecoder的__init__定義
# 一個多層感知機(MLP)模塊,其目的是預(yù)測每個掩模token對應(yīng)的IoU(Intersection over Union,交并比)值,以評估預(yù)測掩模與真實掩模的重合程度
self.iou_prediction_head = MLP(
    # transformer_dim: 輸入到MLP的特征維度,通常與Transformer的輸出維度相同
    # iou_head_hidden_dim: MLP中間層的維度,用于增強模型的表達能力
    # self.num_mask_tokens: 輸出維度,即預(yù)測的掩模令牌數(shù)量,每個令牌對應(yīng)一個IoU預(yù)測值
    transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
)
# 在MaskDecoder的predict_masks添加位置編碼
iou_pred = self.iou_prediction_head(iou_token_out)

⑦ MaskDeco_MLP

Mask Decoder中MLP的結(jié)構(gòu)如下圖所示。

'''
定義了一個多層感知機,它包含一個可配置的隱藏層數(shù)目、輸入和輸出維度,并可以選擇是否在輸出層應(yīng)用Sigmoid激活函數(shù)
'''
class MLP(nn.Module):
    def __init__(
        self,
        input_dim: int,         # 輸入特征的維度,即輸入張量的通道數(shù)
        hidden_dim: int,        # 隱藏層的通道數(shù),中間層的寬度
        output_dim: int,        # 輸出特征的維度,即輸出張量的通道數(shù)
        num_layers: int,        # 多層感知機的層數(shù),包括輸入層和輸出層
        sigmoid_output: bool = False, #  一個布爾值,表示是否在輸出層應(yīng)用Sigmoid激活函數(shù),默認(rèn)為False
    ) -> None:
        '''
        內(nèi)部組件
        '''
        super().__init__()
        # 存儲輸入的層數(shù)
        self.num_layers = num_layers
        # 一個列表,包含num_layers - 1個hidden_dim,用于構(gòu)建中間層的線性變換
        h = [hidden_dim] * (num_layers - 1)
        #  一個nn.ModuleList,包含num_layers個線性層(全連接層),每個層的輸入和輸出通道數(shù)由h和input_dim、output_dim決定
        self.layers = nn.ModuleList(
            nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
        )
        self.sigmoid_output = sigmoid_output

    def forward(self, x):
        # 對輸入張量x,遍歷layers列表中的每個線性層
        for i, layer in enumerate(self.layers):
            # 如果當(dāng)前層不是最后一層,應(yīng)用ReLU激活函數(shù)(F.relu)
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        # 如果sigmoid_output為True,最后對輸出應(yīng)用Sigmoid激活函數(shù)
        if self.sigmoid_output:
            x = F.sigmoid(x)
        return x

責(zé)任編輯:趙寧寧 來源: 小喵學(xué)AI
相關(guān)推薦

2025-02-28 09:25:03

2024-06-18 12:36:08

2016-05-27 08:23:33

數(shù)據(jù)分析數(shù)據(jù)科學(xué)數(shù)據(jù)思維

2019-01-18 12:59:46

智能養(yǎng)老IOT智能

2024-01-16 10:54:14

2022-04-20 10:33:59

人工智能數(shù)字經(jīng)濟互聯(lián)網(wǎng) 文章鏈接:智

2023-02-11 12:47:07

2020-10-12 17:21:21

IPv6互聯(lián)網(wǎng)技術(shù)

2021-02-05 22:47:01

物聯(lián)網(wǎng)IOT物聯(lián)網(wǎng)技術(shù)

2014-08-11 14:36:42

2021-12-27 10:16:06

AI 數(shù)據(jù)人工智能

2020-05-09 13:00:08

AI 工具自動化

2020-12-18 09:32:03

Wi-Fi計算機隱私

2023-06-08 10:28:13

2023-09-03 16:17:25

物聯(lián)網(wǎng)6G

2018-12-22 19:46:45

2017-11-30 13:29:39

邊緣智算ECC

2009-09-09 11:14:16

Scala對象
點贊
收藏

51CTO技術(shù)棧公眾號