使用 PyTorch 從頭開始構建 CLIP | 對比語言圖像預訓練
在2021年,OpenAI發(fā)布了一篇論文《從自然語言監(jiān)督中學習可轉移的視覺模型》(https://arxiv.org/pdf/2103.00020),提出了CLIP(對比語言圖像預訓練),這是一個強大的深度學習模型,旨在以統(tǒng)一的方式理解和解釋圖像和文本。它結合了視覺和語言編碼器,將文本描述與視覺內容聯(lián)系起來。CLIP模型本身不生成圖像的描述,但可以用來評估文本和圖像之間的關系。例如,你可以提供一張貓的圖片,以及一個標簽列表,如“貓”和“狗”,以確定哪個標簽與圖片匹配的可能性最高。今天,這篇文章將涵蓋使用PyTorch從頭開始實現(xiàn)CLIP的過程。
CLIP(對比學習-圖像預訓練)
傳統(tǒng)的機器學習模型通常需要大量特定任務的標記數(shù)據(jù)集進行微調。例如,一個訓練用來識別狗的模型可能在識別貓方面表現(xiàn)不佳,除非它專門針對貓的圖片進行了微調。
CLIP的架構支持零樣本學習,這意味著它可以執(zhí)行它沒有直接訓練過的任務,通過利用其在圖像和文本之間學到的廣泛關聯(lián)。例如,基于它們的文本描述,它可以對它在訓練期間從未見過的圖片進行分類。引用他們的論文:“我們在零樣本的情況下匹配原始ResNet-50在ImageNet上的準確性,而不需要使用它訓練時的128萬個訓練樣本?!?/p>
CLIP有以下我們需要構建的組件:
- 文本編碼器
- 圖像編碼器
- 自定義數(shù)據(jù)集(如果你正在訓練)
- 對稱損失
文本編碼器
由于我們的主要目標是使文本和視覺表示的嵌入對齊,我們將需要一個文本編碼器模型來為圖像的文本描述創(chuàng)建特征。本文不會涵蓋如何從頭開始構建文本編碼器,而是直接使用變換器庫來創(chuàng)建編碼器,盡管這將涵蓋CLIP實現(xiàn)的主要思想。為了簡單起見,使用Distil Bert模型是一個不錯的選擇,因為它輕量級,性能幾乎和標準BERT模型一樣好,具有類似的基礎架構。這是需要記住的一點,我們不是加載預訓練版本。
class TextEncoder(nn.Module):
def __init__(self, embed_dim, proj_dim):
super().__init__()
self.model = DistilBertModel(config=DistilBertConfig())
self.layer_norm = nn.LayerNorm(proj_dim)
def forward(self, input_ids, attention_mask):
x = self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
return self.layer_norm(x)
TextEncoder()類將期望兩個輸入,input_ids和attention_mask,這兩個都將通過分詞器生成。
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
texts = ["This is a sample sentence.", "This is another example."]
inputs= tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(device)
encoder = ImageEncoder(embed_dim=768, proj_dim=256)
inputs = encoder(inputs['input_ids'], inputs['mask'])
現(xiàn)在,TextEncoder的前向傳遞輸出將是(Batch_Size, Token_Size + 1, Embed_Size),在標準BERT架構中,模型的目標是兩個任務,輸出一個額外的CLS_Token,附加到原始令牌前面,通常用于進一步微調分類任務,以及預測掩蔽令牌,使用掩蔽令牌前后的所有令牌的信息。
由于我們關心的是為我們的文本數(shù)據(jù)獲取特征嵌入,我們將只取[CLS]令牌,并將其投影到一個共同的空間,與圖像編碼器的視覺嵌入具有相同的嵌入大小。
class TextEncoder(nn.Module):
def __init__(self, embed_dim, proj_dim):
super().__init__()
self.model = DistilBertModel(config=DistilBertConfig())
self.projection = nn.Linear(embed_dim, proj_dim)
self.layer_norm = nn.LayerNorm(proj_dim)
def forward(self, input_ids, attention_mask):
x = self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
x = x[:, 0, :] # B, T[cls], E
x = self.projection(x)
return self.layer_norm(x)
層歸一化是深度學習中非常常見的概念,這不是我第一次解釋它,但讓我們再次解釋一下,我們有一個網(wǎng)絡的輸入,其中包含來自不同類別或特征的數(shù)據(jù),因為在每個訓練周期中批次會變化,數(shù)據(jù)的分布也會變化,在一批中分布可能在[0, 2)范圍內,而在下一批中它可能有樣本分布在[0, 100]范圍內。在訓練過程中數(shù)據(jù)分布的變化被稱為協(xié)變量偏移。由于輸入的劇烈變化,輸出也會變化,損失也會變化,如果損失劇烈變化,那么在反向傳播過程中權重將以更高的幅度更新,導致梯度不平滑。簡而言之,歸一化輸入將限制其在整個訓練批次中的分布,因此,損失不會有劇烈變化,將導致更平滑的梯度和更快的訓練,幫助模型更多地關注學習特征。
圖像編碼器
CLIP有兩個圖像編碼器選項,ResNet或視覺變換器。我們已經(jīng)開發(fā)了各種視覺變換器,因此將使用標準實現(xiàn)。如果你想使用ResNet作為圖像編碼器,你可以簡單地用視覺變換器模型替換它,你可以使用PyTorch自己的ResNet模型或timm。
class ImageEncoder(nn.Module):
def __init__(self, base_model, embed_dim, proj_dim):
super().__init__()
self.model = base_model
for param in self.model.parameters():
param.requires_grad = True
self.projection = nn.Linear(embed_dim, proj_dim)
self.layer_norm = nn.LayerNorm(proj_dim)
def forward(self, x):
x = self.projection(self.model(x))
return self.layer_norm(x)
上面的編碼器類將圖像張量傳遞給模型,然后將其投影到與文本編碼器輸出相同的共同嵌入空間,后面是一個歸一化層。
自定義數(shù)據(jù)集
現(xiàn)在CLIP是一個(相當)密集的模型,所以如果你想從頭開始訓練它,你必須在一個小數(shù)據(jù)集上訓練它。由于本文只涉及如何從頭開始實現(xiàn)架構,我們將不會進一步詳細說明如何創(chuàng)建數(shù)據(jù)集,但為了示例,這可能是你想要做的。
class CustomDataset(Dataset):
def __init__(self, texts, image_paths):
self.image_paths = image_paths
self.texts = texts
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
self.inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
self.transform = torchvision.transforms.ToTensor()
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path)
image = self.transform(image)
caption, mask = self.inputs[idx].items()
return {
"image": image,
"input_ids": caption["input_ids"],
"mask": mask["attention_mask"]
}
- image_paths:你選擇的數(shù)據(jù)集中圖像的路徑列表。
- texts:數(shù)據(jù)集中每張圖片的標題或文本描述。
自定義數(shù)據(jù)集類在調用Dataset類時創(chuàng)建分詞器并分詞所有文本,我們使用的是distillbert分詞器,這也是我們的文本編碼器模型。
把所有放在一起
class CLIPModel(nn.Module):
def __init__(self):
super().__init__()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
ViT = VissionTransformer(
num_layers=8,
img_size=224,
emb_size=768,
patch_size=16,
num_head=6,
num_class=768).to(self.device)
self.image_encoder = ImageEncoder(base_model=ViT, embed_dim=768, proj_dim=256)
self.text_encoder = TextEncoder(embed_dim=768, proj_dim=256)
ClipModel()類是我們把所有放在一起的地方,架構將包括來自圖像和文本編碼器的嵌入,然后用于計算對稱損失。這是核心實現(xiàn)的NumPy樣式偽代碼。
在我們的實現(xiàn)中,我們將在CLIPModel類的前向函數(shù)中計算損失。第一步是獲取圖像和文本嵌入,然后進行交叉乘法以獲得相似性矩陣或logits?;氐轿覀兊牡诙垐D。
logits是通過取圖像和文本嵌入的點積來創(chuàng)建的,由于這篇論文基于對比學習,我們的主要目標是將文本表示與視覺對齊。那么計算相似性矩陣有什么幫助呢?
答案是每個從圖像編碼器接收到的圖像令牌(圖5:I_1,I_2,.., I_n; 其中I是嵌入,n是批次大小)乘以文本編碼器接收到的每個令牌。得到最終矩陣(B, Token, Embed)@(B, Embed, Token)→(B, Token, Token)?,F(xiàn)在我們的任務是最大化每個對角線元素(I1T1, I2T2,…, InTn)的值。由于我們想要對齊我們的文本和視覺表示,相應的圖像令牌應該與其相應的文本最高相關。這就是我們將如何為批次中的所有圖像完成的,但讓我們看看單個令牌。
這里,圖并不是真的不同,我們取圖像嵌入I,并與批次中的每個文本嵌入計算點積。例如,當我們使用I3時,我們希望它與批次中相應的文本嵌入T3最強地對齊。理想情況下,I3行中最高的值應該是點積I3?T3,以同樣的方式批量處理,看起來就像我們在最大化所有對角線元素,其中每個In與其相應的Tn最佳對齊。為了實現(xiàn)這一點,我們使用一個損失函數(shù)來衡量每一行中最大值與其他值的突出程度。這實際上是通過取行和列的交叉熵損失來實現(xiàn)的。
from vit import VissionTransformer # Importing ViT from previous implementaton (GitHub: Ml-Models)
import numpy as np
import torch.nn.functional as F
class CLIPModel(nn.Module):
def __init__(self):
super().__init__()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
ViT = VissionTransformer(
num_layers=8,
img_size=224,
emb_size=768,
patch_size=16,
num_head=6,
num_class=False).to(self.device)
self.image_encoder = ImageEncoder(base_model=ViT, embed_dim=768, proj_dim=256)
self.text_encoder = TextEncoder(embed_dim=768, proj_dim=256)
self.temperature = nn.Parameter(torch.ones([])*np.log(1/7)).to(self.device)
def forward(self, x):
I_t = self.image_encoder(x["image"])
T_t = self.text_encoder(x["input_ids"], x["mask"])
logits = I_t@T_t.T * torch.exp(self.temperature)
labels = torch.arange(I_t.size(0)).to(self.device)
loss_I = F.cross_entropy(logits.T, labels)
loss_T = F.cross_entropy(logits, labels)
loss = (loss_I + loss_T)/2.0
return loss, logits
- 我們得到I_t和T_t(大?。築, Token_Size, Embed_Size)
- 我們通過取點積來計算logits,如前所述,然后乘以溫度參數(shù)的指數(shù)。如果你熟悉對比學習或讀過我關于DINO(無標簽蒸餾)的文章,你可能知道通常使用除以溫度來銳化輸出分布。然而,我們不直接除以溫度,而是乘以一個可訓練的張量,該張量使用nn.Parameter()設置,并初始化為log(1/7)。由于eln(x)=x,那么exp(log(1/T))應該是1/T,你可能會想知道我們?yōu)槭裁床缓唵蔚爻艘?/T。原因是使用log(1/T)可以讓優(yōu)化器在訓練期間更容易計算和更新梯度。這種方法在深度學習中是一種常見的做法,因為它可以帶來更平滑的訓練和更穩(wěn)定的模型權重更新。
- 標簽簡單地用批次大小生成([0, 1,..N])。正如我們之前討論的,目標是最大化每個對角線元素(i1T1, i2T2,..inTn),因此整個矩陣中每一行的標簽是[0, 1, 2, ..N],對應于哪一行的元素應該是最大的。
- 如偽代碼所述,嵌入已經(jīng)歸一化,但我們不需要這樣做,因為我們在返回圖像和文本編碼器的輸出時已經(jīng)應用了層歸一化。
- 按照偽代碼,計算行和列的交叉熵損失。我們通過傳遞logits的轉置和正常的logits以及標簽,取兩個損失的平均值,現(xiàn)在我們有最終的損失結果。
設置模型
texts = ["This is a sample sentence.", "This is another example."]
# You can Use a CustomDataset as we Implemented above for training
train_data = CustomDataset(texts, image_path)
train_loader = DataLoader(train_data, batch_size, shuffle=True)
# Example Usage
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
inputs= tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(device)
test = {
"image" : torch.rand(2, 3, 224, 224).to(device),
"input_ids" : inputs["input_ids"],
"mask" : inputs["attention_mask"]
}
model = CLIPModel().to(device)
loss, logits = model(test)
print("Loss:", loss, "Logits:", logits)
源碼鏈接:https://github.com/mishra-18/ML-Models/blob/main/clip.py